burn_tensor/tensor/api/options.rs
1use burn_backend::{Backend, Element, tensor::Device};
2use burn_std::DType;
3
4/// Options for tensor creation.
5///
6/// This struct allows specifying the `device` and/or data type (`dtype`) when creating a tensor.
7#[derive(Debug, Clone)]
8pub struct TensorCreationOptions<B: Backend> {
9 /// Device where the tensor will be created.
10 pub device: Device<B>,
11 /// Optional data type.
12 /// If `None`, the dtype will be inferred on creation from the backend's default dtype for the tensor kind.
13 pub dtype: Option<DType>,
14}
15
16impl<B: Backend> Default for TensorCreationOptions<B> {
17 /// Returns new options with the backend's default device.
18 fn default() -> Self {
19 Self::new(Default::default())
20 }
21}
22
23impl<B: Backend> TensorCreationOptions<B> {
24 /// Create new options with a specific device.
25 ///
26 /// Data type will be inferred on creation from the backend's default dtype for the tensor kind.
27 pub fn new(device: Device<B>) -> Self {
28 Self {
29 device,
30 dtype: None,
31 }
32 }
33
34 /// Set the tensor creation data type.
35 pub fn with_dtype(mut self, dtype: DType) -> Self {
36 self.dtype = Some(dtype);
37
38 self
39 }
40
41 /// Set the tensor creation device.
42 pub fn with_device(mut self, device: Device<B>) -> Self {
43 self.device = device;
44
45 self
46 }
47
48 /// Create options with backend's default device and float dtype.
49 pub fn float() -> Self {
50 Self::default().with_dtype(<B::FloatElem as Element>::dtype())
51 }
52
53 /// Create options with backend's default device and int dtype.
54 pub fn int() -> Self {
55 Self::default().with_dtype(<B::IntElem as Element>::dtype())
56 }
57
58 /// Create options with backend's default device and bool dtype.
59 pub fn bool() -> Self {
60 Self::default().with_dtype(<B::BoolElem as Element>::dtype())
61 }
62
63 /// Returns the tensor data type, or a provided default if not set.
64 ///
65 /// This is useful for cases where [`TensorCreationOptions`] may not have an explicit `dtype`.
66 pub fn dtype_or(&self, dtype: DType) -> DType {
67 self.dtype.unwrap_or(dtype)
68 }
69}
70
71impl<B: Backend> From<&Device<B>> for TensorCreationOptions<B> {
72 /// Convenience conversion from a reference to a device.
73 ///
74 /// Example:
75 /// ```rust
76 /// use burn_tensor::backend::Backend;
77 /// use burn_tensor::TensorCreationOptions;
78 ///
79 /// fn example<B: Backend>(device: B::Device) {
80 /// let options: TensorCreationOptions<B> = (&device).into();
81 /// }
82 /// ```
83 fn from(device: &Device<B>) -> Self {
84 TensorCreationOptions::new(device.clone())
85 }
86}
87
88impl<B: Backend> From<(&Device<B>, DType)> for TensorCreationOptions<B> {
89 /// Convenience conversion for a specified `(&device, dtype)` tuple.
90 fn from(args: (&Device<B>, DType)) -> Self {
91 TensorCreationOptions::new(args.0.clone()).with_dtype(args.1)
92 }
93}