symrs/expr/
function.rs

1use std::{collections::HashMap, sync::LazyLock};
2
3use itertools::Itertools;
4
5use super::*;
6
7#[derive(Clone, PartialEq, Eq, Hash)]
8pub struct Func {
9    pub name: String,
10    pub args: Vec<Box<dyn Expr>>,
11}
12
13impl Func {
14    pub fn new<'a, T: IntoIterator<Item = &'a dyn Expr>>(name: &str, args: T) -> Self {
15        Func {
16            name: name.to_string(),
17            args: args.into_iter().map(|expr| expr.clone_box()).collect(),
18        }
19    }
20
21    pub fn new_move(name: String, args: Vec<Box<dyn Expr>>) -> Func {
22        Func { name, args }
23    }
24
25    pub fn new_move_box(name: String, args: Vec<Box<dyn Expr>>) -> Box<dyn Expr> {
26        Box::new(Func::new_move(name, args))
27    }
28
29    pub fn time_discretize(&self) -> [Func; 2] {
30        return [
31            Func {
32                name: format!("{}^n-1", self.name),
33                args: self.args.clone(),
34            },
35            Func {
36                name: format!("{}^n", self.name),
37                args: self.args.clone(),
38            },
39        ];
40    }
41
42    pub fn to_vector(&self) -> Func {
43        let pieces: Vec<_> = self.name.split("^").collect();
44        let mut name = pieces[0].to_uppercase();
45
46        for piece in &pieces[1..] {
47            name += &format!("^{piece}");
48        }
49
50        Func {
51            name,
52            args: self.args.clone(),
53        }
54    }
55}
56
57impl Expr for Func {
58    fn get_ref<'a>(&'a self) -> &'a dyn Expr {
59        self as &dyn Expr
60    }
61    fn for_each_arg(&self, f: &mut dyn FnMut(&dyn Arg) -> ()) {
62        f(&self.name);
63        f(&self.args);
64    }
65
66    fn from_args(&self, args: Vec<Box<dyn Arg>>) -> Box<dyn Expr> {
67        let name = args[0]
68            .as_any()
69            .downcast_ref::<String>()
70            .expect("First arg should be string")
71            .clone();
72        let params = args[1]
73            .as_any()
74            .downcast_ref::<Vec<Box<dyn Expr>>>()
75            .unwrap();
76
77        Box::new(Func {
78            name,
79            args: params.to_vec(),
80        })
81    }
82
83    fn clone_box(&self) -> Box<dyn Expr> {
84        Box::new(self.clone())
85    }
86
87    fn str(&self) -> String {
88        format!(
89            "{}",
90            self.name,
91            // self.args.iter().map(|x| x.str()).collect::<String>()
92        )
93    }
94
95    fn to_cpp(&self) -> String {
96        if !self.name.contains("^") && self.name.len() > 1 {
97            format!(
98                "std::{}({})",
99                self.name,
100                self.args
101                    .iter()
102                    .map(|x| x.to_cpp())
103                    .collect_vec()
104                    .join(", ")
105            )
106        } else {
107            self.name
108                .replace("^n-1", "_prev")
109                .replace("^n", "")
110                .to_lowercase()
111        }
112    }
113
114    fn as_function(&self) -> Option<&Func> {
115        Some(self)
116    }
117}
118
119static CPP_FUNC_NAMES: LazyLock<HashMap<&'static str, &'static str>> = LazyLock::new(|| {
120    let res = HashMap::from([("sin", "sin"), ("cos", "cos"), ("sqrt", "sqrt")]);
121
122    return res;
123});
124
125impl fmt::Debug for Func {
126    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
127        write!(f, "{}", self.str())
128    }
129}