burn_tensor/tensor/api/
options.rs

1use burn_backend::{Backend, Element, tensor::Device};
2use burn_std::DType;
3
4/// Options for tensor creation.
5///
6/// This struct allows specifying the `device` and/or data type (`dtype`) when creating a tensor.
7#[derive(Debug, Clone)]
8pub struct TensorCreationOptions<B: Backend> {
9    /// Device where the tensor will be created.
10    pub device: Device<B>,
11    /// Optional data type.
12    /// If `None`, the dtype will be inferred on creation from the backend's default dtype for the tensor kind.
13    pub dtype: Option<DType>,
14}
15
16impl<B: Backend> Default for TensorCreationOptions<B> {
17    /// Returns new options with the backend's default device.
18    fn default() -> Self {
19        Self::new(Default::default())
20    }
21}
22
23impl<B: Backend> TensorCreationOptions<B> {
24    /// Create new options with a specific device.
25    ///
26    /// Data type will be inferred on creation from the backend's default dtype for the tensor kind.
27    pub fn new(device: Device<B>) -> Self {
28        Self {
29            device,
30            dtype: None,
31        }
32    }
33
34    /// Set the tensor creation data type.
35    pub fn with_dtype(mut self, dtype: DType) -> Self {
36        self.dtype = Some(dtype);
37
38        self
39    }
40
41    /// Set the tensor creation device.
42    pub fn with_device(mut self, device: Device<B>) -> Self {
43        self.device = device;
44
45        self
46    }
47
48    /// Create options with backend's default device and float dtype.
49    pub fn float() -> Self {
50        Self::default().with_dtype(<B::FloatElem as Element>::dtype())
51    }
52
53    /// Create options with backend's default device and int dtype.
54    pub fn int() -> Self {
55        Self::default().with_dtype(<B::IntElem as Element>::dtype())
56    }
57
58    /// Create options with backend's default device and bool dtype.
59    pub fn bool() -> Self {
60        Self::default().with_dtype(<B::BoolElem as Element>::dtype())
61    }
62
63    /// Returns the tensor data type, or a provided default if not set.
64    ///
65    /// This is useful for cases where [`TensorCreationOptions`] may not have an explicit `dtype`.
66    pub fn dtype_or(&self, dtype: DType) -> DType {
67        self.dtype.unwrap_or(dtype)
68    }
69}
70
71impl<B: Backend> From<&Device<B>> for TensorCreationOptions<B> {
72    /// Convenience conversion from a reference to a device.
73    ///
74    /// Example:
75    /// ```rust
76    /// use burn_tensor::backend::Backend;
77    /// use burn_tensor::TensorCreationOptions;
78    ///
79    /// fn example<B: Backend>(device: B::Device) {
80    ///     let options: TensorCreationOptions<B> = (&device).into();
81    /// }
82    /// ```
83    fn from(device: &Device<B>) -> Self {
84        TensorCreationOptions::new(device.clone())
85    }
86}
87
88impl<B: Backend> From<(&Device<B>, DType)> for TensorCreationOptions<B> {
89    /// Convenience conversion for a specified `(&device, dtype)` tuple.
90    fn from(args: (&Device<B>, DType)) -> Self {
91        TensorCreationOptions::new(args.0.clone()).with_dtype(args.1)
92    }
93}