Skip to main content

lie_groups/
sun.rs

1//! Generic SU(N) - Special unitary N×N matrices
2//!
3//! This module provides a compile-time generic implementation of SU(N) for arbitrary N.
4//! It elegantly generalizes SU(2) and SU(3) while maintaining type safety and efficiency.
5//!
6//! # Mathematical Structure
7//!
8//! ```text
9//! SU(N) = { U ∈ ℂᴺˣᴺ | U† U = I, det(U) = 1 }
10//! ```
11//!
12//! # Lie Algebra
13//!
14//! The Lie algebra su(N) consists of N×N traceless anti-Hermitian matrices:
15//! ```text
16//! su(N) = { X ∈ ℂᴺˣᴺ | X† = -X, Tr(X) = 0 }
17//! dim(su(N)) = N² - 1
18//! ```
19//!
20//! # Design Philosophy
21//!
22//! - **Type Safety**: Const generics ensure dimension errors are caught at compile time
23//! - **Efficiency**: Lazy matrix construction, SIMD-friendly operations
24//! - **Elegance**: Unified interface for all N (including N=2,3)
25//! - **Generality**: Works for arbitrary N ≥ 2
26//!
27//! # Examples
28//!
29//! ```ignore
30//! use lie_groups::sun::SunAlgebra;
31//! use lie_groups::LieAlgebra;
32//!
33//! // SU(4) for grand unified theories
34//! type Su4Algebra = SunAlgebra<4>;
35//! let x = Su4Algebra::zero();
36//! assert_eq!(Su4Algebra::dim(), 15);  // 4² - 1 = 15
37//!
38//! // Type safety: dimensions checked at compile time
39//! let su2 = SunAlgebra::<2>::basis_element(0);  // dim = 3
40//! let su3 = SunAlgebra::<3>::basis_element(0);  // dim = 8
41//! // su2.add(&su3);  // Compile error! Incompatible types
42//! ```
43//!
44//! # Physics Applications
45//!
46//! - **SU(2)**: Weak force, isospin
47//! - **SU(3)**: Strong force (QCD), color charge
48//! - **SU(4)**: Pati-Salam model, flavor symmetry
49//! - **SU(5)**: Georgi-Glashow GUT
50//! - **SU(6)**: Flavor SU(3) × color SU(2)
51//!
52//! # Performance
53//!
54//! - Algebra operations: O(N²) `[optimal]`
55//! - Matrix construction: O(N²) `[lazy, only when needed]`
56//! - Exponential map: O(N³) via scaling-and-squaring
57//! - Memory: (N²-1)·sizeof(f64) bytes for algebra
58
59use crate::traits::{
60    AntiHermitianByConstruction, Compact, LieAlgebra, LieGroup, SemiSimple, Simple,
61    TracelessByConstruction,
62};
63use ndarray::Array2;
64use num_complex::Complex64;
65use std::fmt;
66use std::marker::PhantomData;
67use std::ops::{Add, Mul, MulAssign, Neg, Sub};
68
69/// Lie algebra su(N) - (N²-1)-dimensional space of traceless anti-Hermitian matrices
70///
71/// # Type Parameter
72///
73/// - `N`: Matrix dimension (must be ≥ 2)
74///
75/// # Representation
76///
77/// Elements are stored as (N²-1) real coefficients corresponding to the generalized
78/// Gell-Mann basis. The basis is constructed systematically:
79///
80/// 1. **Symmetric generators** (N(N-1)/2 elements):
81///    - λᵢⱼ with i < j: has 1 at (i,j) and (j,i)
82///
83/// 2. **Antisymmetric generators** (N(N-1)/2 elements):
84///    - λᵢⱼ with i < j: has i at (i,j) and -i at (j,i)
85///
86/// 3. **Diagonal generators** (N-1 elements):
87///    - λₖ diagonal with first k entries = 1, (k+1)-th entry = -k
88///
89/// This generalizes the Pauli matrices (N=2) and Gell-Mann matrices (N=3).
90///
91/// # Mathematical Properties
92///
93/// - Hermitian generators: λⱼ† = λⱼ
94/// - Traceless: Tr(λⱼ) = 0
95/// - Normalized: Tr(λᵢλⱼ) = 2δᵢⱼ
96/// - Completeness: {λⱼ/√2} form orthonormal basis for traceless Hermitian matrices
97///
98/// # Memory Layout
99///
100/// For SU(N), we store (N²-1) f64 values in a heap-allocated Vec for N > 4,
101/// or stack-allocated array for N ≤ 4 (common cases).
102#[derive(Clone, Debug)]
103pub struct SunAlgebra<const N: usize> {
104    /// Coefficients in generalized Gell-Mann basis
105    /// Length: N² - 1
106    pub coefficients: Vec<f64>,
107    _phantom: PhantomData<[(); N]>,
108}
109
110impl<const N: usize> SunAlgebra<N> {
111    /// Dimension of su(N) algebra: N² - 1, valid only for N ≥ 2
112    const DIM: usize = {
113        assert!(
114            N >= 2,
115            "SU(N) requires N >= 2: SU(1) is trivial, SU(0) is undefined"
116        );
117        N * N - 1
118    };
119
120    /// Create new algebra element from coefficients
121    ///
122    /// # Panics
123    ///
124    /// Panics if `coefficients.len() != N² - 1`
125    #[must_use]
126    pub fn new(coefficients: Vec<f64>) -> Self {
127        assert_eq!(
128            coefficients.len(),
129            Self::DIM,
130            "SU({}) algebra requires {} coefficients, got {}",
131            N,
132            Self::DIM,
133            coefficients.len()
134        );
135        Self {
136            coefficients,
137            _phantom: PhantomData,
138        }
139    }
140
141    /// Convert to N×N anti-Hermitian matrix: X = i·∑ⱼ aⱼ·λⱼ
142    ///
143    /// This is the fundamental representation in ℂᴺˣᴺ.
144    ///
145    /// # Performance
146    ///
147    /// - Time: O(N²)
148    /// - Space: O(N²)
149    /// - Lazy: Only computed when called
150    ///
151    /// # Mathematical Formula
152    ///
153    /// Given coefficients [a₁, ..., a_{N²-1}], returns:
154    /// ```text
155    /// X = i·∑ⱼ aⱼ·λⱼ
156    /// ```
157    /// where λⱼ are the generalized Gell-Mann matrices.
158    #[must_use]
159    pub fn to_matrix(&self) -> Array2<Complex64> {
160        let mut matrix = Array2::zeros((N, N));
161        let i = Complex64::new(0.0, 1.0);
162
163        let mut idx = 0;
164
165        // Symmetric generators: (i,j) with i < j
166        for row in 0..N {
167            for col in (row + 1)..N {
168                let coeff = self.coefficients[idx];
169                matrix[[row, col]] += i * coeff;
170                matrix[[col, row]] += i * coeff;
171                idx += 1;
172            }
173        }
174
175        // Antisymmetric generators: (i,j) with i < j
176        for row in 0..N {
177            for col in (row + 1)..N {
178                let coeff = self.coefficients[idx];
179                matrix[[row, col]] += Complex64::new(-coeff, 0.0); // -coeff (real)
180                matrix[[col, row]] += Complex64::new(coeff, 0.0); // +coeff (real)
181                idx += 1;
182            }
183        }
184
185        // Diagonal generators
186        // The k-th diagonal generator (k=0..N-2) has:
187        // - First (k+1) diagonal entries = +1
188        // - (k+2)-th diagonal entry = -(k+1)
189        // - Normalized so Tr(λ²) = 2
190        for k in 0..(N - 1) {
191            let coeff = self.coefficients[idx];
192
193            // Normalization: √(2 / (k+1)(k+2))
194            let k_f = k as f64;
195            let normalization = 2.0 / ((k_f + 1.0) * (k_f + 2.0));
196            let scale = normalization.sqrt();
197
198            // First (k+1) entries: +1
199            for j in 0..=k {
200                matrix[[j, j]] += i * coeff * scale;
201            }
202            // (k+2)-th entry: -(k+1)
203            matrix[[k + 1, k + 1]] += i * coeff * scale * (-(k_f + 1.0));
204
205            idx += 1;
206        }
207
208        matrix
209    }
210
211    /// Construct algebra element from matrix
212    ///
213    /// Given X ∈ su(N), extract coefficients in Gell-Mann basis.
214    ///
215    /// # Performance
216    ///
217    /// O(N²) time via inner products with basis elements.
218    #[must_use]
219    pub fn from_matrix(matrix: &Array2<Complex64>) -> Self {
220        assert_eq!(matrix.nrows(), N);
221        assert_eq!(matrix.ncols(), N);
222
223        let mut coefficients = vec![0.0; Self::DIM];
224        let mut idx = 0;
225
226        // Extract symmetric components
227        for row in 0..N {
228            for col in (row + 1)..N {
229                // λ has 1 at (row,col) and (col,row)
230                // i·λ·a has i·a at those positions
231                // X = i·∑ aⱼ·λⱼ, so X[row,col] = i·a
232                let val = matrix[[row, col]];
233                coefficients[idx] = val.im; // Extract imaginary part
234                idx += 1;
235            }
236        }
237
238        // Extract antisymmetric components
239        for row in 0..N {
240            for col in (row + 1)..N {
241                // λ has i at (row,col) and -i at (col,row)
242                // i·λ·a = -a at (row,col), +a at (col,row)
243                let val = matrix[[row, col]];
244                coefficients[idx] = -val.re; // Extract real part, negate
245                idx += 1;
246            }
247        }
248
249        // Extract diagonal components using proper inner product
250        //
251        // The k-th diagonal generator H_k has:
252        //   - entries [[j,j]] = scale_k for j = 0..=k
253        //   - entry [[k+1, k+1]] = -(k+1) * scale_k
254        //   - normalized so Tr(H_k²) = 2
255        //
256        // To extract coefficient a_k, use: a_k = Im(Tr(X · H_k)) / 2
257        // where Tr(X · H_k) = Σ_j X[[j,j]] * H_k[[j,j]]
258        for k in 0..(N - 1) {
259            let k_f = k as f64;
260            let normalization = 2.0 / ((k_f + 1.0) * (k_f + 2.0));
261            let scale = normalization.sqrt();
262
263            // Compute inner product: Tr(X · H_k)
264            let mut trace_prod = Complex64::new(0.0, 0.0);
265
266            // Entries 0..=k contribute +scale
267            for j in 0..=k {
268                trace_prod += matrix[[j, j]] * scale;
269            }
270
271            // Entry k+1 contributes -(k+1)*scale
272            trace_prod += matrix[[k + 1, k + 1]] * (-(k_f + 1.0) * scale);
273
274            // a_k = Im(Tr(X · H_k)) / 2
275            // (The /2 comes from Tr(H_k²) = 2 normalization)
276            coefficients[idx] = trace_prod.im / 2.0;
277            idx += 1;
278        }
279
280        Self::new(coefficients)
281    }
282}
283
284impl<const N: usize> Add for SunAlgebra<N> {
285    type Output = Self;
286    fn add(self, rhs: Self) -> Self {
287        let coefficients = self
288            .coefficients
289            .iter()
290            .zip(&rhs.coefficients)
291            .map(|(a, b)| a + b)
292            .collect();
293        Self::new(coefficients)
294    }
295}
296
297impl<const N: usize> Add<&SunAlgebra<N>> for SunAlgebra<N> {
298    type Output = SunAlgebra<N>;
299    fn add(self, rhs: &SunAlgebra<N>) -> SunAlgebra<N> {
300        let coefficients = self
301            .coefficients
302            .iter()
303            .zip(&rhs.coefficients)
304            .map(|(a, b)| a + b)
305            .collect();
306        Self::new(coefficients)
307    }
308}
309
310impl<const N: usize> Add<SunAlgebra<N>> for &SunAlgebra<N> {
311    type Output = SunAlgebra<N>;
312    fn add(self, rhs: SunAlgebra<N>) -> SunAlgebra<N> {
313        let coefficients = self
314            .coefficients
315            .iter()
316            .zip(&rhs.coefficients)
317            .map(|(a, b)| a + b)
318            .collect();
319        SunAlgebra::<N>::new(coefficients)
320    }
321}
322
323impl<const N: usize> Add<&SunAlgebra<N>> for &SunAlgebra<N> {
324    type Output = SunAlgebra<N>;
325    fn add(self, rhs: &SunAlgebra<N>) -> SunAlgebra<N> {
326        let coefficients = self
327            .coefficients
328            .iter()
329            .zip(&rhs.coefficients)
330            .map(|(a, b)| a + b)
331            .collect();
332        SunAlgebra::<N>::new(coefficients)
333    }
334}
335
336impl<const N: usize> Sub for SunAlgebra<N> {
337    type Output = Self;
338    fn sub(self, rhs: Self) -> Self {
339        let coefficients = self
340            .coefficients
341            .iter()
342            .zip(&rhs.coefficients)
343            .map(|(a, b)| a - b)
344            .collect();
345        Self::new(coefficients)
346    }
347}
348
349impl<const N: usize> Neg for SunAlgebra<N> {
350    type Output = Self;
351    fn neg(self) -> Self {
352        let coefficients = self.coefficients.iter().map(|x| -x).collect();
353        Self::new(coefficients)
354    }
355}
356
357impl<const N: usize> Mul<f64> for SunAlgebra<N> {
358    type Output = Self;
359    fn mul(self, scalar: f64) -> Self {
360        let coefficients = self.coefficients.iter().map(|x| x * scalar).collect();
361        Self::new(coefficients)
362    }
363}
364
365impl<const N: usize> Mul<SunAlgebra<N>> for f64 {
366    type Output = SunAlgebra<N>;
367    fn mul(self, rhs: SunAlgebra<N>) -> SunAlgebra<N> {
368        rhs * self
369    }
370}
371
372impl<const N: usize> LieAlgebra for SunAlgebra<N> {
373    // Model-theoretic guard: SU(N) requires N ≥ 2.
374    // SU(1) = {I} is trivial (algebra is zero-dimensional, bracket undefined).
375    // SU(0) underflows usize arithmetic.
376    // This const assert promotes the degenerate-model failure from runtime panic
377    // to a compile-time error, consistent with the sealed-trait philosophy.
378    const DIM: usize = {
379        assert!(
380            N >= 2,
381            "SU(N) requires N >= 2: SU(1) is trivial, SU(0) is undefined"
382        );
383        N * N - 1
384    };
385
386    fn zero() -> Self {
387        Self {
388            coefficients: vec![0.0; Self::DIM],
389            _phantom: PhantomData,
390        }
391    }
392
393    fn add(&self, other: &Self) -> Self {
394        let coefficients = self
395            .coefficients
396            .iter()
397            .zip(&other.coefficients)
398            .map(|(a, b)| a + b)
399            .collect();
400        Self::new(coefficients)
401    }
402
403    fn scale(&self, scalar: f64) -> Self {
404        let coefficients = self.coefficients.iter().map(|x| x * scalar).collect();
405        Self::new(coefficients)
406    }
407
408    fn norm(&self) -> f64 {
409        self.coefficients
410            .iter()
411            .map(|x| x.powi(2))
412            .sum::<f64>()
413            .sqrt()
414    }
415
416    fn basis_element(i: usize) -> Self {
417        assert!(
418            i < Self::DIM,
419            "Basis index {} out of range for SU({})",
420            i,
421            N
422        );
423        let mut coefficients = vec![0.0; Self::DIM];
424        coefficients[i] = 1.0;
425        Self::new(coefficients)
426    }
427
428    fn from_components(components: &[f64]) -> Self {
429        assert_eq!(
430            components.len(),
431            Self::DIM,
432            "Expected {} components for SU({}), got {}",
433            Self::DIM,
434            N,
435            components.len()
436        );
437        Self::new(components.to_vec())
438    }
439
440    fn to_components(&self) -> Vec<f64> {
441        self.coefficients.clone()
442    }
443
444    /// Lie bracket: [X, Y] = XY - YX
445    ///
446    /// Computed via matrix commutator for generality.
447    ///
448    /// # Performance
449    ///
450    /// - Time: O(N³) [matrix multiplication]
451    /// - Space: O(N²)
452    ///
453    /// # Note
454    ///
455    /// For N=2,3, specialized implementations with structure constants
456    /// would be faster (O(1) and O(1) respectively). This generic version
457    /// prioritizes correctness and simplicity.
458    ///
459    /// # Mathematical Formula
460    ///
461    /// ```text
462    /// [X, Y] = XY - YX
463    /// ```
464    ///
465    /// This satisfies:
466    /// - Antisymmetry: `[X,Y] = -[Y,X]`
467    /// - Jacobi identity: `[X,[Y,Z]] + [Y,[Z,X]] + [Z,[X,Y]] = 0`
468    /// - Bilinearity
469    fn bracket(&self, other: &Self) -> Self {
470        let x = self.to_matrix();
471        let y = other.to_matrix();
472        let commutator = x.dot(&y) - y.dot(&x);
473        Self::from_matrix(&commutator)
474    }
475}
476
477/// SU(N) group element - N×N unitary matrix with det = 1
478///
479/// # Type Parameter
480///
481/// - `N`: Matrix dimension
482///
483/// # Representation
484///
485/// Stored as N×N complex matrix satisfying:
486/// - U†U = I (unitarity)
487/// - det(U) = 1 (special)
488///
489/// # Verification
490///
491/// Use `verify_unitarity()` to check constraints numerically.
492#[derive(Debug, Clone)]
493pub struct SUN<const N: usize> {
494    /// N×N complex unitary matrix
495    pub matrix: Array2<Complex64>,
496}
497
498impl<const N: usize> SUN<N> {
499    /// Identity element: Iₙ
500    #[must_use]
501    pub fn identity() -> Self {
502        Self {
503            matrix: Array2::eye(N),
504        }
505    }
506
507    /// Verify unitarity: ||U†U - I|| < ε
508    ///
509    /// # Arguments
510    ///
511    /// - `tolerance`: Maximum Frobenius norm deviation
512    ///
513    /// # Returns
514    ///
515    /// `true` if U†U ≈ I within tolerance
516    #[must_use]
517    pub fn verify_unitarity(&self, tolerance: f64) -> bool {
518        let adjoint = self.matrix.t().mapv(|z| z.conj());
519        let product = adjoint.dot(&self.matrix);
520        let identity: Array2<Complex64> = Array2::eye(N);
521
522        let diff = &product - &identity;
523        let norm_sq: f64 = diff.iter().map(num_complex::Complex::norm_sqr).sum();
524
525        norm_sq.sqrt() < tolerance
526    }
527
528    /// Compute determinant
529    ///
530    /// For SU(N), the determinant should be exactly 1 by definition.
531    ///
532    /// # Implementation
533    ///
534    /// - **N=2**: Direct formula `ad - bc`
535    /// - **N=3**: Sarrus' rule / cofactor expansion
536    /// - **N>3**: Returns `1.0` (assumes matrix is on SU(N) manifold)
537    ///
538    /// # Limitations
539    ///
540    /// For N > 3, this function **does not compute the actual determinant**.
541    /// It returns 1.0 under the assumption that matrices constructed via
542    /// `exp()` or `reorthogonalize()` remain on the SU(N) manifold.
543    ///
544    /// To verify unitarity for N > 3, use `verify_special_unitarity()` instead,
545    /// which checks `U†U = I` (a stronger condition than det=1).
546    ///
547    /// For actual determinant computation with N > 3, enable the `ndarray-linalg`
548    /// feature (not currently available) or compute via eigenvalue product.
549    #[must_use]
550    #[allow(clippy::many_single_char_names)] // Standard math notation for matrix elements
551    pub fn determinant(&self) -> Complex64 {
552        // For small N, compute directly using Leibniz formula
553        if N == 2 {
554            let a = self.matrix[[0, 0]];
555            let b = self.matrix[[0, 1]];
556            let c = self.matrix[[1, 0]];
557            let d = self.matrix[[1, 1]];
558            return a * d - b * c;
559        }
560
561        if N == 3 {
562            // 3x3 determinant via Sarrus' rule / cofactor expansion
563            let (a, b, c, d, e, f, g, h, i) = {
564                let m = &self.matrix;
565                (
566                    m[[0, 0]],
567                    m[[0, 1]],
568                    m[[0, 2]],
569                    m[[1, 0]],
570                    m[[1, 1]],
571                    m[[1, 2]],
572                    m[[2, 0]],
573                    m[[2, 1]],
574                    m[[2, 2]],
575                )
576            };
577
578            // det = a(ei - fh) - b(di - fg) + c(dh - eg)
579            return a * (e * i - f * h) - b * (d * i - f * g) + c * (d * h - e * g);
580        }
581
582        // For N > 3: LU decomposition would be ideal, but requires ndarray-linalg
583        // For now, return 1.0 since matrices constructed via exp() preserve det=1
584        // This is valid for elements on the SU(N) manifold
585        Complex64::new(1.0, 0.0)
586    }
587
588    /// Gram-Schmidt reorthogonalization for SU(N) matrices
589    ///
590    /// Projects a potentially corrupted matrix back onto the SU(N) manifold
591    /// using Gram-Schmidt orthogonalization followed by determinant correction.
592    ///
593    /// # Algorithm
594    ///
595    /// 1. Orthogonalize columns using Modified Gram-Schmidt (MGS)
596    /// 2. Normalize to ensure unitarity
597    /// 3. Adjust phase to ensure det(U) = 1
598    ///
599    /// This avoids the log-exp round-trip that would cause infinite recursion
600    /// when called from within `exp()`.
601    ///
602    /// # Numerical Stability
603    ///
604    /// Uses Modified Gram-Schmidt (not Classical GS) for better numerical stability.
605    /// Projections are computed against already-orthonormalized vectors and
606    /// subtracted immediately, providing O(ε) backward error vs O(κε) for CGS.
607    ///
608    /// Reference: Björck, "Numerical Methods for Least Squares Problems" (1996)
609    #[must_use]
610    fn gram_schmidt_project(matrix: Array2<Complex64>) -> Array2<Complex64> {
611        let mut result: Array2<Complex64> = Array2::zeros((N, N));
612
613        // Modified Gram-Schmidt on columns
614        for j in 0..N {
615            let mut col = matrix.column(j).to_owned();
616
617            // Subtract projections onto previous columns
618            for k in 0..j {
619                let prev_col = result.column(k);
620                let proj: Complex64 = prev_col
621                    .iter()
622                    .zip(col.iter())
623                    .map(|(p, c)| p.conj() * c)
624                    .sum();
625                for i in 0..N {
626                    col[i] -= proj * prev_col[i];
627                }
628            }
629
630            // Normalize
631            let norm: f64 = col
632                .iter()
633                .map(num_complex::Complex::norm_sqr)
634                .sum::<f64>()
635                .sqrt();
636
637            // Detect linear dependence: column became zero after orthogonalization
638            debug_assert!(
639                norm > 1e-14,
640                "Gram-Schmidt: column {} is linearly dependent (norm = {:.2e}). \
641                 Input matrix is rank-deficient.",
642                j,
643                norm
644            );
645
646            if norm > 1e-14 {
647                for i in 0..N {
648                    result[[i, j]] = col[i] / norm;
649                }
650            }
651            // Note: if norm ≤ 1e-14, column remains zero → det will be ~0 → identity fallback
652        }
653
654        // For N=2 or N=3, compute determinant and fix phase
655        // For larger N, approximate (Gram-Schmidt usually produces det ≈ 1 already)
656        if N <= 3 {
657            let det = if N == 2 {
658                result[[0, 0]] * result[[1, 1]] - result[[0, 1]] * result[[1, 0]]
659            } else {
660                // N=3
661                result[[0, 0]] * (result[[1, 1]] * result[[2, 2]] - result[[1, 2]] * result[[2, 1]])
662                    - result[[0, 1]]
663                        * (result[[1, 0]] * result[[2, 2]] - result[[1, 2]] * result[[2, 0]])
664                    + result[[0, 2]]
665                        * (result[[1, 0]] * result[[2, 1]] - result[[1, 1]] * result[[2, 0]])
666            };
667
668            // Guard against zero determinant (degenerate matrix)
669            let det_norm = det.norm();
670            if det_norm < 1e-14 {
671                // Matrix is degenerate; return identity as fallback
672                return Array2::eye(N);
673            }
674
675            let det_phase = det / det_norm;
676            let correction = (det_phase.conj()).powf(1.0 / N as f64);
677            result.mapv_inplace(|z| z * correction);
678        }
679
680        result
681    }
682
683    /// Distance from identity: ||U - I||_F
684    ///
685    /// Frobenius norm of difference from identity.
686    #[must_use]
687    pub fn distance_to_identity(&self) -> f64 {
688        let identity: Array2<Complex64> = Array2::eye(N);
689        let diff = &self.matrix - &identity;
690        diff.iter()
691            .map(num_complex::Complex::norm_sqr)
692            .sum::<f64>()
693            .sqrt()
694    }
695}
696
697impl<const N: usize> fmt::Display for SunAlgebra<N> {
698    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
699        write!(f, "su({})[", N)?;
700        for (i, c) in self.coefficients.iter().enumerate() {
701            if i > 0 {
702                write!(f, ", ")?;
703            }
704            write!(f, "{:.4}", c)?;
705        }
706        write!(f, "]")
707    }
708}
709
710impl<const N: usize> fmt::Display for SUN<N> {
711    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
712        let dist = self.distance_to_identity();
713        write!(f, "SU({})(d={:.4})", N, dist)
714    }
715}
716
717/// Group multiplication: U₁ · U₂
718impl<const N: usize> Mul<&SUN<N>> for &SUN<N> {
719    type Output = SUN<N>;
720    fn mul(self, rhs: &SUN<N>) -> SUN<N> {
721        SUN {
722            matrix: self.matrix.dot(&rhs.matrix),
723        }
724    }
725}
726
727impl<const N: usize> Mul<&SUN<N>> for SUN<N> {
728    type Output = SUN<N>;
729    fn mul(self, rhs: &SUN<N>) -> SUN<N> {
730        &self * rhs
731    }
732}
733
734impl<const N: usize> MulAssign<&SUN<N>> for SUN<N> {
735    fn mul_assign(&mut self, rhs: &SUN<N>) {
736        self.matrix = self.matrix.dot(&rhs.matrix);
737    }
738}
739
740impl<const N: usize> LieGroup for SUN<N> {
741    type Algebra = SunAlgebra<N>;
742    const DIM: usize = {
743        assert!(
744            N >= 2,
745            "SU(N) requires N >= 2: SU(1) is trivial, SU(0) is undefined"
746        );
747        N
748    };
749
750    fn identity() -> Self {
751        Self::identity()
752    }
753
754    fn compose(&self, other: &Self) -> Self {
755        Self {
756            matrix: self.matrix.dot(&other.matrix),
757        }
758    }
759
760    fn inverse(&self) -> Self {
761        // For unitary matrices: U⁻¹ = U†
762        Self {
763            matrix: self.matrix.t().mapv(|z| z.conj()),
764        }
765    }
766
767    fn adjoint(&self) -> Self {
768        Self {
769            matrix: self.matrix.t().mapv(|z| z.conj()),
770        }
771    }
772
773    /// Adjoint action: `Ad_g(X)` = gXg⁻¹
774    ///
775    /// # Mathematical Formula
776    ///
777    /// For g ∈ SU(N) and X ∈ su(N):
778    /// ```text
779    /// Ad_g(X) = gXg⁻¹
780    /// ```
781    ///
782    /// # Properties
783    ///
784    /// - Group homomorphism: Ad_{gh} = `Ad_g` ∘ `Ad_h`
785    /// - Preserves bracket: `Ad_g([X,Y])` = `[Ad_g(X), Ad_g(Y)]`
786    /// - Preserves norm (SU(N) is compact)
787    fn adjoint_action(&self, algebra_element: &Self::Algebra) -> Self::Algebra {
788        let x = algebra_element.to_matrix();
789        let g_inv = self.inverse();
790
791        // Compute gXg⁻¹
792        let result = self.matrix.dot(&x).dot(&g_inv.matrix);
793
794        SunAlgebra::from_matrix(&result)
795    }
796
797    fn distance_to_identity(&self) -> f64 {
798        self.distance_to_identity()
799    }
800
801    /// Exponential map: exp: su(N) → SU(N)
802    ///
803    /// # Algorithm: Scaling-and-Squaring
804    ///
805    /// For X ∈ su(N) with ||X|| large:
806    /// 1. Scale: X' = X / 2^k such that ||X'|| ≤ 0.5
807    /// 2. Taylor: exp(X') ≈ ∑_{n=0}^{15} X'^n / n!
808    /// 3. Square: exp(X) = [exp(X')]^{2^k}
809    ///
810    /// # Properties
811    ///
812    /// - Preserves unitarity: exp(X) ∈ SU(N) for X ∈ su(N)
813    /// - Preserves det = 1 (Tr(X) = 0 ⟹ det(exp(X)) = exp(Tr(X)) = 1)
814    /// - Numerically stable for all ||X||
815    ///
816    /// # Performance
817    ///
818    /// - Time: O(N³·log(||X||))
819    /// - Space: O(N²)
820    ///
821    /// # Accuracy
822    ///
823    /// - Relative error: ~10⁻¹⁴ (double precision)
824    /// - Unitarity preserved to ~10⁻¹²
825    fn exp(tangent: &Self::Algebra) -> Self {
826        let x = tangent.to_matrix();
827        let norm = matrix_frobenius_norm(&x);
828
829        // Determine scaling factor: k such that ||X/2^k|| ≤ 0.5
830        let k = if norm > 0.5 {
831            (norm / 0.5).log2().ceil() as u32
832        } else {
833            0
834        };
835
836        // Scale down
837        let scale_factor = 2_f64.powi(-(k as i32));
838        let x_scaled = x.mapv(|z| z * scale_factor);
839
840        // Taylor series: exp(X') ≈ ∑ X'^n / n!
841        let exp_scaled = matrix_exp_taylor(&x_scaled, 15);
842
843        // Square k times: exp(X) = [exp(X/2^k)]^{2^k}
844        //
845        // Tao priority: Reorthogonalize periodically to prevent numerical drift
846        // from the SU(N) manifold during repeated squaring operations
847        let mut result = exp_scaled;
848        for i in 0..k {
849            result = result.dot(&result);
850
851            // Reorthogonalize every 4 squarings to maintain manifold constraints
852            // (Prevents accumulation of floating-point errors)
853            if (i + 1) % 4 == 0 && i + 1 < k {
854                result = Self::gram_schmidt_project(result);
855            }
856        }
857
858        // Final reorthogonalization to ensure result is exactly on manifold
859        Self {
860            matrix: Self::gram_schmidt_project(result),
861        }
862    }
863
864    fn log(&self) -> crate::error::LogResult<Self::Algebra> {
865        use crate::error::LogError;
866
867        // Matrix logarithm for SU(N) using inverse scaling-squaring algorithm.
868        //
869        // Algorithm (Higham, "Functions of Matrices", Ch. 11):
870        // 1. Take square roots until ||U^{1/2^k} - I|| < 0.5
871        // 2. Use Taylor series for log(I + X) with ||X|| < 0.5 (fast convergence)
872        // 3. Scale back: log(U) = 2^k × log(U^{1/2^k})
873
874        let dist = self.distance_to_identity();
875        const MAX_DISTANCE: f64 = 2.0;
876
877        if dist > MAX_DISTANCE {
878            return Err(LogError::NotNearIdentity {
879                distance: dist,
880                threshold: MAX_DISTANCE,
881            });
882        }
883
884        if dist < 1e-14 {
885            return Ok(SunAlgebra::zero());
886        }
887
888        // Phase 1: Inverse scaling via matrix square roots
889        let identity_matrix: Array2<Complex64> = Array2::eye(N);
890        let mut current = self.matrix.clone();
891        let mut num_sqrts: u32 = 0;
892        const MAX_SQRTS: u32 = 32;
893        const TARGET_NORM: f64 = 0.5;
894
895        while num_sqrts < MAX_SQRTS {
896            let x_matrix = &current - &identity_matrix;
897            let x_norm = matrix_frobenius_norm(&x_matrix);
898            if x_norm < TARGET_NORM {
899                break;
900            }
901            current = matrix_sqrt_db(&current);
902            num_sqrts += 1;
903        }
904
905        // Phase 2: Taylor series for log(I + X) with ||X|| < 0.5
906        let x_matrix = &current - &identity_matrix;
907        let log_matrix = matrix_log_taylor(&x_matrix, 30);
908
909        // Phase 3: Scale back: log(U) = 2^k × log(U^{1/2^k})
910        let scale_factor = (1_u64 << num_sqrts) as f64;
911        let log_scaled = log_matrix.mapv(|z| z * scale_factor);
912
913        Ok(SunAlgebra::from_matrix(&log_scaled))
914    }
915
916    fn dim() -> usize {
917        N
918    }
919
920    fn trace(&self) -> Complex64 {
921        (0..N).map(|i| self.matrix[[i, i]]).sum()
922    }
923}
924
925/// Compute Frobenius norm of complex matrix: ||A||_F = √(Tr(A†A))
926fn matrix_frobenius_norm(matrix: &Array2<Complex64>) -> f64 {
927    matrix
928        .iter()
929        .map(num_complex::Complex::norm_sqr)
930        .sum::<f64>()
931        .sqrt()
932}
933
934/// Compute matrix exponential via Taylor series
935///
936/// exp(X) = I + X + X²/2! + X³/3! + ... + X^n/n!
937///
938/// Converges rapidly for ||X|| ≤ 0.5.
939fn matrix_exp_taylor(matrix: &Array2<Complex64>, terms: usize) -> Array2<Complex64> {
940    let n = matrix.nrows();
941    let mut result = Array2::eye(n); // I
942    let mut term = Array2::eye(n); // Current term: X^k / k!
943
944    for k in 1..=terms {
945        // term = term · X / k
946        term = term.dot(matrix).mapv(|z| z / (k as f64));
947        result += &term;
948    }
949
950    result
951}
952
953/// Compute matrix logarithm via Taylor series
954///
955/// log(I + X) = X - X²/2 + X³/3 - X⁴/4 + ... + (-1)^{n+1}·X^n/n
956///
957/// Converges for spectral radius ρ(X) < 1. For ||X||_F < 0.5, convergence
958/// is rapid (30 terms gives ~3e-11 truncation error).
959fn matrix_log_taylor(matrix: &Array2<Complex64>, terms: usize) -> Array2<Complex64> {
960    let mut result = matrix.clone(); // First term: X
961    let mut x_power = matrix.clone(); // Current power: X^k
962
963    for k in 2..=terms {
964        // x_power = X^k
965        x_power = x_power.dot(matrix);
966
967        // Coefficient: (-1)^{k+1} / k
968        let sign = if k % 2 == 0 { -1.0 } else { 1.0 };
969        let coefficient = sign / (k as f64);
970
971        // Add term to result
972        result = result + x_power.mapv(|z| z * coefficient);
973    }
974
975    result
976}
977
978/// Complex N×N matrix inverse via Gauss-Jordan elimination with partial pivoting.
979///
980/// Returns None if the matrix is singular (pivot < 1e-15).
981fn matrix_inverse(a: &Array2<Complex64>) -> Option<Array2<Complex64>> {
982    let n = a.nrows();
983    assert_eq!(n, a.ncols());
984
985    // Build augmented matrix [A | I]
986    let mut aug = Array2::<Complex64>::zeros((n, 2 * n));
987    for i in 0..n {
988        for j in 0..n {
989            aug[[i, j]] = a[[i, j]];
990        }
991        aug[[i, n + i]] = Complex64::new(1.0, 0.0);
992    }
993
994    for col in 0..n {
995        // Partial pivoting: find row with largest magnitude in this column
996        let mut max_norm = 0.0;
997        let mut max_row = col;
998        for row in col..n {
999            let norm = aug[[row, col]].norm();
1000            if norm > max_norm {
1001                max_norm = norm;
1002                max_row = row;
1003            }
1004        }
1005        if max_norm < 1e-15 {
1006            return None; // Singular
1007        }
1008
1009        // Swap rows
1010        if max_row != col {
1011            for j in 0..2 * n {
1012                let tmp = aug[[col, j]];
1013                aug[[col, j]] = aug[[max_row, j]];
1014                aug[[max_row, j]] = tmp;
1015            }
1016        }
1017
1018        // Scale pivot row
1019        let pivot = aug[[col, col]];
1020        for j in 0..2 * n {
1021            aug[[col, j]] /= pivot;
1022        }
1023
1024        // Eliminate column in all other rows
1025        for row in 0..n {
1026            if row != col {
1027                let factor = aug[[row, col]];
1028                // Read pivot row values first to avoid borrow conflict
1029                let pivot_row: Vec<Complex64> = (0..2 * n).map(|j| aug[[col, j]]).collect();
1030                for j in 0..2 * n {
1031                    aug[[row, j]] -= factor * pivot_row[j];
1032                }
1033            }
1034        }
1035    }
1036
1037    // Extract inverse from right half
1038    let mut result = Array2::zeros((n, n));
1039    for i in 0..n {
1040        for j in 0..n {
1041            result[[i, j]] = aug[[i, n + j]];
1042        }
1043    }
1044    Some(result)
1045}
1046
1047/// Matrix square root via Denman-Beavers iteration.
1048///
1049/// Computes U^{1/2} for a matrix U close to identity.
1050/// Uses the iteration Y_{k+1} = (Y_k + Z_k^{-1})/2, Z_{k+1} = (Z_k + Y_k^{-1})/2
1051/// which converges quadratically to U^{1/2}.
1052fn matrix_sqrt_db(u: &Array2<Complex64>) -> Array2<Complex64> {
1053    let n = u.nrows();
1054    let mut y = u.clone();
1055    let mut z = Array2::<Complex64>::eye(n);
1056
1057    const MAX_ITERS: usize = 20;
1058    const TOL: f64 = 1e-14;
1059
1060    for _ in 0..MAX_ITERS {
1061        let y_inv = matrix_inverse(&y).unwrap_or_else(|| y.t().mapv(|z| z.conj()));
1062        let z_inv = matrix_inverse(&z).unwrap_or_else(|| z.t().mapv(|z| z.conj()));
1063
1064        let y_new = (&y + &z_inv).mapv(|z| z * 0.5);
1065        let z_new = (&z + &y_inv).mapv(|z| z * 0.5);
1066
1067        let diff = matrix_frobenius_norm(&(&y_new - &y));
1068        y = y_new;
1069        z = z_new;
1070
1071        if diff < TOL {
1072            break;
1073        }
1074    }
1075
1076    y
1077}
1078
1079// ============================================================================
1080// Const Generic Specializations
1081// ============================================================================
1082
1083// ============================================================================
1084// Algebra Marker Traits
1085// ============================================================================
1086
1087/// SU(N) is compact for all N ≥ 2.
1088impl<const N: usize> Compact for SUN<N> {}
1089
1090/// SU(N) is simple for all N ≥ 2.
1091impl<const N: usize> Simple for SUN<N> {}
1092
1093/// SU(N) is semi-simple (implied by simple) for all N ≥ 2.
1094impl<const N: usize> SemiSimple for SUN<N> {}
1095
1096/// su(N) algebra elements are traceless by construction.
1097///
1098/// The representation `SunAlgebra<N>` stores N²-1 coefficients in a
1099/// generalized Gell-Mann basis. All generators are traceless by definition.
1100impl<const N: usize> TracelessByConstruction for SunAlgebra<N> {}
1101
1102/// su(N) algebra elements are anti-Hermitian by construction.
1103///
1104/// The representation uses i·λⱼ where λⱼ are Hermitian generators.
1105impl<const N: usize> AntiHermitianByConstruction for SunAlgebra<N> {}
1106
1107// ============================================================================
1108// Type Aliases
1109// ============================================================================
1110
1111/// Type alias for SU(2) via generic implementation
1112pub type SU2Generic = SUN<2>;
1113/// Type alias for SU(3) via generic implementation
1114pub type SU3Generic = SUN<3>;
1115/// Type alias for SU(4) - Pati-Salam model
1116pub type SU4 = SUN<4>;
1117/// Type alias for SU(5) - Georgi-Glashow GUT
1118pub type SU5 = SUN<5>;
1119
1120#[cfg(test)]
1121mod tests {
1122    use super::*;
1123    use approx::assert_relative_eq;
1124
1125    #[test]
1126    fn test_sun_algebra_dimensions() {
1127        assert_eq!(SunAlgebra::<2>::dim(), 3); // SU(2)
1128        assert_eq!(SunAlgebra::<3>::dim(), 8); // SU(3)
1129        assert_eq!(SunAlgebra::<4>::dim(), 15); // SU(4)
1130        assert_eq!(SunAlgebra::<5>::dim(), 24); // SU(5)
1131    }
1132
1133    #[test]
1134    fn test_sun_algebra_zero() {
1135        let zero = SunAlgebra::<3>::zero();
1136        assert_eq!(zero.coefficients.len(), 8);
1137        assert!(zero.coefficients.iter().all(|&x| x == 0.0));
1138    }
1139
1140    #[test]
1141    fn test_sun_algebra_add_scale() {
1142        let x = SunAlgebra::<2>::basis_element(0);
1143        let y = SunAlgebra::<2>::basis_element(1);
1144
1145        let sum = &x + &y;
1146        assert_eq!(sum.coefficients, vec![1.0, 1.0, 0.0]);
1147
1148        let scaled = x.scale(2.5);
1149        assert_eq!(scaled.coefficients, vec![2.5, 0.0, 0.0]);
1150    }
1151
1152    #[test]
1153    fn test_sun_identity() {
1154        let id = SUN::<3>::identity();
1155        assert!(id.verify_unitarity(1e-10));
1156        assert_relative_eq!(id.distance_to_identity(), 0.0, epsilon = 1e-10);
1157    }
1158
1159    #[test]
1160    fn test_sun_exponential_preserves_unitarity() {
1161        // Random algebra element
1162        let algebra =
1163            SunAlgebra::<3>::from_components(&[0.5, -0.3, 0.8, 0.2, -0.6, 0.4, 0.1, -0.2]);
1164        let g = SUN::<3>::exp(&algebra);
1165
1166        // Verify U†U = I
1167        assert!(
1168            g.verify_unitarity(1e-10),
1169            "Exponential should preserve unitarity"
1170        );
1171    }
1172
1173    #[test]
1174    fn test_sun_exp_identity() {
1175        let zero = SunAlgebra::<4>::zero();
1176        let g = SUN::<4>::exp(&zero);
1177        assert_relative_eq!(g.distance_to_identity(), 0.0, epsilon = 1e-10);
1178    }
1179
1180    #[test]
1181    fn test_sun_group_composition() {
1182        let g1 = SUN::<2>::exp(&SunAlgebra::<2>::basis_element(0).scale(0.5));
1183        let g2 = SUN::<2>::exp(&SunAlgebra::<2>::basis_element(1).scale(0.3));
1184
1185        let product = g1.compose(&g2);
1186
1187        assert!(product.verify_unitarity(1e-10));
1188    }
1189
1190    #[test]
1191    fn test_sun_inverse() {
1192        let algebra =
1193            SunAlgebra::<3>::from_components(&[0.2, 0.3, -0.1, 0.5, -0.2, 0.1, 0.4, -0.3]);
1194        let g = SUN::<3>::exp(&algebra);
1195        let g_inv = g.inverse();
1196
1197        let product = g.compose(&g_inv);
1198
1199        assert_relative_eq!(product.distance_to_identity(), 0.0, epsilon = 1e-9);
1200    }
1201
1202    #[test]
1203    fn test_sun_adjoint_action_preserves_norm() {
1204        let g = SUN::<3>::exp(&SunAlgebra::<3>::basis_element(0).scale(1.2));
1205        let x = SunAlgebra::<3>::basis_element(2).scale(0.5);
1206
1207        let ad_x = g.adjoint_action(&x);
1208
1209        // Adjoint action preserves norm for compact groups
1210        assert_relative_eq!(x.norm(), ad_x.norm(), epsilon = 1e-9);
1211    }
1212
1213    #[test]
1214    fn test_sun_exp_log_roundtrip() {
1215        // SU(3): exp then log should recover the original algebra element
1216        let x = SunAlgebra::<3>::from_components(&[0.1, -0.2, 0.15, 0.08, -0.12, 0.05, 0.1, -0.06]);
1217        let g = SUN::<3>::exp(&x);
1218        assert!(g.verify_unitarity(1e-10));
1219
1220        let x_back = SUN::<3>::log(&g).expect("log should succeed near identity");
1221        let diff_norm: f64 = x
1222            .coefficients
1223            .iter()
1224            .zip(x_back.coefficients.iter())
1225            .map(|(a, b)| (a - b).powi(2))
1226            .sum::<f64>()
1227            .sqrt();
1228
1229        assert!(
1230            diff_norm < 1e-8,
1231            "SU(3) exp/log roundtrip error: {:.2e}",
1232            diff_norm
1233        );
1234    }
1235
1236    #[test]
1237    fn test_sun_exp_log_roundtrip_su2() {
1238        // SU(2) via generic SUN: smaller algebra
1239        let x = SunAlgebra::<2>::from_components(&[0.3, -0.2, 0.4]);
1240        let g = SUN::<2>::exp(&x);
1241        assert!(g.verify_unitarity(1e-10));
1242
1243        let x_back = SUN::<2>::log(&g).expect("log should succeed");
1244        let diff_norm: f64 = x
1245            .coefficients
1246            .iter()
1247            .zip(x_back.coefficients.iter())
1248            .map(|(a, b)| (a - b).powi(2))
1249            .sum::<f64>()
1250            .sqrt();
1251
1252        assert!(
1253            diff_norm < 1e-8,
1254            "SU(2) exp/log roundtrip error: {:.2e}",
1255            diff_norm
1256        );
1257    }
1258
1259    #[test]
1260    fn test_sun_log_exp_roundtrip() {
1261        // Start from group element, log, then exp back
1262        let x = SunAlgebra::<3>::from_components(&[0.2, 0.3, -0.1, 0.5, -0.2, 0.1, 0.4, -0.3]);
1263        let g = SUN::<3>::exp(&x);
1264
1265        let log_g = SUN::<3>::log(&g).expect("log should succeed");
1266        let g_back = SUN::<3>::exp(&log_g);
1267
1268        assert_relative_eq!(
1269            g.distance_to_identity(),
1270            g_back.distance_to_identity(),
1271            epsilon = 1e-8
1272        );
1273
1274        // Check that g and g_back are close
1275        let product = g.compose(&g_back.inverse());
1276        assert_relative_eq!(product.distance_to_identity(), 0.0, epsilon = 1e-8);
1277    }
1278
1279    #[test]
1280    fn test_sun_jacobi_identity() {
1281        let x = SunAlgebra::<3>::basis_element(0);
1282        let y = SunAlgebra::<3>::basis_element(1);
1283        let z = SunAlgebra::<3>::basis_element(2);
1284
1285        // [X, [Y, Z]] + [Y, [Z, X]] + [Z, [X, Y]] = 0
1286        let t1 = x.bracket(&y.bracket(&z));
1287        let t2 = y.bracket(&z.bracket(&x));
1288        let t3 = z.bracket(&x.bracket(&y));
1289        let sum = t1.add(&t2).add(&t3);
1290
1291        assert!(
1292            sum.norm() < 1e-10,
1293            "Jacobi identity violated for SU(3): ||sum|| = {:.2e}",
1294            sum.norm()
1295        );
1296    }
1297
1298    #[test]
1299    fn test_sun_bracket_antisymmetry() {
1300        let x = SunAlgebra::<4>::basis_element(0);
1301        let y = SunAlgebra::<4>::basis_element(3);
1302
1303        let xy = x.bracket(&y);
1304        let yx = y.bracket(&x);
1305
1306        // [X,Y] = -[Y,X]
1307        for i in 0..SunAlgebra::<4>::dim() {
1308            assert_relative_eq!(xy.coefficients[i], -yx.coefficients[i], epsilon = 1e-10);
1309        }
1310    }
1311
1312    #[test]
1313    fn test_sun_bracket_bilinearity() {
1314        let x = SunAlgebra::<3>::basis_element(0);
1315        let y = SunAlgebra::<3>::basis_element(3);
1316        let z = SunAlgebra::<3>::basis_element(5);
1317        let alpha = 2.5;
1318
1319        // Left linearity: [αX + Y, Z] = α[X, Z] + [Y, Z]
1320        let lhs = x.scale(alpha).add(&y).bracket(&z);
1321        let rhs = x.bracket(&z).scale(alpha).add(&y.bracket(&z));
1322        for i in 0..SunAlgebra::<3>::dim() {
1323            assert!(
1324                (lhs.coefficients[i] - rhs.coefficients[i]).abs() < 1e-14,
1325                "Left linearity failed at component {}: {} vs {}",
1326                i,
1327                lhs.coefficients[i],
1328                rhs.coefficients[i]
1329            );
1330        }
1331
1332        // Right linearity: [Z, αX + Y] = α[Z, X] + [Z, Y]
1333        let lhs = z.bracket(&x.scale(alpha).add(&y));
1334        let rhs = z.bracket(&x).scale(alpha).add(&z.bracket(&y));
1335        for i in 0..SunAlgebra::<3>::dim() {
1336            assert!(
1337                (lhs.coefficients[i] - rhs.coefficients[i]).abs() < 1e-14,
1338                "Right linearity failed at component {}: {} vs {}",
1339                i,
1340                lhs.coefficients[i],
1341                rhs.coefficients[i]
1342            );
1343        }
1344    }
1345}