quantrs2_symengine_pure/serialize/
mod.rs1use crate::error::{SymEngineError, SymEngineResult};
21use crate::expr::Expression;
22use crate::matrix::SymbolicMatrix;
23use crate::parser;
24
25#[derive(Clone, Debug)]
30pub struct SerializedExpression {
31 repr: String,
33}
34
35impl SerializedExpression {
36 #[must_use]
38 pub fn from_expr(expr: &Expression) -> Self {
39 Self {
40 repr: expr.to_string(),
41 }
42 }
43
44 pub fn to_expr(&self) -> SymEngineResult<Expression> {
49 if self.repr.starts_with('(') {
52 Ok(Expression::new(&self.repr))
54 } else {
55 parser::parse(&self.repr)
57 }
58 }
59}
60
61pub fn to_bytes(expr: &Expression) -> SymEngineResult<Vec<u8>> {
72 let repr = expr.to_string();
73 let len = repr.len() as u32;
74
75 let mut bytes = Vec::with_capacity(4 + repr.len());
76
77 bytes.extend_from_slice(&len.to_le_bytes());
79
80 bytes.extend_from_slice(repr.as_bytes());
82
83 Ok(bytes)
84}
85
86pub fn from_bytes(bytes: &[u8]) -> SymEngineResult<Expression> {
97 if bytes.len() < 4 {
98 return Err(SymEngineError::parse("buffer too short for expression"));
99 }
100
101 let len = u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]) as usize;
102
103 if bytes.len() < 4 + len {
104 return Err(SymEngineError::parse(
105 "buffer too short for expression data",
106 ));
107 }
108
109 let repr = std::str::from_utf8(&bytes[4..4 + len])
110 .map_err(|e| SymEngineError::parse(format!("invalid UTF-8: {e}")))?;
111
112 SerializedExpression {
114 repr: repr.to_string(),
115 }
116 .to_expr()
117}
118
119pub fn to_bytes_many(exprs: &[Expression]) -> SymEngineResult<Vec<u8>> {
130 let mut bytes = Vec::new();
131
132 let count = exprs.len() as u32;
134 bytes.extend_from_slice(&count.to_le_bytes());
135
136 for expr in exprs {
138 let expr_bytes = to_bytes(expr)?;
139 bytes.extend_from_slice(&expr_bytes);
140 }
141
142 Ok(bytes)
143}
144
145pub fn from_bytes_many(bytes: &[u8]) -> SymEngineResult<Vec<Expression>> {
156 if bytes.len() < 4 {
157 return Err(SymEngineError::parse("buffer too short for count"));
158 }
159
160 let count = u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]) as usize;
161 let mut offset = 4;
162 let mut exprs = Vec::with_capacity(count);
163
164 for _ in 0..count {
165 if offset + 4 > bytes.len() {
166 return Err(SymEngineError::parse("unexpected end of buffer"));
167 }
168
169 let len = u32::from_le_bytes([
170 bytes[offset],
171 bytes[offset + 1],
172 bytes[offset + 2],
173 bytes[offset + 3],
174 ]) as usize;
175
176 let total_size = 4 + len;
177 if offset + total_size > bytes.len() {
178 return Err(SymEngineError::parse("unexpected end of buffer"));
179 }
180
181 let expr = from_bytes(&bytes[offset..offset + total_size])?;
182 exprs.push(expr);
183 offset += total_size;
184 }
185
186 Ok(exprs)
187}
188
189pub fn matrix_to_bytes(matrix: &SymbolicMatrix) -> SymEngineResult<Vec<u8>> {
204 let mut bytes = Vec::new();
205
206 let rows = matrix.nrows() as u32;
208 let cols = matrix.ncols() as u32;
209 bytes.extend_from_slice(&rows.to_le_bytes());
210 bytes.extend_from_slice(&cols.to_le_bytes());
211
212 for i in 0..matrix.nrows() {
214 for j in 0..matrix.ncols() {
215 let expr_bytes = to_bytes(matrix.get(i, j))?;
216 bytes.extend_from_slice(&expr_bytes);
217 }
218 }
219
220 Ok(bytes)
221}
222
223pub fn matrix_from_bytes(bytes: &[u8]) -> SymEngineResult<SymbolicMatrix> {
234 if bytes.len() < 8 {
235 return Err(SymEngineError::parse(
236 "buffer too short for matrix dimensions",
237 ));
238 }
239
240 let rows = u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]) as usize;
241 let cols = u32::from_le_bytes([bytes[4], bytes[5], bytes[6], bytes[7]]) as usize;
242
243 let mut offset = 8;
244 let mut elements = Vec::with_capacity(rows * cols);
245
246 for _ in 0..(rows * cols) {
247 if offset + 4 > bytes.len() {
248 return Err(SymEngineError::parse("unexpected end of buffer"));
249 }
250
251 let len = u32::from_le_bytes([
252 bytes[offset],
253 bytes[offset + 1],
254 bytes[offset + 2],
255 bytes[offset + 3],
256 ]) as usize;
257
258 let total_size = 4 + len;
259 if offset + total_size > bytes.len() {
260 return Err(SymEngineError::parse("unexpected end of buffer"));
261 }
262
263 let expr = from_bytes(&bytes[offset..offset + total_size])?;
264 elements.push(expr);
265 offset += total_size;
266 }
267
268 SymbolicMatrix::from_flat(elements, rows, cols)
269}
270
271#[must_use]
279pub fn to_json(expr: &Expression) -> String {
280 format!("{{\"expr\":\"{}\"}}", escape_json(&expr.to_string()))
281}
282
283pub fn from_json(json: &str) -> SymEngineResult<Expression> {
288 let json = json.trim();
290
291 if !json.starts_with('{') || !json.ends_with('}') {
292 return Err(SymEngineError::parse("invalid JSON: expected object"));
293 }
294
295 let inner = &json[1..json.len() - 1];
296
297 if let Some(start) = inner.find("\"expr\":\"") {
299 let value_start = start + 8;
300 if let Some(end) = inner[value_start..].find('"') {
301 let value = &inner[value_start..value_start + end];
302 let unescaped = unescape_json(value);
303 return SerializedExpression { repr: unescaped }.to_expr();
304 }
305 }
306
307 Err(SymEngineError::parse("invalid JSON: missing 'expr' field"))
308}
309
310fn escape_json(s: &str) -> String {
312 let mut result = String::with_capacity(s.len());
313 for c in s.chars() {
314 match c {
315 '"' => result.push_str("\\\""),
316 '\\' => result.push_str("\\\\"),
317 '\n' => result.push_str("\\n"),
318 '\r' => result.push_str("\\r"),
319 '\t' => result.push_str("\\t"),
320 _ => result.push(c),
321 }
322 }
323 result
324}
325
326fn unescape_json(s: &str) -> String {
328 let mut result = String::with_capacity(s.len());
329 let mut chars = s.chars();
330
331 while let Some(c) = chars.next() {
332 if c == '\\' {
333 match chars.next() {
334 Some('"') => result.push('"'),
335 Some('n') => result.push('\n'),
336 Some('r') => result.push('\r'),
337 Some('t') => result.push('\t'),
338 Some('\\') | None => result.push('\\'),
339 Some(other) => {
340 result.push('\\');
341 result.push(other);
342 }
343 }
344 } else {
345 result.push(c);
346 }
347 }
348
349 result
350}
351
352#[cfg(test)]
353#[allow(clippy::approx_constant)]
354mod tests {
355 use super::*;
356 use std::collections::HashMap;
357
358 #[test]
359 fn test_serialize_simple() {
360 let expr = Expression::symbol("x");
361 let bytes = to_bytes(&expr).expect("should serialize");
362
363 let decoded = from_bytes(&bytes).expect("should deserialize");
364
365 let mut values = HashMap::new();
367 values.insert("x".to_string(), 5.0);
368
369 let orig = expr.eval(&values).expect("should eval");
370 let dec = decoded.eval(&values).expect("should eval");
371
372 assert!((orig - dec).abs() < 1e-10);
373 }
374
375 #[test]
376 fn test_serialize_number() {
377 let expr = Expression::float_unchecked(3.14);
378 let bytes = to_bytes(&expr).expect("should serialize");
379
380 let decoded = from_bytes(&bytes).expect("should deserialize");
381
382 let orig = expr.eval(&HashMap::new()).expect("should eval");
383 let dec = decoded.eval(&HashMap::new()).expect("should eval");
384
385 assert!((orig - dec).abs() < 1e-10);
386 }
387
388 #[test]
389 fn test_serialize_many() {
390 let exprs = vec![
391 Expression::symbol("x"),
392 Expression::symbol("y"),
393 Expression::int(42),
394 ];
395
396 let bytes = to_bytes_many(&exprs).expect("should serialize");
397 let decoded = from_bytes_many(&bytes).expect("should deserialize");
398
399 assert_eq!(decoded.len(), 3);
400 }
401
402 #[test]
403 fn test_serialize_matrix() {
404 let matrix = SymbolicMatrix::identity(2);
405 let bytes = matrix_to_bytes(&matrix).expect("should serialize");
406
407 let decoded = matrix_from_bytes(&bytes).expect("should deserialize");
408
409 assert_eq!(decoded.nrows(), 2);
410 assert_eq!(decoded.ncols(), 2);
411 assert!(decoded.get(0, 0).is_one());
412 assert!(decoded.get(0, 1).is_zero());
413 }
414
415 #[test]
416 fn test_json_serialize() {
417 let expr = Expression::symbol("x");
418 let json = to_json(&expr);
419
420 assert!(json.contains("\"expr\":"));
421 assert!(json.contains("\"x\""));
422
423 let decoded = from_json(&json).expect("should parse");
424 assert_eq!(decoded.as_symbol(), Some("x"));
425 }
426
427 #[test]
428 fn test_json_escape() {
429 let s = "hello\"world\\test";
430 let escaped = escape_json(s);
431 let unescaped = unescape_json(&escaped);
432 assert_eq!(s, unescaped);
433 }
434}