quantrs2_symengine_pure/scirs2_bridge/
ndarray.rs1use std::fmt::Write;
7
8use scirs2_core::ndarray::{Array1, Array2};
9use scirs2_core::Complex64;
10
11use crate::error::{SymEngineError, SymEngineResult};
12use crate::expr::Expression;
13
14fn parse_cell(s: &str) -> Result<Complex64, SymEngineError> {
21 let s = s.trim();
22
23 let s = if s.starts_with('(') && s.ends_with(')') {
25 &s[1..s.len() - 1]
26 } else {
27 s
28 };
29
30 if let Some(without_i) = s.strip_suffix("*I") {
32 if let Some(plus_pos) = find_split_plus(without_i) {
36 let re_str = &without_i[..plus_pos];
38 let im_str = &without_i[plus_pos + 1..];
39 let re = re_str
40 .trim()
41 .parse::<f64>()
42 .map_err(|_| SymEngineError::parse(format!("cannot parse real part: {re_str}")))?;
43 let im = im_str.trim().parse::<f64>().map_err(|_| {
44 SymEngineError::parse(format!("cannot parse imaginary coefficient: {im_str}"))
45 })?;
46 return Ok(Complex64::new(re, im));
47 }
48 let im = without_i.trim().parse::<f64>().map_err(|_| {
50 SymEngineError::parse(format!("cannot parse imaginary coefficient: {without_i}"))
51 })?;
52 return Ok(Complex64::new(0.0, im));
53 }
54
55 let re = s
57 .parse::<f64>()
58 .map_err(|_| SymEngineError::parse(format!("cannot parse cell value: {s}")))?;
59 Ok(Complex64::new(re, 0.0))
60}
61
62fn find_split_plus(s: &str) -> Option<usize> {
68 let bytes = s.as_bytes();
69 for i in 1..bytes.len() {
71 if bytes[i] == b'+' {
72 let prev = bytes[i - 1];
74 if prev == b'e' || prev == b'E' {
75 continue;
76 }
77 return Some(i);
78 }
79 }
80 None
81}
82
83fn parse_matrix_expr(expr: &Expression) -> SymEngineResult<Vec<Vec<Complex64>>> {
92 let raw = expr
93 .as_symbol()
94 .ok_or_else(|| SymEngineError::parse("expression is not a matrix symbol"))?;
95
96 let inner = if raw.starts_with("Matrix(") && raw.ends_with(')') {
98 &raw["Matrix(".len()..raw.len() - 1]
99 } else {
100 raw
101 };
102
103 let inner = inner.trim();
105 if !inner.starts_with('[') || !inner.ends_with(']') {
106 return Err(SymEngineError::parse(format!(
107 "expected outer '[...]' in matrix expression, got: {inner}"
108 )));
109 }
110 let inner = &inner[1..inner.len() - 1];
111
112 let rows_strs = split_rows(inner);
114
115 let mut rows: Vec<Vec<Complex64>> = Vec::with_capacity(rows_strs.len());
116 for row_str in rows_strs {
117 let row_str = row_str.trim();
118 if !row_str.starts_with('[') || !row_str.ends_with(']') {
119 return Err(SymEngineError::parse(format!(
120 "expected row '[...]', got: {row_str}"
121 )));
122 }
123 let cells_str = &row_str[1..row_str.len() - 1];
124 let cells = split_cells(cells_str);
125 let row: Vec<Complex64> = cells
126 .iter()
127 .map(|c| parse_cell(c.trim()))
128 .collect::<Result<_, _>>()?;
129 rows.push(row);
130 }
131
132 Ok(rows)
133}
134
135fn split_rows(s: &str) -> Vec<&str> {
139 let mut parts = Vec::new();
140 let mut depth: usize = 0;
141 let mut start: usize = 0;
142 let bytes = s.as_bytes();
143
144 for (i, &b) in bytes.iter().enumerate() {
145 match b {
146 b'[' => {
147 if depth == 0 {
148 start = i;
149 }
150 depth += 1;
151 }
152 b']' => {
153 depth = depth.saturating_sub(1);
154 if depth == 0 {
155 parts.push(&s[start..=i]);
156 }
157 }
158 _ => {}
159 }
160 }
161
162 parts
163}
164
165fn split_cells(s: &str) -> Vec<&str> {
168 let mut parts = Vec::new();
169 let mut depth: usize = 0;
170 let mut start: usize = 0;
171 let bytes = s.as_bytes();
172
173 for (i, &b) in bytes.iter().enumerate() {
174 match b {
175 b'(' => depth += 1,
176 b')' => depth = depth.saturating_sub(1),
177 b',' if depth == 0 => {
178 parts.push(&s[start..i]);
179 start = i + 1;
180 }
181 _ => {}
182 }
183 }
184 parts.push(&s[start..]);
186 parts
187}
188
189pub fn to_array2(
201 expr: &Expression,
202 _values: &std::collections::HashMap<String, f64>,
203) -> SymEngineResult<Array2<Complex64>> {
204 let rows = parse_matrix_expr(expr)?;
205
206 if rows.is_empty() {
207 return Ok(Array2::zeros((0, 0)));
208 }
209
210 let nrows = rows.len();
211 let ncols = rows[0].len();
212
213 for (i, row) in rows.iter().enumerate() {
215 if row.len() != ncols {
216 return Err(SymEngineError::dimension(format!(
217 "row {i} has {} columns, expected {ncols}",
218 row.len()
219 )));
220 }
221 }
222
223 let flat: Vec<Complex64> = rows.into_iter().flatten().collect();
224 Array2::from_shape_vec((nrows, ncols), flat)
225 .map_err(|e| SymEngineError::dimension(e.to_string()))
226}
227
228pub fn from_array2(arr: &Array2<Complex64>) -> Expression {
230 let (rows, cols) = arr.dim();
231
232 let mut matrix_str = String::from("Matrix([");
233
234 for i in 0..rows {
235 matrix_str.push('[');
236 for j in 0..cols {
237 let c = arr[[i, j]];
238 if c.im.abs() < 1e-15 {
239 let _ = write!(matrix_str, "{}", c.re);
240 } else if c.re.abs() < 1e-15 {
241 let _ = write!(matrix_str, "{}*I", c.im);
242 } else {
243 let _ = write!(matrix_str, "({}+{}*I)", c.re, c.im);
244 }
245 if j < cols - 1 {
246 matrix_str.push_str(", ");
247 }
248 }
249 matrix_str.push(']');
250 if i < rows - 1 {
251 matrix_str.push_str(", ");
252 }
253 }
254
255 matrix_str.push_str("])");
256
257 Expression::new(matrix_str)
258}
259
260pub fn to_array1(
271 expr: &Expression,
272 _values: &std::collections::HashMap<String, f64>,
273) -> SymEngineResult<Array1<Complex64>> {
274 let rows = parse_matrix_expr(expr)?;
275
276 let flat: Vec<Complex64> = rows
277 .into_iter()
278 .enumerate()
279 .map(|(i, row)| {
280 if row.len() == 1 {
281 Ok(row[0])
282 } else {
283 Err(SymEngineError::dimension(format!(
284 "row {i} has {} cells; expected 1 for Array1 conversion",
285 row.len()
286 )))
287 }
288 })
289 .collect::<Result<_, _>>()?;
290
291 Ok(Array1::from_vec(flat))
292}
293
294pub fn from_array1(arr: &Array1<Complex64>) -> Expression {
296 let n = arr.len();
297
298 let mut matrix_str = String::from("Matrix([");
299
300 for (i, c) in arr.iter().enumerate() {
301 matrix_str.push('[');
302 if c.im.abs() < 1e-15 {
303 let _ = write!(matrix_str, "{}", c.re);
304 } else if c.re.abs() < 1e-15 {
305 let _ = write!(matrix_str, "{}*I", c.im);
306 } else {
307 let _ = write!(matrix_str, "({}+{}*I)", c.re, c.im);
308 }
309 matrix_str.push(']');
310 if i < n - 1 {
311 matrix_str.push_str(", ");
312 }
313 }
314
315 matrix_str.push_str("])");
316
317 Expression::new(matrix_str)
318}
319
320pub fn gradient_array(
324 expr: &Expression,
325 params: &[Expression],
326 values: &std::collections::HashMap<String, f64>,
327) -> SymEngineResult<Array1<f64>> {
328 let grad_vec = crate::optimization::gradient_at(expr, params, values)?;
329 Ok(Array1::from_vec(grad_vec))
330}
331
332pub fn hessian_array(
336 expr: &Expression,
337 params: &[Expression],
338 values: &std::collections::HashMap<String, f64>,
339) -> SymEngineResult<Array2<f64>> {
340 let hess_vec = crate::optimization::hessian_at(expr, params, values)?;
341 let n = params.len();
342 let mut arr = Array2::zeros((n, n));
343
344 for (i, row) in hess_vec.iter().enumerate() {
345 for (j, &val) in row.iter().enumerate() {
346 arr[[i, j]] = val;
347 }
348 }
349
350 Ok(arr)
351}
352
353#[cfg(test)]
354mod tests {
355 use super::*;
356 use scirs2_core::ndarray::array;
357 use std::collections::HashMap;
358
359 fn no_values() -> HashMap<String, f64> {
361 HashMap::new()
362 }
363
364 #[test]
365 fn test_from_array2() {
366 let arr: Array2<Complex64> = array![
367 [Complex64::new(1.0, 0.0), Complex64::new(0.0, 1.0)],
368 [Complex64::new(0.0, -1.0), Complex64::new(1.0, 0.0)],
369 ];
370
371 let expr = from_array2(&arr);
372 assert!(expr.to_string().contains("Matrix"));
374 }
375
376 #[test]
377 fn test_from_array1() {
378 let arr: Array1<Complex64> = array![Complex64::new(1.0, 0.0), Complex64::new(0.0, 1.0),];
379
380 let expr = from_array1(&arr);
381 assert!(expr.to_string().contains("Matrix"));
383 }
384
385 #[test]
386 fn test_gradient_array() {
387 let x = Expression::symbol("x");
388 let expr = x.clone() * x.clone(); let params = vec![x];
390
391 let mut values = std::collections::HashMap::new();
392 values.insert("x".to_string(), 3.0);
393
394 let grad = gradient_array(&expr, ¶ms, &values).expect("should compute");
395 assert!((grad[0] - 6.0).abs() < 1e-6); }
397
398 #[test]
403 fn test_to_array1_real() {
404 let src: Array1<Complex64> = array![
406 Complex64::new(1.0, 0.0),
407 Complex64::new(2.0, 0.0),
408 Complex64::new(3.0, 0.0),
409 ];
410 let expr = from_array1(&src);
411 let arr = to_array1(&expr, &no_values()).expect("to_array1 should succeed");
412 assert_eq!(arr.len(), 3);
413 assert!((arr[0].re - 1.0).abs() < 1e-10);
414 assert!((arr[1].re - 2.0).abs() < 1e-10);
415 assert!((arr[2].re - 3.0).abs() < 1e-10);
416 }
417
418 #[test]
419 fn test_to_array1_complex() {
420 let src: Array1<Complex64> = array![
421 Complex64::new(1.0, 2.0),
422 Complex64::new(0.0, 3.0),
423 Complex64::new(4.0, 0.0),
424 ];
425 let expr = from_array1(&src);
426 let arr = to_array1(&expr, &no_values()).expect("to_array1 complex should succeed");
427 assert_eq!(arr.len(), 3);
428 assert!((arr[0].re - 1.0).abs() < 1e-10);
429 assert!((arr[0].im - 2.0).abs() < 1e-10);
430 assert!((arr[1].re - 0.0).abs() < 1e-10);
431 assert!((arr[1].im - 3.0).abs() < 1e-10);
432 }
433
434 #[test]
435 fn test_to_array2_2x2_real() {
436 let src: Array2<Complex64> = array![
438 [Complex64::new(1.0, 0.0), Complex64::new(2.0, 0.0)],
439 [Complex64::new(3.0, 0.0), Complex64::new(4.0, 0.0)],
440 ];
441 let expr = from_array2(&src);
442 let arr = to_array2(&expr, &no_values()).expect("to_array2 should succeed");
443 assert_eq!(arr.shape(), &[2, 2]);
444 assert!((arr[[0, 0]].re - 1.0).abs() < 1e-10);
445 assert!((arr[[0, 1]].re - 2.0).abs() < 1e-10);
446 assert!((arr[[1, 0]].re - 3.0).abs() < 1e-10);
447 assert!((arr[[1, 1]].re - 4.0).abs() < 1e-10);
448 }
449
450 #[test]
451 fn test_to_array2_2x2_complex() {
452 let src: Array2<Complex64> = array![
453 [Complex64::new(1.0, 0.0), Complex64::new(0.0, 1.0)],
454 [Complex64::new(0.0, -1.0), Complex64::new(1.0, 0.0)],
455 ];
456 let expr = from_array2(&src);
457 let arr = to_array2(&expr, &no_values()).expect("to_array2 complex should succeed");
458 assert_eq!(arr.shape(), &[2, 2]);
459 assert!((arr[[0, 1]].re - 0.0).abs() < 1e-10);
461 assert!((arr[[0, 1]].im - 1.0).abs() < 1e-10);
462 assert!((arr[[1, 0]].re - 0.0).abs() < 1e-10);
464 assert!((arr[[1, 0]].im - (-1.0)).abs() < 1e-10);
465 }
466
467 #[test]
468 fn test_to_array2_general_complex() {
469 let src: Array2<Complex64> = array![[Complex64::new(3.0, 4.0)]];
470 let expr = from_array2(&src);
471 let arr = to_array2(&expr, &no_values()).expect("to_array2 general complex should succeed");
472 assert_eq!(arr.shape(), &[1, 1]);
473 assert!((arr[[0, 0]].re - 3.0).abs() < 1e-10);
474 assert!((arr[[0, 0]].im - 4.0).abs() < 1e-10);
475 }
476
477 #[test]
478 fn test_to_array2_negative_imaginary() {
479 let src: Array2<Complex64> = array![[Complex64::new(2.0, -3.0)]];
481 let expr = from_array2(&src);
482 let arr =
483 to_array2(&expr, &no_values()).expect("to_array2 negative imaginary should succeed");
484 assert_eq!(arr.shape(), &[1, 1]);
485 assert!((arr[[0, 0]].re - 2.0).abs() < 1e-10);
486 assert!((arr[[0, 0]].im - (-3.0)).abs() < 1e-10);
487 }
488}