Skip to main content

burn_tensor/tensor/api/
options.rs

1use burn_backend::{
2    Backend, Element, get_device_settings,
3    tensor::{BasicOps, Device},
4};
5use burn_std::DType;
6
7/// Options for tensor creation.
8///
9/// This struct allows specifying the `device` and overriding the data type when creating a tensor.
10/// When the `dtype` is not specified, the [device's default policy](crate::set_default_dtypes) is used.
11#[derive(Debug, Clone)]
12pub struct TensorCreationOptions<B: Backend> {
13    /// Device where the tensor will be created.
14    pub device: Device<B>,
15    /// Optional data type.
16    /// If `None`, the dtype will be inferred on creation from the [device policy](crate::set_default_dtypes).
17    pub dtype: Option<DType>,
18}
19
20impl<B: Backend> Default for TensorCreationOptions<B> {
21    /// Returns new options with the backend's default device.
22    fn default() -> Self {
23        Self::new(Default::default())
24    }
25}
26
27impl<B: Backend> TensorCreationOptions<B> {
28    /// Create new options with a specific device.
29    ///
30    /// Data type will follow the [device policy](crate::set_default_dtypes) on tensor creation.
31    pub fn new(device: Device<B>) -> Self {
32        Self {
33            device,
34            dtype: None,
35        }
36    }
37
38    /// Set the tensor creation data type.
39    pub fn with_dtype(mut self, dtype: DType) -> Self {
40        self.dtype = Some(dtype);
41
42        self
43    }
44
45    /// Set the tensor creation device.
46    pub fn with_device(mut self, device: Device<B>) -> Self {
47        self.device = device;
48
49        self
50    }
51
52    /// Create options with backend's default device and float dtype.
53    pub fn float() -> Self {
54        Self::default().with_dtype(<B::FloatElem as Element>::dtype())
55    }
56
57    /// Create options with backend's default device and int dtype.
58    pub fn int() -> Self {
59        Self::default().with_dtype(<B::IntElem as Element>::dtype())
60    }
61
62    /// Create options with backend's default device and bool dtype.
63    pub fn bool() -> Self {
64        Self::default().with_dtype(<B::BoolElem as Element>::dtype())
65    }
66
67    /// Returns the tensor data type, or a provided default if not set.
68    ///
69    /// This is useful for cases where [`TensorCreationOptions`] may not have an explicit `dtype`.
70    pub fn dtype_or(&self, dtype: DType) -> DType {
71        self.dtype.unwrap_or(dtype)
72    }
73
74    /// Returns the tensor data type, or the default from the [device settings](crate::set_default_dtypes).
75    pub(crate) fn resolve_dtype<K: BasicOps<B>>(&self) -> DType {
76        let dtype = K::Elem::dtype();
77        let kind_name = K::name();
78        // TODO: tensor kind enum?
79        self.dtype.unwrap_or_else(|| {
80            let settings = get_device_settings::<B>(&self.device);
81            if dtype.is_float() && kind_name == "Float" {
82                settings.float_dtype.into()
83            } else if (dtype.is_int() || dtype.is_uint()) && kind_name == "Int" {
84                settings.int_dtype.into()
85            } else {
86                settings.bool_dtype.into()
87            }
88        })
89    }
90}
91
92impl<B: Backend> From<&Device<B>> for TensorCreationOptions<B> {
93    /// Convenience conversion from a reference to a device.
94    ///
95    /// Example:
96    /// ```rust
97    /// use burn_tensor::backend::Backend;
98    /// use burn_tensor::TensorCreationOptions;
99    ///
100    /// fn example<B: Backend>(device: B::Device) {
101    ///     let options: TensorCreationOptions<B> = (&device).into();
102    /// }
103    /// ```
104    fn from(device: &Device<B>) -> Self {
105        TensorCreationOptions::new(device.clone())
106    }
107}
108
109impl<B: Backend> From<(&Device<B>, DType)> for TensorCreationOptions<B> {
110    /// Convenience conversion for a specified `(&device, dtype)` tuple.
111    fn from(args: (&Device<B>, DType)) -> Self {
112        TensorCreationOptions::new(args.0.clone()).with_dtype(args.1)
113    }
114}