easy_ml/tensors/views/
traits.rs

1/*!
2 * Trait implementations for [TensorRef] and [TensorMut].
3 *
4 * These implementations are written here but Rust docs will display them on the
5 * traits' pages.
6 *
7 * An owned or referenced [Tensor] is a TensorRef, and a TensorMut if not a shared
8 * reference, Therefore, you can pass a Tensor to any function which takes a TensorRef.
9 *
10 * Boxed TensorRef and TensorMut values also implement TensorRef and TensorMut respectively.
11 */
12
13use crate::tensors::Dimension;
14#[allow(unused_imports)] // used in doc links
15use crate::tensors::Tensor;
16use crate::tensors::views::{DataLayout, TensorMut, TensorRef};
17
18// # Safety
19//
20// The type implementing TensorRef must implement it correctly, so by delegating to it
21// without changing any indexes or introducing interior mutability, we implement TensorRef
22// correctly as well.
23/**
24 * If some type implements TensorRef, then a reference to it implements TensorRef as well
25 */
26unsafe impl<'source, T, S, const D: usize> TensorRef<T, D> for &'source S
27where
28    S: TensorRef<T, D>,
29{
30    fn get_reference(&self, indexes: [usize; D]) -> Option<&T> {
31        TensorRef::get_reference(*self, indexes)
32    }
33
34    fn view_shape(&self) -> [(Dimension, usize); D] {
35        TensorRef::view_shape(*self)
36    }
37
38    unsafe fn get_reference_unchecked(&self, indexes: [usize; D]) -> &T {
39        unsafe { TensorRef::get_reference_unchecked(*self, indexes) }
40    }
41
42    fn data_layout(&self) -> DataLayout<D> {
43        TensorRef::data_layout(*self)
44    }
45}
46
47// # Safety
48//
49// The type implementing TensorRef must implement it correctly, so by delegating to it
50// without changing any indexes or introducing interior mutability, we implement TensorRef
51// correctly as well.
52/**
53 * If some type implements TensorRef, then an exclusive reference to it implements TensorRef
54 * as well
55 */
56unsafe impl<'source, T, S, const D: usize> TensorRef<T, D> for &'source mut S
57where
58    S: TensorRef<T, D>,
59{
60    fn get_reference(&self, indexes: [usize; D]) -> Option<&T> {
61        TensorRef::get_reference(*self, indexes)
62    }
63
64    fn view_shape(&self) -> [(Dimension, usize); D] {
65        TensorRef::view_shape(*self)
66    }
67
68    unsafe fn get_reference_unchecked(&self, indexes: [usize; D]) -> &T {
69        unsafe { TensorRef::get_reference_unchecked(*self, indexes) }
70    }
71
72    fn data_layout(&self) -> DataLayout<D> {
73        TensorRef::data_layout(*self)
74    }
75}
76
77// # Safety
78//
79// The type implementing TensorMut must implement it correctly, so by delegating to it
80// without changing any indexes or introducing interior mutability, we implement TensorMut
81// correctly as well.
82/**
83 * If some type implements TensorMut, then an exclusive reference to it implements TensorMut
84 * as well
85 */
86unsafe impl<'source, T, S, const D: usize> TensorMut<T, D> for &'source mut S
87where
88    S: TensorMut<T, D>,
89{
90    fn get_reference_mut(&mut self, indexes: [usize; D]) -> Option<&mut T> {
91        TensorMut::get_reference_mut(*self, indexes)
92    }
93
94    unsafe fn get_reference_unchecked_mut(&mut self, indexes: [usize; D]) -> &mut T {
95        unsafe { TensorMut::get_reference_unchecked_mut(*self, indexes) }
96    }
97}
98
99// # Safety
100//
101// The type implementing TensorRef must implement it correctly, so by delegating to it
102// without changing any indexes or introducing interior mutability, we implement TensorRef
103// correctly as well.
104/**
105 * A box of a TensorRef also implements TensorRef.
106 */
107unsafe impl<T, S, const D: usize> TensorRef<T, D> for Box<S>
108where
109    S: TensorRef<T, D>,
110{
111    fn get_reference(&self, indexes: [usize; D]) -> Option<&T> {
112        self.as_ref().get_reference(indexes)
113    }
114
115    fn view_shape(&self) -> [(Dimension, usize); D] {
116        self.as_ref().view_shape()
117    }
118
119    unsafe fn get_reference_unchecked(&self, indexes: [usize; D]) -> &T {
120        unsafe { self.as_ref().get_reference_unchecked(indexes) }
121    }
122
123    fn data_layout(&self) -> DataLayout<D> {
124        self.as_ref().data_layout()
125    }
126}
127
128// # Safety
129//
130// The type implementing TensorMut must implement it correctly, so by delegating to it
131// without changing any indexes or introducing interior mutability, we implement TensorMut
132// correctly as well.
133/**
134 * A box of a TensorMut also implements TensorMut.
135 */
136unsafe impl<T, S, const D: usize> TensorMut<T, D> for Box<S>
137where
138    S: TensorMut<T, D>,
139{
140    fn get_reference_mut(&mut self, indexes: [usize; D]) -> Option<&mut T> {
141        self.as_mut().get_reference_mut(indexes)
142    }
143
144    unsafe fn get_reference_unchecked_mut(&mut self, indexes: [usize; D]) -> &mut T {
145        unsafe { self.as_mut().get_reference_unchecked_mut(indexes) }
146    }
147}
148
149// # Safety
150//
151// The type implementing TensorRef must implement it correctly, so by delegating to it
152// without changing any indexes or introducing interior mutability, we implement TensorRef
153// correctly as well.
154/**
155 * A box of a dynamic TensorRef also implements TensorRef.
156 */
157unsafe impl<T, const D: usize> TensorRef<T, D> for Box<dyn TensorRef<T, D>> {
158    fn get_reference(&self, indexes: [usize; D]) -> Option<&T> {
159        self.as_ref().get_reference(indexes)
160    }
161
162    fn view_shape(&self) -> [(Dimension, usize); D] {
163        self.as_ref().view_shape()
164    }
165
166    unsafe fn get_reference_unchecked(&self, indexes: [usize; D]) -> &T {
167        unsafe { self.as_ref().get_reference_unchecked(indexes) }
168    }
169
170    fn data_layout(&self) -> DataLayout<D> {
171        self.as_ref().data_layout()
172    }
173}
174
175// # Safety
176//
177// The type implementing TensorMut must implement TensorRef correctly, so by delegating to it
178// without changing any indexes or introducing interior mutability, we implement TensorRef
179// correctly as well.
180/**
181 * A box of a dynamic TensorMut also implements TensorRef.
182 */
183unsafe impl<T, const D: usize> TensorRef<T, D> for Box<dyn TensorMut<T, D>> {
184    fn get_reference(&self, indexes: [usize; D]) -> Option<&T> {
185        self.as_ref().get_reference(indexes)
186    }
187
188    fn view_shape(&self) -> [(Dimension, usize); D] {
189        self.as_ref().view_shape()
190    }
191
192    unsafe fn get_reference_unchecked(&self, indexes: [usize; D]) -> &T {
193        unsafe { self.as_ref().get_reference_unchecked(indexes) }
194    }
195
196    fn data_layout(&self) -> DataLayout<D> {
197        self.as_ref().data_layout()
198    }
199}
200
201// # Safety
202//
203// The type implementing TensorMut must implement it correctly, so by delegating to it
204// without changing any indexes or introducing interior mutability, we implement TensorMut
205// correctly as well.
206/**
207 * A box of a dynamic TensorMut also implements TensorMut.
208 */
209unsafe impl<T, const D: usize> TensorMut<T, D> for Box<dyn TensorMut<T, D>> {
210    fn get_reference_mut(&mut self, indexes: [usize; D]) -> Option<&mut T> {
211        self.as_mut().get_reference_mut(indexes)
212    }
213
214    unsafe fn get_reference_unchecked_mut(&mut self, indexes: [usize; D]) -> &mut T {
215        unsafe { self.as_mut().get_reference_unchecked_mut(indexes) }
216    }
217}