evalexpr_jit/
system.rs

1//! System of equations evaluation with JIT compilation.
2//!
3//! This module provides functionality for evaluating multiple mathematical equations
4//! simultaneously using JIT compilation. The equations are combined into a single
5//! optimized function for efficient evaluation.
6//!
7//! # Features
8//!
9//! - JIT compilation of multiple equations into a single function
10//! - Consistent variable ordering across equations
11//! - Automatic variable extraction and mapping
12//! - Efficient batch evaluation
13//! - Automatic derivative computation and Jacobian matrix generation
14//! - Higher-order derivative support
15//! - Parallel evaluation capabilities
16//!
17//! # Example
18//!
19//! ```
20//! use evalexpr_jit::system::EquationSystem;
21//!
22//! let system = EquationSystem::new(vec![
23//!     "2*x + y".to_string(),   // first equation
24//!     "x^2 + z".to_string(),   // second equation
25//! ]).unwrap();
26//!
27//! // Variables are automatically sorted (x, y, z)
28//! let results = system.eval(&vec![1.0, 2.0, 3.0]).unwrap();
29//! assert_eq!(results.as_slice(), vec![4.0, 4.0]); // [2*1 + 2, 1^2 + 3]
30//!
31//! // Compute derivatives
32//! let dx = system.gradient(&vec![1.0, 2.0, 3.0], "x").unwrap();
33//! assert_eq!(dx.as_slice(), vec![2.0, 2.0]); // [d/dx(2x + y), d/dx(x^2 + z)]
34//! ```
35//!
36//! # Computing Derivatives
37//!
38//! There are two main ways to create new equation systems with derivatives from existing ones:
39//!
40//! ## 1. Using `derive_wrt` for Gradients
41//!
42//! Creates a new equation system that computes derivatives with respect to specified variables:
43//!
44//! ```
45//! # use evalexpr_jit::system::EquationSystem;
46//! let system = EquationSystem::new(vec![
47//!     "x^2*y".to_string(),     // f1
48//!     "x*y^2".to_string(),     // f2
49//! ]).unwrap();
50//!
51//! // First-order derivative with respect to x
52//! let dx = system.derive_wrt(&["x"]).unwrap();
53//! let results = dx.eval(&[2.0, 3.0]).unwrap();
54//! assert_eq!(results.as_slice(), &[12.0, 9.0]);  // [d(x^2*y)/dx = 2xy, d(x*y^2)/dx = y^2]
55//!
56//! // Second-order mixed derivative (first x, then y)
57//! let dxy = system.derive_wrt(&["x", "y"]).unwrap();
58//! let results = dxy.eval(&[2.0, 3.0]).unwrap();
59//! assert_eq!(results.as_slice(), &[4.0, 6.0]);   // [d²(x^2*y)/dxdy = 2x, d²(x*y^2)/dxdy = 2y]
60//! ```
61//!
62//! ## 2. Using `jacobian_wrt` for Jacobian Matrices
63//!
64//! Creates a new equation system that computes the Jacobian matrix with respect to specified variables:
65//!
66//! ```
67//! # use evalexpr_jit::system::EquationSystem;
68//! # use ndarray::Array2;
69//! let system = EquationSystem::new(vec![
70//!     "x^2*y + z".to_string(),    // f1
71//!     "x*y^2 - z^2".to_string(),  // f2
72//! ]).unwrap();
73//!
74//! // Create Jacobian system for x and y
75//! let jacobian_system = system.jacobian_wrt(&["x", "y"]).unwrap();
76//!
77//! // Prepare matrix to store results (2 equations × 2 variables)
78//! let mut results = Array2::zeros((2, 2));
79//!
80//! // Evaluate Jacobian at point (x=2, y=3, z=1)
81//! jacobian_system.eval_into_matrix(&[2.0, 3.0, 1.0], &mut results).unwrap();
82//!
83//! // results now contains:
84//! // [
85//! //   [12.0, 4.0],   // [∂f1/∂x, ∂f1/∂y]
86//! //   [9.0,  12.0],  // [∂f2/∂x, ∂f2/∂y]
87//! // ]
88//! ```
89//!
90
91use crate::backends::vector::Vector;
92use crate::builder::build_combined_function;
93use crate::convert::build_ast;
94use crate::equation::{extract_all_symbols, extract_symbols};
95use crate::errors::EquationError;
96use crate::expr::Expr;
97use crate::prelude::Matrix;
98use crate::types::CombinedJITFunction;
99use evalexpr::build_operator_tree;
100use itertools::Itertools;
101use rayon::prelude::*;
102use std::collections::HashMap;
103use std::sync::Arc;
104
105/// Represents a system of mathematical equations that can be evaluated together.
106pub struct EquationSystem {
107    /// The original string representations of the equations
108    pub equations: Vec<String>,
109    /// The AST representations of the equations
110    pub asts: Vec<Expr>,
111    /// Maps variable names to their indices in the input array
112    pub variable_map: HashMap<String, u32>,
113    /// Variables in sorted order for consistent input ordering
114    pub sorted_variables: Vec<String>,
115    /// The JIT-compiled function that evaluates all equations
116    pub combined_fun: CombinedJITFunction,
117    /// Partial derivatives of the system - maps variable names to their derivative functions
118    /// E.g. {"x": df(x, y, z)/dx, "y": df(x, y, z)/dy}
119    pub partial_derivatives: HashMap<String, CombinedJITFunction>,
120    /// The type of output
121    output_type: OutputType,
122}
123
124impl EquationSystem {
125    /// Creates a new equation system from a vector of expression strings.
126    ///
127    /// This constructor automatically extracts variables from the expressions and assigns them indices
128    /// in alphabetical order. The resulting system can evaluate all equations efficiently and compute
129    /// derivatives with respect to any variable.
130    ///
131    /// # Arguments
132    /// * `expressions` - Vector of mathematical expressions as strings
133    ///
134    /// # Returns
135    /// A new `EquationSystem` with JIT-compiled evaluation function and derivative capabilities
136    ///
137    /// # Example
138    /// ```
139    /// # use evalexpr_jit::system::EquationSystem;
140    /// let system = EquationSystem::new(vec![
141    ///     "2*x + y".to_string(),
142    ///     "x^2 + z".to_string()
143    /// ]).unwrap();
144    ///
145    /// // Evaluate system
146    /// let results = system.eval(&vec![1.0, 2.0, 3.0]).unwrap();
147    ///
148    /// // Compute derivatives
149    /// let dx = system.gradient(&vec![1.0, 2.0, 3.0], "x").unwrap();
150    /// ```
151    pub fn new(expressions: Vec<String>) -> Result<Self, EquationError> {
152        let sorted_variables = extract_all_symbols(&expressions);
153        let variable_map: HashMap<String, u32> = sorted_variables
154            .iter()
155            .enumerate()
156            .map(|(i, v)| (v.clone(), i as u32))
157            .collect();
158        let asts = Self::create_asts(&expressions, &variable_map)?;
159        Self::build(asts, expressions, variable_map, OutputType::Vector)
160    }
161
162    /// Creates a new equation system from a vector of expressions and a variable map.
163    ///
164    /// This constructor allows explicit control over variable ordering by providing a map
165    /// of variable names to their indices. This is useful when you need to ensure specific
166    /// variable ordering or when integrating with other systems that expect variables in
167    /// a particular order.
168    ///
169    /// # Arguments
170    /// * `expressions` - Vector of mathematical expressions as strings
171    /// * `variable_map` - Map of variable names to their indices, defining input order
172    ///
173    /// # Returns
174    /// A new `EquationSystem` with JIT-compiled evaluation function using the specified variable ordering
175    ///
176    /// # Example
177    /// ```
178    /// use evalexpr_jit::prelude::*;
179    /// use std::collections::HashMap;
180    ///
181    /// let var_map: HashMap<String, u32> = [
182    ///     ("x".to_string(), 0),
183    ///     ("y".to_string(), 1),
184    ///     ("z".to_string(), 2),
185    /// ].into_iter().collect();
186    ///
187    /// let system = EquationSystem::from_var_map(
188    ///     vec!["2*x + y".to_string(), "x^2 + z".to_string()],
189    ///     &var_map
190    /// ).unwrap();
191    /// ```
192    pub fn from_var_map(
193        expressions: Vec<String>,
194        variable_map: &HashMap<String, u32>,
195    ) -> Result<Self, EquationError> {
196        let asts = Self::create_asts(&expressions, variable_map)?;
197        Self::build(asts, expressions, variable_map.clone(), OutputType::Vector)
198    }
199
200    /// Creates a new equation system directly from AST nodes and a variable map.
201    ///
202    /// This constructor allows creating a system from pre-built AST nodes rather than parsing
203    /// expressions from strings. This is useful when augmenting existing systems or creating derivatives
204    /// of existing systems.
205    ///
206    /// Please note, this constructor is not meant to be used by the end user and is only available for internal use
207    /// to create derivatives of existing systems.
208    ///
209    /// # Arguments
210    /// * `asts` - Vector of expression AST nodes
211    /// * `variable_map` - Map of variable names to their indices, defining input order
212    /// * `output_type` - The type of output. Used to determine the shape of the output vector
213    /// # Returns
214    /// A new `EquationSystem` with JIT-compiled evaluation function using the specified ASTs
215    pub(crate) fn from_asts(
216        asts: Vec<Expr>,
217        variable_map: &HashMap<String, u32>,
218        output_type: OutputType,
219    ) -> Result<Self, EquationError> {
220        let expressions = asts.iter().map(|ast| ast.to_string()).collect();
221        Self::build(asts, expressions, variable_map.clone(), output_type)
222    }
223
224    /// Core builder function used by both `new()` and `from_var_map()`.
225    ///
226    /// This internal function handles the common construction logic for both public constructors.
227    /// It builds ASTs and creates JIT-compiled functions for both evaluation and derivatives.
228    ///
229    /// # Arguments
230    /// * `asts` - Vector of AST nodes
231    /// * `equations` - Original expression strings
232    /// * `variable_map` - Map of variable names to indices
233    /// * `output_type` - The type of output (vector or matrix)
234    ///
235    /// # Returns
236    /// A new `EquationSystem` with JIT-compiled evaluation function and derivative capabilities
237    fn build(
238        asts: Vec<Expr>,
239        equations: Vec<String>,
240        variable_map: HashMap<String, u32>,
241        output_type: OutputType,
242    ) -> Result<Self, EquationError> {
243        // Create combined JIT function
244        let combined_fun = build_combined_function(asts.clone(), equations.len())?;
245
246        // Create derivative functions for each variable forming a Jacobian matrix
247        let mut jacobian_funs = HashMap::with_capacity(variable_map.len());
248
249        let sorted_variables: Vec<String> = variable_map
250            .iter()
251            .sorted_by_key(|(_, idx)| *idx)
252            .map(|(var, _)| var.clone())
253            .collect();
254
255        for var in sorted_variables {
256            let derivative_ast = asts
257                .iter()
258                .map(|ast| *ast.derivative(&var))
259                .collect::<Vec<Expr>>();
260            let jacobian_fun = build_combined_function(derivative_ast, asts.len())?;
261            jacobian_funs.insert(var, jacobian_fun);
262        }
263
264        Ok(Self {
265            equations,
266            asts,
267            variable_map: variable_map.clone(),
268            sorted_variables: variable_map.keys().sorted().cloned().collect(),
269            combined_fun,
270            partial_derivatives: jacobian_funs,
271            output_type,
272        })
273    }
274
275    /// Creates abstract syntax trees (ASTs) from a vector of mathematical expressions.
276    ///
277    /// This internal function parses each expression string into an AST, validates that all
278    /// variables used in the expressions exist in the provided variable map, and returns
279    /// simplified ASTs ready for compilation.
280    ///
281    /// # Arguments
282    /// * `expressions` - Vector of mathematical expression strings to parse
283    /// * `variable_map` - Map of valid variable names to their indices
284    ///
285    /// # Returns
286    /// A vector of simplified ASTs, one for each input expression
287    ///
288    /// # Errors
289    /// Returns `EquationError::VariableNotFound` if an expression uses a variable
290    /// that doesn't exist in the variable map
291    fn create_asts(
292        expressions: &[String],
293        variable_map: &HashMap<String, u32>,
294    ) -> Result<Vec<Expr>, EquationError> {
295        expressions
296            .iter()
297            .map(|expr| {
298                let node = build_operator_tree(expr)?;
299
300                // Validate variables
301                let expr_vars = extract_symbols(&node);
302                for var in expr_vars.keys() {
303                    if !variable_map.contains_key(var) {
304                        return Err(EquationError::VariableNotFound(var.clone()));
305                    }
306                }
307
308                // Build and simplify AST
309                let ast = build_ast(&node, variable_map)?;
310                Ok(*ast.simplify())
311            })
312            .collect::<Result<Vec<_>, EquationError>>()
313    }
314
315    /// Evaluates all equations in the system with the given input values
316    /// into a pre-allocated buffer.
317    ///
318    /// # Arguments
319    /// * `inputs` - Input vector implementing the Vector trait
320    /// * `results` - Pre-allocated vector to store results
321    ///
322    /// # Returns
323    /// Reference to the results vector containing the evaluated values
324    pub fn eval_into<V: Vector, R: Vector>(
325        &self,
326        inputs: &V,
327        results: &mut R,
328    ) -> Result<(), EquationError> {
329        self.validate_input_length(inputs.as_slice())?;
330        if results.len() != self.equations.len() {
331            return Err(EquationError::InvalidInputLength {
332                expected: self.equations.len(),
333                got: results.len(),
334            });
335        }
336
337        (self.combined_fun)(inputs.as_slice(), results.as_mut_slice());
338        Ok(())
339    }
340
341    /// Evaluates all equations in the system with the given input values.
342    /// Allocates a new vector for results.
343    ///
344    /// # Arguments
345    /// * `inputs` - Input vector implementing the Vector trait
346    ///
347    /// # Returns
348    /// Vector of results, one for each equation in the system
349    pub fn eval<V: Vector>(&self, inputs: &V) -> Result<V, EquationError> {
350        let mut results = V::zeros(self.equations.len());
351        self.eval_into(inputs, &mut results)?;
352        Ok(results)
353    }
354
355    /// Evaluates all equations in the system with the given input values into a pre-allocated matrix.
356    ///
357    /// This method is used when the equation system is configured to output a matrix rather than a vector.
358    /// The results matrix must have the correct dimensions matching the system's output type.
359    ///
360    /// # Arguments
361    /// * `inputs` - Input vector implementing the Vector trait
362    /// * `results` - Pre-allocated matrix to store results
363    ///
364    /// # Returns
365    /// Ok(()) if successful, or an error if the system is not configured for matrix output
366    pub fn eval_into_matrix<V: Vector, R: Matrix>(
367        &self,
368        inputs: &V,
369        results: &mut R,
370    ) -> Result<(), EquationError> {
371        match self.output_type {
372            OutputType::Vector => {
373                // If the system is not configured to output a matrix, throw an error
374                return Err(EquationError::MatrixOutputRequired);
375            }
376            OutputType::Matrix(n_rows, n_cols) => {
377                self.validate_matrix_dimensions(n_rows, n_cols)?;
378            }
379        }
380
381        (self.combined_fun)(inputs.as_slice(), results.flat_mut_slice());
382        Ok(())
383    }
384
385    /// Evaluates all equations in the system with the given input values into a new matrix.
386    ///
387    /// This method allocates a new matrix with the correct dimensions and evaluates the system into it.
388    /// It should only be used when the equation system is configured to output a matrix.
389    ///
390    /// # Arguments
391    /// * `inputs` - Input vector implementing the Vector trait
392    ///
393    /// # Returns
394    /// Matrix containing the evaluated results, or an error if the system is not configured for matrix output
395    pub fn eval_matrix<V: Vector, R: Matrix>(&self, inputs: &V) -> Result<R, EquationError> {
396        match self.output_type {
397            OutputType::Vector => Err(EquationError::MatrixOutputRequired),
398            OutputType::Matrix(n_rows, n_cols) => {
399                let mut results = R::zeros(n_rows, n_cols);
400                self.eval_into_matrix(inputs, &mut results)?;
401                Ok(results)
402            }
403        }
404    }
405
406    /// Evaluates the equation system in parallel for multiple input sets.
407    ///
408    /// # Arguments
409    /// * `input_sets` - Slice of input vectors, each must match the number of variables
410    ///
411    /// # Returns
412    /// Vector of result vectors, one for each input set
413    pub fn eval_parallel<V: Vector + Send + Sync>(
414        &self,
415        input_sets: &[V],
416    ) -> Result<Vec<V>, EquationError> {
417        let num_threads = std::thread::available_parallelism()
418            .map(|n| n.get())
419            .unwrap_or(8);
420
421        let chunk_size = (input_sets.len() / (num_threads * 4)).max(1);
422        let n_equations = self.equations.len();
423
424        // Since we're using Arc, we can efficiently clone the system for parallel processing
425        let systems: Vec<_> = (0..num_threads).map(|_| self.clone()).collect();
426
427        Ok(input_sets
428            .par_chunks(chunk_size)
429            .enumerate()
430            .map(|(chunk_idx, chunk)| {
431                let system = &systems[chunk_idx % systems.len()];
432                chunk
433                    .iter()
434                    .map(|inputs| {
435                        let mut results = V::zeros(n_equations);
436                        (system.combined_fun)(inputs.as_slice(), results.as_mut_slice());
437                        results
438                    })
439                    .collect::<Vec<_>>()
440            })
441            .flatten()
442            .collect())
443    }
444
445    /// Returns the gradient of the equation system with respect to a specific variable.
446    ///
447    /// The gradient contains the partial derivatives of all equations with respect to the given variable,
448    /// evaluated at the provided input values. This is equivalent to one row of the Jacobian matrix.
449    ///
450    /// # Arguments
451    /// * `inputs` - Slice of input values at which to evaluate the gradient
452    /// * `variable` - Name of the variable to compute derivatives with respect to
453    ///
454    /// # Returns
455    /// Vector containing the partial derivatives of each equation with respect to
456    /// the specified variable, evaluated at the given input values
457    ///
458    /// # Example
459    /// ```
460    /// use evalexpr_jit::prelude::*;
461    ///
462    /// let system = EquationSystem::new(vec![
463    ///     "x^2*y".to_string(),  // f1
464    ///     "x*y^2".to_string(),  // f2
465    /// ]).unwrap();
466    ///
467    /// let gradient = system.gradient(&[2.0, 3.0], "x").unwrap();
468    /// // gradient contains [12.0, 9.0] (∂f1/∂x, ∂f2/∂x)
469    /// ```
470    pub fn gradient(&self, inputs: &[f64], variable: &str) -> Result<Vec<f64>, EquationError> {
471        self.validate_input_length(inputs)?;
472        let n_equations = self.equations.len();
473        let mut results = vec![0.0; n_equations];
474        self.partial_derivatives
475            .get(variable)
476            .ok_or(EquationError::VariableNotFound(variable.to_string()))?(
477            inputs, &mut results
478        );
479        Ok(results)
480    }
481
482    /// Computes the Jacobian matrix of the equation system at the given input values.
483    ///
484    /// The Jacobian matrix contains all first-order partial derivatives of the system.
485    /// Each row corresponds to an equation, and each column corresponds to a variable.
486    /// The entry at position (i,j) is the partial derivative of equation i with respect to variable j.
487    ///
488    /// # Arguments
489    /// * `inputs` - Slice of input values at which to evaluate the Jacobian
490    /// * `variables` - Optional slice of variable names to include in the Jacobian. If None, includes all variables in sorted order.
491    ///
492    /// # Returns
493    /// The Jacobian matrix as a vector of vectors, where each inner vector
494    /// contains the partial derivatives of one equation with respect to all variables
495    ///
496    /// # Example
497    /// ```
498    /// use evalexpr_jit::prelude::*;
499    ///
500    /// let system = EquationSystem::new(vec![
501    ///     "x^2*y".to_string(),  // f1
502    ///     "x*y^2".to_string(),  // f2
503    /// ]).unwrap();
504    ///
505    /// let jacobian = system.eval_jacobian(&[2.0, 3.0], None).unwrap();
506    /// // jacobian[0] contains [12.0, 4.0]   // ∂f1/∂x, ∂f1/∂y
507    /// // jacobian[1] contains [9.0,  12.0]   // ∂f2/∂x, ∂f2/∂y
508    /// ```
509    pub fn eval_jacobian(
510        &self,
511        inputs: &[f64],
512        variables: Option<&[String]>,
513    ) -> Result<Vec<Vec<f64>>, EquationError> {
514        self.validate_input_length(inputs)?;
515
516        let sorted_variables = variables.unwrap_or(&self.sorted_variables);
517        let mut results = Vec::with_capacity(self.equations.len());
518        let n_vars = sorted_variables.len();
519
520        // Initialize vectors for each equation
521        for _ in 0..self.equations.len() {
522            results.push(Vec::with_capacity(n_vars));
523        }
524
525        // Fill the transposed matrix
526        let n_equations = self.equations.len();
527        for var in sorted_variables {
528            let fun = self.partial_derivatives.get(var).unwrap();
529            let mut derivatives = vec![0.0; n_equations];
530            fun(inputs, &mut derivatives);
531            for (eq_idx, &value) in derivatives.iter().enumerate() {
532                results[eq_idx].push(value);
533            }
534        }
535
536        Ok(results)
537    }
538
539    /// Creates a new equation system that computes the Jacobian matrix with respect to specific variables.
540    ///
541    /// This method creates a new equation system optimized for computing partial derivatives
542    /// with respect to a subset of variables. The resulting system evaluates all equations' derivatives
543    /// with respect to the specified variables and arranges them in matrix form.
544    ///
545    /// # Arguments
546    /// * `variables` - Slice of variable names to include in the Jacobian matrix
547    ///
548    /// # Returns
549    /// A new `EquationSystem` that computes the Jacobian matrix when evaluated. The output matrix
550    /// has dimensions `[n_equations × n_variables]`, where each row contains the derivatives
551    /// of one equation with respect to the specified variables.
552    ///
553    /// # Errors
554    /// Returns `EquationError::VariableNotFound` if any of the specified variables doesn't exist
555    /// in the system.
556    ///
557    /// # Example
558    /// ```
559    /// use evalexpr_jit::prelude::*;
560    /// use ndarray::Array2;
561    /// let system = EquationSystem::new(vec![
562    ///     "x^2*y + z".to_string(),    // f1
563    ///     "x*y^2 - z^2".to_string(),  // f2
564    /// ]).unwrap();
565    ///
566    /// // Create Jacobian system for x and y only
567    /// let jacobian_system = system.jacobian_wrt(&["x", "y"]).unwrap();
568    ///
569    /// // Prepare matrix to store results (2 equations × 2 variables)
570    /// let mut results = Array2::zeros((2, 2));
571    ///
572    /// // Evaluate Jacobian at point (x=2, y=3, z=1)
573    /// jacobian_system.eval_into_matrix(&[2.0, 3.0, 1.0], &mut results).unwrap();
574    ///
575    /// // results now contains:
576    /// // [
577    /// //   [12.0, 4.0],   // [∂f1/∂x, ∂f1/∂y]
578    /// //   [9.0,  12.0],  // [∂f2/∂x, ∂f2/∂y]
579    /// // ]
580    /// ```
581    ///
582    /// # Performance Notes
583    /// - The returned system is JIT-compiled and optimized for repeated evaluations
584    /// - Pre-allocate the results matrix and reuse it for better performance
585    /// - The matrix dimensions will be `[n_equations × n_variables]`
586    pub fn jacobian_wrt(&self, variables: &[&str]) -> Result<EquationSystem, EquationError> {
587        // Verify all variables exist
588        for var in variables {
589            if !self.variable_map.contains_key(*var) {
590                return Err(EquationError::VariableNotFound(var.to_string()));
591            }
592        }
593
594        let mut asts = vec![];
595        for ast in self.asts.iter() {
596            for var in variables {
597                asts.push(*ast.derivative(var));
598            }
599        }
600
601        let output_type = OutputType::Matrix(self.num_equations(), variables.len());
602
603        EquationSystem::from_asts(asts, &self.variable_map, output_type)
604    }
605
606    /// Creates a new equation system containing the higher-order derivatives of all equations
607    /// with respect to multiple variables.
608    ///
609    /// This method allows computing mixed partial derivatives by specifying the variables
610    /// in the order of differentiation. For example, passing ["x", "y"] computes ∂²f/∂x∂y
611    /// for each equation f.
612    ///
613    /// # Arguments
614    /// * `variables` - Slice of variable names to differentiate with respect to, in order
615    ///
616    /// # Returns
617    /// A JIT-compiled function that computes the higher-order derivatives
618    ///
619    /// # Example
620    /// ```
621    /// # use evalexpr_jit::system::EquationSystem;
622    /// # use ndarray::Array1;
623    /// let system = EquationSystem::new(vec![
624    ///     "x^2*y".to_string(),  // f1
625    ///     "x*y^2".to_string(),  // f2
626    /// ]).unwrap();
627    ///
628    /// let derivatives = system.derive_wrt(&["x", "y"]).unwrap();
629    /// let mut results = Array1::zeros(2);
630    /// derivatives.eval_into(&vec![2.0, 3.0], &mut results).unwrap();
631    /// assert_eq!(results.as_slice().unwrap(), vec![4.0, 6.0]); // ∂²f1/∂x∂y = 2x, ∂²f2/∂x∂y = 2y
632    /// ```
633    pub fn derive_wrt(&self, variables: &[&str]) -> Result<EquationSystem, EquationError> {
634        // Verify all variables exist
635        for var in variables {
636            if !self.variable_map.contains_key(*var) {
637                return Err(EquationError::VariableNotFound(var.to_string()));
638            }
639        }
640
641        // Create higher-order derivatives of all ASTs
642        let mut derivative_asts = self.asts.clone();
643        for var in variables {
644            derivative_asts = derivative_asts
645                .into_iter()
646                .map(|ast| *ast.derivative(var))
647                .collect();
648        }
649
650        // Create new system from derivative ASTs
651        EquationSystem::from_asts(derivative_asts, &self.variable_map, OutputType::Vector)
652    }
653
654    pub fn validate_matrix_dimensions(
655        &self,
656        n_rows: usize,
657        n_cols: usize,
658    ) -> Result<(), EquationError> {
659        match self.output_type {
660            OutputType::Vector => {
661                return Err(EquationError::MatrixOutputRequired);
662            }
663            OutputType::Matrix(expected_rows, expected_cols) => {
664                if n_rows != expected_rows || n_cols != expected_cols {
665                    return Err(EquationError::InvalidMatrixDimensions {
666                        expected_rows,
667                        expected_cols,
668                        got_rows: n_rows,
669                        got_cols: n_cols,
670                    });
671                }
672            }
673        }
674        Ok(())
675    }
676
677    /// Returns the sorted variables in the system.
678    pub fn sorted_variables(&self) -> &[String] {
679        &self.sorted_variables
680    }
681
682    /// Returns the map of variable names to their indices.
683    pub fn variables(&self) -> &HashMap<String, u32> {
684        &self.variable_map
685    }
686
687    /// Returns the original equation strings.
688    pub fn equations(&self) -> &[String] {
689        &self.equations
690    }
691
692    /// Returns the compiled evaluation function.
693    pub fn fun(&self) -> &CombinedJITFunction {
694        &self.combined_fun
695    }
696
697    /// Returns the map of variable names to their derivative functions.
698    pub fn jacobian_funs(&self) -> &HashMap<String, CombinedJITFunction> {
699        &self.partial_derivatives
700    }
701
702    /// Returns the derivative function for a specific variable.
703    pub fn gradient_fun(&self, variable: &str) -> &CombinedJITFunction {
704        self.partial_derivatives.get(variable).unwrap()
705    }
706
707    /// Returns the number of equations in the system.
708    pub fn num_equations(&self) -> usize {
709        self.equations.len()
710    }
711
712    /// Validates that the number of input values matches the number of variables.
713    fn validate_input_length(&self, inputs: &[f64]) -> Result<(), EquationError> {
714        if inputs.len() != self.sorted_variables.len() {
715            return Err(EquationError::InvalidInputLength {
716                expected: self.sorted_variables.len(),
717                got: inputs.len(),
718            });
719        }
720        Ok(())
721    }
722}
723
724impl Clone for EquationSystem {
725    fn clone(&self) -> Self {
726        Self {
727            equations: self.equations.clone(),
728            asts: self.asts.clone(),
729            variable_map: self.variable_map.clone(),
730            sorted_variables: self.sorted_variables.clone(),
731            combined_fun: Arc::clone(&self.combined_fun),
732            partial_derivatives: self.partial_derivatives.clone(),
733            output_type: self.output_type,
734        }
735    }
736}
737
738#[derive(Debug, Clone, Copy)]
739pub(crate) enum OutputType {
740    Vector,
741    Matrix(usize, usize),
742}
743
744#[cfg(test)]
745mod tests {
746    use super::*;
747    use nalgebra::DVector;
748    use ndarray::{Array1, Array2};
749
750    #[test]
751    fn test_system_with_different_variables() -> Result<(), Box<dyn std::error::Error>> {
752        let expressions = vec![
753            "2*x + y".to_string(),   // uses x, y
754            "z^2".to_string(),       // uses only z
755            "x + y + z".to_string(), // uses all
756        ];
757
758        let system = EquationSystem::new(expressions)?;
759
760        // Check that all variables are tracked
761        assert_eq!(system.sorted_variables, &["x", "y", "z"]);
762
763        // Evaluate with values for all variables
764        let results = system.eval(&[1.0, 2.0, 3.0])?;
765
766        // Check results
767        assert_eq!(
768            results.as_slice(),
769            vec![
770                4.0, // 2*1 + 2
771                9.0, // 3^2
772                6.0, // 1 + 2 + 3
773            ]
774        );
775
776        Ok(())
777    }
778
779    #[test]
780    fn test_consistent_variable_ordering() -> Result<(), Box<dyn std::error::Error>> {
781        let expressions = vec![
782            "y + x".to_string(), // variables in different order
783            "x + z".to_string(), // different subset of variables
784        ];
785
786        let system = EquationSystem::new(expressions)?;
787
788        // Check that ordering is consistent (alphabetical)
789        assert_eq!(system.sorted_variables, &["x", "y", "z"]);
790
791        // Values must be provided in the consistent order [x, y, z]
792        let results = system.eval(&vec![1.0, 2.0, 3.0])?;
793
794        assert_eq!(
795            results.as_slice(),
796            vec![
797                3.0, // y(2.0) + x(1.0)
798                4.0, // x(1.0) + z(3.0)
799            ]
800        );
801
802        Ok(())
803    }
804
805    #[test]
806    #[should_panic]
807    fn test_invalid_input_length() {
808        let system = EquationSystem::new(vec!["x + y".to_string(), "y + z".to_string()]).unwrap();
809
810        // Should panic: providing only 2 values when system needs 3 (x, y, z)
811        let _ = system.eval(&[1.0, 2.0]).unwrap();
812    }
813
814    #[test]
815    fn test_complex_expressions() -> Result<(), Box<dyn std::error::Error>> {
816        let expressions = vec![
817            "(x + y) * (x - y)".to_string(),     // difference of squares
818            "x^3 + y^2 * z".to_string(),         // polynomial
819            "(x + y + z) / (x + 1)".to_string(), // division
820        ];
821
822        let system = EquationSystem::new(expressions)?;
823        let results = system.eval(&[2.0, 3.0, 4.0])?;
824
825        assert_eq!(results[0], -5.0); // (2 + 3) * (2 - 3) = 5 * -1 = -5
826        assert_eq!(results[1], 44.0); // 2^3 + 3^2 * 4 = 8 + 9 * 4 = 44
827        assert_eq!(results[2], 3.0); // (2 + 3 + 4) / (2 + 1) = 9 / 3 = 3
828
829        Ok(())
830    }
831
832    #[test]
833    fn test_custom_variable_map() -> Result<(), Box<dyn std::error::Error>> {
834        let mut var_map = HashMap::new();
835        var_map.insert("alpha".to_string(), 1);
836        var_map.insert("beta".to_string(), 0);
837
838        let expressions = vec!["2*alpha + beta".to_string(), "alpha^2 - beta".to_string()];
839
840        let system = EquationSystem::from_var_map(expressions, &var_map)?;
841        let results = system.eval(&[2.0, 1.0])?;
842
843        assert_eq!(results.as_slice(), &[4.0, -1.0]);
844
845        Ok(())
846    }
847
848    #[test]
849    fn test_error_undefined_variable() {
850        let expressions = vec![
851            "x + y".to_string(),
852            "x + undefined_var".to_string(), // undefined variable
853        ];
854
855        let mut var_map = HashMap::new();
856        var_map.insert("x".to_string(), 0);
857        var_map.insert("y".to_string(), 1);
858
859        let result = EquationSystem::from_var_map(expressions, &var_map);
860        assert!(matches!(result, Err(EquationError::VariableNotFound(_))));
861    }
862
863    #[test]
864    fn test_empty_system() -> Result<(), Box<dyn std::error::Error>> {
865        let system = EquationSystem::new(vec![])?;
866        let results = system.eval(&[])?;
867        assert!(results.is_empty());
868        Ok(())
869    }
870
871    #[test]
872    fn test_derive_wrt() -> Result<(), Box<dyn std::error::Error>> {
873        let system = EquationSystem::new(vec!["x^2*y".to_string(), "x*y^2".to_string()])?;
874
875        // First order derivative
876        let dx = system.derive_wrt(&["x"]).unwrap();
877        let mut dx_results = vec![0.0, 0.0];
878        dx.eval_into(&[2.0, 3.0], &mut dx_results).unwrap();
879        assert_eq!(dx_results, vec![12.0, 9.0]); // d/dx[x^2*y] = 2xy, d/dx[x*y^2] = y^2
880
881        // Second order derivative
882        let dxy = system.derive_wrt(&["x", "y"]).unwrap();
883        let mut dxy_results = vec![0.0, 0.0];
884        dxy.eval_into(&[2.0, 3.0], &mut dxy_results).unwrap();
885        assert_eq!(dxy_results, vec![4.0, 6.0]); // d²/dxdy[x^2*y] = 2x, d²/dxdy[x*y^2] = 2y
886
887        Ok(())
888    }
889
890    #[test]
891    fn test_derive_wrt_invalid_variable() {
892        let system =
893            EquationSystem::new(vec!["2*x + y^2".to_string(), "x^2 + z".to_string()]).unwrap();
894
895        let result = system.derive_wrt(&["w"]);
896        assert!(matches!(result, Err(EquationError::VariableNotFound(_))));
897    }
898
899    #[test]
900    fn test_jacobian() -> Result<(), Box<dyn std::error::Error>> {
901        let system = EquationSystem::new(vec![
902            "x^2*y".to_string(), // f1
903            "x*y^2".to_string(), // f2
904        ])?;
905
906        let jacobian = system.eval_jacobian(&[2.0, 3.0], None)?;
907
908        // Jacobian matrix should be:
909        // [∂f1/∂x  ∂f1/∂y] = [12.0  4.0]   // derivatives of first equation
910        // [∂f2/∂x  ∂f2/∂y] = [9.0   12.0]  // derivatives of second equation
911
912        assert_eq!(jacobian.len(), 2); // Two rows (one per equation)
913        assert_eq!(jacobian[0], vec![12.0, 4.0]); // Derivatives of first equation
914        assert_eq!(jacobian[1], vec![9.0, 12.0]); // Derivatives of second equation
915
916        Ok(())
917    }
918
919    #[test]
920    fn test_jacobian_wrt() -> Result<(), Box<dyn std::error::Error>> {
921        let system = EquationSystem::new(vec![
922            "x^2*y + z".to_string(),   // f1
923            "x*y^2 - z^2".to_string(), // f2
924        ])?;
925
926        // Test subset of variables (x and y only)
927        let jacobian_fn = system.jacobian_wrt(&["x", "y"]).unwrap();
928        let mut results = Array2::zeros((2, 2));
929        jacobian_fn
930            .eval_into_matrix(&vec![2.0, 3.0, 1.0], &mut results)
931            .unwrap();
932
933        // Expected derivatives:
934        // ∂f1/∂x = 2xy = 2(2)(3) = 12
935        // ∂f1/∂y = x^2 = 4
936        // ∂f2/∂x = y^2 = 9
937        // ∂f2/∂y = 2xy = 2(2)(3) = 12
938        assert_eq!(results[[0, 0]], 12.0); // ∂f1/∂x
939        assert_eq!(results[[0, 1]], 4.0); // ∂f1/∂y
940        assert_eq!(results[[1, 0]], 9.0); // ∂f2/∂x
941        assert_eq!(results[[1, 1]], 12.0); // ∂f2/∂y
942
943        Ok(())
944    }
945
946    #[test]
947    fn test_jacobian_wrt_single_variable() -> Result<(), Box<dyn std::error::Error>> {
948        let system = EquationSystem::new(vec![
949            "x^2*y".to_string(), // f1
950            "x*y^2".to_string(), // f2
951        ])?;
952
953        // Test with single variable
954        let jacobian_fn = system.jacobian_wrt(&["x"])?;
955        let mut results = Array2::zeros((2, 1));
956        jacobian_fn
957            .eval_into_matrix(&vec![2.0, 3.0], &mut results)
958            .unwrap();
959
960        // Expected derivatives:
961        // ∂f1/∂x = 2xy = 2(2)(3) = 12
962        // ∂f2/∂x = y^2 = 9
963        assert_eq!(results[[0, 0]], 12.0); // [∂f1/∂x]
964        assert_eq!(results[[1, 0]], 9.0); // [∂f2/∂x]
965
966        Ok(())
967    }
968
969    #[test]
970    fn test_jacobian_wrt_all_variables() -> Result<(), Box<dyn std::error::Error>> {
971        let system = EquationSystem::new(vec![
972            "x^2*y + z".to_string(),   // f1
973            "x*y^2 - z^2".to_string(), // f2
974        ])?;
975
976        // Test with all variables
977        let jacobian_fn = system.jacobian_wrt(&["x", "y", "z"])?;
978        let mut results = Array2::zeros((2, 3));
979        jacobian_fn
980            .eval_into_matrix(&vec![2.0, 3.0, 1.0], &mut results)
981            .unwrap();
982
983        // Expected derivatives:
984        // ∂f1/∂x = 2xy = 2(2)(3) = 12
985        // ∂f1/∂y = x^2 = 4
986        // ∂f1/∂z = 1
987        // ∂f2/∂x = y^2 = 9
988        // ∂f2/∂y = 2xy = 2(2)(3) = 12
989        // ∂f2/∂z = -2z = -2
990        assert_eq!(results[[0, 0]], 12.0); // ∂f1/∂x
991        assert_eq!(results[[0, 1]], 4.0); // ∂f1/∂y
992        assert_eq!(results[[0, 2]], 1.0); // ∂f1/∂z
993        assert_eq!(results[[1, 0]], 9.0); // ∂f2/∂x
994        assert_eq!(results[[1, 1]], 12.0); // ∂f2/∂y
995        assert_eq!(results[[1, 2]], -2.0); // ∂f2/∂z
996
997        Ok(())
998    }
999
1000    #[test]
1001    fn test_jacobian_wrt_invalid_variable() {
1002        let system =
1003            EquationSystem::new(vec!["x^2*y + z".to_string(), "x*y^2 - z^2".to_string()]).unwrap();
1004
1005        // Test with non-existent variable
1006        let result = system.jacobian_wrt(&["x", "w"]);
1007        assert!(matches!(result, Err(EquationError::VariableNotFound(_))));
1008    }
1009
1010    #[test]
1011    fn test_jacobian_wrt_reuse_buffer() -> Result<(), Box<dyn std::error::Error>> {
1012        let system = EquationSystem::new(vec![
1013            "x^2*y".to_string(), // f1
1014            "x*y^2".to_string(), // f2
1015        ])?;
1016
1017        let jacobian_fn = system.jacobian_wrt(&["x", "y"])?;
1018        let mut results = Array2::zeros((2, 2));
1019
1020        // First evaluation
1021        jacobian_fn
1022            .eval_into_matrix(&vec![2.0, 3.0], &mut results)
1023            .unwrap();
1024        assert_eq!(results[[0, 0]], 12.0);
1025        assert_eq!(results[[0, 1]], 4.0);
1026        assert_eq!(results[[1, 0]], 9.0);
1027        assert_eq!(results[[1, 1]], 12.0);
1028
1029        // Reuse buffer for second evaluation
1030        jacobian_fn
1031            .eval_into_matrix(&vec![1.0, 2.0], &mut results)
1032            .unwrap();
1033        assert_eq!(results[[0, 0]], 4.0);
1034        assert_eq!(results[[0, 1]], 1.0);
1035        assert_eq!(results[[1, 0]], 4.0);
1036        assert_eq!(results[[1, 1]], 4.0);
1037
1038        Ok(())
1039    }
1040
1041    #[test]
1042    fn test_different_vector_types() -> Result<(), Box<dyn std::error::Error>> {
1043        let system = EquationSystem::new(vec!["x^2*y".to_string(), "x*y^2".to_string()])?;
1044
1045        // Test with Vec<f64>
1046        let vec_inputs = vec![2.0, 3.0];
1047        let vec_results = system.eval(&vec_inputs)?;
1048        assert_eq!(vec_results.as_slice(), &[12.0, 18.0]);
1049
1050        // Test with ndarray::Array1
1051        let ndarray_inputs = Array1::from_vec(vec![2.0, 3.0]);
1052        let ndarray_results = system.eval(&ndarray_inputs)?;
1053        assert_eq!(ndarray_results.as_slice().unwrap(), &[12.0, 18.0]);
1054
1055        // Test with nalgebra::DVector
1056        let nalgebra_inputs = DVector::from_vec(vec![2.0, 3.0]);
1057        let nalgebra_results = system.eval(&nalgebra_inputs)?;
1058        assert_eq!(nalgebra_results.as_slice(), &[12.0, 18.0]);
1059
1060        Ok(())
1061    }
1062
1063    #[test]
1064    fn test_eval_parallel() -> Result<(), Box<dyn std::error::Error>> {
1065        let system = EquationSystem::new(vec!["x^2*y".to_string(), "x*y^2".to_string()])?;
1066
1067        let input_sets = vec![
1068            vec![2.0, 3.0],
1069            vec![1.0, 2.0],
1070            vec![3.0, 4.0],
1071            vec![0.0, 1.0],
1072        ];
1073
1074        let results = system.eval_parallel(&input_sets)?;
1075
1076        assert_eq!(results[0].as_slice(), &[12.0, 18.0]); // [2^2 * 3, 2 * 3^2]
1077        assert_eq!(results[1].as_slice(), &[2.0, 4.0]); // [1^2 * 2, 1 * 2^2]
1078        assert_eq!(results[2].as_slice(), &[36.0, 48.0]); // [3^2 * 4, 3 * 4^2]
1079        assert_eq!(results[3].as_slice(), &[0.0, 0.0]); // [0^2 * 1, 0 * 1^2]
1080
1081        Ok(())
1082    }
1083
1084    #[test]
1085    fn test_eval_into() -> Result<(), Box<dyn std::error::Error>> {
1086        let system = EquationSystem::new(vec!["x^2*y".to_string(), "x*y^2".to_string()])?;
1087
1088        // Test with Vec<f64>
1089        let mut results = vec![0.0; 2];
1090        system.eval_into(&vec![2.0, 3.0], &mut results)?;
1091        assert_eq!(results, vec![12.0, 18.0]);
1092
1093        // Test with ndarray
1094        let mut ndarray_results = Array1::zeros(2);
1095        system.eval_into(&Array1::from_vec(vec![2.0, 3.0]), &mut ndarray_results)?;
1096        assert_eq!(ndarray_results.as_slice().unwrap(), &[12.0, 18.0]);
1097
1098        // Test error case: wrong buffer size
1099        let mut wrong_size = vec![0.0; 3];
1100        assert!(matches!(
1101            system.eval_into(&vec![2.0, 3.0], &mut wrong_size),
1102            Err(EquationError::InvalidInputLength { .. })
1103        ));
1104
1105        Ok(())
1106    }
1107
1108    #[test]
1109    fn test_matrix_output_errors() -> Result<(), Box<dyn std::error::Error>> {
1110        let system = EquationSystem::new(vec!["x^2*y".to_string(), "x*y^2".to_string()])?;
1111
1112        // Regular vector system should error when trying to output as matrix
1113        let mut results = Array2::zeros((2, 2));
1114        assert!(matches!(
1115            system.eval_into_matrix(&vec![2.0, 3.0], &mut results),
1116            Err(EquationError::MatrixOutputRequired)
1117        ));
1118
1119        assert!(matches!(
1120            system.eval_matrix::<_, Array2<f64>>(&vec![2.0, 3.0]),
1121            Err(EquationError::MatrixOutputRequired)
1122        ));
1123
1124        Ok(())
1125    }
1126
1127    #[test]
1128    fn test_gradient() -> Result<(), Box<dyn std::error::Error>> {
1129        let system = EquationSystem::new(vec![
1130            "x^2*y + z".to_string(),   // f1
1131            "x*y^2 - z^2".to_string(), // f2
1132        ])?;
1133
1134        // Test gradient with respect to x
1135        let dx = system.gradient(&[2.0, 3.0, 1.0], "x")?;
1136        assert_eq!(dx, vec![12.0, 9.0]); // [∂f1/∂x = 2xy, ∂f2/∂x = y^2]
1137
1138        // Test gradient with respect to y
1139        let dy = system.gradient(&[2.0, 3.0, 1.0], "y")?;
1140        assert_eq!(dy, vec![4.0, 12.0]); // [∂f1/∂y = x^2, ∂f2/∂y = 2xy]
1141
1142        // Test gradient with respect to z
1143        let dz = system.gradient(&[2.0, 3.0, 1.0], "z")?;
1144        assert_eq!(dz, vec![1.0, -2.0]); // [∂f1/∂z = 1, ∂f2/∂z = -2z]
1145
1146        // Test error case: invalid variable
1147        let result = system.gradient(&[2.0, 3.0, 1.0], "w");
1148        assert!(matches!(result, Err(EquationError::VariableNotFound(_))));
1149
1150        // Test error case: wrong input length
1151        let result = system.gradient(&[2.0, 3.0], "x");
1152        assert!(matches!(
1153            result,
1154            Err(EquationError::InvalidInputLength { .. })
1155        ));
1156
1157        Ok(())
1158    }
1159
1160    #[test]
1161    fn test_eval_matrix_on_vector_system() -> Result<(), Box<dyn std::error::Error>> {
1162        let system = EquationSystem::new(vec!["x^2*y".to_string(), "x*y^2".to_string()])?;
1163
1164        // Attempt to evaluate as matrix should fail
1165        let mut results = Array2::zeros((2, 2));
1166        assert!(matches!(
1167            system.eval_into_matrix(&vec![2.0, 3.0], &mut results),
1168            Err(EquationError::MatrixOutputRequired)
1169        ));
1170
1171        // Direct matrix evaluation should also fail
1172        assert!(matches!(
1173            system.eval_matrix::<_, Array2<f64>>(&vec![2.0, 3.0]),
1174            Err(EquationError::MatrixOutputRequired)
1175        ));
1176
1177        Ok(())
1178    }
1179
1180    #[test]
1181    fn test_getters() -> Result<(), Box<dyn std::error::Error>> {
1182        let system = EquationSystem::new(vec!["x^2*y".to_string(), "x*y^2".to_string()])?;
1183
1184        // Test sorted_variables()
1185        assert_eq!(system.sorted_variables(), &["x", "y"]);
1186
1187        // Test variables()
1188        let var_map = system.variables();
1189        assert_eq!(var_map.get("x"), Some(&0));
1190        assert_eq!(var_map.get("y"), Some(&1));
1191
1192        // Test equations()
1193        assert_eq!(system.equations(), &["x^2*y", "x*y^2"]);
1194
1195        // Test fun() returns a valid function
1196        let fun = system.fun();
1197        let mut results = vec![0.0; 2];
1198        fun(&[2.0, 3.0], &mut results);
1199        assert_eq!(results, vec![12.0, 18.0]);
1200
1201        // Test jacobian_funs() returns valid derivative functions
1202        let jacobian_funs = system.jacobian_funs();
1203        assert!(jacobian_funs.contains_key("x"));
1204        assert!(jacobian_funs.contains_key("y"));
1205
1206        // Test gradient_fun() returns valid derivative function
1207        let dx_fun = system.gradient_fun("x");
1208        let mut dx_results = vec![0.0; 2];
1209        dx_fun(&[2.0, 3.0], &mut dx_results);
1210        assert_eq!(dx_results, vec![12.0, 9.0]); // [∂(x^2*y)/∂x, ∂(x*y^2)/∂x]
1211
1212        // Test num_equations()
1213        assert_eq!(system.num_equations(), 2);
1214
1215        Ok(())
1216    }
1217}