Skip to main content

quantrs2_symengine_pure/scirs2_bridge/
ndarray.rs

1//! Ndarray integration with SciRS2.
2//!
3//! This module provides conversion between symbolic matrices and
4//! SciRS2's ndarray types.
5
6use std::fmt::Write;
7
8use scirs2_core::ndarray::{Array1, Array2};
9use scirs2_core::Complex64;
10
11use crate::error::{SymEngineError, SymEngineResult};
12use crate::expr::Expression;
13
14/// Parse a single matrix cell string (as produced by `from_array2`) into a `Complex64`.
15///
16/// Recognises the three forms emitted by `from_array2`:
17/// - `"{re}"` — pure real
18/// - `"{im}*I"` — pure imaginary
19/// - `"({re}+{im}*I)"` — general (note: negative imaginary looks like `(1+-2*I)`)
20fn parse_cell(s: &str) -> Result<Complex64, SymEngineError> {
21    let s = s.trim();
22
23    // Strip optional outer parentheses — the general format is `({re}+{im}*I)`.
24    let s = if s.starts_with('(') && s.ends_with(')') {
25        &s[1..s.len() - 1]
26    } else {
27        s
28    };
29
30    // Determine form by checking whether the string ends with "*I".
31    if let Some(without_i) = s.strip_suffix("*I") {
32        // Could be:
33        //   pure imaginary  "{im}*I"        → without_i has no '+' (except in exponent)
34        //   general complex "{re}+{im}*I"   → without_i contains a '+' split point
35        if let Some(plus_pos) = find_split_plus(without_i) {
36            // General complex: re = without_i[..plus_pos], im = without_i[plus_pos+1..]
37            let re_str = &without_i[..plus_pos];
38            let im_str = &without_i[plus_pos + 1..];
39            let re = re_str
40                .trim()
41                .parse::<f64>()
42                .map_err(|_| SymEngineError::parse(format!("cannot parse real part: {re_str}")))?;
43            let im = im_str.trim().parse::<f64>().map_err(|_| {
44                SymEngineError::parse(format!("cannot parse imaginary coefficient: {im_str}"))
45            })?;
46            return Ok(Complex64::new(re, im));
47        }
48        // Pure imaginary: no '+' separator found
49        let im = without_i.trim().parse::<f64>().map_err(|_| {
50            SymEngineError::parse(format!("cannot parse imaginary coefficient: {without_i}"))
51        })?;
52        return Ok(Complex64::new(0.0, im));
53    }
54
55    // Pure real fallback
56    let re = s
57        .parse::<f64>()
58        .map_err(|_| SymEngineError::parse(format!("cannot parse cell value: {s}")))?;
59    Ok(Complex64::new(re, 0.0))
60}
61
62/// Find the position of the '+' that separates a real part from an imaginary part.
63///
64/// We scan left-to-right and stop at a '+' that is not preceded by 'e' or 'E'
65/// (to avoid splitting scientific notation like `1e+10`), and that appears after
66/// at least one digit/dot character.
67fn find_split_plus(s: &str) -> Option<usize> {
68    let bytes = s.as_bytes();
69    // We skip index 0 — the real part must have at least one character before '+'.
70    for i in 1..bytes.len() {
71        if bytes[i] == b'+' {
72            // Exclude exponent markers in scientific notation
73            let prev = bytes[i - 1];
74            if prev == b'e' || prev == b'E' {
75                continue;
76            }
77            return Some(i);
78        }
79    }
80    None
81}
82
83/// Parse the string representation of a symbolic matrix expression into a
84/// `Vec<Vec<Complex64>>` row-major matrix.
85///
86/// Accepts the format produced by [`from_array2`] / [`from_array1`]:
87///
88/// ```text
89/// Matrix([[cell, cell, ...], [cell, ...], ...])
90/// ```
91fn parse_matrix_expr(expr: &Expression) -> SymEngineResult<Vec<Vec<Complex64>>> {
92    let raw = expr
93        .as_symbol()
94        .ok_or_else(|| SymEngineError::parse("expression is not a matrix symbol"))?;
95
96    // Strip optional "Matrix(" prefix and matching ")"
97    let inner = if raw.starts_with("Matrix(") && raw.ends_with(')') {
98        &raw["Matrix(".len()..raw.len() - 1]
99    } else {
100        raw
101    };
102
103    // Expect outer "[...]"
104    let inner = inner.trim();
105    if !inner.starts_with('[') || !inner.ends_with(']') {
106        return Err(SymEngineError::parse(format!(
107            "expected outer '[...]' in matrix expression, got: {inner}"
108        )));
109    }
110    let inner = &inner[1..inner.len() - 1];
111
112    // Split into row strings by scanning bracket nesting
113    let rows_strs = split_rows(inner);
114
115    let mut rows: Vec<Vec<Complex64>> = Vec::with_capacity(rows_strs.len());
116    for row_str in rows_strs {
117        let row_str = row_str.trim();
118        if !row_str.starts_with('[') || !row_str.ends_with(']') {
119            return Err(SymEngineError::parse(format!(
120                "expected row '[...]', got: {row_str}"
121            )));
122        }
123        let cells_str = &row_str[1..row_str.len() - 1];
124        let cells = split_cells(cells_str);
125        let row: Vec<Complex64> = cells
126            .iter()
127            .map(|c| parse_cell(c.trim()))
128            .collect::<Result<_, _>>()?;
129        rows.push(row);
130    }
131
132    Ok(rows)
133}
134
135/// Split the contents of the outer `[...]` into individual `[row]` strings.
136///
137/// We track bracket depth so that nested `[cell]` groups are handled correctly.
138fn split_rows(s: &str) -> Vec<&str> {
139    let mut parts = Vec::new();
140    let mut depth: usize = 0;
141    let mut start: usize = 0;
142    let bytes = s.as_bytes();
143
144    for (i, &b) in bytes.iter().enumerate() {
145        match b {
146            b'[' => {
147                if depth == 0 {
148                    start = i;
149                }
150                depth += 1;
151            }
152            b']' => {
153                depth = depth.saturating_sub(1);
154                if depth == 0 {
155                    parts.push(&s[start..=i]);
156                }
157            }
158            _ => {}
159        }
160    }
161
162    parts
163}
164
165/// Split a flat cell list (contents between `[` and `]` of a row) by commas,
166/// respecting nested parentheses so that `(1+-2*I)` is not split.
167fn split_cells(s: &str) -> Vec<&str> {
168    let mut parts = Vec::new();
169    let mut depth: usize = 0;
170    let mut start: usize = 0;
171    let bytes = s.as_bytes();
172
173    for (i, &b) in bytes.iter().enumerate() {
174        match b {
175            b'(' => depth += 1,
176            b')' => depth = depth.saturating_sub(1),
177            b',' if depth == 0 => {
178                parts.push(&s[start..i]);
179                start = i + 1;
180            }
181            _ => {}
182        }
183    }
184    // Push the final segment
185    parts.push(&s[start..]);
186    parts
187}
188
189/// Convert a symbolic matrix expression to a numeric `Array2<Complex64>`.
190///
191/// The expression is expected to be in the format produced by [`from_array2`],
192/// i.e. `Matrix([[cell, cell, ...], [cell, ...], ...])`.
193///
194/// The `values` map is accepted for API uniformity but the matrix representation
195/// already contains fully evaluated numeric cells; symbolic cells are not currently
196/// supported.
197///
198/// # Errors
199/// Returns an error if the expression is not a matrix symbol or cell parsing fails.
200pub fn to_array2(
201    expr: &Expression,
202    _values: &std::collections::HashMap<String, f64>,
203) -> SymEngineResult<Array2<Complex64>> {
204    let rows = parse_matrix_expr(expr)?;
205
206    if rows.is_empty() {
207        return Ok(Array2::zeros((0, 0)));
208    }
209
210    let nrows = rows.len();
211    let ncols = rows[0].len();
212
213    // Validate uniform column count
214    for (i, row) in rows.iter().enumerate() {
215        if row.len() != ncols {
216            return Err(SymEngineError::dimension(format!(
217                "row {i} has {} columns, expected {ncols}",
218                row.len()
219            )));
220        }
221    }
222
223    let flat: Vec<Complex64> = rows.into_iter().flatten().collect();
224    Array2::from_shape_vec((nrows, ncols), flat)
225        .map_err(|e| SymEngineError::dimension(e.to_string()))
226}
227
228/// Convert a numeric `Array2<Complex64>` to a symbolic matrix expression.
229pub fn from_array2(arr: &Array2<Complex64>) -> Expression {
230    let (rows, cols) = arr.dim();
231
232    let mut matrix_str = String::from("Matrix([");
233
234    for i in 0..rows {
235        matrix_str.push('[');
236        for j in 0..cols {
237            let c = arr[[i, j]];
238            if c.im.abs() < 1e-15 {
239                let _ = write!(matrix_str, "{}", c.re);
240            } else if c.re.abs() < 1e-15 {
241                let _ = write!(matrix_str, "{}*I", c.im);
242            } else {
243                let _ = write!(matrix_str, "({}+{}*I)", c.re, c.im);
244            }
245            if j < cols - 1 {
246                matrix_str.push_str(", ");
247            }
248        }
249        matrix_str.push(']');
250        if i < rows - 1 {
251            matrix_str.push_str(", ");
252        }
253    }
254
255    matrix_str.push_str("])");
256
257    Expression::new(matrix_str)
258}
259
260/// Convert a symbolic vector expression to a numeric `Array1<Complex64>`.
261///
262/// The expression is expected to be in the format produced by [`from_array1`],
263/// i.e. a column-vector matrix `Matrix([[c1], [c2], ...])`.  Each row must
264/// contain exactly one cell.
265///
266/// The `values` map is accepted for API uniformity (see [`to_array2`]).
267///
268/// # Errors
269/// Returns an error if the expression is not a matrix symbol or cell parsing fails.
270pub fn to_array1(
271    expr: &Expression,
272    _values: &std::collections::HashMap<String, f64>,
273) -> SymEngineResult<Array1<Complex64>> {
274    let rows = parse_matrix_expr(expr)?;
275
276    let flat: Vec<Complex64> = rows
277        .into_iter()
278        .enumerate()
279        .map(|(i, row)| {
280            if row.len() == 1 {
281                Ok(row[0])
282            } else {
283                Err(SymEngineError::dimension(format!(
284                    "row {i} has {} cells; expected 1 for Array1 conversion",
285                    row.len()
286                )))
287            }
288        })
289        .collect::<Result<_, _>>()?;
290
291    Ok(Array1::from_vec(flat))
292}
293
294/// Convert a numeric `Array1<Complex64>` to a symbolic column vector expression.
295pub fn from_array1(arr: &Array1<Complex64>) -> Expression {
296    let n = arr.len();
297
298    let mut matrix_str = String::from("Matrix([");
299
300    for (i, c) in arr.iter().enumerate() {
301        matrix_str.push('[');
302        if c.im.abs() < 1e-15 {
303            let _ = write!(matrix_str, "{}", c.re);
304        } else if c.re.abs() < 1e-15 {
305            let _ = write!(matrix_str, "{}*I", c.im);
306        } else {
307            let _ = write!(matrix_str, "({}+{}*I)", c.re, c.im);
308        }
309        matrix_str.push(']');
310        if i < n - 1 {
311            matrix_str.push_str(", ");
312        }
313    }
314
315    matrix_str.push_str("])");
316
317    Expression::new(matrix_str)
318}
319
320/// Compute the gradient at given values as an `Array1<f64>`.
321///
322/// This is useful for integration with SciRS2 optimization routines.
323pub fn gradient_array(
324    expr: &Expression,
325    params: &[Expression],
326    values: &std::collections::HashMap<String, f64>,
327) -> SymEngineResult<Array1<f64>> {
328    let grad_vec = crate::optimization::gradient_at(expr, params, values)?;
329    Ok(Array1::from_vec(grad_vec))
330}
331
332/// Compute the Hessian at given values as an `Array2<f64>`.
333///
334/// This is useful for integration with SciRS2 optimization routines.
335pub fn hessian_array(
336    expr: &Expression,
337    params: &[Expression],
338    values: &std::collections::HashMap<String, f64>,
339) -> SymEngineResult<Array2<f64>> {
340    let hess_vec = crate::optimization::hessian_at(expr, params, values)?;
341    let n = params.len();
342    let mut arr = Array2::zeros((n, n));
343
344    for (i, row) in hess_vec.iter().enumerate() {
345        for (j, &val) in row.iter().enumerate() {
346            arr[[i, j]] = val;
347        }
348    }
349
350    Ok(arr)
351}
352
353#[cfg(test)]
354mod tests {
355    use super::*;
356    use scirs2_core::ndarray::array;
357    use std::collections::HashMap;
358
359    /// Helper: build a values map (empty, since our matrices are fully numeric).
360    fn no_values() -> HashMap<String, f64> {
361        HashMap::new()
362    }
363
364    #[test]
365    fn test_from_array2() {
366        let arr: Array2<Complex64> = array![
367            [Complex64::new(1.0, 0.0), Complex64::new(0.0, 1.0)],
368            [Complex64::new(0.0, -1.0), Complex64::new(1.0, 0.0)],
369        ];
370
371        let expr = from_array2(&arr);
372        // Matrix expressions are stored as symbolic strings
373        assert!(expr.to_string().contains("Matrix"));
374    }
375
376    #[test]
377    fn test_from_array1() {
378        let arr: Array1<Complex64> = array![Complex64::new(1.0, 0.0), Complex64::new(0.0, 1.0),];
379
380        let expr = from_array1(&arr);
381        // Vector expressions are stored as symbolic matrix strings
382        assert!(expr.to_string().contains("Matrix"));
383    }
384
385    #[test]
386    fn test_gradient_array() {
387        let x = Expression::symbol("x");
388        let expr = x.clone() * x.clone(); // x^2
389        let params = vec![x];
390
391        let mut values = std::collections::HashMap::new();
392        values.insert("x".to_string(), 3.0);
393
394        let grad = gradient_array(&expr, &params, &values).expect("should compute");
395        assert!((grad[0] - 6.0).abs() < 1e-6); // d/dx(x^2) = 2x = 6 at x=3
396    }
397
398    // =========================================================================
399    // to_array1 / to_array2 round-trip tests
400    // =========================================================================
401
402    #[test]
403    fn test_to_array1_real() {
404        // Build a column-vector expression via from_array1 then round-trip through to_array1.
405        let src: Array1<Complex64> = array![
406            Complex64::new(1.0, 0.0),
407            Complex64::new(2.0, 0.0),
408            Complex64::new(3.0, 0.0),
409        ];
410        let expr = from_array1(&src);
411        let arr = to_array1(&expr, &no_values()).expect("to_array1 should succeed");
412        assert_eq!(arr.len(), 3);
413        assert!((arr[0].re - 1.0).abs() < 1e-10);
414        assert!((arr[1].re - 2.0).abs() < 1e-10);
415        assert!((arr[2].re - 3.0).abs() < 1e-10);
416    }
417
418    #[test]
419    fn test_to_array1_complex() {
420        let src: Array1<Complex64> = array![
421            Complex64::new(1.0, 2.0),
422            Complex64::new(0.0, 3.0),
423            Complex64::new(4.0, 0.0),
424        ];
425        let expr = from_array1(&src);
426        let arr = to_array1(&expr, &no_values()).expect("to_array1 complex should succeed");
427        assert_eq!(arr.len(), 3);
428        assert!((arr[0].re - 1.0).abs() < 1e-10);
429        assert!((arr[0].im - 2.0).abs() < 1e-10);
430        assert!((arr[1].re - 0.0).abs() < 1e-10);
431        assert!((arr[1].im - 3.0).abs() < 1e-10);
432    }
433
434    #[test]
435    fn test_to_array2_2x2_real() {
436        // Round-trip: from_array2 → Expression → to_array2
437        let src: Array2<Complex64> = array![
438            [Complex64::new(1.0, 0.0), Complex64::new(2.0, 0.0)],
439            [Complex64::new(3.0, 0.0), Complex64::new(4.0, 0.0)],
440        ];
441        let expr = from_array2(&src);
442        let arr = to_array2(&expr, &no_values()).expect("to_array2 should succeed");
443        assert_eq!(arr.shape(), &[2, 2]);
444        assert!((arr[[0, 0]].re - 1.0).abs() < 1e-10);
445        assert!((arr[[0, 1]].re - 2.0).abs() < 1e-10);
446        assert!((arr[[1, 0]].re - 3.0).abs() < 1e-10);
447        assert!((arr[[1, 1]].re - 4.0).abs() < 1e-10);
448    }
449
450    #[test]
451    fn test_to_array2_2x2_complex() {
452        let src: Array2<Complex64> = array![
453            [Complex64::new(1.0, 0.0), Complex64::new(0.0, 1.0)],
454            [Complex64::new(0.0, -1.0), Complex64::new(1.0, 0.0)],
455        ];
456        let expr = from_array2(&src);
457        let arr = to_array2(&expr, &no_values()).expect("to_array2 complex should succeed");
458        assert_eq!(arr.shape(), &[2, 2]);
459        // (0,1) should be pure imaginary 0+1i
460        assert!((arr[[0, 1]].re - 0.0).abs() < 1e-10);
461        assert!((arr[[0, 1]].im - 1.0).abs() < 1e-10);
462        // (1,0) should be pure imaginary 0-1i
463        assert!((arr[[1, 0]].re - 0.0).abs() < 1e-10);
464        assert!((arr[[1, 0]].im - (-1.0)).abs() < 1e-10);
465    }
466
467    #[test]
468    fn test_to_array2_general_complex() {
469        let src: Array2<Complex64> = array![[Complex64::new(3.0, 4.0)]];
470        let expr = from_array2(&src);
471        let arr = to_array2(&expr, &no_values()).expect("to_array2 general complex should succeed");
472        assert_eq!(arr.shape(), &[1, 1]);
473        assert!((arr[[0, 0]].re - 3.0).abs() < 1e-10);
474        assert!((arr[[0, 0]].im - 4.0).abs() < 1e-10);
475    }
476
477    #[test]
478    fn test_to_array2_negative_imaginary() {
479        // Negative imaginary: from_array2 emits "(2+-3*I)" for Complex(2, -3)
480        let src: Array2<Complex64> = array![[Complex64::new(2.0, -3.0)]];
481        let expr = from_array2(&src);
482        let arr =
483            to_array2(&expr, &no_values()).expect("to_array2 negative imaginary should succeed");
484        assert_eq!(arr.shape(), &[1, 1]);
485        assert!((arr[[0, 0]].re - 2.0).abs() < 1e-10);
486        assert!((arr[[0, 0]].im - (-3.0)).abs() < 1e-10);
487    }
488}