1use std::{collections::HashMap, sync::LazyLock};
2
3use itertools::Itertools;
4
5use super::*;
6
7#[derive(Clone, PartialEq, Eq, Hash)]
8pub struct Func {
9 pub name: String,
10 pub args: Vec<Box<dyn Expr>>,
11}
12
13impl Func {
14 pub fn new<'a, T: IntoIterator<Item = &'a dyn Expr>>(name: &str, args: T) -> Self {
15 Func {
16 name: name.to_string(),
17 args: args.into_iter().map(|expr| expr.clone_box()).collect(),
18 }
19 }
20
21 pub fn new_move(name: String, args: Vec<Box<dyn Expr>>) -> Func {
22 Func { name, args }
23 }
24
25 pub fn new_move_box(name: String, args: Vec<Box<dyn Expr>>) -> Box<dyn Expr> {
26 Box::new(Func::new_move(name, args))
27 }
28
29 pub fn time_discretize(&self) -> [Func; 2] {
30 return [
31 Func {
32 name: format!("{}^n-1", self.name),
33 args: self.args.clone(),
34 },
35 Func {
36 name: format!("{}^n", self.name),
37 args: self.args.clone(),
38 },
39 ];
40 }
41
42 pub fn to_vector(&self) -> Func {
43 let pieces: Vec<_> = self.name.split("^").collect();
44 let mut name = pieces[0].to_uppercase();
45
46 for piece in &pieces[1..] {
47 name += &format!("^{piece}");
48 }
49
50 Func {
51 name,
52 args: self.args.clone(),
53 }
54 }
55}
56
57impl Expr for Func {
58 fn get_ref<'a>(&'a self) -> &'a dyn Expr {
59 self as &dyn Expr
60 }
61 fn for_each_arg(&self, f: &mut dyn FnMut(&dyn Arg) -> ()) {
62 f(&self.name);
63 f(&self.args);
64 }
65
66 fn from_args(&self, args: Vec<Box<dyn Arg>>) -> Box<dyn Expr> {
67 let name = args[0]
68 .as_any()
69 .downcast_ref::<String>()
70 .expect("First arg should be string")
71 .clone();
72 let params = args[1]
73 .as_any()
74 .downcast_ref::<Vec<Box<dyn Expr>>>()
75 .unwrap();
76
77 Box::new(Func {
78 name,
79 args: params.to_vec(),
80 })
81 }
82
83 fn clone_box(&self) -> Box<dyn Expr> {
84 Box::new(self.clone())
85 }
86
87 fn str(&self) -> String {
88 format!(
89 "{}",
90 self.name,
91 )
93 }
94
95 fn to_cpp(&self) -> String {
96 if !self.name.contains("^") && self.name.len() > 1 {
97 format!(
98 "std::{}({})",
99 self.name,
100 self.args
101 .iter()
102 .map(|x| x.to_cpp())
103 .collect_vec()
104 .join(", ")
105 )
106 } else {
107 self.name
108 .replace("^n-1", "_prev")
109 .replace("^n", "")
110 .to_lowercase()
111 }
112 }
113
114 fn as_function(&self) -> Option<&Func> {
115 Some(self)
116 }
117}
118
119static CPP_FUNC_NAMES: LazyLock<HashMap<&'static str, &'static str>> = LazyLock::new(|| {
120 let res = HashMap::from([("sin", "sin"), ("cos", "cos"), ("sqrt", "sqrt")]);
121
122 return res;
123});
124
125impl fmt::Debug for Func {
126 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
127 write!(f, "{}", self.str())
128 }
129}