radiate_gp/ops/
operation.rs

1use crate::{Arity, Eval, Factory, NodeValue, TreeNode};
2#[cfg(feature = "pgm")]
3use std::sync::Arc;
4use std::{
5    fmt::{Debug, Display},
6    hash::Hash,
7};
8
9/// [Op] is an enumeration that represents the different types of operations
10/// that can be performed within the genetic programming framework. Each variant
11/// of the enum encapsulates a different kind of operation, allowing for a flexible
12/// and extensible way to define the behavior of nodes within trees and graphs.
13///
14/// The [Op] heavilty depends on it's [Arity] to define how many inputs it expects.
15/// This is crucial for ensuring that the operations receive the correct number of inputs
16/// and that the structures built using these operations are built in ways that respect
17/// these input requirements. For example, an addition operation would typically have an arity of 2,
18/// while a constant operation would have an arity of 0. This is the _base_ level of the GP system, meaning
19/// that everything built on top of it (trees, graphs, etc.) will relies *heavily* on how these
20/// operations are defined and used.
21pub enum Op<T> {
22    /// 1) A stateless function operation:
23    ///
24    /// # Arguments
25    ///    - A `&'static str` name (e.g., "Add", "Sigmoid")
26    ///    - Arity (how many inputs it takes)
27    ///    - Arc<dyn Fn(&`\[`T`\]`) -> T> for the actual function logic
28    Fn(&'static str, Arity, fn(&[T]) -> T),
29    /// 2) A variable-like operation:
30    ///
31    /// # Arguments
32    ///    - `String` = a name or identifier
33    ///    - `usize` = an index to retrieve from some external context
34    Var(&'static str, usize),
35    /// 3) A compile-time constant: e.g., 1, 2, 3, etc.
36    ///
37    /// # Arguments
38    ///    - `&'static str` name
39    ///    - `T` the actual constant value
40    Const(&'static str, T),
41    /// 4) A `mutable const` is a constant that can change over time:
42    ///
43    ///  # Arguments
44    /// - `&'static str` name
45    /// - `Arity` of how many inputs it might read
46    /// - Current value of type `T`
47    /// - An `Arc<dyn Fn() -> T>` for retrieving (or resetting) the value
48    /// - An `Arc<dyn Fn(&[T], &T) -> T>` for updating or combining inputs & old value -> new value
49    ///
50    ///    This suggests a node that can mutate its internal state over time, or
51    ///    one that needs a special function to incorporate the inputs into the next state.
52    MutableConst {
53        name: &'static str,
54        arity: Arity,
55        value: T,
56        supplier: fn() -> T,
57        modifier: fn(&T) -> T,
58        operation: fn(&[T], &T) -> T,
59    },
60    /// 5) A Probabilistic Graph Model (PGM) operation that can be used to create complex functions that can
61    /// be used to discover _how_ the inputs relate to the output and can be used to sample new inputs
62    /// based on the learned relationships.
63    ///
64    /// # Arguments
65    /// - `&'static str` name
66    /// - `Arity` of how many inputs it might read
67    /// - A `Vec<TreeNode<Op<T>>>` that can be used to learn from the inputs and generate new outputs based on the learned relationships.
68    /// - An `Arc<dyn Fn(&[T], &[TreeNode<Op<T>>]) -> T>` for the actual function logic that uses the inputs and the PGM to produce an output.
69    #[cfg(feature = "pgm")]
70    PGM(
71        &'static str,
72        Arity,
73        Arc<Vec<TreeNode<Op<T>>>>,
74        fn(&[T], &[TreeNode<Op<T>>]) -> T,
75    ),
76}
77
78impl<T> Op<T> {
79    pub fn name(&self) -> &str {
80        match self {
81            Op::Fn(name, _, _) => name,
82            Op::Var(name, _) => name,
83            Op::Const(name, _) => name,
84            Op::MutableConst { name, .. } => name,
85            #[cfg(feature = "pgm")]
86            Op::PGM(name, _, _, _) => name,
87        }
88    }
89
90    pub fn arity(&self) -> Arity {
91        match self {
92            Op::Fn(_, arity, _) => *arity,
93            Op::Var(_, _) => Arity::Zero,
94            Op::Const(_, _) => Arity::Zero,
95            Op::MutableConst { arity, .. } => *arity,
96            #[cfg(feature = "pgm")]
97            Op::PGM(_, arity, _, _) => *arity,
98        }
99    }
100
101    pub fn is_fn(&self) -> bool {
102        matches!(self, Op::Fn(_, _, _))
103    }
104
105    pub fn is_var(&self) -> bool {
106        matches!(self, Op::Var(_, _))
107    }
108
109    pub fn is_const(&self) -> bool {
110        matches!(self, Op::Const(_, _))
111    }
112
113    pub fn is_mutable_const(&self) -> bool {
114        matches!(self, Op::MutableConst { .. })
115    }
116
117    #[cfg(feature = "pgm")]
118    pub fn is_pgm(&self) -> bool {
119        matches!(self, Op::PGM(_, _, _, _))
120    }
121}
122
123unsafe impl<T> Send for Op<T> {}
124unsafe impl<T> Sync for Op<T> {}
125
126impl<T> Eval<[T], T> for Op<T>
127where
128    T: Clone,
129{
130    fn eval(&self, inputs: &[T]) -> T {
131        match self {
132            Op::Fn(_, _, op) => op(inputs),
133            Op::Var(_, index) => inputs[*index].clone(),
134            Op::Const(_, value) => value.clone(),
135            Op::MutableConst {
136                value, operation, ..
137            } => operation(inputs, value),
138            #[cfg(feature = "pgm")]
139            Op::PGM(_, _, model, operation) => operation(inputs, &model),
140        }
141    }
142}
143
144impl<T> Factory<(), Op<T>> for Op<T>
145where
146    T: Clone,
147{
148    fn new_instance(&self, _: ()) -> Op<T> {
149        match self {
150            Op::Fn(name, arity, op) => Op::Fn(name, *arity, *op),
151            Op::Var(name, index) => Op::Var(name, *index),
152            Op::Const(name, value) => Op::Const(name, value.clone()),
153            Op::MutableConst {
154                name,
155                arity,
156                value: _,
157                supplier,
158                modifier,
159                operation,
160            } => Op::MutableConst {
161                name,
162                arity: *arity,
163                value: (*supplier)(),
164                supplier: *supplier,
165                modifier: *modifier,
166                operation: *operation,
167            },
168            #[cfg(feature = "pgm")]
169            Op::PGM(name, arity, model, operation) => {
170                use std::sync::Arc;
171                Op::PGM(name, *arity, Arc::clone(model), *operation)
172            }
173        }
174    }
175}
176
177impl<T> Clone for Op<T>
178where
179    T: Clone,
180{
181    fn clone(&self) -> Self {
182        match self {
183            Op::Fn(name, arity, op) => Op::Fn(name, *arity, *op),
184            Op::Var(name, index) => Op::Var(name, *index),
185            Op::Const(name, value) => Op::Const(name, value.clone()),
186            Op::MutableConst {
187                name,
188                arity,
189                value,
190                supplier,
191                modifier,
192                operation,
193            } => Op::MutableConst {
194                name,
195                arity: *arity,
196                value: value.clone(),
197                supplier: *supplier,
198                modifier: *modifier,
199                operation: *operation,
200            },
201            #[cfg(feature = "pgm")]
202            Op::PGM(name, arity, model, operation) => {
203                use std::sync::Arc;
204                Op::PGM(name, *arity, Arc::clone(model), *operation)
205            }
206        }
207    }
208}
209
210impl<T> PartialEq for Op<T>
211where
212    T: PartialEq,
213{
214    fn eq(&self, other: &Self) -> bool {
215        self.name() == other.name()
216    }
217}
218
219impl<T> Hash for Op<T> {
220    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
221        self.name().hash(state);
222    }
223}
224
225impl<T> Display for Op<T> {
226    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
227        write!(f, "{}", self.name())
228    }
229}
230
231impl<T> Default for Op<T>
232where
233    T: Default,
234{
235    fn default() -> Self {
236        Op::Fn("default", Arity::Zero, |_: &[T]| T::default())
237    }
238}
239
240impl<T> Debug for Op<T>
241where
242    T: Debug,
243{
244    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
245        match self {
246            Op::Fn(name, _, _) => write!(f, "Fn: {}", name),
247            Op::Var(name, index) => write!(f, "Var: {}({})", name, index),
248            Op::Const(name, value) => write!(f, "C: {}({:?})", name, value),
249            Op::MutableConst { name, value, .. } => write!(f, "{}({:.2?})", name, value),
250            #[cfg(feature = "pgm")]
251            Op::PGM(name, _, model, _) => {
252                let mut model_str = String::new();
253                for (i, node) in model.iter().enumerate() {
254                    use crate::Format;
255
256                    let node_str = &node.format();
257                    model_str.push_str(&format!("[{}: S {} Prog {}], ", i, node.size(), node_str));
258                }
259                write!(f, "PGM: {}({})", name, model_str)
260            }
261        }
262    }
263}
264
265impl<T> From<Op<T>> for NodeValue<Op<T>> {
266    fn from(value: Op<T>) -> Self {
267        let arity = value.arity();
268        NodeValue::Bounded(value, arity)
269    }
270}
271
272impl<T> From<Op<T>> for TreeNode<Op<T>> {
273    fn from(value: Op<T>) -> Self {
274        let arity = value.arity();
275        TreeNode::with_arity(value, arity)
276    }
277}
278
279impl<T> From<Op<T>> for Vec<TreeNode<Op<T>>> {
280    fn from(value: Op<T>) -> Self {
281        vec![TreeNode::from(value)]
282    }
283}
284
285#[cfg(test)]
286mod test {
287    use super::*;
288    use radiate_core::random_provider;
289
290    #[test]
291    fn test_ops() {
292        let op = Op::add();
293        assert_eq!(op.name(), "add");
294        assert_eq!(op.arity(), Arity::Exact(2));
295        assert_eq!(op.eval(&[1_f32, 2_f32]), 3_f32);
296        assert_eq!(op.new_instance(()), op);
297    }
298
299    #[test]
300    fn test_random_seed_works() {
301        random_provider::set_seed(42);
302
303        let op = Op::weight();
304        let op2 = Op::weight();
305
306        let o_one = match op {
307            Op::MutableConst { value, .. } => value,
308            _ => panic!("Expected MutableConst"),
309        };
310
311        let o_two = match op2 {
312            Op::MutableConst { value, .. } => value,
313            _ => panic!("Expected MutableConst"),
314        };
315
316        println!("o_one: {:?}", o_one);
317        println!("o_two: {:?}", o_two);
318    }
319
320    #[test]
321    fn test_op_clone() {
322        let op = Op::add();
323        let op2 = op.clone();
324
325        let result = op.eval(&[1_f32, 2_f32]);
326        let result2 = op2.eval(&[1_f32, 2_f32]);
327
328        assert_eq!(op, op2);
329        assert_eq!(result, result2);
330    }
331
332    #[test]
333    #[cfg(feature = "pgm")]
334    fn test_pgm_op() {
335        use std::sync::Arc;
336        let model = TreeNode::with_children(
337            Op::add(),
338            vec![
339                TreeNode::new(Op::constant(1_f32)),
340                TreeNode::new(Op::constant(2_f32)),
341            ],
342        );
343
344        let pgm_op = Op::PGM(
345            "pgm",
346            Arity::Any,
347            Arc::new(vec![model]),
348            |inputs: &[f32], prog: &[TreeNode<Op<f32>>]| {
349                let sum: f32 = prog.iter().map(|node| node.eval(inputs)).sum();
350                sum + inputs.iter().sum::<f32>()
351            },
352        );
353
354        let result = pgm_op.eval(&[]);
355        assert_eq!(result, 3_f32);
356    }
357}