Skip to main content

numra_ode/
index_reduction.rs

1//! DAE index analysis and automatic index reduction.
2//!
3//! This module implements the Pantelides algorithm for structural index analysis
4//! of Differential-Algebraic Equation (DAE) systems, and automatic index reduction
5//! from higher-index (index-2, index-3) to index-1 via constraint differentiation.
6//!
7//! # Background
8//!
9//! A DAE in semi-explicit form is:
10//! ```text
11//! y' = f(t, y, z)     (differential equations)
12//! 0  = g(t, y, z)     (algebraic constraints)
13//! ```
14//!
15//! The **differential index** measures how many times the algebraic constraints
16//! must be differentiated before the system can be solved as an ODE. Index-1 DAEs
17//! can be solved directly by existing BDF/Radau solvers. Higher-index systems
18//! need reduction first.
19//!
20//! # Index Reduction Strategy
21//!
22//! 1. **Structural analysis** via the Pantelides algorithm determines which
23//!    equations need differentiation and how many times.
24//! 2. **Constraint differentiation** generates the time derivatives of algebraic
25//!    constraints, introducing new variables for the derivatives.
26//! 3. The result is an augmented index-1 system solvable by standard methods.
27//!
28//! # Example: Simple Pendulum (Index-3 → Index-1)
29//!
30//! ```text
31//! Original (index-3):
32//!   x'' = -λx           y'' = -λy - g
33//!   x² + y² = L²
34//!
35//! After reduction:
36//!   x' = vx             y' = vy
37//!   vx' = -λx           vy' = -λy - g
38//!   0 = x*vx + y*vy                      (differentiated once: velocity constraint)
39//!   0 = vx² + vy² + x*(-λx) + y*(-λy-g) - ... (differentiated twice, or use index-2 form)
40//! ```
41//!
42//! Author: Moussa Leblouba
43//! Date: 9 February 2026
44//! Modified: 2 May 2026
45
46use crate::OdeSystem;
47use numra_core::Scalar;
48
49// ============================================================================
50// Structural Analysis Types
51// ============================================================================
52
53/// Result of structural DAE index analysis.
54#[derive(Clone, Debug)]
55pub struct DaeIndexInfo {
56    /// Structural (Pantelides) index of the DAE system.
57    /// - 0: pure ODE (no algebraic equations)
58    /// - 1: index-1 DAE (algebraic constraints directly solvable)
59    /// - 2: index-2 DAE (constraints need 1 differentiation)
60    /// - 3: index-3 DAE (constraints need 2 differentiations)
61    pub structural_index: usize,
62
63    /// Number of hidden constraints discovered by the algorithm.
64    pub n_hidden_constraints: usize,
65
66    /// Schedule of differentiations: (equation_index, n_times_to_differentiate).
67    ///
68    /// Each entry says "equation `equation_index` must be differentiated
69    /// `n_times_to_differentiate` times to reduce the index."
70    pub differentiation_schedule: Vec<(usize, usize)>,
71
72    /// Assignment from the Pantelides algorithm: equation `i` is matched to variable `assign[i]`.
73    /// Only meaningful after a successful structural analysis.
74    pub assignment: Vec<Option<usize>>,
75
76    /// Number of differential variables in the original system.
77    pub n_diff: usize,
78
79    /// Number of algebraic variables in the original system.
80    pub n_alg: usize,
81}
82
83/// Incidence structure for a DAE system.
84///
85/// Describes which variables appear in which equations, partitioned into
86/// differential and algebraic variables.
87#[derive(Clone, Debug)]
88pub struct DaeStructure {
89    /// Number of differential variables (y)
90    pub n_diff: usize,
91    /// Number of algebraic variables (z)
92    pub n_alg: usize,
93    /// Number of differential equations
94    pub n_diff_eqs: usize,
95    /// Number of algebraic equations (constraints)
96    pub n_alg_eqs: usize,
97    /// Incidence: for each equation i, the set of variable indices it depends on.
98    /// Variable indices 0..n_diff are differential, n_diff..n_diff+n_alg are algebraic.
99    pub incidence: Vec<Vec<usize>>,
100}
101
102impl DaeStructure {
103    /// Total number of variables.
104    pub fn n_vars(&self) -> usize {
105        self.n_diff + self.n_alg
106    }
107
108    /// Total number of equations.
109    pub fn n_eqs(&self) -> usize {
110        self.n_diff_eqs + self.n_alg_eqs
111    }
112}
113
114// ============================================================================
115// Pantelides Algorithm (Structural Index Analysis)
116// ============================================================================
117
118/// Analyze the structural index of a DAE system.
119///
120/// Uses a constraint-based approach for semi-explicit DAEs:
121/// - Differential equations `y_i' = f_i(t, y, z)` are "pre-matched" to their variables y_i
122/// - Algebraic constraints `0 = g_j(t, y, z)` must be matched to algebraic variables z_j
123/// - If a constraint depends only on differential variables (no algebraic ones),
124///   it cannot be matched and must be differentiated → this signals index > 1
125///
126/// # Arguments
127///
128/// * `structure` - The incidence structure of the DAE
129///
130/// # Returns
131///
132/// A `DaeIndexInfo` with the structural index and differentiation schedule.
133pub fn analyze_dae_index(structure: &DaeStructure) -> DaeIndexInfo {
134    let n_eqs = structure.n_eqs();
135    let _n_vars = structure.n_vars();
136    let n_diff = structure.n_diff;
137    let n_alg = structure.n_alg;
138    let n_diff_eqs = structure.n_diff_eqs;
139    let n_alg_eqs = structure.n_alg_eqs;
140
141    // If no algebraic equations, it's a pure ODE (index 0)
142    if n_alg_eqs == 0 {
143        return DaeIndexInfo {
144            structural_index: 0,
145            n_hidden_constraints: 0,
146            differentiation_schedule: Vec::new(),
147            assignment: vec![None; n_eqs],
148            n_diff,
149            n_alg,
150        };
151    }
152
153    // For semi-explicit DAEs, the index is determined by whether the algebraic
154    // constraints can be "resolved" for the algebraic variables.
155    //
156    // The key criterion: for each algebraic constraint g_j(t, y, z) = 0,
157    // check if it depends on at least one algebraic variable z_k that is not
158    // yet "claimed" by another constraint.
159    //
160    // If a constraint depends only on differential variables (no algebraic ones),
161    // it's a "hidden constraint" that must be differentiated (index >= 2).
162
163    // Algebraic variable indices are n_diff..n_diff+n_alg
164    let alg_var_start = n_diff;
165
166    // Build bipartite matching: algebraic equations ↔ algebraic variables
167    // This is a restricted matching (only algebraic vars are matchable targets)
168    let mut alg_incidence: Vec<Vec<usize>> = Vec::new();
169    for alg_eq_idx in 0..n_alg_eqs {
170        let eq_idx = n_diff_eqs + alg_eq_idx;
171        let vars = if eq_idx < structure.incidence.len() {
172            &structure.incidence[eq_idx]
173        } else {
174            continue;
175        };
176
177        // Filter to only algebraic variables (remap to 0..n_alg)
178        let alg_vars: Vec<usize> = vars
179            .iter()
180            .filter(|&&v| v >= alg_var_start && v < alg_var_start + n_alg)
181            .map(|&v| v - alg_var_start)
182            .collect();
183        alg_incidence.push(alg_vars);
184    }
185
186    // Try to find a maximum matching in the algebraic bipartite graph
187    let mut eq_to_var: Vec<Option<usize>> = vec![None; n_alg_eqs];
188    let mut var_to_eq: Vec<Option<usize>> = vec![None; n_alg];
189
190    for eq in 0..n_alg_eqs {
191        let mut visited = vec![false; n_alg];
192        augmenting_path_restricted(
193            eq,
194            &alg_incidence,
195            &mut eq_to_var,
196            &mut var_to_eq,
197            &mut visited,
198        );
199    }
200
201    // Count unmatched algebraic equations
202    let mut diff_schedule: Vec<(usize, usize)> = Vec::new();
203    let mut max_differentiations = 0usize;
204
205    for alg_eq_idx in 0..n_alg_eqs {
206        if eq_to_var[alg_eq_idx].is_none() {
207            // This constraint could not be matched to any algebraic variable.
208            // It must be differentiated (at least once).
209            // For the structural index, each differentiation needed adds 1 to the index.
210            diff_schedule.push((n_diff_eqs + alg_eq_idx, 1));
211            max_differentiations = max_differentiations.max(1);
212        }
213    }
214
215    // If any constraints needed differentiation, check if a second round is needed.
216    // For index-3: differentiated constraints might also fail to match.
217    // We do a simple iterative check: after differentiating, do the new constraints match?
218    if !diff_schedule.is_empty() {
219        // After differentiating, the new constraints depend on y' variables
220        // (which are "known" from the differential equations) plus potentially
221        // the algebraic variables. If the original constraint depended on NO
222        // algebraic variables, its derivative likely depends on y' (matched by
223        // diff eqs) but may also introduce new dependencies.
224        //
225        // For a simple index-2 detection, one round of differentiation suffices.
226        // For index-3, we'd need to check the differentiated constraints too.
227        // We handle up to index-3 with a second pass.
228
229        // Augment: for each differentiated constraint, create a new "virtual" equation
230        // that depends on the original differential variables' derivatives (new vars)
231        // plus the original algebraic variables.
232        let n_new_eqs = diff_schedule.len();
233        let mut aug_alg_incidence: Vec<Vec<usize>> = alg_incidence.clone();
234
235        // New "algebraic variables" for derivatives of diff vars appearing in constraints
236        let mut n_aug_alg = n_alg;
237        for &(orig_eq_idx, _) in &diff_schedule {
238            // The differentiated constraint introduces dependencies on y'_j for each
239            // differential variable y_j in the original constraint.
240            // These y'_j are "known" from the differential equations, so the
241            // differentiated constraint effectively gains new algebraic dependencies
242            // (derivative vars become new algebraic vars in the augmented system).
243            let orig_vars = if orig_eq_idx < structure.incidence.len() {
244                &structure.incidence[orig_eq_idx]
245            } else {
246                continue;
247            };
248
249            let mut new_alg_vars: Vec<usize> = Vec::new();
250            for &v in orig_vars {
251                if v < n_diff {
252                    // y'_v is a new variable
253                    new_alg_vars.push(n_aug_alg);
254                    n_aug_alg += 1;
255                }
256            }
257
258            // Original algebraic dependencies still present
259            let orig_alg: Vec<usize> = orig_vars
260                .iter()
261                .filter(|&&v| v >= alg_var_start && v < alg_var_start + n_alg)
262                .map(|&v| v - alg_var_start)
263                .collect();
264
265            let mut combined = orig_alg;
266            combined.extend(new_alg_vars);
267            aug_alg_incidence.push(combined);
268        }
269
270        // Re-run matching on augmented system
271        let total_alg_eqs = n_alg_eqs + n_new_eqs;
272        let mut eq_to_var2: Vec<Option<usize>> = vec![None; total_alg_eqs];
273        let mut var_to_eq2: Vec<Option<usize>> = vec![None; n_aug_alg];
274
275        for eq in 0..total_alg_eqs {
276            let mut visited = vec![false; n_aug_alg];
277            augmenting_path_restricted(
278                eq,
279                &aug_alg_incidence,
280                &mut eq_to_var2,
281                &mut var_to_eq2,
282                &mut visited,
283            );
284        }
285
286        // Check if any of the NEW equations still couldn't match
287        let mut second_round_unmatched = 0;
288        for new_eq in n_alg_eqs..total_alg_eqs {
289            if eq_to_var2[new_eq].is_none() {
290                second_round_unmatched += 1;
291            }
292        }
293
294        if second_round_unmatched > 0 {
295            // Need second differentiation → index 3
296            for entry in &mut diff_schedule {
297                entry.1 += 1; // Each needs one more differentiation
298            }
299            max_differentiations += 1;
300        }
301    }
302
303    let structural_index = if diff_schedule.is_empty() {
304        1 // Index-1: all constraints matched to algebraic variables
305    } else {
306        max_differentiations + 1 // Index = 1 + max differentiations needed
307    };
308
309    let n_hidden = diff_schedule.iter().map(|&(_, n)| n).sum::<usize>();
310
311    // Build full assignment
312    let mut assignment: Vec<Option<usize>> = vec![None; n_eqs];
313    // Differential equations matched to their own variables
314    for i in 0..n_diff_eqs.min(n_diff) {
315        assignment[i] = Some(i);
316    }
317    // Algebraic equations matched to algebraic variables
318    for (alg_eq_idx, &matched_var) in eq_to_var.iter().enumerate() {
319        if let Some(v) = matched_var {
320            assignment[n_diff_eqs + alg_eq_idx] = Some(alg_var_start + v);
321        }
322    }
323
324    DaeIndexInfo {
325        structural_index,
326        n_hidden_constraints: n_hidden,
327        differentiation_schedule: diff_schedule,
328        assignment,
329        n_diff,
330        n_alg,
331    }
332}
333
334/// Find an augmenting path in the restricted bipartite graph (algebraic eqs ↔ algebraic vars).
335fn augmenting_path_restricted(
336    eq: usize,
337    incidence: &[Vec<usize>],
338    eq_to_var: &mut [Option<usize>],
339    var_to_eq: &mut [Option<usize>],
340    visited: &mut [bool],
341) -> bool {
342    if eq >= incidence.len() {
343        return false;
344    }
345
346    for &var in &incidence[eq] {
347        if var >= visited.len() || visited[var] {
348            continue;
349        }
350        visited[var] = true;
351
352        let can_reassign = match var_to_eq[var] {
353            None => true,
354            Some(other_eq) => {
355                augmenting_path_restricted(other_eq, incidence, eq_to_var, var_to_eq, visited)
356            }
357        };
358
359        if can_reassign {
360            eq_to_var[eq] = Some(var);
361            var_to_eq[var] = Some(eq);
362            return true;
363        }
364    }
365
366    false
367}
368
369// ============================================================================
370// Automatic Structure Detection
371// ============================================================================
372
373/// Detect the incidence structure of a DAE system automatically using finite differences.
374///
375/// Probes the RHS function to determine which variables appear in which equations.
376///
377/// # Arguments
378///
379/// * `system` - The DAE system (must report `is_singular_mass()` and `algebraic_indices()`)
380/// * `t0` - Time point for probing
381/// * `y0` - State point for probing
382///
383/// # Returns
384///
385/// A `DaeStructure` describing the equation-variable dependencies.
386pub fn detect_structure<S, Sys>(system: &Sys, t0: S, y0: &[S]) -> DaeStructure
387where
388    S: Scalar,
389    Sys: OdeSystem<S>,
390{
391    let n = system.dim();
392    let alg_indices = system.algebraic_indices();
393    let n_alg = alg_indices.len();
394    let n_diff = n - n_alg;
395
396    // Map variable index to whether it's algebraic
397    let is_algebraic: Vec<bool> = (0..n).map(|i| alg_indices.contains(&i)).collect();
398
399    // Differential equation indices
400    let diff_eq_indices: Vec<usize> = (0..n).filter(|i| !is_algebraic[*i]).collect();
401    let alg_eq_indices: Vec<usize> = (0..n).filter(|i| is_algebraic[*i]).collect();
402
403    let n_diff_eqs = diff_eq_indices.len();
404    let n_alg_eqs = alg_eq_indices.len();
405
406    // Probe RHS at y0
407    let eps = S::from_f64(1e-7);
408    let mut f0 = vec![S::ZERO; n];
409    system.rhs(t0, y0, &mut f0);
410
411    // For each equation, determine which variables it depends on
412    let mut incidence = vec![Vec::new(); n_diff_eqs + n_alg_eqs];
413    let mut y_pert = y0.to_vec();
414
415    for j in 0..n {
416        let yj_save = y_pert[j];
417        let h = eps * (S::ONE + yj_save.abs());
418        y_pert[j] = yj_save + h;
419
420        let mut f1 = vec![S::ZERO; n];
421        system.rhs(t0, &y_pert, &mut f1);
422        y_pert[j] = yj_save;
423
424        // Check which equations changed
425        let threshold = S::from_f64(1e-12);
426
427        // Check differential equations
428        for (local_idx, &eq_idx) in diff_eq_indices.iter().enumerate() {
429            let diff = (f1[eq_idx] - f0[eq_idx]).abs();
430            if diff > threshold * h {
431                incidence[local_idx].push(j);
432            }
433        }
434
435        // Check algebraic equations
436        for (local_idx, &eq_idx) in alg_eq_indices.iter().enumerate() {
437            let diff = (f1[eq_idx] - f0[eq_idx]).abs();
438            if diff > threshold * h {
439                incidence[n_diff_eqs + local_idx].push(j);
440            }
441        }
442    }
443
444    DaeStructure {
445        n_diff,
446        n_alg,
447        n_diff_eqs,
448        n_alg_eqs,
449        incidence,
450    }
451}
452
453// ============================================================================
454// Index Reduction via Constraint Differentiation
455// ============================================================================
456
457/// Type alias for boxed RHS function.
458type RhsFn<S> = Box<dyn Fn(S, &[S], &mut [S]) + Send + Sync>;
459/// Type alias for boxed mass matrix function.
460type MassFn<S> = Box<dyn Fn(&mut [S]) + Send + Sync>;
461
462/// A reduced (index-1) DAE system produced by differentiating constraints.
463///
464/// The reduced system augments the original with:
465/// - New variables for the time derivatives of algebraic variables
466/// - Differentiated constraint equations
467///
468/// It implements `OdeSystem<S>` so it can be used directly with BDF/Radau solvers.
469impl<S: Scalar> core::fmt::Debug for ReducedDaeSystem<S> {
470    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
471        f.debug_struct("ReducedDaeSystem")
472            .field("n_orig", &self.n_orig)
473            .field("n_diff", &self.n_diff)
474            .field("aug_dim", &self.aug_dim)
475            .field("n_new_vars", &self.n_new_vars)
476            .field("info", &self.info)
477            .finish()
478    }
479}
480
481pub struct ReducedDaeSystem<S: Scalar> {
482    /// Original system dimension
483    n_orig: usize,
484    /// Number of original differential equations
485    n_diff: usize,
486    /// Indices of algebraic equations in the original system
487    alg_eq_indices: Vec<usize>,
488    /// Indices of differential equations in the original system
489    diff_eq_indices: Vec<usize>,
490    /// The original RHS function (boxed for type erasure)
491    rhs_fn: RhsFn<S>,
492    /// The original mass matrix function (fills row-major)
493    mass_fn: MassFn<S>,
494    /// Differentiation info
495    info: DaeIndexInfo,
496    /// Augmented dimension (original + new derivative variables)
497    aug_dim: usize,
498    /// Number of new variables added by differentiation
499    n_new_vars: usize,
500    /// Finite difference epsilon for constraint differentiation
501    fd_eps: S,
502}
503
504impl<S: Scalar> ReducedDaeSystem<S> {
505    /// Get the differentiation info.
506    pub fn info(&self) -> &DaeIndexInfo {
507        &self.info
508    }
509
510    /// Get the number of differential variables in the original system.
511    pub fn n_diff(&self) -> usize {
512        self.n_diff
513    }
514
515    /// Get the augmented system dimension.
516    pub fn augmented_dim(&self) -> usize {
517        self.aug_dim
518    }
519
520    /// Get the original system dimension.
521    pub fn original_dim(&self) -> usize {
522        self.n_orig
523    }
524
525    /// Extract the original state variables from an augmented state vector.
526    pub fn extract_original(&self, y_aug: &[S]) -> Vec<S> {
527        y_aug[..self.n_orig].to_vec()
528    }
529
530    /// Build augmented initial conditions.
531    ///
532    /// The new variables (time derivatives of algebraic variables) are initialized
533    /// by evaluating the constraints' time derivatives at the initial point.
534    pub fn augment_initial_conditions(&self, t0: S, y0: &[S]) -> Vec<S> {
535        assert_eq!(y0.len(), self.n_orig, "y0 must have original dimension");
536
537        let mut y_aug = vec![S::ZERO; self.aug_dim];
538        // Copy original variables
539        y_aug[..self.n_orig].copy_from_slice(y0);
540
541        // Estimate initial values for new derivative variables using FD
542        // For each algebraic variable z_i, the new variable is dz_i/dt
543        // We approximate this from the original RHS
544        let mut f0 = vec![S::ZERO; self.n_orig];
545        (self.rhs_fn)(t0, y0, &mut f0);
546
547        // The new variables correspond to the time derivatives of algebraic variables
548        // For index-2 reduction: new var = dz/dt, estimated from the RHS
549        for (k, &alg_eq) in self.alg_eq_indices.iter().enumerate() {
550            if self.n_orig + k < self.aug_dim {
551                // Use the residual rate of change as initial estimate
552                // For a consistent index-1 system, the algebraic residual should be ~0
553                // and its time derivative gives the new variable
554                y_aug[self.n_orig + k] = f0[alg_eq];
555            }
556        }
557
558        y_aug
559    }
560
561    /// Evaluate the time derivative of an algebraic constraint using FD.
562    ///
563    /// For constraint g(t, y) = 0, computes dg/dt = ∂g/∂t + (∂g/∂y) * y'
564    fn differentiate_constraint(&self, t: S, y: &[S], eq_idx: usize, dydt_diff: &[S]) -> S {
565        let n = self.n_orig;
566        let eps = self.fd_eps;
567
568        // Compute g(t, y)
569        let mut f0 = vec![S::ZERO; n];
570        (self.rhs_fn)(t, y, &mut f0);
571        let g0 = f0[eq_idx];
572
573        // ∂g/∂t (explicit time dependence)
574        let mut f_tp = vec![S::ZERO; n];
575        let ht = eps * (S::ONE + t.abs());
576        (self.rhs_fn)(t + ht, y, &mut f_tp);
577        let dgdt = (f_tp[eq_idx] - g0) / ht;
578
579        // ∂g/∂y_j * dy_j/dt for all variables j
580        let mut dgdy_dot = S::ZERO;
581        let mut y_pert = y.to_vec();
582
583        for j in 0..n {
584            let yj_save = y_pert[j];
585            let h = eps * (S::ONE + yj_save.abs());
586            y_pert[j] = yj_save + h;
587
588            let mut f1 = vec![S::ZERO; n];
589            (self.rhs_fn)(t, &y_pert, &mut f1);
590            y_pert[j] = yj_save;
591
592            let dgdyj = (f1[eq_idx] - g0) / h;
593            dgdy_dot = dgdy_dot + dgdyj * dydt_diff[j];
594        }
595
596        dgdt + dgdy_dot
597    }
598}
599
600impl<S: Scalar> OdeSystem<S> for ReducedDaeSystem<S> {
601    fn dim(&self) -> usize {
602        self.aug_dim
603    }
604
605    fn rhs(&self, t: S, y: &[S], dydt: &mut [S]) {
606        let n = self.n_orig;
607
608        // Evaluate original RHS for the original variables
609        let y_orig = &y[..n];
610        let mut f_orig = vec![S::ZERO; n];
611        (self.rhs_fn)(t, y_orig, &mut f_orig);
612
613        // Copy differential equation RHS
614        for &i in &self.diff_eq_indices {
615            dydt[i] = f_orig[i];
616        }
617
618        // For algebraic equations: keep the original constraint as residual
619        // (these become 0 = g(t, y) in the mass matrix form)
620        for &i in &self.alg_eq_indices {
621            dydt[i] = f_orig[i];
622        }
623
624        // For the new variables (differentiated constraints):
625        // The differentiated constraint dg/dt = 0 becomes a new algebraic equation.
626        // We need the current y' to compute dg/dt.
627        // Use the original f values as the current y' estimate.
628        for (k, &(eq_idx, n_diffs)) in self.info.differentiation_schedule.iter().enumerate() {
629            if k >= self.n_new_vars {
630                break;
631            }
632
633            // For the first differentiation: dg/dt = ∂g/∂t + Σ (∂g/∂y_j) * y_j'
634            // The differentiated constraint = 0 is a new algebraic equation
635            // whose residual we place in the augmented slot
636            let new_var_idx = n + k;
637            if new_var_idx < self.aug_dim && n_diffs >= 1 {
638                // Compute dg/dt using the current state
639                let dg = self.differentiate_constraint(t, y_orig, eq_idx, &f_orig);
640                dydt[new_var_idx] = dg;
641            }
642        }
643    }
644
645    fn has_mass_matrix(&self) -> bool {
646        true
647    }
648
649    fn mass_matrix(&self, mass: &mut [S]) {
650        let aug = self.aug_dim;
651        // Zero out
652        for i in 0..aug * aug {
653            mass[i] = S::ZERO;
654        }
655
656        // Fill original mass matrix block
657        let n = self.n_orig;
658        let mut orig_mass = vec![S::ZERO; n * n];
659        (self.mass_fn)(&mut orig_mass);
660
661        for i in 0..n {
662            for j in 0..n {
663                mass[i * aug + j] = orig_mass[i * n + j];
664            }
665        }
666
667        // New equations are algebraic (mass = 0 on their rows)
668        // They are already zero from initialization
669        // No need to set M[new_row, new_row] = 0, it's already 0
670    }
671
672    fn is_singular_mass(&self) -> bool {
673        true
674    }
675
676    fn algebraic_indices(&self) -> Vec<usize> {
677        let mut indices = Vec::new();
678
679        // Original algebraic indices
680        for &i in &self.alg_eq_indices {
681            indices.push(i);
682        }
683
684        // New differentiated constraint equations are also algebraic
685        for k in 0..self.n_new_vars {
686            indices.push(self.n_orig + k);
687        }
688
689        indices
690    }
691}
692
693// ============================================================================
694// Index Reduction Entry Points
695// ============================================================================
696
697/// Reduce a higher-index DAE to index-1 by differentiating constraints.
698///
699/// This function:
700/// 1. Detects the DAE structure via FD probing
701/// 2. Analyzes the structural index via the Pantelides algorithm
702/// 3. If index > 1, builds a reduced system with differentiated constraints
703///
704/// # Arguments
705///
706/// * `rhs_fn` - The RHS function f(t, y, dydt)
707/// * `mass_fn` - The mass matrix function M(mass_out)
708/// * `alg_indices` - Indices of algebraic equations
709/// * `n` - System dimension
710/// * `t0` - Initial time for structure probing
711/// * `y0` - Initial state for structure probing
712///
713/// # Returns
714///
715/// A `ReducedDaeSystem` if the index is > 1, or an error if reduction fails.
716pub fn reduce_index<S, F, M>(
717    rhs_fn: F,
718    mass_fn: M,
719    alg_indices: &[usize],
720    n: usize,
721    t0: S,
722    y0: &[S],
723) -> Result<ReducedDaeSystem<S>, String>
724where
725    S: Scalar,
726    F: Fn(S, &[S], &mut [S]) + Send + Sync + 'static,
727    M: Fn(&mut [S]) + Send + Sync + 'static,
728{
729    // Build structure info
730    let n_alg = alg_indices.len();
731    let n_diff = n - n_alg;
732
733    let is_algebraic: Vec<bool> = (0..n).map(|i| alg_indices.contains(&i)).collect();
734    let diff_eq_indices: Vec<usize> = (0..n).filter(|i| !is_algebraic[*i]).collect();
735    let alg_eq_indices: Vec<usize> = alg_indices.to_vec();
736
737    // Detect incidence structure
738    let structure = detect_structure_from_fn(&rhs_fn, n, &diff_eq_indices, &alg_eq_indices, t0, y0);
739
740    // Analyze index
741    let info = analyze_dae_index(&structure);
742
743    if info.structural_index <= 1 {
744        return Err("System is already index-1 or index-0; no reduction needed".to_string());
745    }
746
747    // Count new variables needed
748    let n_new_vars: usize = info
749        .differentiation_schedule
750        .iter()
751        .map(|&(_, nd)| nd)
752        .sum();
753
754    let aug_dim = n + n_new_vars;
755
756    Ok(ReducedDaeSystem {
757        n_orig: n,
758        n_diff,
759        alg_eq_indices,
760        diff_eq_indices,
761        rhs_fn: Box::new(rhs_fn),
762        mass_fn: Box::new(mass_fn),
763        info,
764        aug_dim,
765        n_new_vars,
766        fd_eps: S::from_f64(1e-7),
767    })
768}
769
770/// Detect structure from a bare function (not wrapped in OdeSystem).
771fn detect_structure_from_fn<S, F>(
772    rhs_fn: &F,
773    n: usize,
774    diff_eq_indices: &[usize],
775    alg_eq_indices: &[usize],
776    t0: S,
777    y0: &[S],
778) -> DaeStructure
779where
780    S: Scalar,
781    F: Fn(S, &[S], &mut [S]),
782{
783    let n_diff_eqs = diff_eq_indices.len();
784    let n_alg_eqs = alg_eq_indices.len();
785    let n_diff = n - n_alg_eqs;
786    let n_alg = n_alg_eqs;
787
788    let eps = S::from_f64(1e-7);
789    let mut f0 = vec![S::ZERO; n];
790    rhs_fn(t0, y0, &mut f0);
791
792    let mut incidence = vec![Vec::new(); n_diff_eqs + n_alg_eqs];
793    let mut y_pert = y0.to_vec();
794
795    for j in 0..n {
796        let yj_save = y_pert[j];
797        let h = eps * (S::ONE + yj_save.abs());
798        y_pert[j] = yj_save + h;
799
800        let mut f1 = vec![S::ZERO; n];
801        rhs_fn(t0, &y_pert, &mut f1);
802        y_pert[j] = yj_save;
803
804        let threshold = S::from_f64(1e-12);
805
806        for (local_idx, &eq_idx) in diff_eq_indices.iter().enumerate() {
807            let diff = (f1[eq_idx] - f0[eq_idx]).abs();
808            if diff > threshold * h {
809                incidence[local_idx].push(j);
810            }
811        }
812
813        for (local_idx, &eq_idx) in alg_eq_indices.iter().enumerate() {
814            let diff = (f1[eq_idx] - f0[eq_idx]).abs();
815            if diff > threshold * h {
816                incidence[n_diff_eqs + local_idx].push(j);
817            }
818        }
819    }
820
821    DaeStructure {
822        n_diff,
823        n_alg,
824        n_diff_eqs,
825        n_alg_eqs,
826        incidence,
827    }
828}
829
830/// Convenience: analyze and reduce a `DaeProblem` directly.
831///
832/// Returns a `ReducedDaeSystem` if the system is higher-index,
833/// or an `Err` if it's already index-1.
834pub fn reduce_dae_problem<S, F, M>(
835    problem: &DaeProblem<S, F, M>,
836    t0: S,
837    y0: &[S],
838) -> Result<(DaeIndexInfo, ReducedDaeSystem<S>), String>
839where
840    S: Scalar,
841    F: Fn(S, &[S], &mut [S]) + Clone + Send + Sync + 'static,
842    M: Fn(&mut [S]) + Clone + Send + Sync + 'static,
843{
844    let structure = detect_structure(problem, t0, y0);
845    let info = analyze_dae_index(&structure);
846
847    if info.structural_index <= 1 {
848        return Err(format!(
849            "System is index-{}, no reduction needed",
850            info.structural_index
851        ));
852    }
853
854    let rhs_fn = problem.f.clone();
855    let mass_fn = problem.mass.clone();
856    let alg_indices = problem.alg_indices.clone();
857    let n = problem.dim();
858
859    let n_new_vars: usize = info
860        .differentiation_schedule
861        .iter()
862        .map(|&(_, nd)| nd)
863        .sum();
864    let aug_dim = n + n_new_vars;
865
866    let is_algebraic: Vec<bool> = (0..n).map(|i| alg_indices.contains(&i)).collect();
867    let diff_eq_indices: Vec<usize> = (0..n).filter(|i| !is_algebraic[*i]).collect();
868
869    let reduced = ReducedDaeSystem {
870        n_orig: n,
871        n_diff: n - alg_indices.len(),
872        alg_eq_indices: alg_indices,
873        diff_eq_indices,
874        rhs_fn: Box::new(rhs_fn),
875        mass_fn: Box::new(mass_fn),
876        info: info.clone(),
877        aug_dim,
878        n_new_vars,
879        fd_eps: S::from_f64(1e-7),
880    };
881
882    Ok((info, reduced))
883}
884
885/// Analyze a DAE system and return its structural index without reducing.
886///
887/// Useful for diagnostic purposes.
888pub fn analyze_system<S, Sys>(system: &Sys, t0: S, y0: &[S]) -> DaeIndexInfo
889where
890    S: Scalar,
891    Sys: OdeSystem<S>,
892{
893    let structure = detect_structure(system, t0, y0);
894    analyze_dae_index(&structure)
895}
896
897use crate::DaeProblem;
898
899// ============================================================================
900// Tests
901// ============================================================================
902
903#[cfg(test)]
904mod tests {
905    use super::*;
906
907    // ---- Structural analysis tests ----
908
909    #[test]
910    fn test_pure_ode_index_0() {
911        // No algebraic equations => index 0
912        let structure = DaeStructure {
913            n_diff: 2,
914            n_alg: 0,
915            n_diff_eqs: 2,
916            n_alg_eqs: 0,
917            incidence: vec![
918                vec![0, 1], // eq0 depends on y0, y1
919                vec![0, 1], // eq1 depends on y0, y1
920            ],
921        };
922
923        let info = analyze_dae_index(&structure);
924        assert_eq!(info.structural_index, 0);
925        assert_eq!(info.n_hidden_constraints, 0);
926        assert!(info.differentiation_schedule.is_empty());
927    }
928
929    #[test]
930    fn test_index_1_dae() {
931        // Index-1 DAE:
932        //   y0' = -y0 + y1      (diff eq, depends on y0, y1)
933        //   0   = y1 - y0^2     (alg eq, depends on y0, y1)
934        //
935        // The algebraic equation can be matched to y1 => index 1
936        let structure = DaeStructure {
937            n_diff: 1,
938            n_alg: 1,
939            n_diff_eqs: 1,
940            n_alg_eqs: 1,
941            incidence: vec![
942                vec![0, 1], // diff eq depends on y0, y1
943                vec![0, 1], // alg eq depends on y0, y1
944            ],
945        };
946
947        let info = analyze_dae_index(&structure);
948        assert_eq!(info.structural_index, 1);
949        assert_eq!(info.n_hidden_constraints, 0);
950    }
951
952    #[test]
953    fn test_index_2_dae() {
954        // Index-2 DAE:
955        //   x'  = v           (diff, depends on x, v, but really just v)
956        //   v'  = -lambda*x   (diff, depends on x, lambda)
957        //   0   = x^2 - 1     (alg, depends on x only)
958        //
959        // Variables: 0=x, 1=v, 2=lambda
960        // Diff eqs: eq0 (for x), eq1 (for v)
961        // Alg eq: eq2 (constraint)
962        //
963        // The constraint only depends on x (var 0). Var 0 is already matched
964        // to eq0. The constraint can't match to lambda (var 2) since it doesn't
965        // depend on it. So it needs differentiation => index 2.
966        let structure = DaeStructure {
967            n_diff: 2,
968            n_alg: 1,
969            n_diff_eqs: 2,
970            n_alg_eqs: 1,
971            incidence: vec![
972                vec![0, 1], // eq0: x' = v (depends on x, v)
973                vec![0, 2], // eq1: v' = -lambda*x (depends on x, lambda)
974                vec![0],    // eq2: 0 = x^2 - 1 (depends on x only)
975            ],
976        };
977
978        let info = analyze_dae_index(&structure);
979        assert!(
980            info.structural_index >= 2,
981            "Expected index >= 2, got {}",
982            info.structural_index
983        );
984        assert!(info.n_hidden_constraints >= 1);
985        assert!(!info.differentiation_schedule.is_empty());
986    }
987
988    #[test]
989    fn test_detect_structure_index1() {
990        // Semi-explicit index-1 DAE:
991        //   y0' = -y0        (differential)
992        //   0 = y1 - y0^2    (algebraic: y1 = y0^2)
993        let dae = DaeProblem::new(
994            |_t, y: &[f64], dydt: &mut [f64]| {
995                dydt[0] = -y[0];
996                dydt[1] = y[1] - y[0] * y[0];
997            },
998            |mass: &mut [f64]| {
999                mass[0] = 1.0;
1000                mass[1] = 0.0;
1001                mass[2] = 0.0;
1002                mass[3] = 0.0;
1003            },
1004            0.0,
1005            1.0,
1006            vec![2.0, 4.0],
1007            vec![1],
1008        );
1009
1010        let structure = detect_structure(&dae, 0.0, &[2.0, 4.0]);
1011        assert_eq!(structure.n_diff, 1);
1012        assert_eq!(structure.n_alg, 1);
1013        assert_eq!(structure.n_diff_eqs, 1);
1014        assert_eq!(structure.n_alg_eqs, 1);
1015
1016        let info = analyze_dae_index(&structure);
1017        assert_eq!(info.structural_index, 1);
1018    }
1019
1020    #[test]
1021    fn test_detect_structure_index2() {
1022        // Index-2 DAE (simplified pendulum-like):
1023        //   x'  = v                    (eq for x)
1024        //   v'  = -lambda * x          (eq for v)
1025        //   0   = x^2 - 1.0            (constraint: x on unit circle)
1026        //
1027        // State: [x, v, lambda]
1028        // Algebraic index: 2 (lambda)
1029        let dae = DaeProblem::new(
1030            |_t, y: &[f64], dydt: &mut [f64]| {
1031                let x = y[0];
1032                let v = y[1];
1033                let lam = y[2];
1034                dydt[0] = v; // x' = v
1035                dydt[1] = -lam * x; // v' = -lambda*x
1036                dydt[2] = x * x - 1.0; // 0 = x^2 - 1
1037            },
1038            |mass: &mut [f64]| {
1039                // 3x3 mass matrix
1040                for i in 0..9 {
1041                    mass[i] = 0.0;
1042                }
1043                mass[0] = 1.0; // x is differential
1044                mass[4] = 1.0; // v is differential
1045                               // mass[8] = 0   // lambda is algebraic
1046            },
1047            0.0,
1048            1.0,
1049            vec![1.0, 0.0, 0.0],
1050            vec![2],
1051        );
1052
1053        let structure = detect_structure(&dae, 0.0, &[1.0, 0.0, 0.0]);
1054        let info = analyze_dae_index(&structure);
1055        assert!(
1056            info.structural_index >= 2,
1057            "Expected index >= 2, got {}. Schedule: {:?}",
1058            info.structural_index,
1059            info.differentiation_schedule
1060        );
1061    }
1062
1063    #[test]
1064    fn test_reduce_index_creates_augmented_system() {
1065        // Index-2 system from the pendulum-like example
1066        let n = 3;
1067        let alg_indices = vec![2usize];
1068
1069        let result = reduce_index(
1070            |_t, y: &[f64], dydt: &mut [f64]| {
1071                let x = y[0];
1072                let v = y[1];
1073                let lam = y[2];
1074                dydt[0] = v;
1075                dydt[1] = -lam * x;
1076                dydt[2] = x * x - 1.0;
1077            },
1078            |mass: &mut [f64]| {
1079                for i in 0..9 {
1080                    mass[i] = 0.0;
1081                }
1082                mass[0] = 1.0;
1083                mass[4] = 1.0;
1084            },
1085            &alg_indices,
1086            n,
1087            0.0_f64,
1088            &[1.0, 0.0, 0.0],
1089        );
1090
1091        assert!(
1092            result.is_ok(),
1093            "Reduction should succeed: {:?}",
1094            result.err()
1095        );
1096        let reduced = result.unwrap();
1097
1098        // Augmented system should be larger
1099        assert!(
1100            reduced.augmented_dim() > n,
1101            "Augmented dim {} should be > original dim {}",
1102            reduced.augmented_dim(),
1103            n
1104        );
1105
1106        // Should still report as DAE
1107        assert!(reduced.is_singular_mass());
1108
1109        // Test augmented IC
1110        let y0_aug = reduced.augment_initial_conditions(0.0, &[1.0, 0.0, 0.0]);
1111        assert_eq!(y0_aug.len(), reduced.augmented_dim());
1112        // Original part should be preserved
1113        assert!((y0_aug[0] - 1.0).abs() < 1e-10);
1114        assert!((y0_aug[1] - 0.0).abs() < 1e-10);
1115        assert!((y0_aug[2] - 0.0).abs() < 1e-10);
1116    }
1117
1118    #[test]
1119    fn test_reduced_system_rhs_evaluates() {
1120        // Verify the reduced system can evaluate its RHS without panicking
1121        let n = 3;
1122        let alg_indices = vec![2usize];
1123
1124        let reduced = reduce_index(
1125            |_t, y: &[f64], dydt: &mut [f64]| {
1126                let x = y[0];
1127                let v = y[1];
1128                let lam = y[2];
1129                dydt[0] = v;
1130                dydt[1] = -lam * x;
1131                dydt[2] = x * x - 1.0;
1132            },
1133            |mass: &mut [f64]| {
1134                for i in 0..9 {
1135                    mass[i] = 0.0;
1136                }
1137                mass[0] = 1.0;
1138                mass[4] = 1.0;
1139            },
1140            &alg_indices,
1141            n,
1142            0.0_f64,
1143            &[1.0, 0.0, 0.0],
1144        )
1145        .unwrap();
1146
1147        let aug_dim = reduced.augmented_dim();
1148        let y0_aug = reduced.augment_initial_conditions(0.0, &[1.0, 0.0, 0.0]);
1149        let mut dydt = vec![0.0; aug_dim];
1150        reduced.rhs(0.0, &y0_aug, &mut dydt);
1151
1152        // Original differential equations should give correct values
1153        // x' = v = 0
1154        assert!(
1155            (dydt[0] - 0.0).abs() < 1e-10,
1156            "x' should be 0, got {}",
1157            dydt[0]
1158        );
1159        // v' = -lambda*x = 0*1 = 0
1160        assert!(
1161            (dydt[1] - 0.0).abs() < 1e-10,
1162            "v' should be 0, got {}",
1163            dydt[1]
1164        );
1165        // Original constraint: x^2 - 1 = 0 (consistent)
1166        assert!(
1167            (dydt[2] - 0.0).abs() < 1e-8,
1168            "constraint should be ~0, got {}",
1169            dydt[2]
1170        );
1171    }
1172
1173    #[test]
1174    fn test_reduced_system_mass_matrix() {
1175        let n = 3;
1176        let alg_indices = vec![2usize];
1177
1178        let reduced = reduce_index(
1179            |_t, y: &[f64], dydt: &mut [f64]| {
1180                dydt[0] = y[1];
1181                dydt[1] = -y[2] * y[0];
1182                dydt[2] = y[0] * y[0] - 1.0;
1183            },
1184            |mass: &mut [f64]| {
1185                for i in 0..9 {
1186                    mass[i] = 0.0;
1187                }
1188                mass[0] = 1.0;
1189                mass[4] = 1.0;
1190            },
1191            &alg_indices,
1192            n,
1193            0.0_f64,
1194            &[1.0, 0.0, 0.0],
1195        )
1196        .unwrap();
1197
1198        let aug_dim = reduced.augmented_dim();
1199        let mut mass = vec![0.0; aug_dim * aug_dim];
1200        reduced.mass_matrix(&mut mass);
1201
1202        // Original mass entries preserved
1203        assert!((mass[0 * aug_dim + 0] - 1.0).abs() < 1e-10); // M[0,0] = 1
1204        assert!((mass[1 * aug_dim + 1] - 1.0).abs() < 1e-10); // M[1,1] = 1
1205        assert!((mass[2 * aug_dim + 2] - 0.0).abs() < 1e-10); // M[2,2] = 0 (algebraic)
1206
1207        // New rows should be zero (algebraic)
1208        for k in n..aug_dim {
1209            for j in 0..aug_dim {
1210                assert!(
1211                    (mass[k * aug_dim + j] - 0.0).abs() < 1e-10,
1212                    "New row {} should be all zero, but M[{},{}] = {}",
1213                    k,
1214                    k,
1215                    j,
1216                    mass[k * aug_dim + j]
1217                );
1218            }
1219        }
1220    }
1221
1222    #[test]
1223    fn test_already_index1_returns_err() {
1224        // Index-1 system should return Err (no reduction needed)
1225        let result = reduce_index(
1226            |_t, y: &[f64], dydt: &mut [f64]| {
1227                dydt[0] = -y[0];
1228                dydt[1] = y[1] - y[0] * y[0];
1229            },
1230            |mass: &mut [f64]| {
1231                mass[0] = 1.0;
1232                mass[1] = 0.0;
1233                mass[2] = 0.0;
1234                mass[3] = 0.0;
1235            },
1236            &[1],
1237            2,
1238            0.0_f64,
1239            &[2.0, 4.0],
1240        );
1241
1242        assert!(result.is_err());
1243        assert!(result.unwrap_err().contains("index-1"));
1244    }
1245
1246    #[test]
1247    fn test_analyze_system_convenience() {
1248        let dae = DaeProblem::new(
1249            |_t, y: &[f64], dydt: &mut [f64]| {
1250                dydt[0] = -y[0];
1251                dydt[1] = y[1] - y[0] * y[0];
1252            },
1253            |mass: &mut [f64]| {
1254                mass[0] = 1.0;
1255                mass[1] = 0.0;
1256                mass[2] = 0.0;
1257                mass[3] = 0.0;
1258            },
1259            0.0,
1260            1.0,
1261            vec![2.0, 4.0],
1262            vec![1],
1263        );
1264
1265        let info = analyze_system(&dae, 0.0, &[2.0, 4.0]);
1266        assert_eq!(info.structural_index, 1);
1267    }
1268
1269    #[test]
1270    fn test_extract_original() {
1271        let n = 3;
1272        let reduced = reduce_index(
1273            |_t, y: &[f64], dydt: &mut [f64]| {
1274                dydt[0] = y[1];
1275                dydt[1] = -y[2] * y[0];
1276                dydt[2] = y[0] * y[0] - 1.0;
1277            },
1278            |mass: &mut [f64]| {
1279                for i in 0..9 {
1280                    mass[i] = 0.0;
1281                }
1282                mass[0] = 1.0;
1283                mass[4] = 1.0;
1284            },
1285            &[2],
1286            n,
1287            0.0_f64,
1288            &[1.0, 0.0, 0.0],
1289        )
1290        .unwrap();
1291
1292        let y_aug = vec![1.0, 2.0, 3.0, 4.0, 5.0]; // some augmented state
1293        let y_orig = reduced.extract_original(&y_aug);
1294        assert_eq!(y_orig.len(), 3);
1295        assert!((y_orig[0] - 1.0).abs() < 1e-10);
1296        assert!((y_orig[1] - 2.0).abs() < 1e-10);
1297        assert!((y_orig[2] - 3.0).abs() < 1e-10);
1298    }
1299
1300    #[test]
1301    fn test_dae_structure_helpers() {
1302        let structure = DaeStructure {
1303            n_diff: 2,
1304            n_alg: 1,
1305            n_diff_eqs: 2,
1306            n_alg_eqs: 1,
1307            incidence: vec![vec![0, 1], vec![0, 2], vec![0]],
1308        };
1309
1310        assert_eq!(structure.n_vars(), 3);
1311        assert_eq!(structure.n_eqs(), 3);
1312    }
1313
1314    #[test]
1315    fn test_pantelides_empty_incidence() {
1316        // Edge case: algebraic equation with no dependencies
1317        let structure = DaeStructure {
1318            n_diff: 1,
1319            n_alg: 1,
1320            n_diff_eqs: 1,
1321            n_alg_eqs: 1,
1322            incidence: vec![
1323                vec![0], // diff eq depends on y0
1324                vec![],  // alg eq depends on nothing — structurally singular
1325            ],
1326        };
1327
1328        let info = analyze_dae_index(&structure);
1329        // Should handle gracefully, may report index >= 1
1330        assert!(info.structural_index >= 1);
1331    }
1332
1333    #[test]
1334    fn test_multiple_algebraic_equations() {
1335        // Two algebraic constraints, both index-1
1336        // y0' = -y0
1337        // 0 = y1 - y0
1338        // 0 = y2 - y0^2
1339        let structure = DaeStructure {
1340            n_diff: 1,
1341            n_alg: 2,
1342            n_diff_eqs: 1,
1343            n_alg_eqs: 2,
1344            incidence: vec![
1345                vec![0],    // diff eq: y0' = -y0
1346                vec![0, 1], // alg eq1: y1 - y0 = 0 (depends on y0, y1)
1347                vec![0, 2], // alg eq2: y2 - y0^2 = 0 (depends on y0, y2)
1348            ],
1349        };
1350
1351        let info = analyze_dae_index(&structure);
1352        assert_eq!(info.structural_index, 1);
1353        assert_eq!(info.n_hidden_constraints, 0);
1354    }
1355}