1use crate::{Arity, Eval, Factory, NodeValue, TreeNode};
2#[cfg(feature = "pgm")]
3use std::sync::Arc;
4use std::{
5 fmt::{Debug, Display},
6 hash::Hash,
7};
8
9pub enum Op<T> {
22 Fn(&'static str, Arity, fn(&[T]) -> T),
29 Var(&'static str, usize),
35 Const(&'static str, T),
41 MutableConst {
53 name: &'static str,
54 arity: Arity,
55 value: T,
56 supplier: fn() -> T,
57 modifier: fn(&T) -> T,
58 operation: fn(&[T], &T) -> T,
59 },
60 #[cfg(feature = "pgm")]
70 PGM(
71 &'static str,
72 Arity,
73 Arc<Vec<TreeNode<Op<T>>>>,
74 fn(&[T], &[TreeNode<Op<T>>]) -> T,
75 ),
76}
77
78impl<T> Op<T> {
79 pub fn name(&self) -> &str {
80 match self {
81 Op::Fn(name, _, _) => name,
82 Op::Var(name, _) => name,
83 Op::Const(name, _) => name,
84 Op::MutableConst { name, .. } => name,
85 #[cfg(feature = "pgm")]
86 Op::PGM(name, _, _, _) => name,
87 }
88 }
89
90 pub fn arity(&self) -> Arity {
91 match self {
92 Op::Fn(_, arity, _) => *arity,
93 Op::Var(_, _) => Arity::Zero,
94 Op::Const(_, _) => Arity::Zero,
95 Op::MutableConst { arity, .. } => *arity,
96 #[cfg(feature = "pgm")]
97 Op::PGM(_, arity, _, _) => *arity,
98 }
99 }
100
101 pub fn is_fn(&self) -> bool {
102 matches!(self, Op::Fn(_, _, _))
103 }
104
105 pub fn is_var(&self) -> bool {
106 matches!(self, Op::Var(_, _))
107 }
108
109 pub fn is_const(&self) -> bool {
110 matches!(self, Op::Const(_, _))
111 }
112
113 pub fn is_mutable_const(&self) -> bool {
114 matches!(self, Op::MutableConst { .. })
115 }
116
117 #[cfg(feature = "pgm")]
118 pub fn is_pgm(&self) -> bool {
119 matches!(self, Op::PGM(_, _, _, _))
120 }
121}
122
123unsafe impl<T> Send for Op<T> {}
124unsafe impl<T> Sync for Op<T> {}
125
126impl<T> Eval<[T], T> for Op<T>
127where
128 T: Clone,
129{
130 fn eval(&self, inputs: &[T]) -> T {
131 match self {
132 Op::Fn(_, _, op) => op(inputs),
133 Op::Var(_, index) => inputs[*index].clone(),
134 Op::Const(_, value) => value.clone(),
135 Op::MutableConst {
136 value, operation, ..
137 } => operation(inputs, value),
138 #[cfg(feature = "pgm")]
139 Op::PGM(_, _, model, operation) => operation(inputs, &model),
140 }
141 }
142}
143
144impl<T> Factory<(), Op<T>> for Op<T>
145where
146 T: Clone,
147{
148 fn new_instance(&self, _: ()) -> Op<T> {
149 match self {
150 Op::Fn(name, arity, op) => Op::Fn(name, *arity, *op),
151 Op::Var(name, index) => Op::Var(name, *index),
152 Op::Const(name, value) => Op::Const(name, value.clone()),
153 Op::MutableConst {
154 name,
155 arity,
156 value: _,
157 supplier,
158 modifier,
159 operation,
160 } => Op::MutableConst {
161 name,
162 arity: *arity,
163 value: (*supplier)(),
164 supplier: *supplier,
165 modifier: *modifier,
166 operation: *operation,
167 },
168 #[cfg(feature = "pgm")]
169 Op::PGM(name, arity, model, operation) => {
170 use std::sync::Arc;
171 Op::PGM(name, *arity, Arc::clone(model), *operation)
172 }
173 }
174 }
175}
176
177impl<T> Clone for Op<T>
178where
179 T: Clone,
180{
181 fn clone(&self) -> Self {
182 match self {
183 Op::Fn(name, arity, op) => Op::Fn(name, *arity, *op),
184 Op::Var(name, index) => Op::Var(name, *index),
185 Op::Const(name, value) => Op::Const(name, value.clone()),
186 Op::MutableConst {
187 name,
188 arity,
189 value,
190 supplier,
191 modifier,
192 operation,
193 } => Op::MutableConst {
194 name,
195 arity: *arity,
196 value: value.clone(),
197 supplier: *supplier,
198 modifier: *modifier,
199 operation: *operation,
200 },
201 #[cfg(feature = "pgm")]
202 Op::PGM(name, arity, model, operation) => {
203 use std::sync::Arc;
204 Op::PGM(name, *arity, Arc::clone(model), *operation)
205 }
206 }
207 }
208}
209
210impl<T> PartialEq for Op<T>
211where
212 T: PartialEq,
213{
214 fn eq(&self, other: &Self) -> bool {
215 self.name() == other.name()
216 }
217}
218
219impl<T> Hash for Op<T> {
220 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
221 self.name().hash(state);
222 }
223}
224
225impl<T> Display for Op<T> {
226 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
227 write!(f, "{}", self.name())
228 }
229}
230
231impl<T> Default for Op<T>
232where
233 T: Default,
234{
235 fn default() -> Self {
236 Op::Fn("default", Arity::Zero, |_: &[T]| T::default())
237 }
238}
239
240impl<T> Debug for Op<T>
241where
242 T: Debug,
243{
244 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
245 match self {
246 Op::Fn(name, _, _) => write!(f, "Fn: {}", name),
247 Op::Var(name, index) => write!(f, "Var: {}({})", name, index),
248 Op::Const(name, value) => write!(f, "C: {}({:?})", name, value),
249 Op::MutableConst { name, value, .. } => write!(f, "{}({:.2?})", name, value),
250 #[cfg(feature = "pgm")]
251 Op::PGM(name, _, model, _) => {
252 let mut model_str = String::new();
253 for (i, node) in model.iter().enumerate() {
254 use crate::Format;
255
256 let node_str = &node.format();
257 model_str.push_str(&format!("[{}: S {} Prog {}], ", i, node.size(), node_str));
258 }
259 write!(f, "PGM: {}({})", name, model_str)
260 }
261 }
262 }
263}
264
265impl<T> From<Op<T>> for NodeValue<Op<T>> {
266 fn from(value: Op<T>) -> Self {
267 let arity = value.arity();
268 NodeValue::Bounded(value, arity)
269 }
270}
271
272impl<T> From<Op<T>> for TreeNode<Op<T>> {
273 fn from(value: Op<T>) -> Self {
274 let arity = value.arity();
275 TreeNode::with_arity(value, arity)
276 }
277}
278
279impl<T> From<Op<T>> for Vec<TreeNode<Op<T>>> {
280 fn from(value: Op<T>) -> Self {
281 vec![TreeNode::from(value)]
282 }
283}
284
285#[cfg(test)]
286mod test {
287 use super::*;
288 use radiate_core::random_provider;
289
290 #[test]
291 fn test_ops() {
292 let op = Op::add();
293 assert_eq!(op.name(), "add");
294 assert_eq!(op.arity(), Arity::Exact(2));
295 assert_eq!(op.eval(&[1_f32, 2_f32]), 3_f32);
296 assert_eq!(op.new_instance(()), op);
297 }
298
299 #[test]
300 fn test_random_seed_works() {
301 random_provider::set_seed(42);
302
303 let op = Op::weight();
304 let op2 = Op::weight();
305
306 let o_one = match op {
307 Op::MutableConst { value, .. } => value,
308 _ => panic!("Expected MutableConst"),
309 };
310
311 let o_two = match op2 {
312 Op::MutableConst { value, .. } => value,
313 _ => panic!("Expected MutableConst"),
314 };
315
316 println!("o_one: {:?}", o_one);
317 println!("o_two: {:?}", o_two);
318 }
319
320 #[test]
321 fn test_op_clone() {
322 let op = Op::add();
323 let op2 = op.clone();
324
325 let result = op.eval(&[1_f32, 2_f32]);
326 let result2 = op2.eval(&[1_f32, 2_f32]);
327
328 assert_eq!(op, op2);
329 assert_eq!(result, result2);
330 }
331
332 #[test]
333 #[cfg(feature = "pgm")]
334 fn test_pgm_op() {
335 use std::sync::Arc;
336 let model = TreeNode::with_children(
337 Op::add(),
338 vec![
339 TreeNode::new(Op::constant(1_f32)),
340 TreeNode::new(Op::constant(2_f32)),
341 ],
342 );
343
344 let pgm_op = Op::PGM(
345 "pgm",
346 Arity::Any,
347 Arc::new(vec![model]),
348 |inputs: &[f32], prog: &[TreeNode<Op<f32>>]| {
349 let sum: f32 = prog.iter().map(|node| node.eval(inputs)).sum();
350 sum + inputs.iter().sum::<f32>()
351 },
352 );
353
354 let result = pgm_op.eval(&[]);
355 assert_eq!(result, 3_f32);
356 }
357}