1use alloc::format;
2use alloc::string::String;
3use burn_backend::{Backend, Device, DeviceId, DeviceOps};
4use burn_std::stub::RwLock;
5use burn_std::{DType, FloatDType, IntDType};
6
7#[cfg(target_has_atomic = "ptr")]
8use alloc::sync::Arc;
9
10#[cfg(not(target_has_atomic = "ptr"))]
11use portable_atomic_util::Arc;
12use thiserror::Error;
13
14use core::any::TypeId;
15
16#[cfg(feature = "std")]
17pub use std::collections::HashMap;
18#[cfg(feature = "std")]
19use std::sync::LazyLock;
20
21#[cfg(not(feature = "std"))]
22pub use hashbrown::HashMap;
23#[cfg(not(feature = "std"))]
24use spin::Lazy as LazyLock;
25
26#[derive(Debug, Clone, Copy, Default)]
30pub(crate) struct DevicePolicy {
31 float_dtype: Option<FloatDType>,
33 int_dtype: Option<IntDType>,
35}
36
37impl DevicePolicy {
38 pub(crate) fn float_dtype(&self) -> Option<FloatDType> {
40 self.float_dtype
41 }
42
43 pub(crate) fn int_dtype(&self) -> Option<IntDType> {
45 self.int_dtype
46 }
47
48 pub(crate) fn set_float_dtype(&mut self, dtype: FloatDType) {
50 self.float_dtype = Some(dtype);
51 }
52
53 pub(crate) fn set_int_dtype(&mut self, dtype: IntDType) {
55 self.int_dtype = Some(dtype);
56 }
57}
58
59type RegistryKey = (DeviceId, TypeId);
61
62static REGISTRY: LazyLock<RwLock<HashMap<RegistryKey, Arc<DevicePolicy>>>> =
64 LazyLock::new(|| RwLock::new(HashMap::new()));
65
66struct DevicePolicyRegistry;
78
79impl DevicePolicyRegistry {
80 fn get<D: DeviceOps>(device: &D) -> Arc<DevicePolicy> {
84 let key = Self::key(device);
85
86 if let Some(policy) = REGISTRY.read().unwrap().get(&key) {
87 return Arc::clone(policy);
88 }
89
90 let mut map = REGISTRY.write().unwrap();
91 Arc::clone(
92 map.entry(key)
93 .or_insert_with(|| Arc::new(DevicePolicy::default())),
94 )
95 }
96
97 fn update<D: DeviceOps>(device: &D, update_fn: impl FnOnce(&mut DevicePolicy)) {
99 let key = Self::key(device);
100 let mut map = REGISTRY.write().unwrap();
101
102 let policy = map
103 .entry(key)
104 .or_insert_with(|| Arc::new(DevicePolicy::default()));
105
106 let policy_mut = Arc::make_mut(policy);
108 update_fn(policy_mut);
109 }
110
111 fn key<D: Device>(device: &D) -> RegistryKey {
113 (device.to_id(), TypeId::of::<D>())
114 }
115}
116
117pub(crate) fn get_device_policy<D: DeviceOps>(device: &D) -> Arc<DevicePolicy> {
122 DevicePolicyRegistry::get(device)
123}
124
125#[derive(Debug, Error)]
131pub enum DeviceError {
132 #[error("Device {device} does not support the requested data type {dtype:?}")]
134 UnsupportedDType {
135 device: String,
137 dtype: DType,
139 },
140 }
142
143impl DeviceError {
144 pub fn unsupported_dtype<D: DeviceOps>(device: &D, dtype: DType) -> Self {
146 Self::UnsupportedDType {
147 device: format!("{device:?}"),
148 dtype,
149 }
150 }
151}
152
153fn check_dtype_support<B: Backend>(
154 device: &B::Device,
155 dtype: impl Into<DType>,
156) -> Result<(), DeviceError> {
157 let dtype = dtype.into();
158 if B::supports_dtype(device, dtype) {
161 Ok(())
162 } else {
163 Err(DeviceError::unsupported_dtype(device, dtype))
164 }
165}
166
167pub fn set_default_dtypes<B: Backend>(
192 device: &B::Device,
193 float_dtype: impl Into<FloatDType>,
194 int_dtype: impl Into<IntDType>,
195) -> Result<(), DeviceError> {
196 let float_dtype = float_dtype.into();
197 let int_dtype = int_dtype.into();
198 check_dtype_support::<B>(device, float_dtype)?;
199 check_dtype_support::<B>(device, int_dtype)?;
200
201 set_default_dtypes_unchecked(device, float_dtype, int_dtype);
202 Ok(())
203}
204
205pub fn set_default_float_dtype<B: Backend>(
228 device: &B::Device,
229 dtype: impl Into<FloatDType>,
230) -> Result<(), DeviceError> {
231 let dtype = dtype.into();
232 check_dtype_support::<B>(device, dtype)?;
233
234 set_default_float_dtype_unchecked(device, dtype);
235 Ok(())
236}
237
238pub fn set_default_int_dtype<B: Backend>(
261 device: &B::Device,
262 dtype: impl Into<IntDType>,
263) -> Result<(), DeviceError> {
264 let dtype = dtype.into();
265 check_dtype_support::<B>(device, dtype)?;
266
267 set_default_int_dtype_unchecked(device, dtype);
268 Ok(())
269}
270
271fn set_default_dtypes_unchecked<D: DeviceOps>(
273 device: &D,
274 float_dtype: FloatDType,
275 int_dtype: IntDType,
276) {
277 DevicePolicyRegistry::update(device, |p| {
278 p.set_float_dtype(float_dtype);
279 p.set_int_dtype(int_dtype);
280 });
281}
282
283fn set_default_float_dtype_unchecked<D: DeviceOps>(device: &D, dtype: FloatDType) {
284 DevicePolicyRegistry::update(device, |p| {
285 p.set_float_dtype(dtype);
286 });
287}
288
289fn set_default_int_dtype_unchecked<D: DeviceOps>(device: &D, dtype: IntDType) {
290 DevicePolicyRegistry::update(device, |p| {
291 p.set_int_dtype(dtype);
292 });
293}
294
295#[cfg(all(test, feature = "std"))]
296mod tests {
297 use serial_test::serial;
298
299 use super::*;
300
301 fn clear_registry() {
302 REGISTRY.write().unwrap().clear();
303 }
304
305 #[derive(Clone, Debug, Default, PartialEq, new)]
306 pub struct TestDeviceA {
307 index: u32,
308 }
309
310 impl Device for TestDeviceA {
311 fn from_id(device_id: DeviceId) -> Self {
312 Self {
313 index: device_id.index_id,
314 }
315 }
316
317 fn to_id(&self) -> DeviceId {
318 DeviceId {
319 type_id: 0,
320 index_id: self.index,
321 }
322 }
323
324 fn device_count(_type_id: u16) -> usize {
325 1
326 }
327 }
328
329 impl DeviceOps for TestDeviceA {}
330
331 #[derive(Clone, Debug, Default, PartialEq, new)]
332 pub struct TestDeviceB {
333 index: u32,
334 }
335
336 impl Device for TestDeviceB {
337 fn from_id(device_id: DeviceId) -> Self {
338 Self {
339 index: device_id.index_id,
340 }
341 }
342
343 fn to_id(&self) -> DeviceId {
344 DeviceId {
345 type_id: 0,
346 index_id: self.index,
347 }
348 }
349
350 fn device_count(_type_id: u16) -> usize {
351 1
352 }
353 }
354
355 impl DeviceOps for TestDeviceB {}
356
357 #[test]
358 #[serial]
359 fn default_policy_is_created_and_shared() {
360 clear_registry(); let device = TestDeviceA::new(0);
363
364 let p1 = get_device_policy(&device);
365 let p2 = get_device_policy(&device);
366
367 assert!(Arc::ptr_eq(&p1, &p2));
368 assert!(p1.float_dtype().is_none());
370 assert!(p1.int_dtype().is_none());
371 assert!(p2.float_dtype().is_none());
372 assert!(p2.int_dtype().is_none());
373 }
374
375 #[test]
376 #[serial]
377 fn updated_policy_is_shared() {
378 clear_registry(); let device = TestDeviceA::new(0);
381
382 set_default_dtypes_unchecked(&device, FloatDType::BF16, IntDType::I32);
384 let p1 = get_device_policy(&device);
385 let p2 = get_device_policy(&device);
386
387 assert!(Arc::ptr_eq(&p1, &p2));
388 assert_eq!(p1.float_dtype(), Some(FloatDType::BF16));
389 assert_eq!(p1.int_dtype(), Some(IntDType::I32));
390 assert_eq!(p2.float_dtype(), Some(FloatDType::BF16));
391 assert_eq!(p2.int_dtype(), Some(IntDType::I32));
392 }
393
394 #[test]
395 #[serial]
396 fn policy_is_device_id_specific() {
397 clear_registry(); let d1 = TestDeviceA::new(0);
400 let d2 = TestDeviceA::new(1);
401
402 set_default_float_dtype_unchecked(&d1, FloatDType::F16);
403
404 let p1 = get_device_policy(&d1);
405 let p2 = get_device_policy(&d2);
406
407 assert!(!Arc::ptr_eq(&p1, &p2));
408 assert_eq!(p1.float_dtype(), Some(FloatDType::F16));
409 assert!(p1.int_dtype().is_none());
410 assert!(p2.float_dtype().is_none());
411 assert!(p2.int_dtype().is_none());
412 }
413
414 #[test]
415 #[serial]
416 fn policy_is_device_type_specific() {
417 clear_registry(); let d1 = TestDeviceA::new(0);
420 let d2 = TestDeviceB::new(0);
421
422 set_default_float_dtype_unchecked(&d2, FloatDType::F16);
423
424 let p1 = get_device_policy(&d1);
425 let p2 = get_device_policy(&d2);
426
427 assert!(p1.float_dtype().is_none());
428 assert!(p1.int_dtype().is_none());
429 assert_eq!(p2.float_dtype(), Some(FloatDType::F16));
430 assert!(p2.int_dtype().is_none());
431 }
432
433 #[test]
434 #[serial]
435 fn updating_policy_should_not_affect_snapshot() {
436 clear_registry(); let device = TestDeviceA::new(0);
440 let before = get_device_policy(&device);
441
442 set_default_float_dtype_unchecked(&device, FloatDType::BF16);
443
444 let after = get_device_policy(&device);
445
446 assert!(!Arc::ptr_eq(&before, &after));
447 assert_eq!(after.float_dtype(), Some(FloatDType::BF16));
448 assert!(before.float_dtype().is_none());
449 }
450
451 #[test]
452 #[serial]
453 fn set_default_dtypes_overwrites_fields() {
454 clear_registry(); let device = TestDeviceA::new(0);
457
458 set_default_dtypes_unchecked(&device, FloatDType::F16, IntDType::I64);
459
460 let policy = get_device_policy(&device);
461
462 assert_eq!(policy.float_dtype(), Some(FloatDType::F16));
463 assert_eq!(policy.int_dtype(), Some(IntDType::I64));
464 }
465}