Skip to main content

burn_backend/backend/
device.rs

1pub use burn_std::device::*;
2use burn_std::{BoolDType, BoolStore, DType, FloatDType, IntDType};
3
4use alloc::format;
5use alloc::string::String;
6use burn_std::stub::RwLock;
7
8#[cfg(target_has_atomic = "ptr")]
9use alloc::sync::Arc;
10
11#[cfg(not(target_has_atomic = "ptr"))]
12use portable_atomic_util::Arc;
13use thiserror::Error;
14
15use core::any::TypeId;
16
17#[cfg(feature = "std")]
18pub use std::collections::HashMap;
19#[cfg(feature = "std")]
20use std::sync::{LazyLock, OnceLock};
21
22#[cfg(not(feature = "std"))]
23pub use hashbrown::HashMap;
24#[cfg(not(feature = "std"))]
25use spin::{Lazy as LazyLock, Once as OnceLock};
26
27use crate::Backend;
28
29/// Device trait for all burn backend devices.
30pub trait DeviceOps: Clone + Default + PartialEq + Send + Sync + core::fmt::Debug + Device {
31    /// Returns the [device id](DeviceId).
32    fn id(&self) -> DeviceId {
33        self.to_id()
34    }
35
36    /// Returns the inner device without autodiff enabled.
37    ///
38    /// For most devices this is a no-op that returns `self`. For autodiff-enabled
39    /// devices, this returns the underlying inner device.
40    fn inner(&self) -> &Self {
41        self
42    }
43}
44
45/// Settings controlling the default data types for a specific device.
46///
47/// These settings are managed in a global registry that enforces strict initialization semantics:
48///
49/// 1. Manual Initialization: You can set these once at the start of your program using [`set_default_dtypes`].
50/// 2. Default Initialization: If an operation (like creating a tensor) occurs before manual initialization,
51///    the settings are permanently locked to their default values.
52/// 3. Immutability: Once initialized, settings cannot be changed. This ensures consistent behavior across
53///    all threads and operations.
54#[derive(Debug, Clone, Copy, PartialEq, Eq)]
55pub struct DeviceSettings {
56    /// Default floating-point data type.
57    pub float_dtype: FloatDType,
58    /// Default integer data type.
59    pub int_dtype: IntDType,
60    /// Default bool data type.
61    pub bool_dtype: BoolDType,
62}
63
64impl DeviceSettings {
65    fn new(
66        float_dtype: impl Into<FloatDType>,
67        int_dtype: impl Into<IntDType>,
68        bool_dtype: impl Into<BoolDType>,
69    ) -> Self {
70        Self {
71            float_dtype: float_dtype.into(),
72            int_dtype: int_dtype.into(),
73            bool_dtype: bool_dtype.into(),
74        }
75    }
76}
77
78/// Key for the registry: physical device type + device id
79type RegistryKey = (DeviceId, TypeId);
80
81/// Global registry mapping devices to their settings.
82///
83/// Each value is wrapped in a `OnceLock` to enforce that settings are initialized only once
84/// per device.
85static REGISTRY: LazyLock<RwLock<HashMap<RegistryKey, Arc<OnceLock<DeviceSettings>>>>> =
86    LazyLock::new(|| RwLock::new(HashMap::new()));
87
88struct DeviceSettingsRegistry;
89
90impl DeviceSettingsRegistry {
91    /// Returns the settings for the given device, inserting the default if absent.
92    fn get_or_insert<D: DeviceOps>(
93        device: &D,
94        default_fn: impl FnOnce() -> DeviceSettings,
95    ) -> DeviceSettings {
96        let key = Self::key(device);
97        #[cfg(feature = "std")]
98        {
99            let cached = LOCAL_CACHE.with(|cache| cache.borrow().get(&key).copied());
100            if let Some(settings) = cached {
101                return settings;
102            }
103
104            // Entry does not exist in cache
105            let settings = {
106                let read = REGISTRY.read().unwrap();
107                read.get(&key).cloned()
108            }
109            .unwrap_or_else(|| {
110                let mut map = REGISTRY.write().unwrap();
111                Arc::clone(map.entry(key).or_default())
112            });
113
114            let settings = *settings.get_or_init(default_fn);
115
116            LOCAL_CACHE.with(|cache| {
117                cache.borrow_mut().insert(key, settings);
118            });
119
120            settings
121        }
122        #[cfg(not(feature = "std"))]
123        {
124            let settings = {
125                let read = REGISTRY.read().unwrap();
126                read.get(&key).cloned()
127            }
128            .unwrap_or_else(|| {
129                let mut map = REGISTRY.write().unwrap();
130                Arc::clone(map.entry(key).or_default())
131            });
132
133            settings.call_once(default_fn);
134            *settings.get().unwrap()
135        }
136    }
137
138    /// Initializes the settings for the given device.
139    ///
140    /// Returns `Err` with the existing settings if already initialized.
141    fn init<D: DeviceOps>(device: &D, settings: DeviceSettings) -> Result<(), DeviceError> {
142        let key = Self::key(device);
143        let mut map = REGISTRY.write().unwrap();
144        let cell = map.entry(key).or_insert_with(|| Arc::new(OnceLock::new()));
145
146        #[cfg(feature = "std")]
147        return cell
148            .set(settings)
149            .map_err(|_| DeviceError::already_initialized(device));
150
151        #[cfg(not(feature = "std"))]
152        if cell.get().is_some() {
153            Err(DeviceError::already_initialized(device))
154        } else {
155            cell.call_once(|| settings);
156            Ok(())
157        }
158    }
159
160    /// Returns the device registry key.
161    fn key<D: Device>(device: &D) -> RegistryKey {
162        (device.to_id(), TypeId::of::<D>())
163    }
164}
165
166#[cfg(feature = "std")]
167thread_local! {
168    /// Thread-local cache access to initialized device settings is lock-free.
169    static LOCAL_CACHE: core::cell::RefCell<HashMap<RegistryKey, DeviceSettings>> =
170        core::cell::RefCell::new(HashMap::new());
171}
172
173/// Get the [`device`'s settings](DeviceSettings).
174pub fn get_device_settings<B: Backend>(device: &B::Device) -> DeviceSettings {
175    let default_settings = || {
176        DeviceSettings::new(
177            default_float::<B>(),
178            default_int::<B>(),
179            default_bool::<B>(device),
180        )
181    };
182    DeviceSettingsRegistry::get_or_insert(device, default_settings)
183}
184
185fn default_bool<B: Backend>(device: &B::Device) -> BoolDType {
186    // NOTE: this fallback logic is mostly tied to the dispatch backend since we still have associated
187    // element types. Once they're removed, we need to have some sort of `DeviceDefaults` trait that provides
188    // per-device defaults instead.
189
190    // dtype.into() handles u8/u32 conversion to Bool(..)
191    let default_bool: BoolDType = <B::BoolElem as crate::Element>::dtype().into();
192    let bool_as_dtype = default_bool.into();
193    if B::supports_dtype(device, bool_as_dtype) {
194        default_bool
195    } else if !matches!(bool_as_dtype, DType::Bool(BoolStore::U8))
196        && B::supports_dtype(device, DType::Bool(BoolStore::U8))
197    {
198        BoolDType::U8
199    } else if !matches!(bool_as_dtype, DType::Bool(BoolStore::U32))
200        && B::supports_dtype(device, DType::Bool(BoolStore::U32))
201    {
202        BoolDType::U32
203    } else if !matches!(bool_as_dtype, DType::Bool(BoolStore::Native))
204        && B::supports_dtype(device, DType::Bool(BoolStore::Native))
205    {
206        BoolDType::Native
207    } else {
208        unreachable!()
209    }
210}
211
212fn default_float<B: Backend>() -> FloatDType {
213    <B::FloatElem as crate::Element>::dtype().into()
214}
215
216fn default_int<B: Backend>() -> IntDType {
217    <B::IntElem as crate::Element>::dtype().into()
218}
219
220/// Errors that can occur during device-related operations.
221///
222/// This covers errors related to hardware capability mismatches, such as
223/// requesting a data type not supported by the device, and configuration
224/// errors like attempting to change a settings in an invalid context.
225#[derive(Debug, Error)]
226pub enum DeviceError {
227    /// Unsupported data type by the device.
228    #[error("Device {device} does not support the requested data type {dtype:?}")]
229    UnsupportedDType {
230        /// The string representation of the device.
231        device: String,
232        /// The data type that caused the error.
233        dtype: DType,
234    },
235    /// Device settings have already been initialized.
236    #[error("Device {device} settings have already been initialized")]
237    AlreadyInitialized {
238        /// The string representation of the device.
239        device: String,
240    },
241}
242
243impl DeviceError {
244    /// Helper to create a [`DeviceError::UnsupportedDType`] from any device.
245    pub fn unsupported_dtype<D: DeviceOps>(device: &D, dtype: DType) -> Self {
246        Self::UnsupportedDType {
247            device: format!("{device:?}"),
248            dtype,
249        }
250    }
251
252    /// Helper to create a [`DeviceError::AlreadyInitialized`] from any device.
253    pub fn already_initialized<D: DeviceOps>(device: &D) -> Self {
254        Self::AlreadyInitialized {
255            device: format!("{device:?}"),
256        }
257    }
258}
259
260fn check_dtype_support<B: Backend>(
261    device: &B::Device,
262    dtype: impl Into<DType>,
263) -> Result<(), DeviceError> {
264    let dtype = dtype.into();
265    // Default dtypes should have `DTypeUsage::general()`. Types restricted to specialized
266    // operations should not be used as default.
267    if B::supports_dtype(device, dtype) {
268        Ok(())
269    } else {
270        Err(DeviceError::unsupported_dtype(device, dtype))
271    }
272}
273
274/// Sets the default data types for the device.
275///
276/// This updates the device's default data types used for tensor creation.
277///
278/// Settings can only be initialized once per device. Subsequent calls for
279/// the same device return [`DeviceError::AlreadyInitialized`].
280///
281/// # Note
282///
283/// Initialization must happen before any tensor creation on the device.
284/// The first tensor operation will lock the device to its defaults, causing
285/// any subsequent initialization attempt to return [`DeviceError::AlreadyInitialized`].
286///
287/// # Example
288///
289/// ```rust, ignore
290/// fn example<B: Backend>() {
291///     let device = B::Device::default();
292///     
293///     // Update the device settings
294///     set_default_dtypes::<B>(&device, DType::F16, DType::I32);
295///     
296///     // All float tensors created after this will use F16 by default
297///     let tensor = Tensor::<B, 2>::zeros([2, 3], &device);
298///     // All int tensors created after this will use I32 default
299///     let tensor = Tensor::<B, 2, Int>::zeros([2, 3], &device);
300/// }
301/// ```
302pub fn set_default_dtypes<B: Backend>(
303    device: &B::Device,
304    float_dtype: impl Into<FloatDType>,
305    int_dtype: impl Into<IntDType>,
306) -> Result<(), DeviceError> {
307    let float_dtype = float_dtype.into();
308    let int_dtype = int_dtype.into();
309    check_dtype_support::<B>(device, float_dtype)?;
310    check_dtype_support::<B>(device, int_dtype)?;
311
312    let settings = DeviceSettings::new(float_dtype, int_dtype, default_bool::<B>(device));
313
314    initialize_unchecked(device, settings)?;
315    Ok(())
316}
317
318/// Sets the default floating-point data type for the device.
319///
320/// This updates the device's default data types used for tensor creation.
321///
322/// Settings can only be initialized once per device. Subsequent calls for
323/// the same device return [`DeviceError::AlreadyInitialized`].
324///
325/// # Note
326///
327/// Initialization must happen before any tensor creation on the device.
328/// The first tensor operation will lock the device to its defaults, causing
329/// any subsequent initialization attempt to return [`DeviceError::AlreadyInitialized`].
330///
331/// # Example
332///
333/// ```rust, ignore
334/// fn example<B: Backend>() {
335///     let device = B::Device::default();
336///     
337///     // Update the device settings
338///     set_default_float_dtype::<B>(&device, DType::F16);
339///     
340///     // All float tensors created after this will use F16 by default
341///     let tensor = Tensor::<B, 2>::zeros([2, 3], &device);
342/// }
343/// ```
344pub fn set_default_float_dtype<B: Backend>(
345    device: &B::Device,
346    dtype: impl Into<FloatDType>,
347) -> Result<(), DeviceError> {
348    let dtype = dtype.into();
349    check_dtype_support::<B>(device, dtype)?;
350
351    let settings = DeviceSettings::new(dtype, default_int::<B>(), default_bool::<B>(device));
352
353    initialize_unchecked(device, settings)?;
354    Ok(())
355}
356
357/// Sets the default integer data type for the device.
358///
359/// This updates the device's default data types used for tensor creation.
360///
361/// Settings can only be initialized once per device. Subsequent calls for
362/// the same device return [`DeviceError::AlreadyInitialized`].
363///
364/// # Note
365///
366/// Initialization must happen before any tensor creation on the device.
367/// The first tensor operation will lock the device to its defaults, causing
368/// any subsequent initialization attempt to return [`DeviceError::AlreadyInitialized`].
369///
370/// # Example
371///
372/// ```rust, ignore
373/// fn example<B: Backend>() {
374///     let device = B::Device::default();
375///     
376///     // Update the device settings
377///     set_default_int_dtype::<B>(&device, DType::I32);
378///     
379///     // All int tensors created after this will use I32 default
380///     let tensor = Tensor::<B, 2, Int>::zeros([2, 3], &device);
381/// }
382/// ```
383pub fn set_default_int_dtype<B: Backend>(
384    device: &B::Device,
385    dtype: impl Into<IntDType>,
386) -> Result<(), DeviceError> {
387    let dtype = dtype.into();
388    check_dtype_support::<B>(device, dtype)?;
389
390    let settings = DeviceSettings::new(default_float::<B>(), dtype, default_bool::<B>(device));
391
392    initialize_unchecked(device, settings)?;
393    Ok(())
394}
395
396// Unchecked dtypes
397fn initialize_unchecked<D: DeviceOps>(
398    device: &D,
399    settings: DeviceSettings,
400) -> Result<(), DeviceError> {
401    DeviceSettingsRegistry::init(device, settings)
402}
403
404#[cfg(all(test, feature = "std"))]
405mod tests {
406    use serial_test::serial;
407
408    use super::*;
409
410    fn clear_registry() {
411        REGISTRY.write().unwrap().clear();
412    }
413
414    #[derive(Clone, Debug, Default, PartialEq, new)]
415    pub struct TestDeviceA {
416        index: u32,
417    }
418
419    impl Device for TestDeviceA {
420        fn from_id(device_id: DeviceId) -> Self {
421            Self {
422                index: device_id.index_id,
423            }
424        }
425
426        fn to_id(&self) -> DeviceId {
427            DeviceId {
428                type_id: 0,
429                index_id: self.index,
430            }
431        }
432    }
433
434    impl DeviceOps for TestDeviceA {}
435
436    #[derive(Clone, Debug, Default, PartialEq, new)]
437    pub struct TestDeviceB {
438        index: u32,
439    }
440
441    impl Device for TestDeviceB {
442        fn from_id(device_id: DeviceId) -> Self {
443            Self {
444                index: device_id.index_id,
445            }
446        }
447
448        fn to_id(&self) -> DeviceId {
449            DeviceId {
450                type_id: 0,
451                index_id: self.index,
452            }
453        }
454    }
455
456    impl DeviceOps for TestDeviceB {}
457
458    // Test defaults
459    impl DeviceSettings {
460        fn defaults() -> Self {
461            DeviceSettings::new(FloatDType::F32, IntDType::I32, BoolDType::Native)
462        }
463    }
464
465    fn get_test_device_settings<D: DeviceOps>(device: &D) -> DeviceSettings {
466        DeviceSettingsRegistry::get_or_insert(device, DeviceSettings::defaults)
467    }
468
469    #[test]
470    #[serial]
471    fn default_settings_returned_when_uninitialized() {
472        clear_registry(); // reset registry for each test
473
474        let device = TestDeviceA::new(0);
475
476        let s1 = get_test_device_settings(&device);
477        let s2 = get_test_device_settings(&device);
478
479        assert_eq!(s1, s2);
480        assert_eq!(s1, DeviceSettings::defaults());
481    }
482
483    #[test]
484    #[serial]
485    fn initialized_settings_are_returned() {
486        clear_registry(); // reset registry for each test
487
488        let device = TestDeviceA::new(0);
489        let settings = DeviceSettings::new(FloatDType::BF16, IntDType::I32, BoolDType::Native);
490
491        initialize_unchecked(&device, settings).unwrap();
492        let s1 = get_test_device_settings(&device);
493        let s2 = get_test_device_settings(&device);
494
495        assert_eq!(s1, s2);
496        assert_eq!(s1, settings);
497        assert_eq!(s2, settings);
498    }
499
500    #[test]
501    #[serial]
502    fn settings_are_device_id_specific() {
503        clear_registry(); // reset registry for each test
504
505        let d1 = TestDeviceA::new(0);
506        let d2 = TestDeviceA::new(1);
507        let settings = DeviceSettings::new(FloatDType::F16, IntDType::I64, BoolDType::Native);
508
509        initialize_unchecked(&d1, settings).unwrap();
510
511        let s1 = get_test_device_settings(&d1);
512        let s2 = get_test_device_settings(&d2);
513
514        assert_ne!(s1, s2);
515        assert_eq!(s1, settings);
516        assert_eq!(s2, DeviceSettings::defaults());
517    }
518
519    #[test]
520    #[serial]
521    fn settings_are_device_type_specific() {
522        clear_registry(); // reset registry for each test
523
524        let d1 = TestDeviceA::new(0);
525        let d2 = TestDeviceB::new(0);
526        let settings = DeviceSettings::new(FloatDType::F16, IntDType::I64, BoolDType::Native);
527
528        initialize_unchecked(&d2, settings).unwrap();
529
530        let s1 = get_test_device_settings(&d1);
531        let s2 = get_test_device_settings(&d2);
532
533        assert_ne!(s1, s2);
534        assert_eq!(s1, DeviceSettings::defaults());
535        assert_eq!(s2, settings);
536    }
537
538    #[test]
539    #[serial]
540    fn initialization_after_default_returns_error() {
541        clear_registry(); // reset registry for each test
542
543        let device = TestDeviceA::new(0);
544        // Settings are set to default on first access, which forces consistency
545        let _before = get_test_device_settings(&device);
546
547        let settings = DeviceSettings::new(FloatDType::BF16, IntDType::I64, BoolDType::Native);
548        let result = initialize_unchecked(&device, settings);
549
550        assert!(matches!(
551            result,
552            Err(DeviceError::AlreadyInitialized { .. })
553        ));
554    }
555
556    #[test]
557    #[serial]
558    fn second_initialization_returns_error() {
559        clear_registry(); // reset registry for each test
560
561        let device = TestDeviceA::new(0);
562        let settings = DeviceSettings::new(FloatDType::F16, IntDType::I32, BoolDType::Native);
563        initialize_unchecked(&device, settings).unwrap();
564
565        let result = initialize_unchecked(&device, DeviceSettings::defaults());
566        assert!(matches!(
567            result,
568            Err(DeviceError::AlreadyInitialized { .. })
569        ));
570    }
571
572    #[cfg(feature = "std")]
573    #[test]
574    #[serial]
575    fn initialized_settings_are_global() {
576        clear_registry();
577
578        let device = TestDeviceA::new(0);
579        let settings = DeviceSettings::new(FloatDType::F16, IntDType::I32, BoolDType::Native);
580
581        initialize_unchecked(&device, settings).unwrap();
582        let settings_actual = get_test_device_settings(&device);
583        assert_eq!(settings_actual, settings);
584
585        // The other thread will see the initialized settings
586        let seen_by_new_thread =
587            std::thread::spawn(move || get_test_device_settings(&TestDeviceA::new(0)))
588                .join()
589                .unwrap();
590        assert_eq!(seen_by_new_thread, settings_actual);
591    }
592}