1use crate::error::Result;
7use crate::operator::traits::{BoundedOperator, LinearOperator};
8use crate::phantom::Bounded;
9use amari_core::Multivector;
10use core::marker::PhantomData;
11
12#[derive(Debug, Clone, Copy)]
17pub struct IdentityOperator<V> {
18 _phantom: PhantomData<V>,
19}
20
21impl<V> Default for IdentityOperator<V> {
22 fn default() -> Self {
23 Self::new()
24 }
25}
26
27impl<V> IdentityOperator<V> {
28 pub fn new() -> Self {
30 Self {
31 _phantom: PhantomData,
32 }
33 }
34}
35
36impl<const P: usize, const Q: usize, const R: usize> LinearOperator<Multivector<P, Q, R>>
37 for IdentityOperator<Multivector<P, Q, R>>
38{
39 fn apply(&self, x: &Multivector<P, Q, R>) -> Result<Multivector<P, Q, R>> {
40 Ok(x.clone())
41 }
42
43 fn domain_dimension(&self) -> Option<usize> {
44 Some(1 << (P + Q + R))
45 }
46
47 fn codomain_dimension(&self) -> Option<usize> {
48 Some(1 << (P + Q + R))
49 }
50}
51
52impl<const P: usize, const Q: usize, const R: usize>
53 BoundedOperator<Multivector<P, Q, R>, Multivector<P, Q, R>, Bounded>
54 for IdentityOperator<Multivector<P, Q, R>>
55{
56 fn operator_norm(&self) -> f64 {
57 1.0
58 }
59}
60
61#[derive(Debug, Clone, Copy)]
66pub struct ZeroOperator<V, W = V> {
67 _phantom: PhantomData<(V, W)>,
68}
69
70impl<V, W> Default for ZeroOperator<V, W> {
71 fn default() -> Self {
72 Self::new()
73 }
74}
75
76impl<V, W> ZeroOperator<V, W> {
77 pub fn new() -> Self {
79 Self {
80 _phantom: PhantomData,
81 }
82 }
83}
84
85impl<const P: usize, const Q: usize, const R: usize> LinearOperator<Multivector<P, Q, R>>
86 for ZeroOperator<Multivector<P, Q, R>>
87{
88 fn apply(&self, _x: &Multivector<P, Q, R>) -> Result<Multivector<P, Q, R>> {
89 Ok(Multivector::<P, Q, R>::zero())
90 }
91
92 fn domain_dimension(&self) -> Option<usize> {
93 Some(1 << (P + Q + R))
94 }
95
96 fn codomain_dimension(&self) -> Option<usize> {
97 Some(1 << (P + Q + R))
98 }
99}
100
101impl<const P: usize, const Q: usize, const R: usize>
102 BoundedOperator<Multivector<P, Q, R>, Multivector<P, Q, R>, Bounded>
103 for ZeroOperator<Multivector<P, Q, R>>
104{
105 fn operator_norm(&self) -> f64 {
106 0.0
107 }
108}
109
110#[derive(Debug, Clone, Copy)]
114pub struct ScalingOperator<V> {
115 scalar: f64,
117 _phantom: PhantomData<V>,
118}
119
120impl<V> ScalingOperator<V> {
121 pub fn new(scalar: f64) -> Self {
123 Self {
124 scalar,
125 _phantom: PhantomData,
126 }
127 }
128
129 pub fn scalar(&self) -> f64 {
131 self.scalar
132 }
133}
134
135impl<const P: usize, const Q: usize, const R: usize> LinearOperator<Multivector<P, Q, R>>
136 for ScalingOperator<Multivector<P, Q, R>>
137{
138 fn apply(&self, x: &Multivector<P, Q, R>) -> Result<Multivector<P, Q, R>> {
139 Ok(x * self.scalar)
140 }
141
142 fn domain_dimension(&self) -> Option<usize> {
143 Some(1 << (P + Q + R))
144 }
145
146 fn codomain_dimension(&self) -> Option<usize> {
147 Some(1 << (P + Q + R))
148 }
149}
150
151impl<const P: usize, const Q: usize, const R: usize>
152 BoundedOperator<Multivector<P, Q, R>, Multivector<P, Q, R>, Bounded>
153 for ScalingOperator<Multivector<P, Q, R>>
154{
155 fn operator_norm(&self) -> f64 {
156 self.scalar.abs()
157 }
158}
159
160#[derive(Clone)]
164pub struct ProjectionOperator<const P: usize, const Q: usize, const R: usize> {
165 basis: Vec<Multivector<P, Q, R>>,
167}
168
169impl<const P: usize, const Q: usize, const R: usize> std::fmt::Debug
170 for ProjectionOperator<P, Q, R>
171{
172 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
173 f.debug_struct("ProjectionOperator")
174 .field("basis_size", &self.basis.len())
175 .field("signature", &(P, Q, R))
176 .finish()
177 }
178}
179
180impl<const P: usize, const Q: usize, const R: usize> ProjectionOperator<P, Q, R> {
181 pub fn from_orthonormal_basis(basis: Vec<Multivector<P, Q, R>>) -> Self {
185 Self { basis }
186 }
187
188 pub fn onto_direction(direction: Multivector<P, Q, R>) -> Self {
190 Self {
191 basis: vec![direction],
192 }
193 }
194
195 pub fn subspace_dimension(&self) -> usize {
197 self.basis.len()
198 }
199}
200
201impl<const P: usize, const Q: usize, const R: usize> LinearOperator<Multivector<P, Q, R>>
202 for ProjectionOperator<P, Q, R>
203{
204 fn apply(&self, x: &Multivector<P, Q, R>) -> Result<Multivector<P, Q, R>> {
205 let mut result = Multivector::<P, Q, R>::zero();
206
207 for basis_vec in &self.basis {
208 let x_coeffs = x.to_vec();
210 let b_coeffs = basis_vec.to_vec();
211 let inner_product: f64 = x_coeffs
212 .iter()
213 .zip(b_coeffs.iter())
214 .map(|(a, b)| a * b)
215 .sum();
216 result = result.add(&(basis_vec * inner_product));
217 }
218
219 Ok(result)
220 }
221
222 fn domain_dimension(&self) -> Option<usize> {
223 Some(1 << (P + Q + R))
224 }
225
226 fn codomain_dimension(&self) -> Option<usize> {
227 Some(1 << (P + Q + R))
228 }
229}
230
231impl<const P: usize, const Q: usize, const R: usize>
232 BoundedOperator<Multivector<P, Q, R>, Multivector<P, Q, R>, Bounded>
233 for ProjectionOperator<P, Q, R>
234{
235 fn operator_norm(&self) -> f64 {
236 if self.basis.is_empty() {
237 0.0
238 } else {
239 1.0
240 }
241 }
242}
243
244#[derive(Clone)]
246pub struct CompositeOperator<S, T, V, W, U> {
247 outer: S,
249 inner: T,
251 _phantom: PhantomData<(V, W, U)>,
252}
253
254impl<S, T, V, W, U> std::fmt::Debug for CompositeOperator<S, T, V, W, U>
255where
256 S: std::fmt::Debug,
257 T: std::fmt::Debug,
258{
259 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
260 f.debug_struct("CompositeOperator")
261 .field("outer", &self.outer)
262 .field("inner", &self.inner)
263 .finish()
264 }
265}
266
267impl<S, T, V, W, U> CompositeOperator<S, T, V, W, U>
268where
269 S: LinearOperator<W, U>,
270 T: LinearOperator<V, W>,
271{
272 pub fn new(outer: S, inner: T) -> Self {
274 Self {
275 outer,
276 inner,
277 _phantom: PhantomData,
278 }
279 }
280}
281
282impl<S, T, V, W, U> LinearOperator<V, U> for CompositeOperator<S, T, V, W, U>
283where
284 S: LinearOperator<W, U>,
285 T: LinearOperator<V, W>,
286{
287 fn apply(&self, x: &V) -> Result<U> {
288 let intermediate = self.inner.apply(x)?;
289 self.outer.apply(&intermediate)
290 }
291
292 fn domain_dimension(&self) -> Option<usize> {
293 self.inner.domain_dimension()
294 }
295
296 fn codomain_dimension(&self) -> Option<usize> {
297 self.outer.codomain_dimension()
298 }
299}
300
301impl<S, T, V, W, U> BoundedOperator<V, U, Bounded> for CompositeOperator<S, T, V, W, U>
302where
303 S: BoundedOperator<W, U, Bounded>,
304 T: BoundedOperator<V, W, Bounded>,
305{
306 fn operator_norm(&self) -> f64 {
307 self.outer.operator_norm() * self.inner.operator_norm()
309 }
310}
311
312#[cfg(test)]
313mod tests {
314 use super::*;
315 use crate::space::MultivectorHilbertSpace;
316
317 #[test]
318 fn test_identity_operator() {
319 let identity: IdentityOperator<Multivector<2, 0, 0>> = IdentityOperator::new();
320 let x = Multivector::<2, 0, 0>::from_slice(&[1.0, 2.0, 3.0, 4.0]);
321 let y = identity.apply(&x).unwrap();
322 assert_eq!(x.to_vec(), y.to_vec());
323 assert!((identity.operator_norm() - 1.0).abs() < 1e-10);
324 }
325
326 #[test]
327 fn test_zero_operator() {
328 let zero: ZeroOperator<Multivector<2, 0, 0>> = ZeroOperator::new();
329 let x = Multivector::<2, 0, 0>::from_slice(&[1.0, 2.0, 3.0, 4.0]);
330 let y = zero.apply(&x).unwrap();
331 assert!(y.to_vec().iter().all(|&c| c.abs() < 1e-10));
332 assert!((zero.operator_norm() - 0.0).abs() < 1e-10);
333 }
334
335 #[test]
336 fn test_scaling_operator() {
337 let scale: ScalingOperator<Multivector<2, 0, 0>> = ScalingOperator::new(2.0);
338 let x = Multivector::<2, 0, 0>::from_slice(&[1.0, 2.0, 3.0, 4.0]);
339 let y = scale.apply(&x).unwrap();
340 assert_eq!(y.to_vec(), vec![2.0, 4.0, 6.0, 8.0]);
341 assert!((scale.operator_norm() - 2.0).abs() < 1e-10);
342 }
343
344 #[test]
345 fn test_projection_operator() {
346 let space: MultivectorHilbertSpace<2, 0, 0> = MultivectorHilbertSpace::new();
347
348 let e0 = space.basis_vector(0).unwrap();
350 let proj = ProjectionOperator::onto_direction(e0);
351
352 let x = Multivector::<2, 0, 0>::from_slice(&[3.0, 4.0, 0.0, 0.0]);
353 let y = proj.apply(&x).unwrap();
354
355 let y_coeffs = y.to_vec();
357 assert!((y_coeffs[0] - 3.0).abs() < 1e-10);
358 assert!(y_coeffs[1].abs() < 1e-10);
359 }
360
361 #[test]
362 fn test_composite_operator() {
363 let scale2: ScalingOperator<Multivector<2, 0, 0>> = ScalingOperator::new(2.0);
364 let scale3: ScalingOperator<Multivector<2, 0, 0>> = ScalingOperator::new(3.0);
365
366 let composite = CompositeOperator::new(scale2, scale3);
367
368 let x = Multivector::<2, 0, 0>::from_slice(&[1.0, 0.0, 0.0, 0.0]);
369 let y = composite.apply(&x).unwrap();
370
371 assert_eq!(y.to_vec(), vec![6.0, 0.0, 0.0, 0.0]);
373
374 assert!((composite.operator_norm() - 6.0).abs() < 1e-10);
376 }
377
378 #[test]
379 fn test_projection_is_idempotent() {
380 let space: MultivectorHilbertSpace<2, 0, 0> = MultivectorHilbertSpace::new();
381
382 let e0 = space.basis_vector(0).unwrap();
384 let e1 = space.basis_vector(1).unwrap();
385 let proj = ProjectionOperator::from_orthonormal_basis(vec![e0, e1]);
386
387 let x = Multivector::<2, 0, 0>::from_slice(&[1.0, 2.0, 3.0, 4.0]);
388 let y = proj.apply(&x).unwrap();
389 let z = proj.apply(&y).unwrap();
390
391 let y_coeffs = y.to_vec();
393 let z_coeffs = z.to_vec();
394 for (a, b) in y_coeffs.iter().zip(z_coeffs.iter()) {
395 assert!((a - b).abs() < 1e-10);
396 }
397 }
398}