burn_tensor/tensor/api/options.rs
1use burn_backend::{
2 Backend, Element, get_device_settings,
3 tensor::{BasicOps, Device},
4};
5use burn_std::DType;
6
7/// Options for tensor creation.
8///
9/// This struct allows specifying the `device` and overriding the data type when creating a tensor.
10/// When the `dtype` is not specified, the [device's default policy](crate::set_default_dtypes) is used.
11#[derive(Debug, Clone)]
12pub struct TensorCreationOptions<B: Backend> {
13 /// Device where the tensor will be created.
14 pub device: Device<B>,
15 /// Optional data type.
16 /// If `None`, the dtype will be inferred on creation from the [device policy](crate::set_default_dtypes).
17 pub dtype: Option<DType>,
18}
19
20impl<B: Backend> Default for TensorCreationOptions<B> {
21 /// Returns new options with the backend's default device.
22 fn default() -> Self {
23 Self::new(Default::default())
24 }
25}
26
27impl<B: Backend> TensorCreationOptions<B> {
28 /// Create new options with a specific device.
29 ///
30 /// Data type will follow the [device policy](crate::set_default_dtypes) on tensor creation.
31 pub fn new(device: Device<B>) -> Self {
32 Self {
33 device,
34 dtype: None,
35 }
36 }
37
38 /// Set the tensor creation data type.
39 pub fn with_dtype(mut self, dtype: DType) -> Self {
40 self.dtype = Some(dtype);
41
42 self
43 }
44
45 /// Set the tensor creation device.
46 pub fn with_device(mut self, device: Device<B>) -> Self {
47 self.device = device;
48
49 self
50 }
51
52 /// Create options with backend's default device and float dtype.
53 pub fn float() -> Self {
54 Self::default().with_dtype(<B::FloatElem as Element>::dtype())
55 }
56
57 /// Create options with backend's default device and int dtype.
58 pub fn int() -> Self {
59 Self::default().with_dtype(<B::IntElem as Element>::dtype())
60 }
61
62 /// Create options with backend's default device and bool dtype.
63 pub fn bool() -> Self {
64 Self::default().with_dtype(<B::BoolElem as Element>::dtype())
65 }
66
67 /// Returns the tensor data type, or a provided default if not set.
68 ///
69 /// This is useful for cases where [`TensorCreationOptions`] may not have an explicit `dtype`.
70 pub fn dtype_or(&self, dtype: DType) -> DType {
71 self.dtype.unwrap_or(dtype)
72 }
73
74 /// Returns the tensor data type, or the default from the [device settings](crate::set_default_dtypes).
75 pub(crate) fn resolve_dtype<K: BasicOps<B>>(&self) -> DType {
76 let dtype = K::Elem::dtype();
77 let kind_name = K::name();
78 // TODO: tensor kind enum?
79 self.dtype.unwrap_or_else(|| {
80 let settings = get_device_settings::<B>(&self.device);
81 if dtype.is_float() && kind_name == "Float" {
82 settings.float_dtype.into()
83 } else if (dtype.is_int() || dtype.is_uint()) && kind_name == "Int" {
84 settings.int_dtype.into()
85 } else {
86 settings.bool_dtype.into()
87 }
88 })
89 }
90}
91
92impl<B: Backend> From<&Device<B>> for TensorCreationOptions<B> {
93 /// Convenience conversion from a reference to a device.
94 ///
95 /// Example:
96 /// ```rust
97 /// use burn_tensor::backend::Backend;
98 /// use burn_tensor::TensorCreationOptions;
99 ///
100 /// fn example<B: Backend>(device: B::Device) {
101 /// let options: TensorCreationOptions<B> = (&device).into();
102 /// }
103 /// ```
104 fn from(device: &Device<B>) -> Self {
105 TensorCreationOptions::new(device.clone())
106 }
107}
108
109impl<B: Backend> From<(&Device<B>, DType)> for TensorCreationOptions<B> {
110 /// Convenience conversion for a specified `(&device, dtype)` tuple.
111 fn from(args: (&Device<B>, DType)) -> Self {
112 TensorCreationOptions::new(args.0.clone()).with_dtype(args.1)
113 }
114}