radiate_gp/ops/
operation.rs

1use crate::{Arity, Eval, Factory, NodeValue, TreeNode};
2use std::{
3    fmt::{Debug, Display},
4    sync::Arc,
5};
6
7/// A generic operation type that can represent several kinds of “ops”.
8pub enum Op<T> {
9    /// 1) A stateless function operation:
10    ///
11    /// # Arguments
12    ///    - A `&'static str` name (e.g., "Add", "Sigmoid")
13    ///    - Arity (how many inputs it takes)
14    ///    - Arc<dyn Fn(&[T]) -> T> for the actual function logic
15    Fn(&'static str, Arity, Arc<dyn Fn(&[T]) -> T>),
16    /// 2) A variable-like operation:
17    ///
18    /// # Arguments
19    ///    - `String` = a name or identifier
20    ///    - `usize` = perhaps an index to retrieve from some external context
21    Var(&'static str, usize),
22    /// 3) A compile-time constant: e.g., 1, 2, 3, etc.
23    ///
24    /// # Arguments
25    ///    - `&'static str` name
26    ///    - `T` the actual constant value
27    Const(&'static str, T),
28    /// 4) A `mutable const` is a constant that can change over time:
29    ///
30    ///  # Arguments
31    /// - `&'static str` name
32    /// - `Arity` of how many inputs it might read
33    /// - Current value of type `T`
34    /// - An `Arc<dyn Fn() -> T>` for retrieving (or resetting) the value
35    /// - An `Arc<dyn Fn(&[T], &T) -> T>` for updating or combining inputs & old value -> new value
36    ///
37    ///    This suggests a node that can mutate its internal state over time, or
38    ///    one that needs a special function to incorporate the inputs into the next state.
39    MutableConst {
40        name: &'static str,
41        arity: Arity,
42        value: T,
43        supplier: Arc<dyn Fn() -> T>,
44        modifier: Arc<dyn Fn(&T) -> T>,
45        operation: Arc<dyn Fn(&[T], &T) -> T>,
46    },
47    /// 5) A 'Value' operation that can be used as a 'stateful' constant:
48    ///
49    /// # Arguments
50    /// - `&'static str` name
51    /// - `Arity` of how many inputs it might read
52    /// - Current value of type `T`
53    /// - An `Arc<dyn Fn(&[T], &T) -> T>` for updating or combining inputs & old value -> new value
54    Value(&'static str, Arity, T, Arc<dyn Fn(&[T], &T) -> T>),
55}
56
57/// Base functionality for operations.
58impl<T> Op<T> {
59    pub fn name(&self) -> &str {
60        match self {
61            Op::Fn(name, _, _) => name,
62            Op::Var(name, _) => name,
63            Op::Const(name, _) => name,
64            Op::MutableConst { name, .. } => name,
65            Op::Value(name, _, _, _) => name,
66        }
67    }
68
69    pub fn arity(&self) -> Arity {
70        match self {
71            Op::Fn(_, arity, _) => *arity,
72            Op::Var(_, _) => Arity::Zero,
73            Op::Const(_, _) => Arity::Zero,
74            Op::MutableConst { arity, .. } => *arity,
75            Op::Value(_, arity, _, _) => *arity,
76        }
77    }
78
79    pub fn constant(value: T) -> Self
80    where
81        T: Display,
82    {
83        let name = Box::leak(Box::new(format!("{}", value)));
84        Op::Const(name, value)
85    }
86
87    pub fn named_constant(name: &'static str, value: T) -> Self {
88        Op::Const(name, value)
89    }
90
91    pub fn identity() -> Self
92    where
93        T: Clone,
94    {
95        Op::Fn(
96            "identity",
97            1.into(),
98            Arc::new(|inputs: &[T]| inputs[0].clone()),
99        )
100    }
101
102    pub fn var(index: usize) -> Self {
103        let name = Box::leak(Box::new(format!("{}", index)));
104        Op::Var(name, index)
105    }
106}
107
108unsafe impl Send for Op<f32> {}
109unsafe impl Sync for Op<f32> {}
110
111impl<T> Into<NodeValue<Op<T>>> for Op<T>
112where
113    T: Clone,
114{
115    fn into(self) -> NodeValue<Op<T>> {
116        let arity = self.arity();
117        NodeValue::Bounded(self, arity)
118    }
119}
120
121impl<T> Into<TreeNode<Op<T>>> for Op<T> {
122    fn into(self) -> TreeNode<Op<T>> {
123        TreeNode::new(self)
124    }
125}
126
127impl<T> Eval<[T], T> for Op<T>
128where
129    T: Clone,
130{
131    fn eval(&self, inputs: &[T]) -> T {
132        match self {
133            Op::Fn(_, _, op) => op(inputs),
134            Op::Var(_, index) => inputs[*index].clone(),
135            Op::Const(_, value) => value.clone(),
136            Op::MutableConst {
137                value, operation, ..
138            } => operation(inputs, value),
139            Op::Value(_, _, value, operation) => operation(inputs, value),
140        }
141    }
142}
143
144impl<T> Factory<(), Op<T>> for Op<T>
145where
146    T: Clone,
147{
148    fn new_instance(&self, _: ()) -> Op<T> {
149        match self {
150            Op::Fn(name, arity, op) => Op::Fn(name, *arity, Arc::clone(op)),
151            Op::Var(name, index) => Op::Var(name, *index),
152            Op::Const(name, value) => Op::Const(name, value.clone()),
153            Op::MutableConst {
154                name,
155                arity,
156                value: _,
157                supplier,
158                modifier,
159                operation,
160            } => Op::MutableConst {
161                name,
162                arity: *arity,
163                value: (*supplier)(),
164                supplier: Arc::clone(supplier),
165                modifier: Arc::clone(modifier),
166                operation: Arc::clone(operation),
167            },
168            Op::Value(name, arity, value, operation) => {
169                Op::Value(name, *arity, value.clone(), Arc::clone(operation))
170            }
171        }
172    }
173}
174
175impl<T> Clone for Op<T>
176where
177    T: Clone,
178{
179    fn clone(&self) -> Self {
180        match self {
181            Op::Fn(name, arity, op) => Op::Fn(name, *arity, Arc::clone(op)),
182            Op::Var(name, index) => Op::Var(name, *index),
183            Op::Const(name, value) => Op::Const(name, value.clone()),
184            Op::MutableConst {
185                name,
186                arity,
187                value,
188                supplier,
189                modifier,
190                operation,
191            } => Op::MutableConst {
192                name,
193                arity: *arity,
194                value: value.clone(),
195                supplier: Arc::clone(supplier),
196                modifier: Arc::clone(modifier),
197                operation: Arc::clone(operation),
198            },
199            Op::Value(name, arity, value, operation) => {
200                Op::Value(name, *arity, value.clone(), Arc::clone(operation))
201            }
202        }
203    }
204}
205
206impl<T> PartialEq for Op<T>
207where
208    T: PartialEq,
209{
210    fn eq(&self, other: &Self) -> bool {
211        self.name() == other.name()
212    }
213}
214
215impl<T> Display for Op<T> {
216    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
217        write!(f, "{}", self.name())
218    }
219}
220
221impl<T> Default for Op<T>
222where
223    T: Default,
224{
225    fn default() -> Self {
226        Op::Fn("default", Arity::Zero, Arc::new(|_| T::default()))
227    }
228}
229
230impl<T> Debug for Op<T>
231where
232    T: Debug,
233{
234    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
235        match self {
236            Op::Fn(name, _, _) => write!(f, "Fn: {}", name),
237            Op::Var(name, index) => write!(f, "Var: {}({})", name, index),
238            Op::Const(name, value) => write!(f, "C: {}({:?})", name, value),
239            Op::MutableConst { name, value, .. } => write!(f, "{}({:.2?})", name, value),
240            Op::Value(name, _, value, _) => write!(f, "{}({:.2?})", name, value),
241        }
242    }
243}
244
245impl Into<Op<f32>> for f32 {
246    fn into(self) -> Op<f32> {
247        Op::Value("Value(f32)", Arity::Any, self, Arc::new(|_, v| *v))
248    }
249}
250
251impl Into<Op<i32>> for i32 {
252    fn into(self) -> Op<i32> {
253        Op::Value("Value(i32)", Arity::Any, self, Arc::new(|_, v| *v))
254    }
255}
256
257impl Into<Op<bool>> for bool {
258    fn into(self) -> Op<bool> {
259        Op::Value("Value(bool)", Arity::Any, self, Arc::new(|_, v| *v))
260    }
261}
262
263#[cfg(test)]
264mod test {
265    use super::*;
266    use radiate::random_provider;
267
268    #[test]
269    fn test_ops() {
270        let op = Op::add();
271        assert_eq!(op.name(), "add");
272        assert_eq!(op.arity(), Arity::Exact(2));
273        assert_eq!(op.eval(&vec![1_f32, 2_f32]), 3_f32);
274        assert_eq!(op.new_instance(()), op);
275    }
276
277    #[test]
278    fn test_random_seed_works() {
279        random_provider::set_seed(42);
280
281        let op = Op::weight();
282        let op2 = Op::weight();
283
284        let o_one = match op {
285            Op::MutableConst { value, .. } => value,
286            _ => panic!("Expected MutableConst"),
287        };
288
289        let o_two = match op2 {
290            Op::MutableConst { value, .. } => value,
291            _ => panic!("Expected MutableConst"),
292        };
293
294        println!("o_one: {:?}", o_one);
295        println!("o_two: {:?}", o_two);
296    }
297
298    #[test]
299    fn test_op_clone() {
300        let op = Op::add();
301        let op2 = op.clone();
302
303        let result = op.eval(&vec![1_f32, 2_f32]);
304        let result2 = op2.eval(&vec![1_f32, 2_f32]);
305
306        assert_eq!(op, op2);
307        assert_eq!(result, result2);
308    }
309}