amari_functional/operator/
basic.rs

1//! Basic linear operators.
2//!
3//! This module provides fundamental operator types that serve as
4//! building blocks for more complex operators.
5
6use crate::error::Result;
7use crate::operator::traits::{BoundedOperator, LinearOperator};
8use crate::phantom::Bounded;
9use amari_core::Multivector;
10use core::marker::PhantomData;
11
12/// The identity operator I: x ↦ x.
13///
14/// The identity operator is the simplest non-trivial operator.
15/// It has operator norm 1 and is self-adjoint.
16#[derive(Debug, Clone, Copy)]
17pub struct IdentityOperator<V> {
18    _phantom: PhantomData<V>,
19}
20
21impl<V> Default for IdentityOperator<V> {
22    fn default() -> Self {
23        Self::new()
24    }
25}
26
27impl<V> IdentityOperator<V> {
28    /// Create a new identity operator.
29    pub fn new() -> Self {
30        Self {
31            _phantom: PhantomData,
32        }
33    }
34}
35
36impl<const P: usize, const Q: usize, const R: usize> LinearOperator<Multivector<P, Q, R>>
37    for IdentityOperator<Multivector<P, Q, R>>
38{
39    fn apply(&self, x: &Multivector<P, Q, R>) -> Result<Multivector<P, Q, R>> {
40        Ok(x.clone())
41    }
42
43    fn domain_dimension(&self) -> Option<usize> {
44        Some(1 << (P + Q + R))
45    }
46
47    fn codomain_dimension(&self) -> Option<usize> {
48        Some(1 << (P + Q + R))
49    }
50}
51
52impl<const P: usize, const Q: usize, const R: usize>
53    BoundedOperator<Multivector<P, Q, R>, Multivector<P, Q, R>, Bounded>
54    for IdentityOperator<Multivector<P, Q, R>>
55{
56    fn operator_norm(&self) -> f64 {
57        1.0
58    }
59}
60
61/// The zero operator 0: x ↦ 0.
62///
63/// The zero operator maps everything to zero.
64/// It has operator norm 0.
65#[derive(Debug, Clone, Copy)]
66pub struct ZeroOperator<V, W = V> {
67    _phantom: PhantomData<(V, W)>,
68}
69
70impl<V, W> Default for ZeroOperator<V, W> {
71    fn default() -> Self {
72        Self::new()
73    }
74}
75
76impl<V, W> ZeroOperator<V, W> {
77    /// Create a new zero operator.
78    pub fn new() -> Self {
79        Self {
80            _phantom: PhantomData,
81        }
82    }
83}
84
85impl<const P: usize, const Q: usize, const R: usize> LinearOperator<Multivector<P, Q, R>>
86    for ZeroOperator<Multivector<P, Q, R>>
87{
88    fn apply(&self, _x: &Multivector<P, Q, R>) -> Result<Multivector<P, Q, R>> {
89        Ok(Multivector::<P, Q, R>::zero())
90    }
91
92    fn domain_dimension(&self) -> Option<usize> {
93        Some(1 << (P + Q + R))
94    }
95
96    fn codomain_dimension(&self) -> Option<usize> {
97        Some(1 << (P + Q + R))
98    }
99}
100
101impl<const P: usize, const Q: usize, const R: usize>
102    BoundedOperator<Multivector<P, Q, R>, Multivector<P, Q, R>, Bounded>
103    for ZeroOperator<Multivector<P, Q, R>>
104{
105    fn operator_norm(&self) -> f64 {
106        0.0
107    }
108}
109
110/// Scaling operator αI: x ↦ αx.
111///
112/// Scales all elements by a fixed scalar.
113#[derive(Debug, Clone, Copy)]
114pub struct ScalingOperator<V> {
115    /// The scaling factor.
116    scalar: f64,
117    _phantom: PhantomData<V>,
118}
119
120impl<V> ScalingOperator<V> {
121    /// Create a new scaling operator.
122    pub fn new(scalar: f64) -> Self {
123        Self {
124            scalar,
125            _phantom: PhantomData,
126        }
127    }
128
129    /// Get the scaling factor.
130    pub fn scalar(&self) -> f64 {
131        self.scalar
132    }
133}
134
135impl<const P: usize, const Q: usize, const R: usize> LinearOperator<Multivector<P, Q, R>>
136    for ScalingOperator<Multivector<P, Q, R>>
137{
138    fn apply(&self, x: &Multivector<P, Q, R>) -> Result<Multivector<P, Q, R>> {
139        Ok(x * self.scalar)
140    }
141
142    fn domain_dimension(&self) -> Option<usize> {
143        Some(1 << (P + Q + R))
144    }
145
146    fn codomain_dimension(&self) -> Option<usize> {
147        Some(1 << (P + Q + R))
148    }
149}
150
151impl<const P: usize, const Q: usize, const R: usize>
152    BoundedOperator<Multivector<P, Q, R>, Multivector<P, Q, R>, Bounded>
153    for ScalingOperator<Multivector<P, Q, R>>
154{
155    fn operator_norm(&self) -> f64 {
156        self.scalar.abs()
157    }
158}
159
160/// Orthogonal projection operator onto a subspace.
161///
162/// Projects onto the span of a set of orthonormal basis vectors.
163#[derive(Clone)]
164pub struct ProjectionOperator<const P: usize, const Q: usize, const R: usize> {
165    /// Orthonormal basis for the projection subspace.
166    basis: Vec<Multivector<P, Q, R>>,
167}
168
169impl<const P: usize, const Q: usize, const R: usize> std::fmt::Debug
170    for ProjectionOperator<P, Q, R>
171{
172    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
173        f.debug_struct("ProjectionOperator")
174            .field("basis_size", &self.basis.len())
175            .field("signature", &(P, Q, R))
176            .finish()
177    }
178}
179
180impl<const P: usize, const Q: usize, const R: usize> ProjectionOperator<P, Q, R> {
181    /// Create a projection operator from an orthonormal basis.
182    ///
183    /// The basis vectors should already be orthonormal.
184    pub fn from_orthonormal_basis(basis: Vec<Multivector<P, Q, R>>) -> Self {
185        Self { basis }
186    }
187
188    /// Create a projection onto a single normalized direction.
189    pub fn onto_direction(direction: Multivector<P, Q, R>) -> Self {
190        Self {
191            basis: vec![direction],
192        }
193    }
194
195    /// Get the dimension of the projection subspace.
196    pub fn subspace_dimension(&self) -> usize {
197        self.basis.len()
198    }
199}
200
201impl<const P: usize, const Q: usize, const R: usize> LinearOperator<Multivector<P, Q, R>>
202    for ProjectionOperator<P, Q, R>
203{
204    fn apply(&self, x: &Multivector<P, Q, R>) -> Result<Multivector<P, Q, R>> {
205        let mut result = Multivector::<P, Q, R>::zero();
206
207        for basis_vec in &self.basis {
208            // Compute ⟨x, basis_vec⟩ * basis_vec
209            let x_coeffs = x.to_vec();
210            let b_coeffs = basis_vec.to_vec();
211            let inner_product: f64 = x_coeffs
212                .iter()
213                .zip(b_coeffs.iter())
214                .map(|(a, b)| a * b)
215                .sum();
216            result = result.add(&(basis_vec * inner_product));
217        }
218
219        Ok(result)
220    }
221
222    fn domain_dimension(&self) -> Option<usize> {
223        Some(1 << (P + Q + R))
224    }
225
226    fn codomain_dimension(&self) -> Option<usize> {
227        Some(1 << (P + Q + R))
228    }
229}
230
231impl<const P: usize, const Q: usize, const R: usize>
232    BoundedOperator<Multivector<P, Q, R>, Multivector<P, Q, R>, Bounded>
233    for ProjectionOperator<P, Q, R>
234{
235    fn operator_norm(&self) -> f64 {
236        if self.basis.is_empty() {
237            0.0
238        } else {
239            1.0
240        }
241    }
242}
243
244/// Composition of two operators: (S ∘ T)(x) = S(T(x)).
245#[derive(Clone)]
246pub struct CompositeOperator<S, T, V, W, U> {
247    /// The outer operator S: W → U.
248    outer: S,
249    /// The inner operator T: V → W.
250    inner: T,
251    _phantom: PhantomData<(V, W, U)>,
252}
253
254impl<S, T, V, W, U> std::fmt::Debug for CompositeOperator<S, T, V, W, U>
255where
256    S: std::fmt::Debug,
257    T: std::fmt::Debug,
258{
259    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
260        f.debug_struct("CompositeOperator")
261            .field("outer", &self.outer)
262            .field("inner", &self.inner)
263            .finish()
264    }
265}
266
267impl<S, T, V, W, U> CompositeOperator<S, T, V, W, U>
268where
269    S: LinearOperator<W, U>,
270    T: LinearOperator<V, W>,
271{
272    /// Create a composite operator S ∘ T.
273    pub fn new(outer: S, inner: T) -> Self {
274        Self {
275            outer,
276            inner,
277            _phantom: PhantomData,
278        }
279    }
280}
281
282impl<S, T, V, W, U> LinearOperator<V, U> for CompositeOperator<S, T, V, W, U>
283where
284    S: LinearOperator<W, U>,
285    T: LinearOperator<V, W>,
286{
287    fn apply(&self, x: &V) -> Result<U> {
288        let intermediate = self.inner.apply(x)?;
289        self.outer.apply(&intermediate)
290    }
291
292    fn domain_dimension(&self) -> Option<usize> {
293        self.inner.domain_dimension()
294    }
295
296    fn codomain_dimension(&self) -> Option<usize> {
297        self.outer.codomain_dimension()
298    }
299}
300
301impl<S, T, V, W, U> BoundedOperator<V, U, Bounded> for CompositeOperator<S, T, V, W, U>
302where
303    S: BoundedOperator<W, U, Bounded>,
304    T: BoundedOperator<V, W, Bounded>,
305{
306    fn operator_norm(&self) -> f64 {
307        // ||ST|| ≤ ||S|| ||T||
308        self.outer.operator_norm() * self.inner.operator_norm()
309    }
310}
311
312#[cfg(test)]
313mod tests {
314    use super::*;
315    use crate::space::MultivectorHilbertSpace;
316
317    #[test]
318    fn test_identity_operator() {
319        let identity: IdentityOperator<Multivector<2, 0, 0>> = IdentityOperator::new();
320        let x = Multivector::<2, 0, 0>::from_slice(&[1.0, 2.0, 3.0, 4.0]);
321        let y = identity.apply(&x).unwrap();
322        assert_eq!(x.to_vec(), y.to_vec());
323        assert!((identity.operator_norm() - 1.0).abs() < 1e-10);
324    }
325
326    #[test]
327    fn test_zero_operator() {
328        let zero: ZeroOperator<Multivector<2, 0, 0>> = ZeroOperator::new();
329        let x = Multivector::<2, 0, 0>::from_slice(&[1.0, 2.0, 3.0, 4.0]);
330        let y = zero.apply(&x).unwrap();
331        assert!(y.to_vec().iter().all(|&c| c.abs() < 1e-10));
332        assert!((zero.operator_norm() - 0.0).abs() < 1e-10);
333    }
334
335    #[test]
336    fn test_scaling_operator() {
337        let scale: ScalingOperator<Multivector<2, 0, 0>> = ScalingOperator::new(2.0);
338        let x = Multivector::<2, 0, 0>::from_slice(&[1.0, 2.0, 3.0, 4.0]);
339        let y = scale.apply(&x).unwrap();
340        assert_eq!(y.to_vec(), vec![2.0, 4.0, 6.0, 8.0]);
341        assert!((scale.operator_norm() - 2.0).abs() < 1e-10);
342    }
343
344    #[test]
345    fn test_projection_operator() {
346        let space: MultivectorHilbertSpace<2, 0, 0> = MultivectorHilbertSpace::new();
347
348        // Create a projection onto the first basis vector (e₀)
349        let e0 = space.basis_vector(0).unwrap();
350        let proj = ProjectionOperator::onto_direction(e0);
351
352        let x = Multivector::<2, 0, 0>::from_slice(&[3.0, 4.0, 0.0, 0.0]);
353        let y = proj.apply(&x).unwrap();
354
355        // Projection of (3, 4, 0, 0) onto (1, 0, 0, 0) should be (3, 0, 0, 0)
356        let y_coeffs = y.to_vec();
357        assert!((y_coeffs[0] - 3.0).abs() < 1e-10);
358        assert!(y_coeffs[1].abs() < 1e-10);
359    }
360
361    #[test]
362    fn test_composite_operator() {
363        let scale2: ScalingOperator<Multivector<2, 0, 0>> = ScalingOperator::new(2.0);
364        let scale3: ScalingOperator<Multivector<2, 0, 0>> = ScalingOperator::new(3.0);
365
366        let composite = CompositeOperator::new(scale2, scale3);
367
368        let x = Multivector::<2, 0, 0>::from_slice(&[1.0, 0.0, 0.0, 0.0]);
369        let y = composite.apply(&x).unwrap();
370
371        // 2 * (3 * x) = 6x
372        assert_eq!(y.to_vec(), vec![6.0, 0.0, 0.0, 0.0]);
373
374        // ||S ∘ T|| ≤ ||S|| ||T|| = 6
375        assert!((composite.operator_norm() - 6.0).abs() < 1e-10);
376    }
377
378    #[test]
379    fn test_projection_is_idempotent() {
380        let space: MultivectorHilbertSpace<2, 0, 0> = MultivectorHilbertSpace::new();
381
382        // Create orthonormal basis for a 2D subspace
383        let e0 = space.basis_vector(0).unwrap();
384        let e1 = space.basis_vector(1).unwrap();
385        let proj = ProjectionOperator::from_orthonormal_basis(vec![e0, e1]);
386
387        let x = Multivector::<2, 0, 0>::from_slice(&[1.0, 2.0, 3.0, 4.0]);
388        let y = proj.apply(&x).unwrap();
389        let z = proj.apply(&y).unwrap();
390
391        // P² = P (idempotent)
392        let y_coeffs = y.to_vec();
393        let z_coeffs = z.to_vec();
394        for (a, b) in y_coeffs.iter().zip(z_coeffs.iter()) {
395            assert!((a - b).abs() < 1e-10);
396        }
397    }
398}