1pub use burn_std::device::*;
2use burn_std::{BoolDType, BoolStore, DType, FloatDType, IntDType};
3
4use alloc::format;
5use alloc::string::String;
6use burn_std::stub::RwLock;
7
8#[cfg(target_has_atomic = "ptr")]
9use alloc::sync::Arc;
10
11#[cfg(not(target_has_atomic = "ptr"))]
12use portable_atomic_util::Arc;
13use thiserror::Error;
14
15use core::any::TypeId;
16
17#[cfg(feature = "std")]
18pub use std::collections::HashMap;
19#[cfg(feature = "std")]
20use std::sync::{LazyLock, OnceLock};
21
22#[cfg(not(feature = "std"))]
23pub use hashbrown::HashMap;
24#[cfg(not(feature = "std"))]
25use spin::{Lazy as LazyLock, Once as OnceLock};
26
27use crate::Backend;
28
29pub trait DeviceOps: Clone + Default + PartialEq + Send + Sync + core::fmt::Debug + Device {
31 fn id(&self) -> DeviceId {
33 self.to_id()
34 }
35
36 fn inner(&self) -> &Self {
41 self
42 }
43}
44
45#[derive(Debug, Clone, Copy, PartialEq, Eq)]
55pub struct DeviceSettings {
56 pub float_dtype: FloatDType,
58 pub int_dtype: IntDType,
60 pub bool_dtype: BoolDType,
62}
63
64impl DeviceSettings {
65 fn new(
66 float_dtype: impl Into<FloatDType>,
67 int_dtype: impl Into<IntDType>,
68 bool_dtype: impl Into<BoolDType>,
69 ) -> Self {
70 Self {
71 float_dtype: float_dtype.into(),
72 int_dtype: int_dtype.into(),
73 bool_dtype: bool_dtype.into(),
74 }
75 }
76}
77
78type RegistryKey = (DeviceId, TypeId);
80
81static REGISTRY: LazyLock<RwLock<HashMap<RegistryKey, Arc<OnceLock<DeviceSettings>>>>> =
86 LazyLock::new(|| RwLock::new(HashMap::new()));
87
88struct DeviceSettingsRegistry;
89
90impl DeviceSettingsRegistry {
91 fn get_or_insert<D: DeviceOps>(
93 device: &D,
94 default_fn: impl FnOnce() -> DeviceSettings,
95 ) -> DeviceSettings {
96 let key = Self::key(device);
97 #[cfg(feature = "std")]
98 {
99 let cached = LOCAL_CACHE.with(|cache| cache.borrow().get(&key).copied());
100 if let Some(settings) = cached {
101 return settings;
102 }
103
104 let settings = {
106 let read = REGISTRY.read().unwrap();
107 read.get(&key).cloned()
108 }
109 .unwrap_or_else(|| {
110 let mut map = REGISTRY.write().unwrap();
111 Arc::clone(map.entry(key).or_default())
112 });
113
114 let settings = *settings.get_or_init(default_fn);
115
116 LOCAL_CACHE.with(|cache| {
117 cache.borrow_mut().insert(key, settings);
118 });
119
120 settings
121 }
122 #[cfg(not(feature = "std"))]
123 {
124 let settings = {
125 let read = REGISTRY.read().unwrap();
126 read.get(&key).cloned()
127 }
128 .unwrap_or_else(|| {
129 let mut map = REGISTRY.write().unwrap();
130 Arc::clone(map.entry(key).or_default())
131 });
132
133 settings.call_once(default_fn);
134 *settings.get().unwrap()
135 }
136 }
137
138 fn init<D: DeviceOps>(device: &D, settings: DeviceSettings) -> Result<(), DeviceError> {
142 let key = Self::key(device);
143 let mut map = REGISTRY.write().unwrap();
144 let cell = map.entry(key).or_insert_with(|| Arc::new(OnceLock::new()));
145
146 #[cfg(feature = "std")]
147 return cell
148 .set(settings)
149 .map_err(|_| DeviceError::already_initialized(device));
150
151 #[cfg(not(feature = "std"))]
152 if cell.get().is_some() {
153 Err(DeviceError::already_initialized(device))
154 } else {
155 cell.call_once(|| settings);
156 Ok(())
157 }
158 }
159
160 fn key<D: Device>(device: &D) -> RegistryKey {
162 (device.to_id(), TypeId::of::<D>())
163 }
164}
165
166#[cfg(feature = "std")]
167thread_local! {
168 static LOCAL_CACHE: core::cell::RefCell<HashMap<RegistryKey, DeviceSettings>> =
170 core::cell::RefCell::new(HashMap::new());
171}
172
173pub fn get_device_settings<B: Backend>(device: &B::Device) -> DeviceSettings {
175 let default_settings = || {
176 DeviceSettings::new(
177 default_float::<B>(),
178 default_int::<B>(),
179 default_bool::<B>(device),
180 )
181 };
182 DeviceSettingsRegistry::get_or_insert(device, default_settings)
183}
184
185fn default_bool<B: Backend>(device: &B::Device) -> BoolDType {
186 let default_bool: BoolDType = <B::BoolElem as crate::Element>::dtype().into();
192 let bool_as_dtype = default_bool.into();
193 if B::supports_dtype(device, bool_as_dtype) {
194 default_bool
195 } else if !matches!(bool_as_dtype, DType::Bool(BoolStore::U8))
196 && B::supports_dtype(device, DType::Bool(BoolStore::U8))
197 {
198 BoolDType::U8
199 } else if !matches!(bool_as_dtype, DType::Bool(BoolStore::U32))
200 && B::supports_dtype(device, DType::Bool(BoolStore::U32))
201 {
202 BoolDType::U32
203 } else if !matches!(bool_as_dtype, DType::Bool(BoolStore::Native))
204 && B::supports_dtype(device, DType::Bool(BoolStore::Native))
205 {
206 BoolDType::Native
207 } else {
208 unreachable!()
209 }
210}
211
212fn default_float<B: Backend>() -> FloatDType {
213 <B::FloatElem as crate::Element>::dtype().into()
214}
215
216fn default_int<B: Backend>() -> IntDType {
217 <B::IntElem as crate::Element>::dtype().into()
218}
219
220#[derive(Debug, Error)]
226pub enum DeviceError {
227 #[error("Device {device} does not support the requested data type {dtype:?}")]
229 UnsupportedDType {
230 device: String,
232 dtype: DType,
234 },
235 #[error("Device {device} settings have already been initialized")]
237 AlreadyInitialized {
238 device: String,
240 },
241}
242
243impl DeviceError {
244 pub fn unsupported_dtype<D: DeviceOps>(device: &D, dtype: DType) -> Self {
246 Self::UnsupportedDType {
247 device: format!("{device:?}"),
248 dtype,
249 }
250 }
251
252 pub fn already_initialized<D: DeviceOps>(device: &D) -> Self {
254 Self::AlreadyInitialized {
255 device: format!("{device:?}"),
256 }
257 }
258}
259
260fn check_dtype_support<B: Backend>(
261 device: &B::Device,
262 dtype: impl Into<DType>,
263) -> Result<(), DeviceError> {
264 let dtype = dtype.into();
265 if B::supports_dtype(device, dtype) {
268 Ok(())
269 } else {
270 Err(DeviceError::unsupported_dtype(device, dtype))
271 }
272}
273
274pub fn set_default_dtypes<B: Backend>(
303 device: &B::Device,
304 float_dtype: impl Into<FloatDType>,
305 int_dtype: impl Into<IntDType>,
306) -> Result<(), DeviceError> {
307 let float_dtype = float_dtype.into();
308 let int_dtype = int_dtype.into();
309 check_dtype_support::<B>(device, float_dtype)?;
310 check_dtype_support::<B>(device, int_dtype)?;
311
312 let settings = DeviceSettings::new(float_dtype, int_dtype, default_bool::<B>(device));
313
314 initialize_unchecked(device, settings)?;
315 Ok(())
316}
317
318pub fn set_default_float_dtype<B: Backend>(
345 device: &B::Device,
346 dtype: impl Into<FloatDType>,
347) -> Result<(), DeviceError> {
348 let dtype = dtype.into();
349 check_dtype_support::<B>(device, dtype)?;
350
351 let settings = DeviceSettings::new(dtype, default_int::<B>(), default_bool::<B>(device));
352
353 initialize_unchecked(device, settings)?;
354 Ok(())
355}
356
357pub fn set_default_int_dtype<B: Backend>(
384 device: &B::Device,
385 dtype: impl Into<IntDType>,
386) -> Result<(), DeviceError> {
387 let dtype = dtype.into();
388 check_dtype_support::<B>(device, dtype)?;
389
390 let settings = DeviceSettings::new(default_float::<B>(), dtype, default_bool::<B>(device));
391
392 initialize_unchecked(device, settings)?;
393 Ok(())
394}
395
396fn initialize_unchecked<D: DeviceOps>(
398 device: &D,
399 settings: DeviceSettings,
400) -> Result<(), DeviceError> {
401 DeviceSettingsRegistry::init(device, settings)
402}
403
404#[cfg(all(test, feature = "std"))]
405mod tests {
406 use serial_test::serial;
407
408 use super::*;
409
410 fn clear_registry() {
411 REGISTRY.write().unwrap().clear();
412 }
413
414 #[derive(Clone, Debug, Default, PartialEq, new)]
415 pub struct TestDeviceA {
416 index: u32,
417 }
418
419 impl Device for TestDeviceA {
420 fn from_id(device_id: DeviceId) -> Self {
421 Self {
422 index: device_id.index_id,
423 }
424 }
425
426 fn to_id(&self) -> DeviceId {
427 DeviceId {
428 type_id: 0,
429 index_id: self.index,
430 }
431 }
432 }
433
434 impl DeviceOps for TestDeviceA {}
435
436 #[derive(Clone, Debug, Default, PartialEq, new)]
437 pub struct TestDeviceB {
438 index: u32,
439 }
440
441 impl Device for TestDeviceB {
442 fn from_id(device_id: DeviceId) -> Self {
443 Self {
444 index: device_id.index_id,
445 }
446 }
447
448 fn to_id(&self) -> DeviceId {
449 DeviceId {
450 type_id: 0,
451 index_id: self.index,
452 }
453 }
454 }
455
456 impl DeviceOps for TestDeviceB {}
457
458 impl DeviceSettings {
460 fn defaults() -> Self {
461 DeviceSettings::new(FloatDType::F32, IntDType::I32, BoolDType::Native)
462 }
463 }
464
465 fn get_test_device_settings<D: DeviceOps>(device: &D) -> DeviceSettings {
466 DeviceSettingsRegistry::get_or_insert(device, DeviceSettings::defaults)
467 }
468
469 #[test]
470 #[serial]
471 fn default_settings_returned_when_uninitialized() {
472 clear_registry(); let device = TestDeviceA::new(0);
475
476 let s1 = get_test_device_settings(&device);
477 let s2 = get_test_device_settings(&device);
478
479 assert_eq!(s1, s2);
480 assert_eq!(s1, DeviceSettings::defaults());
481 }
482
483 #[test]
484 #[serial]
485 fn initialized_settings_are_returned() {
486 clear_registry(); let device = TestDeviceA::new(0);
489 let settings = DeviceSettings::new(FloatDType::BF16, IntDType::I32, BoolDType::Native);
490
491 initialize_unchecked(&device, settings).unwrap();
492 let s1 = get_test_device_settings(&device);
493 let s2 = get_test_device_settings(&device);
494
495 assert_eq!(s1, s2);
496 assert_eq!(s1, settings);
497 assert_eq!(s2, settings);
498 }
499
500 #[test]
501 #[serial]
502 fn settings_are_device_id_specific() {
503 clear_registry(); let d1 = TestDeviceA::new(0);
506 let d2 = TestDeviceA::new(1);
507 let settings = DeviceSettings::new(FloatDType::F16, IntDType::I64, BoolDType::Native);
508
509 initialize_unchecked(&d1, settings).unwrap();
510
511 let s1 = get_test_device_settings(&d1);
512 let s2 = get_test_device_settings(&d2);
513
514 assert_ne!(s1, s2);
515 assert_eq!(s1, settings);
516 assert_eq!(s2, DeviceSettings::defaults());
517 }
518
519 #[test]
520 #[serial]
521 fn settings_are_device_type_specific() {
522 clear_registry(); let d1 = TestDeviceA::new(0);
525 let d2 = TestDeviceB::new(0);
526 let settings = DeviceSettings::new(FloatDType::F16, IntDType::I64, BoolDType::Native);
527
528 initialize_unchecked(&d2, settings).unwrap();
529
530 let s1 = get_test_device_settings(&d1);
531 let s2 = get_test_device_settings(&d2);
532
533 assert_ne!(s1, s2);
534 assert_eq!(s1, DeviceSettings::defaults());
535 assert_eq!(s2, settings);
536 }
537
538 #[test]
539 #[serial]
540 fn initialization_after_default_returns_error() {
541 clear_registry(); let device = TestDeviceA::new(0);
544 let _before = get_test_device_settings(&device);
546
547 let settings = DeviceSettings::new(FloatDType::BF16, IntDType::I64, BoolDType::Native);
548 let result = initialize_unchecked(&device, settings);
549
550 assert!(matches!(
551 result,
552 Err(DeviceError::AlreadyInitialized { .. })
553 ));
554 }
555
556 #[test]
557 #[serial]
558 fn second_initialization_returns_error() {
559 clear_registry(); let device = TestDeviceA::new(0);
562 let settings = DeviceSettings::new(FloatDType::F16, IntDType::I32, BoolDType::Native);
563 initialize_unchecked(&device, settings).unwrap();
564
565 let result = initialize_unchecked(&device, DeviceSettings::defaults());
566 assert!(matches!(
567 result,
568 Err(DeviceError::AlreadyInitialized { .. })
569 ));
570 }
571
572 #[cfg(feature = "std")]
573 #[test]
574 #[serial]
575 fn initialized_settings_are_global() {
576 clear_registry();
577
578 let device = TestDeviceA::new(0);
579 let settings = DeviceSettings::new(FloatDType::F16, IntDType::I32, BoolDType::Native);
580
581 initialize_unchecked(&device, settings).unwrap();
582 let settings_actual = get_test_device_settings(&device);
583 assert_eq!(settings_actual, settings);
584
585 let seen_by_new_thread =
587 std::thread::spawn(move || get_test_device_settings(&TestDeviceA::new(0)))
588 .join()
589 .unwrap();
590 assert_eq!(seen_by_new_thread, settings_actual);
591 }
592}