mod argument;
mod convenience_expressions;
mod derivation_rules;
mod deriver;
mod equivalence_disbatchers;
mod expressions;
pub mod graph;
mod graph_traversal;
mod mathxml;
mod optimization_profiles;
pub use deriver::DerivationDebugInfo;
use deriver::Deriver;
use expressions::read_from_json::read_object_from_json;
pub use expressions::Expression;
use graph::Graph;
use graph_traversal::better_solution_cmp;
pub use optimization_profiles::{BruteForceProfile, EvaluateFirstProfile, OptimizationProfile};
use petgraph::{algo::astar, graph::NodeIndex, visit::IntoNodeReferences};
use serde::Serialize;
use serde_json::json;
use std::{cell::RefCell, collections::HashSet, rc::Rc};
pub use expressions::absolute_value;
pub use expressions::constant;
pub use expressions::derivative;
pub use expressions::exponent;
pub use expressions::fraction;
pub use expressions::integer;
pub use expressions::integral;
pub use expressions::logarithm;
pub use expressions::negation;
pub use expressions::product;
pub use expressions::substitution;
pub use expressions::sum;
pub use expressions::trig_expression;
pub use expressions::undefined;
pub use expressions::variable;
pub use graph_traversal::Path;
pub fn expression_from_json(exp: &str) -> Result<Expression, anyhow::Error> {
read_object_from_json(exp)
}
pub fn simplify(
expression: &Expression,
search_depth: u32,
optimizer: Box<dyn OptimizationProfile>,
allowed_rules: Option<Vec<String>>,
debug_data: Option<Rc<RefCell<DerivationDebugInfo>>>,
max_derivations: u32,
) -> DerivationResult {
let mut optimizer = optimizer;
if let Some(rules) = allowed_rules {
let _ = optimizer.set_rules(&rules);
}
let mut graph = Graph::new();
let start = graph.add_node(expression.clone());
let mut deriver = Deriver::new(graph, optimizer, None, debug_data);
deriver.expand_to_constraint(search_depth, max_derivations);
let simplest_exp = deriver
.node_references()
.min_by(|a, b| better_solution_cmp(a.1, b.1))
.expect("There must be at least one node");
let success = expression != simplest_exp.1;
let result_path = if success {
Some(build_path(&deriver, expression, start, simplest_exp.0))
} else {
None
};
DerivationResult {
success,
steps: result_path,
}
}
#[derive(Serialize)]
pub struct DerivationResult {
pub success: bool,
pub steps: Option<Path>,
}
pub fn simplify_incremental(
expression: &Expression,
optimizer: Box<dyn OptimizationProfile>,
allowed_rules: Option<Vec<String>>,
) -> DerivationHandle {
let mut graph = Graph::new();
let index = graph.add_node(expression.clone());
let deriver = Deriver::new(graph, optimizer, allowed_rules, None);
DerivationHandle {
deriver,
start_exp: expression.clone(),
start: index,
}
}
pub struct DerivationHandle {
deriver: Deriver,
start: NodeIndex,
start_exp: Expression,
}
impl DerivationHandle {
pub fn do_pass(&mut self, derivations: u32) -> IncrementalResult {
let finished = self.deriver.expand_increment(derivations);
let simplest_exp = self
.deriver
.node_references()
.min_by(|a, b| better_solution_cmp(a.1, b.1))
.expect("There must be at least one node");
let result_path = build_path(&self.deriver, &self.start_exp, self.start, simplest_exp.0);
IncrementalResult {
steps: Some(result_path),
failed: None,
finished,
}
}
pub fn deriver(&self) -> &Deriver {
&self.deriver
}
}
#[derive(Serialize)]
pub struct IncrementalResult {
pub steps: Option<Path>,
pub failed: Option<String>,
pub finished: bool,
}
impl IncrementalResult {
pub fn failed(reason: String) -> Self {
Self {
steps: None,
failed: Some(reason),
finished: false,
}
}
}
pub fn get_all_equivalents(
expression: &Expression,
optimizer: Box<dyn OptimizationProfile>,
search_depth: u32,
max_derivations: u32,
) -> String {
let mut graph = Graph::new();
graph.add_node(expression.clone());
let mut deriver = Deriver::new(graph, optimizer, None, None);
deriver.expand_to_constraint(search_depth, max_derivations);
let mut result = Vec::new();
result.extend(deriver.node_weights().map(|e| e.as_stringable().to_json()));
let rules = deriver
.edge_weights()
.flat_map(|e| &e.derived_from)
.map(|arg| arg.message().to_string())
.collect::<HashSet<String>>();
json!({
"equivalents": &result,
"rules_used": rules,
})
.to_string()
}
fn build_path(graph: &Graph, start_exp: &Expression, start: NodeIndex, end: NodeIndex) -> Path {
let shortest_path = astar(&graph, start, |n| n == end, |_| 1, |_| 0)
.expect("There must be a path because the graph is connected");
let mut result_path = Path {
start: start_exp.clone(),
steps: vec![],
};
let mut last_node = start;
for step in shortest_path.1.iter().skip(1) {
let edge_id = graph.find_edge(last_node, *step).unwrap();
result_path.steps.push((
graph
.edge_weight(edge_id)
.unwrap()
.derived_from
.iter()
.next()
.unwrap()
.clone(),
graph.node_weight(*step).unwrap().clone(),
));
last_node = *step;
}
result_path
}
#[cfg(test)]
mod tests {
use crate::{expressions::Exponent, integral::Integral, product::product_of, sum::sum_of};
use integer::Integer;
use optimization_profiles::BruteForceProfile;
use variable::Variable;
use super::*;
#[test]
fn incremental_derivation() {
let start = sum_of(&[
product_of(&[Integer::of(8), Integer::of(1)]),
Integer::of(1),
]);
let mut graph = Graph::new();
let index = graph.add_node(start.clone());
let mut deriver = Deriver::new(graph, BruteForceProfile::new(), None, None);
deriver.expand_increment(100);
let simplest_exp = deriver
.node_references()
.min_by(|a, b| better_solution_cmp(a.1, b.1))
.expect("There must be at least one node");
let path = build_path(&deriver, &start, index, simplest_exp.0);
assert_eq!(path.steps.len(), 2);
}
#[test]
fn crate_doc_example() {
let start = Integral::of(
product_of(&[Integer::of(2), Variable::of("x")]),
Variable::of("x"),
);
let mut handle = simplify_incremental(&start, EvaluateFirstProfile::new(), None);
let result = handle.do_pass(20);
assert!(result.finished);
assert_eq!(result.failed, None);
assert_eq!(
result.steps.unwrap().steps.last().unwrap().1,
Exponent::of(Variable::of("x"), Integer::of(2))
);
let _graph: &Graph = handle.deriver();
}
}