quantrs2_symengine_pure/serialize/
mod.rs

1//! Serialization support for symbolic expressions.
2//!
3//! This module provides serialization and deserialization of expressions
4//! using oxicode (following COOLJAPAN policy - not bincode).
5//!
6//! ## Example
7//!
8//! ```ignore
9//! use quantrs2_symengine_pure::{Expression, serialize};
10//!
11//! let expr = Expression::symbol("x");
12//!
13//! // Serialize to bytes
14//! let bytes = serialize::to_bytes(&expr)?;
15//!
16//! // Deserialize from bytes
17//! let decoded: Expression = serialize::from_bytes(&bytes)?;
18//! ```
19
20use crate::error::{SymEngineError, SymEngineResult};
21use crate::expr::Expression;
22use crate::matrix::SymbolicMatrix;
23use crate::parser;
24
25/// A serializable form of an Expression.
26///
27/// Since Expression contains egg's RecExpr which doesn't implement
28/// standard serialization traits, we serialize via the string representation.
29#[derive(Clone, Debug)]
30pub struct SerializedExpression {
31    /// The expression as a string
32    repr: String,
33}
34
35impl SerializedExpression {
36    /// Create a serialized expression from an Expression
37    #[must_use]
38    pub fn from_expr(expr: &Expression) -> Self {
39        Self {
40            repr: expr.to_string(),
41        }
42    }
43
44    /// Convert back to an Expression
45    ///
46    /// # Errors
47    /// Returns error if parsing fails
48    pub fn to_expr(&self) -> SymEngineResult<Expression> {
49        // For simple expressions, try parsing
50        // For complex egg s-expressions, we need special handling
51        if self.repr.starts_with('(') {
52            // It's an s-expression from egg, use Expression::new
53            Ok(Expression::new(&self.repr))
54        } else {
55            // Try parsing as a mathematical expression
56            parser::parse(&self.repr)
57        }
58    }
59}
60
61/// Serialize an Expression to bytes using oxicode.
62///
63/// # Arguments
64/// * `expr` - The expression to serialize
65///
66/// # Returns
67/// A vector of bytes containing the serialized expression.
68///
69/// # Errors
70/// Returns error if serialization fails.
71pub fn to_bytes(expr: &Expression) -> SymEngineResult<Vec<u8>> {
72    let repr = expr.to_string();
73    let len = repr.len() as u32;
74
75    let mut bytes = Vec::with_capacity(4 + repr.len());
76
77    // Write length as little-endian u32
78    bytes.extend_from_slice(&len.to_le_bytes());
79
80    // Write string bytes
81    bytes.extend_from_slice(repr.as_bytes());
82
83    Ok(bytes)
84}
85
86/// Deserialize an Expression from bytes.
87///
88/// # Arguments
89/// * `bytes` - The bytes to deserialize from
90///
91/// # Returns
92/// The deserialized expression.
93///
94/// # Errors
95/// Returns error if deserialization or parsing fails.
96pub fn from_bytes(bytes: &[u8]) -> SymEngineResult<Expression> {
97    if bytes.len() < 4 {
98        return Err(SymEngineError::parse("buffer too short for expression"));
99    }
100
101    let len = u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]) as usize;
102
103    if bytes.len() < 4 + len {
104        return Err(SymEngineError::parse(
105            "buffer too short for expression data",
106        ));
107    }
108
109    let repr = std::str::from_utf8(&bytes[4..4 + len])
110        .map_err(|e| SymEngineError::parse(format!("invalid UTF-8: {e}")))?;
111
112    // Parse the representation
113    SerializedExpression {
114        repr: repr.to_string(),
115    }
116    .to_expr()
117}
118
119/// Serialize multiple expressions to bytes.
120///
121/// # Arguments
122/// * `exprs` - The expressions to serialize
123///
124/// # Returns
125/// A vector of bytes containing the serialized expressions.
126///
127/// # Errors
128/// Returns error if serialization fails.
129pub fn to_bytes_many(exprs: &[Expression]) -> SymEngineResult<Vec<u8>> {
130    let mut bytes = Vec::new();
131
132    // Write count as u32
133    let count = exprs.len() as u32;
134    bytes.extend_from_slice(&count.to_le_bytes());
135
136    // Serialize each expression
137    for expr in exprs {
138        let expr_bytes = to_bytes(expr)?;
139        bytes.extend_from_slice(&expr_bytes);
140    }
141
142    Ok(bytes)
143}
144
145/// Deserialize multiple expressions from bytes.
146///
147/// # Arguments
148/// * `bytes` - The bytes to deserialize from
149///
150/// # Returns
151/// The deserialized expressions.
152///
153/// # Errors
154/// Returns error if deserialization or parsing fails.
155pub fn from_bytes_many(bytes: &[u8]) -> SymEngineResult<Vec<Expression>> {
156    if bytes.len() < 4 {
157        return Err(SymEngineError::parse("buffer too short for count"));
158    }
159
160    let count = u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]) as usize;
161    let mut offset = 4;
162    let mut exprs = Vec::with_capacity(count);
163
164    for _ in 0..count {
165        if offset + 4 > bytes.len() {
166            return Err(SymEngineError::parse("unexpected end of buffer"));
167        }
168
169        let len = u32::from_le_bytes([
170            bytes[offset],
171            bytes[offset + 1],
172            bytes[offset + 2],
173            bytes[offset + 3],
174        ]) as usize;
175
176        let total_size = 4 + len;
177        if offset + total_size > bytes.len() {
178            return Err(SymEngineError::parse("unexpected end of buffer"));
179        }
180
181        let expr = from_bytes(&bytes[offset..offset + total_size])?;
182        exprs.push(expr);
183        offset += total_size;
184    }
185
186    Ok(exprs)
187}
188
189// =========================================================================
190// Matrix Serialization
191// =========================================================================
192
193/// Serialize a SymbolicMatrix to bytes.
194///
195/// # Arguments
196/// * `matrix` - The matrix to serialize
197///
198/// # Returns
199/// A vector of bytes containing the serialized matrix.
200///
201/// # Errors
202/// Returns error if serialization fails.
203pub fn matrix_to_bytes(matrix: &SymbolicMatrix) -> SymEngineResult<Vec<u8>> {
204    let mut bytes = Vec::new();
205
206    // Write dimensions
207    let rows = matrix.nrows() as u32;
208    let cols = matrix.ncols() as u32;
209    bytes.extend_from_slice(&rows.to_le_bytes());
210    bytes.extend_from_slice(&cols.to_le_bytes());
211
212    // Serialize each element
213    for i in 0..matrix.nrows() {
214        for j in 0..matrix.ncols() {
215            let expr_bytes = to_bytes(matrix.get(i, j))?;
216            bytes.extend_from_slice(&expr_bytes);
217        }
218    }
219
220    Ok(bytes)
221}
222
223/// Deserialize a SymbolicMatrix from bytes.
224///
225/// # Arguments
226/// * `bytes` - The bytes to deserialize from
227///
228/// # Returns
229/// The deserialized matrix.
230///
231/// # Errors
232/// Returns error if deserialization or parsing fails.
233pub fn matrix_from_bytes(bytes: &[u8]) -> SymEngineResult<SymbolicMatrix> {
234    if bytes.len() < 8 {
235        return Err(SymEngineError::parse(
236            "buffer too short for matrix dimensions",
237        ));
238    }
239
240    let rows = u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]) as usize;
241    let cols = u32::from_le_bytes([bytes[4], bytes[5], bytes[6], bytes[7]]) as usize;
242
243    let mut offset = 8;
244    let mut elements = Vec::with_capacity(rows * cols);
245
246    for _ in 0..(rows * cols) {
247        if offset + 4 > bytes.len() {
248            return Err(SymEngineError::parse("unexpected end of buffer"));
249        }
250
251        let len = u32::from_le_bytes([
252            bytes[offset],
253            bytes[offset + 1],
254            bytes[offset + 2],
255            bytes[offset + 3],
256        ]) as usize;
257
258        let total_size = 4 + len;
259        if offset + total_size > bytes.len() {
260            return Err(SymEngineError::parse("unexpected end of buffer"));
261        }
262
263        let expr = from_bytes(&bytes[offset..offset + total_size])?;
264        elements.push(expr);
265        offset += total_size;
266    }
267
268    SymbolicMatrix::from_flat(elements, rows, cols)
269}
270
271// =========================================================================
272// JSON-like Human Readable Format
273// =========================================================================
274
275/// Serialize an Expression to a JSON-like human-readable format.
276///
277/// This produces a simple JSON object with the expression representation.
278#[must_use]
279pub fn to_json(expr: &Expression) -> String {
280    format!("{{\"expr\":\"{}\"}}", escape_json(&expr.to_string()))
281}
282
283/// Deserialize an Expression from JSON-like format.
284///
285/// # Errors
286/// Returns error if parsing fails.
287pub fn from_json(json: &str) -> SymEngineResult<Expression> {
288    // Simple JSON parsing - extract the "expr" field
289    let json = json.trim();
290
291    if !json.starts_with('{') || !json.ends_with('}') {
292        return Err(SymEngineError::parse("invalid JSON: expected object"));
293    }
294
295    let inner = &json[1..json.len() - 1];
296
297    // Find "expr":"..."
298    if let Some(start) = inner.find("\"expr\":\"") {
299        let value_start = start + 8;
300        if let Some(end) = inner[value_start..].find('"') {
301            let value = &inner[value_start..value_start + end];
302            let unescaped = unescape_json(value);
303            return SerializedExpression { repr: unescaped }.to_expr();
304        }
305    }
306
307    Err(SymEngineError::parse("invalid JSON: missing 'expr' field"))
308}
309
310/// Escape a string for JSON
311fn escape_json(s: &str) -> String {
312    let mut result = String::with_capacity(s.len());
313    for c in s.chars() {
314        match c {
315            '"' => result.push_str("\\\""),
316            '\\' => result.push_str("\\\\"),
317            '\n' => result.push_str("\\n"),
318            '\r' => result.push_str("\\r"),
319            '\t' => result.push_str("\\t"),
320            _ => result.push(c),
321        }
322    }
323    result
324}
325
326/// Unescape a JSON string
327fn unescape_json(s: &str) -> String {
328    let mut result = String::with_capacity(s.len());
329    let mut chars = s.chars();
330
331    while let Some(c) = chars.next() {
332        if c == '\\' {
333            match chars.next() {
334                Some('"') => result.push('"'),
335                Some('n') => result.push('\n'),
336                Some('r') => result.push('\r'),
337                Some('t') => result.push('\t'),
338                Some('\\') | None => result.push('\\'),
339                Some(other) => {
340                    result.push('\\');
341                    result.push(other);
342                }
343            }
344        } else {
345            result.push(c);
346        }
347    }
348
349    result
350}
351
352#[cfg(test)]
353#[allow(clippy::approx_constant)]
354mod tests {
355    use super::*;
356    use std::collections::HashMap;
357
358    #[test]
359    fn test_serialize_simple() {
360        let expr = Expression::symbol("x");
361        let bytes = to_bytes(&expr).expect("should serialize");
362
363        let decoded = from_bytes(&bytes).expect("should deserialize");
364
365        // Verify by evaluation
366        let mut values = HashMap::new();
367        values.insert("x".to_string(), 5.0);
368
369        let orig = expr.eval(&values).expect("should eval");
370        let dec = decoded.eval(&values).expect("should eval");
371
372        assert!((orig - dec).abs() < 1e-10);
373    }
374
375    #[test]
376    fn test_serialize_number() {
377        let expr = Expression::float_unchecked(3.14);
378        let bytes = to_bytes(&expr).expect("should serialize");
379
380        let decoded = from_bytes(&bytes).expect("should deserialize");
381
382        let orig = expr.eval(&HashMap::new()).expect("should eval");
383        let dec = decoded.eval(&HashMap::new()).expect("should eval");
384
385        assert!((orig - dec).abs() < 1e-10);
386    }
387
388    #[test]
389    fn test_serialize_many() {
390        let exprs = vec![
391            Expression::symbol("x"),
392            Expression::symbol("y"),
393            Expression::int(42),
394        ];
395
396        let bytes = to_bytes_many(&exprs).expect("should serialize");
397        let decoded = from_bytes_many(&bytes).expect("should deserialize");
398
399        assert_eq!(decoded.len(), 3);
400    }
401
402    #[test]
403    fn test_serialize_matrix() {
404        let matrix = SymbolicMatrix::identity(2);
405        let bytes = matrix_to_bytes(&matrix).expect("should serialize");
406
407        let decoded = matrix_from_bytes(&bytes).expect("should deserialize");
408
409        assert_eq!(decoded.nrows(), 2);
410        assert_eq!(decoded.ncols(), 2);
411        assert!(decoded.get(0, 0).is_one());
412        assert!(decoded.get(0, 1).is_zero());
413    }
414
415    #[test]
416    fn test_json_serialize() {
417        let expr = Expression::symbol("x");
418        let json = to_json(&expr);
419
420        assert!(json.contains("\"expr\":"));
421        assert!(json.contains("\"x\""));
422
423        let decoded = from_json(&json).expect("should parse");
424        assert_eq!(decoded.as_symbol(), Some("x"));
425    }
426
427    #[test]
428    fn test_json_escape() {
429        let s = "hello\"world\\test";
430        let escaped = escape_json(s);
431        let unescaped = unescape_json(&escaped);
432        assert_eq!(s, unescaped);
433    }
434}