Skip to main content

burn_tensor/tensor/api/
options.rs

1use burn_backend::{Backend, Element, tensor::Device};
2use burn_std::DType;
3
4use crate::get_device_policy;
5
6/// Options for tensor creation.
7///
8/// This struct allows specifying the `device` and overriding the data type when creating a tensor.
9/// When the `dtype` is not specified, the [device's default policy](crate::set_default_dtypes) is used.
10#[derive(Debug, Clone)]
11pub struct TensorCreationOptions<B: Backend> {
12    /// Device where the tensor will be created.
13    pub device: Device<B>,
14    /// Optional data type.
15    /// If `None`, the dtype will be inferred on creation from the [device policy](crate::set_default_dtypes).
16    pub dtype: Option<DType>,
17}
18
19impl<B: Backend> Default for TensorCreationOptions<B> {
20    /// Returns new options with the backend's default device.
21    fn default() -> Self {
22        Self::new(Default::default())
23    }
24}
25
26impl<B: Backend> TensorCreationOptions<B> {
27    /// Create new options with a specific device.
28    ///
29    /// Data type will follow the [device policy](crate::set_default_dtypes) on tensor creation.
30    pub fn new(device: Device<B>) -> Self {
31        Self {
32            device,
33            dtype: None,
34        }
35    }
36
37    /// Set the tensor creation data type.
38    pub fn with_dtype(mut self, dtype: DType) -> Self {
39        self.dtype = Some(dtype);
40
41        self
42    }
43
44    /// Set the tensor creation device.
45    pub fn with_device(mut self, device: Device<B>) -> Self {
46        self.device = device;
47
48        self
49    }
50
51    /// Create options with backend's default device and float dtype.
52    pub fn float() -> Self {
53        Self::default().with_dtype(<B::FloatElem as Element>::dtype())
54    }
55
56    /// Create options with backend's default device and int dtype.
57    pub fn int() -> Self {
58        Self::default().with_dtype(<B::IntElem as Element>::dtype())
59    }
60
61    /// Create options with backend's default device and bool dtype.
62    pub fn bool() -> Self {
63        Self::default().with_dtype(<B::BoolElem as Element>::dtype())
64    }
65
66    /// Returns the tensor data type, or a provided default if not set.
67    ///
68    /// This is useful for cases where [`TensorCreationOptions`] may not have an explicit `dtype`.
69    pub fn dtype_or(&self, dtype: DType) -> DType {
70        self.dtype.unwrap_or(dtype)
71    }
72
73    /// Returns the tensor data type, or the default from the [device policy](crate::set_default_dtypes).
74    pub(crate) fn resolve_policy(&self, dtype: DType) -> DType {
75        // TODO: should rely on tensor kind, not element dtype
76        self.dtype.unwrap_or_else(|| {
77            let policy = get_device_policy(&self.device);
78            if dtype.is_float()
79                && let Some(float_dtype) = policy.float_dtype()
80            {
81                float_dtype.into()
82            } else if (dtype.is_int() || dtype.is_uint())
83                && let Some(int_dtype) = policy.int_dtype()
84            {
85                int_dtype.into()
86            } else {
87                // If policy was not explicitly set, use the fallback dtype (default backend elem type)
88                dtype
89            }
90        })
91    }
92}
93
94impl<B: Backend> From<&Device<B>> for TensorCreationOptions<B> {
95    /// Convenience conversion from a reference to a device.
96    ///
97    /// Example:
98    /// ```rust
99    /// use burn_tensor::backend::Backend;
100    /// use burn_tensor::TensorCreationOptions;
101    ///
102    /// fn example<B: Backend>(device: B::Device) {
103    ///     let options: TensorCreationOptions<B> = (&device).into();
104    /// }
105    /// ```
106    fn from(device: &Device<B>) -> Self {
107        TensorCreationOptions::new(device.clone())
108    }
109}
110
111impl<B: Backend> From<(&Device<B>, DType)> for TensorCreationOptions<B> {
112    /// Convenience conversion for a specified `(&device, dtype)` tuple.
113    fn from(args: (&Device<B>, DType)) -> Self {
114        TensorCreationOptions::new(args.0.clone()).with_dtype(args.1)
115    }
116}