burn_tensor/tensor/api/options.rs
1use burn_backend::{Backend, Element, tensor::Device};
2use burn_std::DType;
3
4use crate::get_device_policy;
5
6/// Options for tensor creation.
7///
8/// This struct allows specifying the `device` and overriding the data type when creating a tensor.
9/// When the `dtype` is not specified, the [device's default policy](crate::set_default_dtypes) is used.
10#[derive(Debug, Clone)]
11pub struct TensorCreationOptions<B: Backend> {
12 /// Device where the tensor will be created.
13 pub device: Device<B>,
14 /// Optional data type.
15 /// If `None`, the dtype will be inferred on creation from the [device policy](crate::set_default_dtypes).
16 pub dtype: Option<DType>,
17}
18
19impl<B: Backend> Default for TensorCreationOptions<B> {
20 /// Returns new options with the backend's default device.
21 fn default() -> Self {
22 Self::new(Default::default())
23 }
24}
25
26impl<B: Backend> TensorCreationOptions<B> {
27 /// Create new options with a specific device.
28 ///
29 /// Data type will follow the [device policy](crate::set_default_dtypes) on tensor creation.
30 pub fn new(device: Device<B>) -> Self {
31 Self {
32 device,
33 dtype: None,
34 }
35 }
36
37 /// Set the tensor creation data type.
38 pub fn with_dtype(mut self, dtype: DType) -> Self {
39 self.dtype = Some(dtype);
40
41 self
42 }
43
44 /// Set the tensor creation device.
45 pub fn with_device(mut self, device: Device<B>) -> Self {
46 self.device = device;
47
48 self
49 }
50
51 /// Create options with backend's default device and float dtype.
52 pub fn float() -> Self {
53 Self::default().with_dtype(<B::FloatElem as Element>::dtype())
54 }
55
56 /// Create options with backend's default device and int dtype.
57 pub fn int() -> Self {
58 Self::default().with_dtype(<B::IntElem as Element>::dtype())
59 }
60
61 /// Create options with backend's default device and bool dtype.
62 pub fn bool() -> Self {
63 Self::default().with_dtype(<B::BoolElem as Element>::dtype())
64 }
65
66 /// Returns the tensor data type, or a provided default if not set.
67 ///
68 /// This is useful for cases where [`TensorCreationOptions`] may not have an explicit `dtype`.
69 pub fn dtype_or(&self, dtype: DType) -> DType {
70 self.dtype.unwrap_or(dtype)
71 }
72
73 /// Returns the tensor data type, or the default from the [device policy](crate::set_default_dtypes).
74 pub(crate) fn resolve_policy(&self, dtype: DType) -> DType {
75 // TODO: should rely on tensor kind, not element dtype
76 self.dtype.unwrap_or_else(|| {
77 let policy = get_device_policy(&self.device);
78 if dtype.is_float()
79 && let Some(float_dtype) = policy.float_dtype()
80 {
81 float_dtype.into()
82 } else if (dtype.is_int() || dtype.is_uint())
83 && let Some(int_dtype) = policy.int_dtype()
84 {
85 int_dtype.into()
86 } else {
87 // If policy was not explicitly set, use the fallback dtype (default backend elem type)
88 dtype
89 }
90 })
91 }
92}
93
94impl<B: Backend> From<&Device<B>> for TensorCreationOptions<B> {
95 /// Convenience conversion from a reference to a device.
96 ///
97 /// Example:
98 /// ```rust
99 /// use burn_tensor::backend::Backend;
100 /// use burn_tensor::TensorCreationOptions;
101 ///
102 /// fn example<B: Backend>(device: B::Device) {
103 /// let options: TensorCreationOptions<B> = (&device).into();
104 /// }
105 /// ```
106 fn from(device: &Device<B>) -> Self {
107 TensorCreationOptions::new(device.clone())
108 }
109}
110
111impl<B: Backend> From<(&Device<B>, DType)> for TensorCreationOptions<B> {
112 /// Convenience conversion for a specified `(&device, dtype)` tuple.
113 fn from(args: (&Device<B>, DType)) -> Self {
114 TensorCreationOptions::new(args.0.clone()).with_dtype(args.1)
115 }
116}