1#[cfg(feature = "nightly")]
2use alloc::alloc::Allocator;
3#[cfg(not(feature = "std"))]
4use alloc::vec::Vec;
5
6#[cfg(not(feature = "nightly"))]
7use crate::allocator::Allocator;
8use crate::expr::adapters::{Cloned, Copied, Enumerate, Map, Zip};
9use crate::expr::iter::Iter;
10use crate::shape::Shape;
11use crate::tensor::Tensor;
12use crate::traits::IntoCloned;
13
14pub trait Apply<T>: IntoExpression {
16 type Output<F: FnMut(Self::Item) -> T>: IntoExpression<Item = T, Shape = Self::Shape>;
18
19 type ZippedWith<I: IntoExpression, F>: IntoExpression<Item = T>
21 where
22 F: FnMut((Self::Item, I::Item)) -> T;
23
24 fn apply<F: FnMut(Self::Item) -> T>(self, f: F) -> Self::Output<F>;
26
27 fn zip_with<I: IntoExpression, F>(self, expr: I, f: F) -> Self::ZippedWith<I, F>
29 where
30 F: FnMut((Self::Item, I::Item)) -> T;
31}
32
33pub trait Expression: IntoIterator {
35 type Shape: Shape;
37
38 const IS_REPEATABLE: bool;
40
41 fn shape(&self) -> &Self::Shape;
43
44 #[inline]
46 fn cloned<'a, T: 'a + Clone>(self) -> Cloned<Self>
47 where
48 Self: Expression<Item = &'a T> + Sized,
49 {
50 Cloned::new(self)
51 }
52
53 #[inline]
55 fn copied<'a, T: 'a + Copy>(self) -> Copied<Self>
56 where
57 Self: Expression<Item = &'a T> + Sized,
58 {
59 Copied::new(self)
60 }
61
62 #[inline]
68 fn dim(&self, index: usize) -> usize {
69 self.shape().dim(index)
70 }
71
72 #[inline]
74 fn enumerate(self) -> Enumerate<Self>
75 where
76 Self: Sized,
77 {
78 Enumerate::new(self)
79 }
80
81 #[inline]
83 fn eq<I: IntoExpression>(self, other: I) -> bool
84 where
85 Self: Expression<Item: PartialEq<I::Item>> + Sized,
86 {
87 self.eq_by(other, |x, y| x == y)
88 }
89
90 #[inline]
93 fn eq_by<I: IntoExpression, F>(self, other: I, mut eq: F) -> bool
94 where
95 Self: Sized,
96 F: FnMut(Self::Item, I::Item) -> bool,
97 {
98 let other = other.into_expr();
99
100 self.shape().with_dims(|dims| other.shape().with_dims(|other| dims == other))
101 && self.zip(other).into_iter().all(|(x, y)| eq(x, y))
102 }
103
104 #[inline]
110 fn eval(self) -> <Self::Shape as Shape>::Owned<Self::Item>
111 where
112 Self: Sized,
113 {
114 FromExpression::from_expr(self)
115 }
116
117 #[inline]
127 fn eval_into<S: Shape, A: Allocator>(
128 self,
129 tensor: &mut Tensor<Self::Item, S, A>,
130 ) -> &mut Tensor<Self::Item, S, A>
131 where
132 Self: Sized,
133 {
134 tensor.expand(self);
135 tensor
136 }
137
138 #[inline]
140 fn fold<T, F: FnMut(T, Self::Item) -> T>(self, init: T, f: F) -> T
141 where
142 Self: Sized,
143 {
144 Iter::new(self).fold(init, f)
145 }
146
147 #[inline]
149 fn for_each<F: FnMut(Self::Item)>(self, mut f: F)
150 where
151 Self: Sized,
152 {
153 self.fold((), |(), x| f(x));
154 }
155
156 #[inline]
158 fn is_empty(&self) -> bool {
159 self.shape().is_empty()
160 }
161
162 #[inline]
164 fn len(&self) -> usize {
165 self.shape().len()
166 }
167
168 #[inline]
170 fn map<T, F: FnMut(Self::Item) -> T>(self, f: F) -> Map<Self, F>
171 where
172 Self: Sized,
173 {
174 Map::new(self, f)
175 }
176
177 #[inline]
179 fn ne<I: IntoExpression>(self, other: I) -> bool
180 where
181 Self: Expression<Item: PartialEq<I::Item>> + Sized,
182 {
183 !self.eq(other)
184 }
185
186 #[inline]
188 fn rank(&self) -> usize {
189 self.shape().rank()
190 }
191
192 #[inline]
198 fn zip<I: IntoExpression>(self, other: I) -> Zip<Self, I::IntoExpr>
199 where
200 Self: Sized,
201 {
202 Zip::new(self, other.into_expr())
203 }
204
205 #[doc(hidden)]
206 unsafe fn get_unchecked(&mut self, index: usize) -> Self::Item;
207
208 #[doc(hidden)]
209 fn inner_rank(&self) -> usize;
210
211 #[doc(hidden)]
212 unsafe fn reset_dim(&mut self, index: usize, count: usize);
213
214 #[doc(hidden)]
215 unsafe fn step_dim(&mut self, index: usize);
216
217 #[cfg(not(feature = "nightly"))]
218 #[doc(hidden)]
219 #[inline]
220 fn clone_into_vec<T>(self, vec: &mut Vec<T>)
221 where
222 Self: Expression<Item: IntoCloned<T>> + Sized,
223 {
224 assert!(self.len() <= vec.capacity() - vec.len(), "length exceeds capacity");
225
226 self.for_each(|x| unsafe {
227 vec.as_mut_ptr().add(vec.len()).write(x.into_cloned());
228 vec.set_len(vec.len() + 1);
229 });
230 }
231
232 #[cfg(feature = "nightly")]
233 #[doc(hidden)]
234 #[inline]
235 fn clone_into_vec<T, A: Allocator>(self, vec: &mut Vec<T, A>)
236 where
237 Self: Expression<Item: IntoCloned<T>> + Sized,
238 {
239 assert!(self.len() <= vec.capacity() - vec.len(), "length exceeds capacity");
240
241 self.for_each(|x| unsafe {
242 vec.as_mut_ptr().add(vec.len()).write(x.into_cloned());
243 vec.set_len(vec.len() + 1);
244 });
245 }
246}
247
248pub trait FromExpression<T, S: Shape>: Sized {
250 fn from_expr<I: IntoExpression<Item = T, Shape = S>>(expr: I) -> Self;
252}
253
254pub trait IntoExpression: IntoIterator {
256 type Shape: Shape;
258
259 type IntoExpr: Expression<Item = Self::Item, Shape = Self::Shape>;
261
262 fn into_expr(self) -> Self::IntoExpr;
264}
265
266impl<T, E: Expression> Apply<T> for E {
267 type Output<F: FnMut(Self::Item) -> T> = Map<E, F>;
268 type ZippedWith<I: IntoExpression, F: FnMut((Self::Item, I::Item)) -> T> =
269 Map<Zip<Self, I::IntoExpr>, F>;
270
271 #[inline]
272 fn apply<F: FnMut(Self::Item) -> T>(self, f: F) -> Self::Output<F> {
273 self.map(f)
274 }
275
276 #[inline]
277 fn zip_with<I: IntoExpression, F>(self, expr: I, f: F) -> Self::ZippedWith<I, F>
278 where
279 F: FnMut((Self::Item, I::Item)) -> T,
280 {
281 self.zip(expr).map(f)
282 }
283}
284
285impl<E: Expression> IntoExpression for E {
286 type Shape = E::Shape;
287 type IntoExpr = E;
288
289 #[inline]
290 fn into_expr(self) -> Self {
291 self
292 }
293}