1use crate::core::expression::LimitDirection;
4use crate::core::{Expression, MathConstant, Number, Symbol};
5use num_traits::ToPrimitive;
6use serde::{Deserialize, Serialize};
7use serde_json;
8
9pub struct MathSerializer;
11
12impl MathSerializer {
13 pub fn parse(data_str: &str) -> Result<Expression, SerializationError> {
15 let expr_data: ExpressionData = serde_json::from_str(data_str)
16 .map_err(|e| SerializationError::InvalidFormat(e.to_string()))?;
17 Self::data_to_expression(expr_data)
18 }
19
20 pub fn stringify(expr: &Expression) -> Result<String, SerializationError> {
22 let expr_data = Self::expression_to_data(expr);
23 serde_json::to_string_pretty(&expr_data)
24 .map_err(|e| SerializationError::SerializationError(e.to_string()))
25 }
26
27 pub fn stringify_compact(expr: &Expression) -> Result<String, SerializationError> {
29 let expr_data = Self::expression_to_data(expr);
30 serde_json::to_string(&expr_data)
31 .map_err(|e| SerializationError::SerializationError(e.to_string()))
32 }
33
34 fn data_to_expression(expr_data: ExpressionData) -> Result<Expression, SerializationError> {
36 match expr_data {
37 ExpressionData::Number { value } => Ok(Expression::integer(value)),
38 ExpressionData::Float { value } => Ok(Expression::number(Number::float(value))),
39 ExpressionData::Symbol { name } => Ok(Expression::symbol(Symbol::new(&name))),
40
41 ExpressionData::Add { terms } => {
42 let expr_terms: Result<Vec<Expression>, SerializationError> =
43 terms.into_iter().map(Self::data_to_expression).collect();
44 Ok(Expression::add(expr_terms?))
45 }
46
47 ExpressionData::Mul { factors } => {
48 let expr_factors: Result<Vec<Expression>, SerializationError> =
49 factors.into_iter().map(Self::data_to_expression).collect();
50 Ok(Expression::mul(expr_factors?))
51 }
52
53 ExpressionData::Pow { base, exponent } => {
54 let base_expr = Self::data_to_expression(*base)?;
55 let exp_expr = Self::data_to_expression(*exponent)?;
56 Ok(Expression::pow(base_expr, exp_expr))
57 }
58
59 ExpressionData::Function { name, args } => {
60 let expr_args: Result<Vec<Expression>, SerializationError> =
61 args.into_iter().map(Self::data_to_expression).collect();
62 Ok(Expression::function(name, expr_args?))
63 }
64
65 ExpressionData::Complex { real, imag } => {
66 let real_expr = Self::data_to_expression(*real)?;
67 let imag_expr = Self::data_to_expression(*imag)?;
68 Ok(Expression::complex(real_expr, imag_expr))
69 }
70
71 ExpressionData::Constant { constant } => Ok(Expression::constant(constant)),
72
73 ExpressionData::Derivative {
74 expression,
75 variable,
76 order,
77 } => {
78 let expr = Self::data_to_expression(*expression)?;
79 Ok(Expression::derivative(expr, Symbol::new(&variable), order))
80 }
81
82 ExpressionData::Integral {
83 integrand,
84 variable,
85 bounds,
86 } => {
87 let integrand_expr = Self::data_to_expression(*integrand)?;
88 let var_symbol = Symbol::new(&variable);
89
90 match bounds {
91 None => Ok(Expression::integral(integrand_expr, var_symbol)),
92 Some((start, end)) => {
93 let start_expr = Self::data_to_expression(*start)?;
94 let end_expr = Self::data_to_expression(*end)?;
95 Ok(Expression::definite_integral(
96 integrand_expr,
97 var_symbol,
98 start_expr,
99 end_expr,
100 ))
101 }
102 }
103 }
104
105 ExpressionData::Limit {
106 expression,
107 variable,
108 approach,
109 direction: _,
110 } => {
111 let expr = Self::data_to_expression(*expression)?;
112 let approach_expr = Self::data_to_expression(*approach)?;
113 Ok(Expression::limit(
114 expr,
115 Symbol::new(&variable),
116 approach_expr,
117 ))
118 }
119 ExpressionData::MethodCall {
120 object,
121 method_name,
122 args,
123 } => {
124 let object_expr = Self::data_to_expression(*object)?;
125 let arg_exprs: Result<Vec<Expression>, SerializationError> =
126 args.into_iter().map(Self::data_to_expression).collect();
127 Ok(Expression::method_call(
128 object_expr,
129 method_name,
130 arg_exprs?,
131 ))
132 }
133 }
134 }
135
136 fn expression_to_data(expr: &Expression) -> ExpressionData {
138 match expr {
139 Expression::Number(Number::Integer(n)) => ExpressionData::Number { value: *n },
140 Expression::Number(Number::BigInteger(n)) => ExpressionData::Number {
141 value: n.to_string().parse().unwrap_or(0),
142 },
143 Expression::Number(Number::Float(f)) => ExpressionData::Float { value: *f },
144 Expression::Number(Number::Rational(r)) => {
145 let float_val =
146 r.numer().to_f64().unwrap_or(0.0) / r.denom().to_f64().unwrap_or(1.0);
147 ExpressionData::Float { value: float_val }
148 }
149
150 Expression::Symbol(s) => ExpressionData::Symbol {
151 name: s.name().to_owned(),
152 },
153
154 Expression::Add(terms) => ExpressionData::Add {
155 terms: terms.iter().map(Self::expression_to_data).collect(),
156 },
157
158 Expression::Mul(factors) => ExpressionData::Mul {
159 factors: factors.iter().map(Self::expression_to_data).collect(),
160 },
161
162 Expression::Pow(base, exp) => ExpressionData::Pow {
163 base: Box::new(Self::expression_to_data(base)),
164 exponent: Box::new(Self::expression_to_data(exp)),
165 },
166
167 Expression::Function { name, args } => ExpressionData::Function {
168 name: name.clone(),
169 args: args.iter().map(Self::expression_to_data).collect(),
170 },
171
172 Expression::Complex(complex_data) => ExpressionData::Complex {
173 real: Box::new(Self::expression_to_data(&complex_data.real)),
174 imag: Box::new(Self::expression_to_data(&complex_data.imag)),
175 },
176
177 Expression::Constant(c) => ExpressionData::Constant { constant: *c },
178
179 Expression::Calculus(calculus_data) => {
180 use crate::core::expression::CalculusData;
181 match calculus_data.as_ref() {
182 CalculusData::Derivative {
183 expression,
184 variable,
185 order,
186 } => ExpressionData::Derivative {
187 expression: Box::new(Self::expression_to_data(expression)),
188 variable: variable.name().to_owned(),
189 order: *order,
190 },
191 CalculusData::Integral {
192 integrand,
193 variable,
194 bounds,
195 } => ExpressionData::Integral {
196 integrand: Box::new(Self::expression_to_data(integrand)),
197 variable: variable.name().to_owned(),
198 bounds: bounds.as_ref().map(|(start, end)| {
199 (
200 Box::new(Self::expression_to_data(start)),
201 Box::new(Self::expression_to_data(end)),
202 )
203 }),
204 },
205 _ => ExpressionData::Function {
206 name: "calculus_operation".to_owned(),
207 args: vec![],
208 },
209 }
210 }
211
212 Expression::MethodCall(method_data) => ExpressionData::MethodCall {
213 object: Box::new(Self::expression_to_data(&method_data.object)),
214 method_name: method_data.method_name.clone(),
215 args: method_data
216 .args
217 .iter()
218 .map(Self::expression_to_data)
219 .collect(),
220 },
221
222 _ => ExpressionData::Symbol {
224 name: format!("unsupported_{}", std::any::type_name_of_val(expr)),
225 },
226 }
227 }
228}
229
230#[derive(Debug, Clone, Serialize, Deserialize)]
232#[serde(tag = "type")]
233pub enum ExpressionData {
234 Number {
235 value: i64,
236 },
237 Float {
238 value: f64,
239 },
240 Symbol {
241 name: String,
242 },
243 Add {
244 terms: Vec<ExpressionData>,
245 },
246 Mul {
247 factors: Vec<ExpressionData>,
248 },
249 Pow {
250 base: Box<ExpressionData>,
251 exponent: Box<ExpressionData>,
252 },
253 Function {
254 name: String,
255 args: Vec<ExpressionData>,
256 },
257 Complex {
258 real: Box<ExpressionData>,
259 imag: Box<ExpressionData>,
260 },
261 Constant {
262 constant: MathConstant,
263 },
264 Derivative {
265 expression: Box<ExpressionData>,
266 variable: String,
267 order: u32,
268 },
269 Integral {
270 integrand: Box<ExpressionData>,
271 variable: String,
272 bounds: Option<(Box<ExpressionData>, Box<ExpressionData>)>,
273 },
274 Limit {
275 expression: Box<ExpressionData>,
276 variable: String,
277 approach: Box<ExpressionData>,
278 direction: LimitDirection,
279 },
280 MethodCall {
281 object: Box<ExpressionData>,
282 method_name: String,
283 args: Vec<ExpressionData>,
284 },
285}
286
287#[derive(Debug, Clone)]
289pub enum SerializationError {
290 InvalidFormat(String),
291 SerializationError(String),
292 ParseError(String),
293 UnsupportedType(String),
294}
295
296impl std::fmt::Display for SerializationError {
297 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
298 match self {
299 SerializationError::InvalidFormat(msg) => write!(f, "Invalid format: {}", msg),
300 SerializationError::SerializationError(msg) => {
301 write!(f, "Serialization error: {}", msg)
302 }
303 SerializationError::ParseError(msg) => write!(f, "Parse error: {}", msg),
304 SerializationError::UnsupportedType(msg) => write!(f, "Unsupported type: {}", msg),
305 }
306 }
307}
308
309impl std::error::Error for SerializationError {}