Skip to main content

burn_tensor/
device.rs

1use alloc::format;
2use alloc::string::String;
3use burn_backend::{Backend, Device, DeviceId, DeviceOps};
4use burn_std::stub::RwLock;
5use burn_std::{DType, FloatDType, IntDType};
6
7#[cfg(target_has_atomic = "ptr")]
8use alloc::sync::Arc;
9
10#[cfg(not(target_has_atomic = "ptr"))]
11use portable_atomic_util::Arc;
12use thiserror::Error;
13
14use core::any::TypeId;
15
16#[cfg(feature = "std")]
17pub use std::collections::HashMap;
18#[cfg(feature = "std")]
19use std::sync::LazyLock;
20
21#[cfg(not(feature = "std"))]
22pub use hashbrown::HashMap;
23#[cfg(not(feature = "std"))]
24use spin::Lazy as LazyLock;
25
26/// Policy controlling default device behavior.
27///
28/// This includes default data types used for tensor creation.
29#[derive(Debug, Clone, Copy, Default)]
30pub(crate) struct DevicePolicy {
31    /// Default floating-point data type for tensor creation.
32    float_dtype: Option<FloatDType>,
33    /// Default integer data type for tensor creation.
34    int_dtype: Option<IntDType>,
35}
36
37impl DevicePolicy {
38    /// Returns the default floating-point data type used for tensor creation.
39    pub(crate) fn float_dtype(&self) -> Option<FloatDType> {
40        self.float_dtype
41    }
42
43    /// Returns the default integer data type used for tensor creation.
44    pub(crate) fn int_dtype(&self) -> Option<IntDType> {
45        self.int_dtype
46    }
47
48    /// Sets the default floating-point data type.
49    pub(crate) fn set_float_dtype(&mut self, dtype: FloatDType) {
50        self.float_dtype = Some(dtype);
51    }
52
53    /// Sets the default integer data type.
54    pub(crate) fn set_int_dtype(&mut self, dtype: IntDType) {
55        self.int_dtype = Some(dtype);
56    }
57}
58
59/// Key for the registry: physical device type + device id
60type RegistryKey = (DeviceId, TypeId);
61
62/// Global registry mapping devices to their policies.
63static REGISTRY: LazyLock<RwLock<HashMap<RegistryKey, Arc<DevicePolicy>>>> =
64    LazyLock::new(|| RwLock::new(HashMap::new()));
65
66/// Device policy management for controlling default tensor creation behavior.
67///
68/// # Policy Semantics
69///
70/// Device policies use snapshot semantics: when you retrieve a policy with
71/// [`get_device_policy`], you get an immutable snapshot of the current configuration.
72/// Updates to the policy (via [`set_default_dtypes`], [`set_default_float_dtype`], etc.)
73/// only affect future policy retrievals, not existing references.
74///
75/// This is intended for the common case where policies are set once during
76/// initialization and then read frequently during tensor creation.
77struct DevicePolicyRegistry;
78
79impl DevicePolicyRegistry {
80    /// Get the policy for a physical device type and device id.
81    ///
82    /// If no policy exists yet, a default one is created and stored.
83    fn get<D: DeviceOps>(device: &D) -> Arc<DevicePolicy> {
84        let key = Self::key(device);
85
86        if let Some(policy) = REGISTRY.read().unwrap().get(&key) {
87            return Arc::clone(policy);
88        }
89
90        let mut map = REGISTRY.write().unwrap();
91        Arc::clone(
92            map.entry(key)
93                .or_insert_with(|| Arc::new(DevicePolicy::default())),
94        )
95    }
96
97    /// Mutate the policy for a given device.
98    fn update<D: DeviceOps>(device: &D, update_fn: impl FnOnce(&mut DevicePolicy)) {
99        let key = Self::key(device);
100        let mut map = REGISTRY.write().unwrap();
101
102        let policy = map
103            .entry(key)
104            .or_insert_with(|| Arc::new(DevicePolicy::default()));
105
106        // Update the policy
107        let policy_mut = Arc::make_mut(policy);
108        update_fn(policy_mut);
109    }
110
111    /// Returns the device registry key.
112    fn key<D: Device>(device: &D) -> RegistryKey {
113        (device.to_id(), TypeId::of::<D>())
114    }
115}
116
117/// Get the [`device`'s policy](DevicePolicy).
118///
119/// Returns an immutable snapshot of the device's current policy. If the policy
120/// is updated after retrieval, this snapshot will not reflect those changes.
121pub(crate) fn get_device_policy<D: DeviceOps>(device: &D) -> Arc<DevicePolicy> {
122    DevicePolicyRegistry::get(device)
123}
124
125/// Errors that can occur during device-related operations.
126///
127/// This covers errors related to hardware capability mismatches, such as
128/// requesting a data type not supported by the device, and configuration
129/// errors like attempting to change a policy in an invalid context.
130#[derive(Debug, Error)]
131pub enum DeviceError {
132    /// Unsupported data type by the device.
133    #[error("Device {device} does not support the requested data type {dtype:?}")]
134    UnsupportedDType {
135        /// The string representation of the device.
136        device: String,
137        /// The data type that caused the error.
138        dtype: DType,
139    },
140    // TODO: `InvalidContext` if a device policy cannot be changed after init / during training / etc.
141}
142
143impl DeviceError {
144    /// Helper to create a [`DeviceError::UnsupportedDType`] from any device.
145    pub fn unsupported_dtype<D: DeviceOps>(device: &D, dtype: DType) -> Self {
146        Self::UnsupportedDType {
147            device: format!("{device:?}"),
148            dtype,
149        }
150    }
151}
152
153fn check_dtype_support<B: Backend>(
154    device: &B::Device,
155    dtype: impl Into<DType>,
156) -> Result<(), DeviceError> {
157    let dtype = dtype.into();
158    // Default dtypes should have `DTypeUsage::general()`. Types restricted to specialized
159    // operations should not be used as default.
160    if B::supports_dtype(device, dtype) {
161        Ok(())
162    } else {
163        Err(DeviceError::unsupported_dtype(device, dtype))
164    }
165}
166
167/// Sets the default data types for the device.
168///
169/// This updates the device's default data types used for tensor creation.
170/// The policy should typically be set once during initialization and then
171/// remains global for all subsequent operations on that device.
172///
173/// # Example
174///
175/// ```rust
176/// use burn_tensor::backend::Backend;
177/// use burn_tensor::{DType, Int, Tensor, set_default_dtypes};
178///
179/// fn example<B: Backend>() {
180///     let device = B::Device::default();
181///     
182///     // Update the device policy
183///     set_default_dtypes::<B>(&device, DType::F16, DType::I32);
184///     
185///     // All float tensors created after this will use F16 by default
186///     let tensor = Tensor::<B, 2>::zeros([2, 3], &device);
187///     // All int tensors created after this will use I32 default
188///     let tensor = Tensor::<B, 2, Int>::zeros([2, 3], &device);
189/// }
190/// ```
191pub fn set_default_dtypes<B: Backend>(
192    device: &B::Device,
193    float_dtype: impl Into<FloatDType>,
194    int_dtype: impl Into<IntDType>,
195) -> Result<(), DeviceError> {
196    let float_dtype = float_dtype.into();
197    let int_dtype = int_dtype.into();
198    check_dtype_support::<B>(device, float_dtype)?;
199    check_dtype_support::<B>(device, int_dtype)?;
200
201    set_default_dtypes_unchecked(device, float_dtype, int_dtype);
202    Ok(())
203}
204
205/// Sets the default floating-point data type for the device.
206///
207/// This updates the device's default data types used for tensor creation.
208/// The policy should typically be set once during initialization and then
209/// remains global for all subsequent operations on that device.
210///
211/// # Example
212///
213/// ```rust
214/// use burn_tensor::backend::Backend;
215/// use burn_tensor::{DType, Tensor, set_default_float_dtype};
216///
217/// fn example<B: Backend>() {
218///     let device = B::Device::default();
219///     
220///     // Update the device policy
221///     set_default_float_dtype::<B>(&device, DType::F16);
222///     
223///     // All float tensors created after this will use F16 by default
224///     let tensor = Tensor::<B, 2>::zeros([2, 3], &device);
225/// }
226/// ```
227pub fn set_default_float_dtype<B: Backend>(
228    device: &B::Device,
229    dtype: impl Into<FloatDType>,
230) -> Result<(), DeviceError> {
231    let dtype = dtype.into();
232    check_dtype_support::<B>(device, dtype)?;
233
234    set_default_float_dtype_unchecked(device, dtype);
235    Ok(())
236}
237
238/// Sets the default integer data type for the device.
239///
240/// This updates the device's default data types used for tensor creation.
241/// The policy should typically be set once during initialization and then
242/// remains global for all subsequent operations on that device.
243///
244/// # Example
245///
246/// ```rust
247/// use burn_tensor::backend::Backend;
248/// use burn_tensor::{DType, Int, Tensor, set_default_int_dtype};
249///
250/// fn example<B: Backend>() {
251///     let device = B::Device::default();
252///     
253///     // Update the device policy
254///     set_default_int_dtype::<B>(&device, DType::I32);
255///     
256///     // All int tensors created after this will use I32 default
257///     let tensor = Tensor::<B, 2, Int>::zeros([2, 3], &device);
258/// }
259/// ```
260pub fn set_default_int_dtype<B: Backend>(
261    device: &B::Device,
262    dtype: impl Into<IntDType>,
263) -> Result<(), DeviceError> {
264    let dtype = dtype.into();
265    check_dtype_support::<B>(device, dtype)?;
266
267    set_default_int_dtype_unchecked(device, dtype);
268    Ok(())
269}
270
271// Unchecked versions
272fn set_default_dtypes_unchecked<D: DeviceOps>(
273    device: &D,
274    float_dtype: FloatDType,
275    int_dtype: IntDType,
276) {
277    DevicePolicyRegistry::update(device, |p| {
278        p.set_float_dtype(float_dtype);
279        p.set_int_dtype(int_dtype);
280    });
281}
282
283fn set_default_float_dtype_unchecked<D: DeviceOps>(device: &D, dtype: FloatDType) {
284    DevicePolicyRegistry::update(device, |p| {
285        p.set_float_dtype(dtype);
286    });
287}
288
289fn set_default_int_dtype_unchecked<D: DeviceOps>(device: &D, dtype: IntDType) {
290    DevicePolicyRegistry::update(device, |p| {
291        p.set_int_dtype(dtype);
292    });
293}
294
295#[cfg(all(test, feature = "std"))]
296mod tests {
297    use serial_test::serial;
298
299    use super::*;
300
301    fn clear_registry() {
302        REGISTRY.write().unwrap().clear();
303    }
304
305    #[derive(Clone, Debug, Default, PartialEq, new)]
306    pub struct TestDeviceA {
307        index: u32,
308    }
309
310    impl Device for TestDeviceA {
311        fn from_id(device_id: DeviceId) -> Self {
312            Self {
313                index: device_id.index_id,
314            }
315        }
316
317        fn to_id(&self) -> DeviceId {
318            DeviceId {
319                type_id: 0,
320                index_id: self.index,
321            }
322        }
323
324        fn device_count(_type_id: u16) -> usize {
325            1
326        }
327    }
328
329    impl DeviceOps for TestDeviceA {}
330
331    #[derive(Clone, Debug, Default, PartialEq, new)]
332    pub struct TestDeviceB {
333        index: u32,
334    }
335
336    impl Device for TestDeviceB {
337        fn from_id(device_id: DeviceId) -> Self {
338            Self {
339                index: device_id.index_id,
340            }
341        }
342
343        fn to_id(&self) -> DeviceId {
344            DeviceId {
345                type_id: 0,
346                index_id: self.index,
347            }
348        }
349
350        fn device_count(_type_id: u16) -> usize {
351            1
352        }
353    }
354
355    impl DeviceOps for TestDeviceB {}
356
357    #[test]
358    #[serial]
359    fn default_policy_is_created_and_shared() {
360        clear_registry(); // reset registry for each test
361
362        let device = TestDeviceA::new(0);
363
364        let p1 = get_device_policy(&device);
365        let p2 = get_device_policy(&device);
366
367        assert!(Arc::ptr_eq(&p1, &p2));
368        // Not explicitly set
369        assert!(p1.float_dtype().is_none());
370        assert!(p1.int_dtype().is_none());
371        assert!(p2.float_dtype().is_none());
372        assert!(p2.int_dtype().is_none());
373    }
374
375    #[test]
376    #[serial]
377    fn updated_policy_is_shared() {
378        clear_registry(); // reset registry for each test
379
380        let device = TestDeviceA::new(0);
381
382        // The device policy is meant to be set once at initialization
383        set_default_dtypes_unchecked(&device, FloatDType::BF16, IntDType::I32);
384        let p1 = get_device_policy(&device);
385        let p2 = get_device_policy(&device);
386
387        assert!(Arc::ptr_eq(&p1, &p2));
388        assert_eq!(p1.float_dtype(), Some(FloatDType::BF16));
389        assert_eq!(p1.int_dtype(), Some(IntDType::I32));
390        assert_eq!(p2.float_dtype(), Some(FloatDType::BF16));
391        assert_eq!(p2.int_dtype(), Some(IntDType::I32));
392    }
393
394    #[test]
395    #[serial]
396    fn policy_is_device_id_specific() {
397        clear_registry(); // reset registry for each test
398
399        let d1 = TestDeviceA::new(0);
400        let d2 = TestDeviceA::new(1);
401
402        set_default_float_dtype_unchecked(&d1, FloatDType::F16);
403
404        let p1 = get_device_policy(&d1);
405        let p2 = get_device_policy(&d2);
406
407        assert!(!Arc::ptr_eq(&p1, &p2));
408        assert_eq!(p1.float_dtype(), Some(FloatDType::F16));
409        assert!(p1.int_dtype().is_none());
410        assert!(p2.float_dtype().is_none());
411        assert!(p2.int_dtype().is_none());
412    }
413
414    #[test]
415    #[serial]
416    fn policy_is_device_type_specific() {
417        clear_registry(); // reset registry for each test
418
419        let d1 = TestDeviceA::new(0);
420        let d2 = TestDeviceB::new(0);
421
422        set_default_float_dtype_unchecked(&d2, FloatDType::F16);
423
424        let p1 = get_device_policy(&d1);
425        let p2 = get_device_policy(&d2);
426
427        assert!(p1.float_dtype().is_none());
428        assert!(p1.int_dtype().is_none());
429        assert_eq!(p2.float_dtype(), Some(FloatDType::F16));
430        assert!(p2.int_dtype().is_none());
431    }
432
433    #[test]
434    #[serial]
435    fn updating_policy_should_not_affect_snapshot() {
436        clear_registry(); // reset registry for each test
437
438        // The device policy is meant to be set once at initialization
439        let device = TestDeviceA::new(0);
440        let before = get_device_policy(&device);
441
442        set_default_float_dtype_unchecked(&device, FloatDType::BF16);
443
444        let after = get_device_policy(&device);
445
446        assert!(!Arc::ptr_eq(&before, &after));
447        assert_eq!(after.float_dtype(), Some(FloatDType::BF16));
448        assert!(before.float_dtype().is_none());
449    }
450
451    #[test]
452    #[serial]
453    fn set_default_dtypes_overwrites_fields() {
454        clear_registry(); // reset registry for each test
455
456        let device = TestDeviceA::new(0);
457
458        set_default_dtypes_unchecked(&device, FloatDType::F16, IntDType::I64);
459
460        let policy = get_device_policy(&device);
461
462        assert_eq!(policy.float_dtype(), Some(FloatDType::F16));
463        assert_eq!(policy.int_dtype(), Some(IntDType::I64));
464    }
465}