1use crate::error::{FunctionalError, Result};
18use crate::phantom::Complete;
19use crate::space::traits::{
20 BanachSpace, HilbertSpace, InnerProductSpace, NormedSpace, VectorSpace,
21};
22use amari_core::Multivector;
23use core::marker::PhantomData;
24use std::sync::Arc;
25
26#[derive(Debug, Clone)]
30pub enum Domain<T> {
31 Interval {
33 a: T,
35 b: T,
37 },
38 Rectangle {
40 bounds: Vec<(T, T)>,
42 },
43 RealLine,
45}
46
47impl Domain<f64> {
48 pub fn interval(a: f64, b: f64) -> Self {
50 Domain::Interval { a, b }
51 }
52
53 pub fn rectangle(bounds: Vec<(f64, f64)>) -> Self {
55 Domain::Rectangle { bounds }
56 }
57
58 pub fn bounds_1d(&self) -> Option<(f64, f64)> {
62 match self {
63 Domain::Interval { a, b } => Some((*a, *b)),
64 Domain::Rectangle { bounds } if bounds.len() == 1 => Some(bounds[0]),
65 _ => None,
66 }
67 }
68
69 pub fn dimension(&self) -> Option<usize> {
71 match self {
72 Domain::Interval { .. } => Some(1),
73 Domain::Rectangle { bounds } => Some(bounds.len()),
74 Domain::RealLine => Some(1),
75 }
76 }
77}
78
79#[derive(Clone)]
84pub struct L2Function<const P: usize, const Q: usize, const R: usize> {
85 func: Arc<dyn Fn(&[f64]) -> Multivector<P, Q, R> + Send + Sync>,
87 cached_norm: Option<f64>,
89}
90
91impl<const P: usize, const Q: usize, const R: usize> std::fmt::Debug for L2Function<P, Q, R> {
92 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
93 f.debug_struct("L2Function")
94 .field("signature", &(P, Q, R))
95 .field("cached_norm", &self.cached_norm)
96 .finish()
97 }
98}
99
100impl<const P: usize, const Q: usize, const R: usize> L2Function<P, Q, R> {
101 pub fn new<F>(f: F) -> Self
103 where
104 F: Fn(&[f64]) -> Multivector<P, Q, R> + Send + Sync + 'static,
105 {
106 Self {
107 func: Arc::new(f),
108 cached_norm: None,
109 }
110 }
111
112 pub fn eval(&self, point: &[f64]) -> Multivector<P, Q, R> {
114 (self.func)(point)
115 }
116
117 pub fn zero_function() -> Self {
119 Self::new(|_| Multivector::<P, Q, R>::zero())
120 }
121
122 pub fn constant(value: Multivector<P, Q, R>) -> Self {
124 Self::new(move |_| value.clone())
125 }
126}
127
128#[derive(Clone)]
138pub struct MultivectorL2<const P: usize, const Q: usize, const R: usize> {
139 domain: Domain<f64>,
141 quadrature_points: usize,
143 _phantom: PhantomData<Multivector<P, Q, R>>,
144}
145
146impl<const P: usize, const Q: usize, const R: usize> std::fmt::Debug for MultivectorL2<P, Q, R> {
147 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
148 f.debug_struct("MultivectorL2")
149 .field("signature", &(P, Q, R))
150 .field("domain", &self.domain)
151 .field("quadrature_points", &self.quadrature_points)
152 .finish()
153 }
154}
155
156impl<const P: usize, const Q: usize, const R: usize> MultivectorL2<P, Q, R> {
157 pub const CODOMAIN_DIM: usize = 1 << (P + Q + R);
159
160 pub fn new(domain: Domain<f64>) -> Self {
162 Self {
163 domain,
164 quadrature_points: 32, _phantom: PhantomData,
166 }
167 }
168
169 pub fn unit_interval() -> Self {
171 Self::new(Domain::interval(0.0, 1.0))
172 }
173
174 pub fn interval(a: f64, b: f64) -> Self {
176 Self::new(Domain::interval(a, b))
177 }
178
179 pub fn with_quadrature_points(mut self, n: usize) -> Self {
181 self.quadrature_points = n;
182 self
183 }
184
185 pub fn domain(&self) -> &Domain<f64> {
187 &self.domain
188 }
189
190 pub fn signature(&self) -> (usize, usize, usize) {
192 (P, Q, R)
193 }
194
195 fn integrate<F>(&self, f: F) -> f64
197 where
198 F: Fn(&[f64]) -> f64,
199 {
200 let (a, b) = self.domain.bounds_1d().unwrap_or((0.0, 1.0));
202 gauss_legendre_integrate(&f, a, b, self.quadrature_points)
203 }
204
205 pub fn l2_inner_product(&self, f: &L2Function<P, Q, R>, g: &L2Function<P, Q, R>) -> f64 {
207 self.integrate(|x| {
208 let fx = f.eval(x);
209 let gx = g.eval(x);
210 fx.to_vec()
212 .iter()
213 .zip(gx.to_vec().iter())
214 .map(|(a, b)| a * b)
215 .sum()
216 })
217 }
218
219 pub fn l2_norm(&self, f: &L2Function<P, Q, R>) -> f64 {
221 self.l2_inner_product(f, f).sqrt()
222 }
223}
224
225fn gauss_legendre_integrate<F>(f: &F, a: f64, b: f64, n: usize) -> f64
227where
228 F: Fn(&[f64]) -> f64,
229{
230 let h = (b - a) / n as f64;
232 let mut sum = 0.5 * (f(&[a]) + f(&[b]));
233 for i in 1..n {
234 let x = a + i as f64 * h;
235 sum += f(&[x]);
236 }
237 sum * h
238}
239
240impl<const P: usize, const Q: usize, const R: usize> VectorSpace<L2Function<P, Q, R>, f64>
241 for MultivectorL2<P, Q, R>
242{
243 fn add(&self, f: &L2Function<P, Q, R>, g: &L2Function<P, Q, R>) -> L2Function<P, Q, R> {
244 let f_clone = f.func.clone();
245 let g_clone = g.func.clone();
246 L2Function::new(move |x| f_clone(x).add(&g_clone(x)))
247 }
248
249 fn sub(&self, f: &L2Function<P, Q, R>, g: &L2Function<P, Q, R>) -> L2Function<P, Q, R> {
250 let f_clone = f.func.clone();
251 let g_clone = g.func.clone();
252 L2Function::new(move |x| &f_clone(x) - &g_clone(x))
253 }
254
255 fn scale(&self, scalar: f64, f: &L2Function<P, Q, R>) -> L2Function<P, Q, R> {
256 let f_clone = f.func.clone();
257 L2Function::new(move |x| &f_clone(x) * scalar)
258 }
259
260 fn zero(&self) -> L2Function<P, Q, R> {
261 L2Function::zero_function()
262 }
263
264 fn dimension(&self) -> Option<usize> {
265 None
267 }
268}
269
270impl<const P: usize, const Q: usize, const R: usize> NormedSpace<L2Function<P, Q, R>, f64>
271 for MultivectorL2<P, Q, R>
272{
273 fn norm(&self, f: &L2Function<P, Q, R>) -> f64 {
274 self.l2_norm(f)
275 }
276
277 fn normalize(&self, f: &L2Function<P, Q, R>) -> Option<L2Function<P, Q, R>> {
278 let n = self.norm(f);
279 if n < 1e-15 {
280 None
281 } else {
282 Some(self.scale(1.0 / n, f))
283 }
284 }
285}
286
287impl<const P: usize, const Q: usize, const R: usize> InnerProductSpace<L2Function<P, Q, R>, f64>
288 for MultivectorL2<P, Q, R>
289{
290 fn inner_product(&self, f: &L2Function<P, Q, R>, g: &L2Function<P, Q, R>) -> f64 {
291 self.l2_inner_product(f, g)
292 }
293
294 fn project(&self, f: &L2Function<P, Q, R>, g: &L2Function<P, Q, R>) -> L2Function<P, Q, R> {
295 let ip_fg = self.inner_product(f, g);
296 let ip_gg = self.inner_product(g, g);
297 if ip_gg.abs() < 1e-15 {
298 self.zero()
299 } else {
300 self.scale(ip_fg / ip_gg, g)
301 }
302 }
303
304 fn gram_schmidt(&self, functions: &[L2Function<P, Q, R>]) -> Vec<L2Function<P, Q, R>> {
305 let mut orthonormal = Vec::new();
306 for f in functions {
307 let mut u = f.clone();
308 for q in &orthonormal {
309 let proj = self.project(&u, q);
310 u = self.sub(&u, &proj);
311 }
312 if let Some(normalized) = self.normalize(&u) {
313 orthonormal.push(normalized);
314 }
315 }
316 orthonormal
317 }
318}
319
320impl<const P: usize, const Q: usize, const R: usize> BanachSpace<L2Function<P, Q, R>, f64, Complete>
321 for MultivectorL2<P, Q, R>
322{
323 fn is_cauchy_sequence(&self, sequence: &[L2Function<P, Q, R>], tolerance: f64) -> bool {
324 if sequence.len() < 2 {
325 return true;
326 }
327
328 let n = sequence.len();
329 for i in (n.saturating_sub(5))..n {
330 for j in (i + 1)..n {
331 if self.distance(&sequence[i], &sequence[j]) > tolerance {
332 return false;
333 }
334 }
335 }
336 true
337 }
338
339 fn sequence_limit(
340 &self,
341 sequence: &[L2Function<P, Q, R>],
342 tolerance: f64,
343 ) -> Result<L2Function<P, Q, R>> {
344 if sequence.is_empty() {
345 return Err(FunctionalError::convergence_error(
346 0,
347 "Empty sequence has no limit",
348 ));
349 }
350
351 if !self.is_cauchy_sequence(sequence, tolerance) {
352 return Err(FunctionalError::convergence_error(
353 sequence.len(),
354 "Sequence is not Cauchy",
355 ));
356 }
357
358 Ok(sequence.last().unwrap().clone())
360 }
361}
362
363impl<const P: usize, const Q: usize, const R: usize>
364 HilbertSpace<L2Function<P, Q, R>, f64, Complete> for MultivectorL2<P, Q, R>
365{
366 fn riesz_representative<F>(&self, _functional: F) -> Result<L2Function<P, Q, R>>
367 where
368 F: Fn(&L2Function<P, Q, R>) -> f64,
369 {
370 Err(FunctionalError::not_complete(
373 "Riesz representative computation not implemented for infinite-dimensional L² spaces",
374 ))
375 }
376}
377
378#[cfg(test)]
379mod tests {
380 use super::*;
381
382 #[test]
383 fn test_l2_space_creation() {
384 let l2: MultivectorL2<2, 0, 0> = MultivectorL2::unit_interval();
385 assert_eq!(l2.signature(), (2, 0, 0));
386 assert_eq!(l2.domain().bounds_1d(), Some((0.0, 1.0)));
387 }
388
389 #[test]
390 fn test_l2_function_evaluation() {
391 let f = L2Function::<2, 0, 0>::new(|x| {
392 Multivector::<2, 0, 0>::scalar(x[0]) });
394
395 let result = f.eval(&[0.5]);
396 assert!((result.scalar_part() - 0.5).abs() < 1e-10);
397 }
398
399 #[test]
400 fn test_zero_function() {
401 let l2: MultivectorL2<2, 0, 0> = MultivectorL2::unit_interval();
402 let zero = l2.zero();
403
404 let norm = l2.norm(&zero);
405 assert!(norm < 1e-10);
406 }
407
408 #[test]
409 fn test_constant_function() {
410 let mv = Multivector::<2, 0, 0>::scalar(1.0);
411 let f = L2Function::constant(mv);
412
413 let l2: MultivectorL2<2, 0, 0> = MultivectorL2::unit_interval();
414
415 let norm = l2.norm(&f);
417 assert!((norm - 1.0).abs() < 0.01); }
419
420 #[test]
421 fn test_vector_space_operations() {
422 let l2: MultivectorL2<2, 0, 0> = MultivectorL2::unit_interval();
423
424 let f = L2Function::new(|x| Multivector::<2, 0, 0>::scalar(x[0]));
425 let g = L2Function::new(|x| Multivector::<2, 0, 0>::scalar(1.0 - x[0]));
426
427 let sum = l2.add(&f, &g);
429 let result = sum.eval(&[0.5]);
430 assert!((result.scalar_part() - 1.0).abs() < 1e-10);
431 }
432
433 #[test]
434 fn test_inner_product_orthogonality() {
435 let l2: MultivectorL2<2, 0, 0> = MultivectorL2::unit_interval().with_quadrature_points(64);
436
437 let f = L2Function::new(|x| {
439 let val = (std::f64::consts::PI * x[0]).sin();
440 Multivector::<2, 0, 0>::scalar(val)
441 });
442 let g = L2Function::new(|x| {
443 let val = (2.0 * std::f64::consts::PI * x[0]).sin();
444 Multivector::<2, 0, 0>::scalar(val)
445 });
446
447 let ip = l2.inner_product(&f, &g);
448 assert!(ip.abs() < 0.1); }
450
451 #[test]
452 fn test_l2_norm_squared() {
453 let l2: MultivectorL2<2, 0, 0> = MultivectorL2::unit_interval().with_quadrature_points(64);
454
455 let f = L2Function::new(|x| {
457 let val = (std::f64::consts::PI * x[0]).sin();
458 Multivector::<2, 0, 0>::scalar(val)
459 });
460
461 let ip = l2.inner_product(&f, &f);
462 assert!((ip - 0.5).abs() < 0.05); }
464
465 #[test]
466 fn test_scaling() {
467 let l2: MultivectorL2<2, 0, 0> = MultivectorL2::unit_interval();
468
469 let f = L2Function::new(|x| Multivector::<2, 0, 0>::scalar(x[0]));
470
471 let scaled = l2.scale(2.0, &f);
472 let result = scaled.eval(&[0.5]);
473 assert!((result.scalar_part() - 1.0).abs() < 1e-10);
474 }
475
476 #[test]
477 fn test_normalization() {
478 let l2: MultivectorL2<2, 0, 0> = MultivectorL2::unit_interval().with_quadrature_points(64);
479
480 let f = L2Function::new(|x| Multivector::<2, 0, 0>::scalar(x[0]));
481
482 let normalized = l2.normalize(&f).unwrap();
483 let norm = l2.norm(&normalized);
484 assert!((norm - 1.0).abs() < 0.05); }
486
487 #[test]
488 fn test_infinite_dimension() {
489 let l2: MultivectorL2<2, 0, 0> = MultivectorL2::unit_interval();
490 assert_eq!(l2.dimension(), None);
491 }
492}