1use smallvec::SmallVec;
2
3#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
4pub struct ExprId(pub u32);
5
6impl ExprId {
7 #[inline]
8 pub fn index(self) -> usize {
9 self.0 as usize
10 }
11}
12
13#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
14pub struct VarId(pub u32);
15
16impl VarId {
17 #[inline]
18 pub fn index(self) -> usize {
19 self.0 as usize
20 }
21}
22
23#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
24pub struct ParamId(pub u32);
25
26impl ParamId {
27 #[inline]
28 pub fn index(self) -> usize {
29 self.0 as usize
30 }
31}
32
33pub type Children = SmallVec<[ExprId; 4]>;
34
35#[derive(Clone, Debug)]
40pub enum ExprNode {
41 Const(f64),
42 Var(VarId),
43 Param(ParamId),
44 Add(Children),
45 Mul(Children),
46 Neg(ExprId),
47 Pow(ExprId, ExprId),
48 Div(ExprId, ExprId),
49 Sin(ExprId),
50 Cos(ExprId),
51 Exp(ExprId),
52 Log(ExprId),
53 Abs(ExprId),
54 Linear { coeffs: Vec<(VarId, f64)>, constant: f64 },
55}
56
57#[derive(Clone, Debug, Default)]
58pub struct ExprArena {
59 nodes: Vec<ExprNode>,
60 param_values: Vec<f64>,
61}
62
63impl ExprArena {
64 pub fn new() -> Self {
65 Self::default()
66 }
67
68 pub fn with_capacity(cap: usize) -> Self {
69 Self { nodes: Vec::with_capacity(cap), param_values: Vec::new() }
70 }
71
72 #[inline]
73 pub fn len(&self) -> usize {
74 self.nodes.len()
75 }
76
77 #[inline]
78 pub fn is_empty(&self) -> bool {
79 self.nodes.is_empty()
80 }
81
82 pub fn push(&mut self, node: ExprNode) -> ExprId {
86 let id = ExprId(u32::try_from(self.nodes.len()).expect("expression arena overflow"));
87 self.nodes.push(node);
88 id
89 }
90
91 #[inline]
92 pub fn get(&self, id: ExprId) -> &ExprNode {
93 &self.nodes[id.index()]
94 }
95
96 #[inline]
97 pub fn get_mut(&mut self, id: ExprId) -> &mut ExprNode {
98 &mut self.nodes[id.index()]
99 }
100
101 pub fn nodes(&self) -> &[ExprNode] {
102 &self.nodes
103 }
104
105 pub fn constant(&mut self, v: f64) -> ExprId {
106 self.push(ExprNode::Const(v))
107 }
108
109 pub fn var(&mut self, v: VarId) -> ExprId {
110 self.push(ExprNode::Var(v))
111 }
112
113 pub fn param(&mut self, p: ParamId) -> ExprId {
114 self.push(ExprNode::Param(p))
115 }
116
117 pub fn new_param(&mut self, value: f64) -> ParamId {
125 let id = ParamId(u32::try_from(self.param_values.len()).expect("parameter arena overflow"));
126 self.param_values.push(value);
127 id
128 }
129
130 #[inline]
131 pub fn num_params(&self) -> usize {
132 self.param_values.len()
133 }
134
135 #[inline]
141 pub fn param_value(&self, p: ParamId) -> f64 {
142 self.param_values[p.index()]
143 }
144
145 #[inline]
147 pub fn try_param_value(&self, p: ParamId) -> Option<f64> {
148 self.param_values.get(p.index()).copied()
149 }
150
151 #[inline]
158 pub fn set_param_value(&mut self, p: ParamId, value: f64) {
159 self.param_values[p.index()] = value;
160 }
161
162 pub fn linear(&mut self, coeffs: Vec<(VarId, f64)>, constant: f64) -> ExprId {
163 self.push(ExprNode::Linear { coeffs, constant })
164 }
165}