1use crate::{Arity, Eval, Factory, NodeValue, TreeNode};
2use std::{
3 fmt::{Debug, Display},
4 sync::Arc,
5};
6
7pub enum Op<T> {
9 Fn(&'static str, Arity, Arc<dyn Fn(&[T]) -> T>),
16 Var(&'static str, usize),
22 Const(&'static str, T),
28 MutableConst {
40 name: &'static str,
41 arity: Arity,
42 value: T,
43 supplier: Arc<dyn Fn() -> T>,
44 modifier: Arc<dyn Fn(&T) -> T>,
45 operation: Arc<dyn Fn(&[T], &T) -> T>,
46 },
47 Value(&'static str, Arity, T, Arc<dyn Fn(&[T], &T) -> T>),
55}
56
57impl<T> Op<T> {
59 pub fn name(&self) -> &str {
60 match self {
61 Op::Fn(name, _, _) => name,
62 Op::Var(name, _) => name,
63 Op::Const(name, _) => name,
64 Op::MutableConst { name, .. } => name,
65 Op::Value(name, _, _, _) => name,
66 }
67 }
68
69 pub fn arity(&self) -> Arity {
70 match self {
71 Op::Fn(_, arity, _) => *arity,
72 Op::Var(_, _) => Arity::Zero,
73 Op::Const(_, _) => Arity::Zero,
74 Op::MutableConst { arity, .. } => *arity,
75 Op::Value(_, arity, _, _) => *arity,
76 }
77 }
78
79 pub fn constant(value: T) -> Self
80 where
81 T: Display,
82 {
83 let name = Box::leak(Box::new(format!("{}", value)));
84 Op::Const(name, value)
85 }
86
87 pub fn named_constant(name: &'static str, value: T) -> Self {
88 Op::Const(name, value)
89 }
90
91 pub fn identity() -> Self
92 where
93 T: Clone,
94 {
95 Op::Fn(
96 "identity",
97 1.into(),
98 Arc::new(|inputs: &[T]| inputs[0].clone()),
99 )
100 }
101
102 pub fn var(index: usize) -> Self {
103 let name = Box::leak(Box::new(format!("{}", index)));
104 Op::Var(name, index)
105 }
106}
107
108unsafe impl Send for Op<f32> {}
109unsafe impl Sync for Op<f32> {}
110
111impl<T> Into<NodeValue<Op<T>>> for Op<T>
112where
113 T: Clone,
114{
115 fn into(self) -> NodeValue<Op<T>> {
116 let arity = self.arity();
117 NodeValue::Bounded(self, arity)
118 }
119}
120
121impl<T> Into<TreeNode<Op<T>>> for Op<T> {
122 fn into(self) -> TreeNode<Op<T>> {
123 TreeNode::new(self)
124 }
125}
126
127impl<T> Eval<[T], T> for Op<T>
128where
129 T: Clone,
130{
131 fn eval(&self, inputs: &[T]) -> T {
132 match self {
133 Op::Fn(_, _, op) => op(inputs),
134 Op::Var(_, index) => inputs[*index].clone(),
135 Op::Const(_, value) => value.clone(),
136 Op::MutableConst {
137 value, operation, ..
138 } => operation(inputs, value),
139 Op::Value(_, _, value, operation) => operation(inputs, value),
140 }
141 }
142}
143
144impl<T> Factory<(), Op<T>> for Op<T>
145where
146 T: Clone,
147{
148 fn new_instance(&self, _: ()) -> Op<T> {
149 match self {
150 Op::Fn(name, arity, op) => Op::Fn(name, *arity, Arc::clone(op)),
151 Op::Var(name, index) => Op::Var(name, *index),
152 Op::Const(name, value) => Op::Const(name, value.clone()),
153 Op::MutableConst {
154 name,
155 arity,
156 value: _,
157 supplier,
158 modifier,
159 operation,
160 } => Op::MutableConst {
161 name,
162 arity: *arity,
163 value: (*supplier)(),
164 supplier: Arc::clone(supplier),
165 modifier: Arc::clone(modifier),
166 operation: Arc::clone(operation),
167 },
168 Op::Value(name, arity, value, operation) => {
169 Op::Value(name, *arity, value.clone(), Arc::clone(operation))
170 }
171 }
172 }
173}
174
175impl<T> Clone for Op<T>
176where
177 T: Clone,
178{
179 fn clone(&self) -> Self {
180 match self {
181 Op::Fn(name, arity, op) => Op::Fn(name, *arity, Arc::clone(op)),
182 Op::Var(name, index) => Op::Var(name, *index),
183 Op::Const(name, value) => Op::Const(name, value.clone()),
184 Op::MutableConst {
185 name,
186 arity,
187 value,
188 supplier,
189 modifier,
190 operation,
191 } => Op::MutableConst {
192 name,
193 arity: *arity,
194 value: value.clone(),
195 supplier: Arc::clone(supplier),
196 modifier: Arc::clone(modifier),
197 operation: Arc::clone(operation),
198 },
199 Op::Value(name, arity, value, operation) => {
200 Op::Value(name, *arity, value.clone(), Arc::clone(operation))
201 }
202 }
203 }
204}
205
206impl<T> PartialEq for Op<T>
207where
208 T: PartialEq,
209{
210 fn eq(&self, other: &Self) -> bool {
211 self.name() == other.name()
212 }
213}
214
215impl<T> Display for Op<T> {
216 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
217 write!(f, "{}", self.name())
218 }
219}
220
221impl<T> Default for Op<T>
222where
223 T: Default,
224{
225 fn default() -> Self {
226 Op::Fn("default", Arity::Zero, Arc::new(|_| T::default()))
227 }
228}
229
230impl<T> Debug for Op<T>
231where
232 T: Debug,
233{
234 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
235 match self {
236 Op::Fn(name, _, _) => write!(f, "Fn: {}", name),
237 Op::Var(name, index) => write!(f, "Var: {}({})", name, index),
238 Op::Const(name, value) => write!(f, "C: {}({:?})", name, value),
239 Op::MutableConst { name, value, .. } => write!(f, "{}({:.2?})", name, value),
240 Op::Value(name, _, value, _) => write!(f, "{}({:.2?})", name, value),
241 }
242 }
243}
244
245impl Into<Op<f32>> for f32 {
246 fn into(self) -> Op<f32> {
247 Op::Value("Value(f32)", Arity::Any, self, Arc::new(|_, v| *v))
248 }
249}
250
251impl Into<Op<i32>> for i32 {
252 fn into(self) -> Op<i32> {
253 Op::Value("Value(i32)", Arity::Any, self, Arc::new(|_, v| *v))
254 }
255}
256
257impl Into<Op<bool>> for bool {
258 fn into(self) -> Op<bool> {
259 Op::Value("Value(bool)", Arity::Any, self, Arc::new(|_, v| *v))
260 }
261}
262
263#[cfg(test)]
264mod test {
265 use super::*;
266 use radiate::random_provider;
267
268 #[test]
269 fn test_ops() {
270 let op = Op::add();
271 assert_eq!(op.name(), "add");
272 assert_eq!(op.arity(), Arity::Exact(2));
273 assert_eq!(op.eval(&vec![1_f32, 2_f32]), 3_f32);
274 assert_eq!(op.new_instance(()), op);
275 }
276
277 #[test]
278 fn test_random_seed_works() {
279 random_provider::set_seed(42);
280
281 let op = Op::weight();
282 let op2 = Op::weight();
283
284 let o_one = match op {
285 Op::MutableConst { value, .. } => value,
286 _ => panic!("Expected MutableConst"),
287 };
288
289 let o_two = match op2 {
290 Op::MutableConst { value, .. } => value,
291 _ => panic!("Expected MutableConst"),
292 };
293
294 println!("o_one: {:?}", o_one);
295 println!("o_two: {:?}", o_two);
296 }
297
298 #[test]
299 fn test_op_clone() {
300 let op = Op::add();
301 let op2 = op.clone();
302
303 let result = op.eval(&vec![1_f32, 2_f32]);
304 let result2 = op2.eval(&vec![1_f32, 2_f32]);
305
306 assert_eq!(op, op2);
307 assert_eq!(result, result2);
308 }
309}