1use crate::codegen;
4use crate::error::Result;
5use oxieml::EmlTree;
6use scirs2_core::ndarray::{Array1, Array2};
7
8#[derive(Debug, Clone)]
10pub struct Solution {
11 pub tree: EmlTree,
13 pub mse: f64,
15 pub complexity: usize,
17}
18
19impl Solution {
20 #[must_use]
22 pub fn new(tree: EmlTree, mse: f64) -> Self {
23 let complexity = tree.size();
24 Self {
25 tree,
26 mse,
27 complexity,
28 }
29 }
30
31 #[must_use]
33 pub fn latex(&self) -> String {
34 self.tree.lower().simplify().to_latex()
35 }
36
37 #[must_use]
39 pub fn pretty(&self) -> String {
40 format!("{}", self.tree)
41 }
42
43 #[cfg(feature = "egraph")]
48 #[must_use]
49 pub fn latex_egraph(&self) -> String {
50 crate::egraph::canonical_latex_egraph(&self.tree)
51 }
52
53 #[must_use]
55 pub fn rust_code(&self) -> String {
56 codegen::rust_code(&self.tree)
57 }
58
59 #[must_use]
61 pub fn numpy_code(&self) -> String {
62 codegen::numpy_code(&self.tree)
63 }
64
65 #[must_use]
67 pub fn sympy_code(&self) -> String {
68 codegen::sympy_code(&self.tree)
69 }
70
71 #[must_use]
74 pub fn distill(&self) -> crate::distill::Distilled {
75 crate::distill::distill(&self.tree)
76 }
77
78 pub fn predict(&self, x: &Array2<f64>) -> Result<Array1<f64>> {
84 crate::forest::eval_tree(&self.tree, x)
85 }
86
87 pub fn to_model_json(&self) -> Result<String> {
93 serde_json::to_string(&self.tree)
94 .map_err(|e| crate::error::PhopError::Symbolic(format!("model serialize: {e}")))
95 }
96
97 pub fn from_model_json(s: &str) -> Result<Self> {
103 let tree: EmlTree = serde_json::from_str(s)
104 .map_err(|e| crate::error::PhopError::Parse(format!("model deserialize: {e}")))?;
105 Ok(Self::new(tree, f64::NAN))
106 }
107
108 #[must_use]
111 pub fn analyze(&self, wrt: usize, series_order: usize) -> crate::analyze::Analysis {
112 crate::analyze::analyze(&self.tree, wrt, series_order)
113 }
114
115 pub fn certified_root(
121 &self,
122 wrt: usize,
123 others: &[f64],
124 lo: f64,
125 hi: f64,
126 ) -> Result<oxieml::RootCertificate> {
127 crate::analyze::certified_root(&self.tree, wrt, others, lo, hi)
128 }
129
130 #[must_use]
133 pub fn certified_range(&self, domain: &[(f64, f64)]) -> (f64, f64) {
134 crate::analyze::certified_range(&self.tree, domain)
135 }
136
137 #[must_use]
142 pub fn compile_rust(&self, fn_name: &str) -> String {
143 oxieml::compile::compile_to_rust(&self.tree, fn_name)
144 }
145
146 #[cfg(feature = "smt")]
150 #[must_use]
151 pub fn prove_no_root(&self, bounds: &[(f64, f64)]) -> crate::verify::Verdict {
152 crate::verify::prove_no_root(&self.tree, bounds)
153 }
154
155 #[cfg(feature = "smt")]
157 #[must_use]
158 pub fn prove_lower_bound(&self, c: f64, bounds: &[(f64, f64)]) -> crate::verify::Verdict {
159 crate::verify::prove_lower_bound(&self.tree, c, bounds)
160 }
161
162 #[cfg(feature = "smt")]
164 #[must_use]
165 pub fn prove_upper_bound(&self, c: f64, bounds: &[(f64, f64)]) -> crate::verify::Verdict {
166 crate::verify::prove_upper_bound(&self.tree, c, bounds)
167 }
168
169 #[cfg(feature = "smt")]
171 #[must_use]
172 pub fn prove_equivalent(
173 &self,
174 other: &Solution,
175 bounds: &[(f64, f64)],
176 ) -> crate::verify::Verdict {
177 crate::verify::prove_equivalent(&self.tree, &other.tree, bounds)
178 }
179
180 #[cfg(feature = "tensorlogic")]
185 #[must_use]
186 pub fn to_tlexpr(&self) -> tensorlogic_ir::TLExpr {
187 oxieml::tensorlogic::to_tlexpr(&self.tree.lower().simplify())
188 }
189
190 #[cfg(feature = "tensorlogic")]
193 #[must_use]
194 pub fn to_tl_weighted_rule(&self, weight: f64) -> tensorlogic_ir::TLExpr {
195 tensorlogic_ir::TLExpr::WeightedRule {
196 weight,
197 rule: Box::new(self.to_tlexpr()),
198 }
199 }
200
201 #[cfg(feature = "tensorlogic")]
204 #[must_use]
205 pub fn to_tl_weighted_equation(&self, target_var: &str, weight: f64) -> tensorlogic_ir::TLExpr {
206 let lhs = tensorlogic_ir::TLExpr::Pred {
207 name: target_var.to_string(),
208 args: vec![tensorlogic_ir::Term::var(target_var)],
209 };
210 let eq = tensorlogic_ir::TLExpr::Eq(Box::new(lhs), Box::new(self.to_tlexpr()));
211 tensorlogic_ir::TLExpr::WeightedRule {
212 weight,
213 rule: Box::new(eq),
214 }
215 }
216
217 #[must_use]
221 pub fn dominates(&self, other: &Solution) -> bool {
222 self.complexity <= other.complexity
223 && self.mse <= other.mse
224 && (self.complexity < other.complexity || self.mse < other.mse)
225 }
226}
227
228#[cfg(test)]
229mod tests {
230 use super::*;
231
232 #[test]
233 fn latex_renders() {
234 let tree = oxieml::Canonical::exp(&EmlTree::var(0));
235 let sol = Solution::new(tree, 0.0);
236 let tex = sol.latex();
237 assert!(!tex.is_empty());
238 }
239
240 #[test]
241 fn predict_matches_forward_eval() {
242 use scirs2_core::ndarray::Array2;
243 let tree = EmlTree::eml(&EmlTree::var(0), &EmlTree::one());
245 let sol = Solution::new(tree, 0.0);
246 let x = Array2::from_shape_vec((3, 1), vec![0.0, 1.0, 2.0]).unwrap();
247 let p = sol.predict(&x).unwrap();
248 for (i, &xi) in [0.0_f64, 1.0, 2.0].iter().enumerate() {
249 assert!(
250 (p[i] - xi.exp()).abs() < 1e-9,
251 "row {i}: {} vs {}",
252 p[i],
253 xi.exp()
254 );
255 }
256 }
257
258 #[test]
259 fn model_json_round_trips() {
260 use scirs2_core::ndarray::Array2;
261 let tree = EmlTree::eml(&EmlTree::var(0), &EmlTree::one());
263 let sol = Solution::new(tree, 0.0);
264 let json = sol.to_model_json().expect("serialize");
265 let back = Solution::from_model_json(&json).expect("deserialize");
266 let x = Array2::from_shape_vec((3, 1), vec![0.0, 0.5, 1.0]).unwrap();
267 let a = sol.predict(&x).unwrap();
268 let b = back.predict(&x).unwrap();
269 for (pa, pb) in a.iter().zip(b.iter()) {
270 assert!(
271 (pa - pb).abs() < 1e-12,
272 "round-trip changed prediction: {pa} vs {pb}"
273 );
274 }
275 assert!(Solution::from_model_json("not json").is_err());
276 }
277
278 #[test]
279 fn compile_rust_emits_a_function() {
280 let tree = oxieml::Canonical::exp(&EmlTree::var(0));
281 let code = Solution::new(tree, 0.0).compile_rust("kepler");
282 assert!(
283 code.contains("fn kepler"),
284 "expected a named fn, got: {code}"
285 );
286 }
287
288 #[cfg(feature = "tensorlogic")]
289 #[test]
290 fn maps_to_weighted_logic_rule() {
291 let tree = oxieml::Canonical::exp(&EmlTree::var(0));
294 let sol = Solution::new(tree, 0.0);
295
296 let weight = match sol.to_tl_weighted_rule(0.7) {
297 tensorlogic_ir::TLExpr::WeightedRule { weight, .. } => weight,
298 _ => f64::NAN,
299 };
300 assert!(
301 (weight - 0.7).abs() < 1e-12,
302 "expected a WeightedRule carrying 0.7, got {weight}"
303 );
304
305 let eq = sol.to_tl_weighted_equation("y", 1.0);
306 assert!(matches!(eq, tensorlogic_ir::TLExpr::WeightedRule { .. }));
307 }
308
309 #[test]
310 fn domination_is_correct() {
311 let t = oxieml::Canonical::exp(&EmlTree::var(0));
312 let a = Solution {
313 tree: t.clone(),
314 mse: 0.1,
315 complexity: 3,
316 };
317 let b = Solution {
318 tree: t,
319 mse: 0.2,
320 complexity: 5,
321 };
322 assert!(a.dominates(&b));
323 assert!(!b.dominates(&a));
324 }
325}