use std::fmt::Write;
use scirs2_core::ndarray::{Array1, Array2};
use scirs2_core::Complex64;
use crate::error::{SymEngineError, SymEngineResult};
use crate::expr::Expression;
fn parse_cell(s: &str) -> Result<Complex64, SymEngineError> {
let s = s.trim();
let s = if s.starts_with('(') && s.ends_with(')') {
&s[1..s.len() - 1]
} else {
s
};
if let Some(without_i) = s.strip_suffix("*I") {
if let Some(plus_pos) = find_split_plus(without_i) {
let re_str = &without_i[..plus_pos];
let im_str = &without_i[plus_pos + 1..];
let re = re_str
.trim()
.parse::<f64>()
.map_err(|_| SymEngineError::parse(format!("cannot parse real part: {re_str}")))?;
let im = im_str.trim().parse::<f64>().map_err(|_| {
SymEngineError::parse(format!("cannot parse imaginary coefficient: {im_str}"))
})?;
return Ok(Complex64::new(re, im));
}
let im = without_i.trim().parse::<f64>().map_err(|_| {
SymEngineError::parse(format!("cannot parse imaginary coefficient: {without_i}"))
})?;
return Ok(Complex64::new(0.0, im));
}
let re = s
.parse::<f64>()
.map_err(|_| SymEngineError::parse(format!("cannot parse cell value: {s}")))?;
Ok(Complex64::new(re, 0.0))
}
fn find_split_plus(s: &str) -> Option<usize> {
let bytes = s.as_bytes();
for i in 1..bytes.len() {
if bytes[i] == b'+' {
let prev = bytes[i - 1];
if prev == b'e' || prev == b'E' {
continue;
}
return Some(i);
}
}
None
}
fn parse_matrix_expr(expr: &Expression) -> SymEngineResult<Vec<Vec<Complex64>>> {
let raw = expr
.as_symbol()
.ok_or_else(|| SymEngineError::parse("expression is not a matrix symbol"))?;
let inner = if raw.starts_with("Matrix(") && raw.ends_with(')') {
&raw["Matrix(".len()..raw.len() - 1]
} else {
raw
};
let inner = inner.trim();
if !inner.starts_with('[') || !inner.ends_with(']') {
return Err(SymEngineError::parse(format!(
"expected outer '[...]' in matrix expression, got: {inner}"
)));
}
let inner = &inner[1..inner.len() - 1];
let rows_strs = split_rows(inner);
let mut rows: Vec<Vec<Complex64>> = Vec::with_capacity(rows_strs.len());
for row_str in rows_strs {
let row_str = row_str.trim();
if !row_str.starts_with('[') || !row_str.ends_with(']') {
return Err(SymEngineError::parse(format!(
"expected row '[...]', got: {row_str}"
)));
}
let cells_str = &row_str[1..row_str.len() - 1];
let cells = split_cells(cells_str);
let row: Vec<Complex64> = cells
.iter()
.map(|c| parse_cell(c.trim()))
.collect::<Result<_, _>>()?;
rows.push(row);
}
Ok(rows)
}
fn split_rows(s: &str) -> Vec<&str> {
let mut parts = Vec::new();
let mut depth: usize = 0;
let mut start: usize = 0;
let bytes = s.as_bytes();
for (i, &b) in bytes.iter().enumerate() {
match b {
b'[' => {
if depth == 0 {
start = i;
}
depth += 1;
}
b']' => {
depth = depth.saturating_sub(1);
if depth == 0 {
parts.push(&s[start..=i]);
}
}
_ => {}
}
}
parts
}
fn split_cells(s: &str) -> Vec<&str> {
let mut parts = Vec::new();
let mut depth: usize = 0;
let mut start: usize = 0;
let bytes = s.as_bytes();
for (i, &b) in bytes.iter().enumerate() {
match b {
b'(' => depth += 1,
b')' => depth = depth.saturating_sub(1),
b',' if depth == 0 => {
parts.push(&s[start..i]);
start = i + 1;
}
_ => {}
}
}
parts.push(&s[start..]);
parts
}
pub fn to_array2(
expr: &Expression,
_values: &std::collections::HashMap<String, f64>,
) -> SymEngineResult<Array2<Complex64>> {
let rows = parse_matrix_expr(expr)?;
if rows.is_empty() {
return Ok(Array2::zeros((0, 0)));
}
let nrows = rows.len();
let ncols = rows[0].len();
for (i, row) in rows.iter().enumerate() {
if row.len() != ncols {
return Err(SymEngineError::dimension(format!(
"row {i} has {} columns, expected {ncols}",
row.len()
)));
}
}
let flat: Vec<Complex64> = rows.into_iter().flatten().collect();
Array2::from_shape_vec((nrows, ncols), flat)
.map_err(|e| SymEngineError::dimension(e.to_string()))
}
pub fn from_array2(arr: &Array2<Complex64>) -> Expression {
let (rows, cols) = arr.dim();
let mut matrix_str = String::from("Matrix([");
for i in 0..rows {
matrix_str.push('[');
for j in 0..cols {
let c = arr[[i, j]];
if c.im.abs() < 1e-15 {
let _ = write!(matrix_str, "{}", c.re);
} else if c.re.abs() < 1e-15 {
let _ = write!(matrix_str, "{}*I", c.im);
} else {
let _ = write!(matrix_str, "({}+{}*I)", c.re, c.im);
}
if j < cols - 1 {
matrix_str.push_str(", ");
}
}
matrix_str.push(']');
if i < rows - 1 {
matrix_str.push_str(", ");
}
}
matrix_str.push_str("])");
Expression::new(matrix_str)
}
pub fn to_array1(
expr: &Expression,
_values: &std::collections::HashMap<String, f64>,
) -> SymEngineResult<Array1<Complex64>> {
let rows = parse_matrix_expr(expr)?;
let flat: Vec<Complex64> = rows
.into_iter()
.enumerate()
.map(|(i, row)| {
if row.len() == 1 {
Ok(row[0])
} else {
Err(SymEngineError::dimension(format!(
"row {i} has {} cells; expected 1 for Array1 conversion",
row.len()
)))
}
})
.collect::<Result<_, _>>()?;
Ok(Array1::from_vec(flat))
}
pub fn from_array1(arr: &Array1<Complex64>) -> Expression {
let n = arr.len();
let mut matrix_str = String::from("Matrix([");
for (i, c) in arr.iter().enumerate() {
matrix_str.push('[');
if c.im.abs() < 1e-15 {
let _ = write!(matrix_str, "{}", c.re);
} else if c.re.abs() < 1e-15 {
let _ = write!(matrix_str, "{}*I", c.im);
} else {
let _ = write!(matrix_str, "({}+{}*I)", c.re, c.im);
}
matrix_str.push(']');
if i < n - 1 {
matrix_str.push_str(", ");
}
}
matrix_str.push_str("])");
Expression::new(matrix_str)
}
pub fn gradient_array(
expr: &Expression,
params: &[Expression],
values: &std::collections::HashMap<String, f64>,
) -> SymEngineResult<Array1<f64>> {
let grad_vec = crate::optimization::gradient_at(expr, params, values)?;
Ok(Array1::from_vec(grad_vec))
}
pub fn hessian_array(
expr: &Expression,
params: &[Expression],
values: &std::collections::HashMap<String, f64>,
) -> SymEngineResult<Array2<f64>> {
let hess_vec = crate::optimization::hessian_at(expr, params, values)?;
let n = params.len();
let mut arr = Array2::zeros((n, n));
for (i, row) in hess_vec.iter().enumerate() {
for (j, &val) in row.iter().enumerate() {
arr[[i, j]] = val;
}
}
Ok(arr)
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::array;
use std::collections::HashMap;
fn no_values() -> HashMap<String, f64> {
HashMap::new()
}
#[test]
fn test_from_array2() {
let arr: Array2<Complex64> = array![
[Complex64::new(1.0, 0.0), Complex64::new(0.0, 1.0)],
[Complex64::new(0.0, -1.0), Complex64::new(1.0, 0.0)],
];
let expr = from_array2(&arr);
assert!(expr.to_string().contains("Matrix"));
}
#[test]
fn test_from_array1() {
let arr: Array1<Complex64> = array![Complex64::new(1.0, 0.0), Complex64::new(0.0, 1.0),];
let expr = from_array1(&arr);
assert!(expr.to_string().contains("Matrix"));
}
#[test]
fn test_gradient_array() {
let x = Expression::symbol("x");
let expr = x.clone() * x.clone(); let params = vec![x];
let mut values = std::collections::HashMap::new();
values.insert("x".to_string(), 3.0);
let grad = gradient_array(&expr, ¶ms, &values).expect("should compute");
assert!((grad[0] - 6.0).abs() < 1e-6); }
#[test]
fn test_to_array1_real() {
let src: Array1<Complex64> = array![
Complex64::new(1.0, 0.0),
Complex64::new(2.0, 0.0),
Complex64::new(3.0, 0.0),
];
let expr = from_array1(&src);
let arr = to_array1(&expr, &no_values()).expect("to_array1 should succeed");
assert_eq!(arr.len(), 3);
assert!((arr[0].re - 1.0).abs() < 1e-10);
assert!((arr[1].re - 2.0).abs() < 1e-10);
assert!((arr[2].re - 3.0).abs() < 1e-10);
}
#[test]
fn test_to_array1_complex() {
let src: Array1<Complex64> = array![
Complex64::new(1.0, 2.0),
Complex64::new(0.0, 3.0),
Complex64::new(4.0, 0.0),
];
let expr = from_array1(&src);
let arr = to_array1(&expr, &no_values()).expect("to_array1 complex should succeed");
assert_eq!(arr.len(), 3);
assert!((arr[0].re - 1.0).abs() < 1e-10);
assert!((arr[0].im - 2.0).abs() < 1e-10);
assert!((arr[1].re - 0.0).abs() < 1e-10);
assert!((arr[1].im - 3.0).abs() < 1e-10);
}
#[test]
fn test_to_array2_2x2_real() {
let src: Array2<Complex64> = array![
[Complex64::new(1.0, 0.0), Complex64::new(2.0, 0.0)],
[Complex64::new(3.0, 0.0), Complex64::new(4.0, 0.0)],
];
let expr = from_array2(&src);
let arr = to_array2(&expr, &no_values()).expect("to_array2 should succeed");
assert_eq!(arr.shape(), &[2, 2]);
assert!((arr[[0, 0]].re - 1.0).abs() < 1e-10);
assert!((arr[[0, 1]].re - 2.0).abs() < 1e-10);
assert!((arr[[1, 0]].re - 3.0).abs() < 1e-10);
assert!((arr[[1, 1]].re - 4.0).abs() < 1e-10);
}
#[test]
fn test_to_array2_2x2_complex() {
let src: Array2<Complex64> = array![
[Complex64::new(1.0, 0.0), Complex64::new(0.0, 1.0)],
[Complex64::new(0.0, -1.0), Complex64::new(1.0, 0.0)],
];
let expr = from_array2(&src);
let arr = to_array2(&expr, &no_values()).expect("to_array2 complex should succeed");
assert_eq!(arr.shape(), &[2, 2]);
assert!((arr[[0, 1]].re - 0.0).abs() < 1e-10);
assert!((arr[[0, 1]].im - 1.0).abs() < 1e-10);
assert!((arr[[1, 0]].re - 0.0).abs() < 1e-10);
assert!((arr[[1, 0]].im - (-1.0)).abs() < 1e-10);
}
#[test]
fn test_to_array2_general_complex() {
let src: Array2<Complex64> = array![[Complex64::new(3.0, 4.0)]];
let expr = from_array2(&src);
let arr = to_array2(&expr, &no_values()).expect("to_array2 general complex should succeed");
assert_eq!(arr.shape(), &[1, 1]);
assert!((arr[[0, 0]].re - 3.0).abs() < 1e-10);
assert!((arr[[0, 0]].im - 4.0).abs() < 1e-10);
}
#[test]
fn test_to_array2_negative_imaginary() {
let src: Array2<Complex64> = array![[Complex64::new(2.0, -3.0)]];
let expr = from_array2(&src);
let arr =
to_array2(&expr, &no_values()).expect("to_array2 negative imaginary should succeed");
assert_eq!(arr.shape(), &[1, 1]);
assert!((arr[[0, 0]].re - 2.0).abs() < 1e-10);
assert!((arr[[0, 0]].im - (-3.0)).abs() < 1e-10);
}
}