radiate_gp/collections/graphs/
builder.rs

1use super::aggregate::GraphAggregate;
2use crate::{
3    Arity, Factory, NodeStore,
4    collections::{Graph, GraphNode, NodeType},
5};
6
7impl<T: Clone + Default> Graph<T> {
8    /// Creates a directed graph with the given input and output sizes.
9    /// The values are used to initialize the nodes in the graph with the given values.
10    ///
11    /// # Example
12    /// ```
13    /// use radiate::*;
14    /// use radiate_gp::*;
15    ///
16    /// let values = vec![
17    ///     (NodeType::Input, vec![Op::var(0), Op::var(1), Op::var(2)]),
18    ///     (NodeType::Output, vec![Op::sigmoid()]),
19    /// ];
20    ///
21    /// let graph = Graph::directed(3, 3, values);
22    ///
23    /// assert_eq!(graph.len(), 6);
24    /// ```
25    ///
26    /// The graph will have 6 nodes, 3 input nodes and 3 output nodes where each input node is
27    /// connected to each output node. Such as:
28    /// ``` text
29    /// [0, 1, 2] -> [3, 4, 5]
30    /// ```
31    ///
32    /// # Arguments
33    /// * `input_size` - The number of input nodes.
34    /// * `output_size` - The number of output nodes.
35    /// * `values` - The values to initialize the nodes with.
36    ///
37    /// # Returns
38    /// A new directed graph.
39    pub fn directed(
40        input_size: usize,
41        output_size: usize,
42        values: impl Into<NodeStore<T>>,
43    ) -> Graph<T> {
44        let builder = NodeBuilder::new(values);
45
46        let input_nodes = builder.input(input_size);
47        let output_nodes = builder.output(output_size);
48
49        GraphAggregate::new()
50            .all_to_all(&input_nodes, &output_nodes)
51            .build()
52    }
53
54    /// Creates a recurrent graph with the given input and output sizes.
55    /// The values are used to initialize the nodes in the graph with the given values.
56    /// The graph will have a recurrent connection from each hidden vertex to itself.
57    /// The graph will have a one-to-one connection from each input node to each hidden vertex.
58    /// The graph will have an all-to-all connection from each hidden vertex to each output node.
59    ///
60    /// # Example
61    /// ```
62    /// use radiate::*;
63    /// use radiate_gp::*;
64    ///
65    /// let values = vec![
66    ///   (NodeType::Input, vec![Op::var(0), Op::var(1), Op::var(2)]),
67    ///   (NodeType::Vertex, vec![Op::linear()]),
68    ///   (NodeType::Output, vec![Op::sigmoid()]),
69    /// ];
70    ///
71    /// let graph = Graph::recurrent(3, 3, values);
72    ///
73    /// assert_eq!(graph.len(), 12);
74    /// ```
75    ///
76    /// The graph will have 12 nodes, 3 input nodes, 3 hidden nodes with recurrent connections to themselves,
77    /// and 3 output nodes. Such as:
78    /// ``` text
79    /// [0, 1, 2] -> [3, 4, 5]
80    ///     [3, 4, 5] -> [6, 7, 8]
81    ///         [6, 7, 8] -> [3, 4, 5]
82    /// [3, 4, 5] -> [9, 10, 11]
83    /// ```
84    ///
85    /// # Arguments
86    /// * `input_size` - The number of input nodes.
87    /// * `output_size` - The number of output nodes.
88    /// * `values` - The values to initialize the nodes with.
89    ///
90    /// # Returns
91    /// A new recurrent graph.
92    pub fn recurrent(
93        input_size: usize,
94        output_size: usize,
95        values: impl Into<NodeStore<T>>,
96    ) -> Graph<T> {
97        let builder = NodeBuilder::new(values);
98
99        let input = builder.input(input_size);
100        let aggregate = builder.vertecies(input_size);
101        let link = builder.vertecies(input_size);
102        let output = builder.output(output_size);
103
104        GraphAggregate::new()
105            .one_to_one(&input, &aggregate)
106            .one_to_self(&aggregate, &link)
107            .all_to_all(&aggregate, &output)
108            .build()
109    }
110
111    /// Creates a weighted directed graph with the given input and output sizes.
112    ///
113    /// This will result in the same graph as `Graph::directed` but with an additional edge
114    /// connecting each input node to each output node.
115    ///
116    /// # Arguments
117    /// * `input_size` - The number of input nodes.
118    /// * `output_size` - The number of output nodes.
119    ///
120    /// # Returns
121    /// A new weighted directed graph.
122    pub fn weighted_directed(
123        input_size: usize,
124        output_size: usize,
125        values: impl Into<NodeStore<T>>,
126    ) -> Graph<T> {
127        let builder = NodeBuilder::new(values);
128
129        let input = builder.input(input_size);
130        let output = builder.output(output_size);
131        let weights = builder.edge(input_size * output_size);
132
133        GraphAggregate::new()
134            .one_to_many(&input, &weights)
135            .many_to_one(&weights, &output)
136            .build()
137    }
138
139    /// Creates a weighted recurrent graph with the given input and output sizes.
140    /// This will result in the same graph as `Graph::recurrent` but with an additional edge
141    /// connecting each hidden vertex to each output node.
142    ///
143    /// # Arguments
144    /// * `input_size` - The number of input nodes.
145    /// * `output_size` - The number of output nodes.
146    ///
147    /// # Returns
148    /// A new weighted recurrent graph.
149    pub fn weighted_recurrent(
150        input_size: usize,
151        output_size: usize,
152        values: impl Into<NodeStore<T>>,
153    ) -> Graph<T> {
154        let builder = NodeBuilder::new(values);
155
156        let input = builder.input(input_size);
157        let aggregate = builder.vertecies(input_size);
158        let link = builder.vertecies(input_size);
159        let output = builder.output(output_size);
160        let weights = builder.edge(input_size * input_size);
161
162        GraphAggregate::new()
163            .one_to_one(&input, &aggregate)
164            .one_to_self(&aggregate, &link)
165            .one_to_many(&link, &weights)
166            .many_to_one(&weights, &output)
167            .build()
168    }
169}
170
171pub struct NodeBuilder<T> {
172    store: NodeStore<T>,
173}
174
175impl<T: Clone + Default> NodeBuilder<T> {
176    pub fn new(store: impl Into<NodeStore<T>>) -> Self {
177        NodeBuilder {
178            store: store.into(),
179        }
180    }
181
182    pub fn input(&self, size: usize) -> Vec<GraphNode<T>> {
183        self.new_nodes(NodeType::Input, size, Arity::Zero)
184    }
185
186    pub fn output(&self, size: usize) -> Vec<GraphNode<T>> {
187        self.new_nodes(NodeType::Output, size, Arity::Any)
188    }
189
190    pub fn edge(&self, size: usize) -> Vec<GraphNode<T>> {
191        self.new_nodes(NodeType::Edge, size, Arity::Exact(1))
192    }
193
194    pub fn vertecies(&self, size: usize) -> Vec<GraphNode<T>> {
195        (0..size)
196            .map(|idx| {
197                self.store
198                    .new_instance((idx, NodeType::Vertex, |arity| match arity {
199                        Arity::Any => true,
200                        _ => false,
201                    }))
202            })
203            .collect()
204    }
205
206    fn new_nodes(
207        &self,
208        node_type: NodeType,
209        size: usize,
210        fallback_arity: Arity,
211    ) -> Vec<GraphNode<T>> {
212        if self.store.contains_type(node_type) {
213            (0..size)
214                .map(|idx| self.store.new_instance((idx, node_type)))
215                .collect()
216        } else {
217            (0..size)
218                .map(|idx| {
219                    self.store
220                        .new_instance((idx, node_type, |arity| arity == fallback_arity))
221                })
222                .collect()
223        }
224    }
225}
226
227#[cfg(test)]
228mod tests {
229    use crate::{Node, Op};
230
231    use super::*;
232
233    #[test]
234    fn test_graph_builder() {
235        let graph = Graph::directed(3, 3, Op::sigmoid());
236
237        assert_eq!(graph.len(), 6);
238
239        for node in graph.iter() {
240            if node.node_type() == NodeType::Input {
241                assert_eq!(node.arity(), Arity::Zero);
242                assert_eq!(node.incoming().iter().count(), 0);
243                assert_eq!(node.outgoing().iter().count(), 3);
244            } else if node.node_type() == NodeType::Output {
245                assert_eq!(node.arity(), Arity::Any);
246                assert_eq!(node.incoming().iter().count(), 3);
247                assert_eq!(node.outgoing().iter().count(), 0);
248                assert_eq!(node.value(), &Op::sigmoid());
249            }
250        }
251    }
252
253    #[test]
254    fn test_graph_builder_recurrent() {
255        let graph = Graph::recurrent(3, 3, Op::sigmoid());
256
257        assert_eq!(graph.len(), 12);
258
259        for node in graph.iter() {
260            if node.node_type() == NodeType::Input {
261                assert_eq!(node.arity(), Arity::Zero);
262                assert_eq!(node.incoming().iter().count(), 0);
263                assert_eq!(node.outgoing().iter().count(), 1);
264            } else if node.node_type() == NodeType::Vertex {
265                assert_eq!(node.arity(), Arity::Any);
266                assert_eq!(node.is_recurrent(), true);
267                assert_eq!(node.value(), &Op::sigmoid());
268            } else if node.node_type() == NodeType::Output {
269                assert_eq!(node.arity(), Arity::Any);
270                assert_eq!(node.incoming().iter().count(), 3);
271                assert_eq!(node.outgoing().iter().count(), 0);
272                assert_eq!(node.value(), &Op::sigmoid());
273            }
274        }
275    }
276}
277
278// pub fn gru(
279//     mut self,
280//     input_size: usize,
281//     output_size: usize,
282//     memory_size: usize,
283//     output: Op<f32>,
284// ) -> GraphBuilder<f32> {
285//     self.with_values(NodeType::Input, (0..input_size).map(Op::var).collect());
286//     self.with_values(NodeType::Output, vec![output]);
287
288//     let input = self.input(input_size);
289//     let output = self.output(output_size);
290
291//     let output_weights = self.edge(memory_size * output_size);
292
293//     let reset_gate = self.aggregates(memory_size);
294//     let update_gate = self.aggregates(memory_size);
295//     let candidate_gate = self.aggregates(memory_size);
296
297//     let input_to_reset_weights = self.edge(input_size * memory_size);
298//     let input_to_update_weights = self.edge(input_size * memory_size);
299//     let input_to_candidate_weights = self.edge(input_size * memory_size);
300
301//     let hidden_to_reset_weights = self.edge(memory_size * memory_size);
302//     let hidden_to_update_weights = self.edge(memory_size * memory_size);
303//     let hidden_to_candidate_weights = self.edge(memory_size * memory_size);
304
305//     let hidden_reset_gate = self.aggregates(memory_size);
306//     let update_candidate_mul_gate = self.aggregates(memory_size);
307//     let invert_update_gate = self.aggregates(memory_size);
308//     let hidden_invert_mul_gate = self.aggregates(memory_size);
309//     let candidate_hidden_add_gate = self.aggregates(memory_size);
310
311//     let graph = GraphAggregate::new()
312//         .one_to_many(&input, &input_to_reset_weights)
313//         .one_to_many(&input, &input_to_update_weights)
314//         .one_to_many(&input, &input_to_candidate_weights)
315//         .one_to_many(&candidate_hidden_add_gate, &hidden_to_reset_weights)
316//         .one_to_many(&candidate_hidden_add_gate, &hidden_to_update_weights)
317//         .many_to_one(&input_to_reset_weights, &reset_gate)
318//         .many_to_one(&hidden_to_reset_weights, &reset_gate)
319//         .many_to_one(&input_to_update_weights, &update_gate)
320//         .many_to_one(&hidden_to_update_weights, &update_gate)
321//         .one_to_one(&reset_gate, &hidden_reset_gate)
322//         .one_to_one(&candidate_hidden_add_gate, &hidden_reset_gate)
323//         .one_to_many(&hidden_reset_gate, &hidden_to_candidate_weights)
324//         .many_to_one(&input_to_candidate_weights, &candidate_gate)
325//         .many_to_one(&hidden_to_candidate_weights, &candidate_gate)
326//         .one_to_one(&update_gate, &update_candidate_mul_gate)
327//         .one_to_one(&candidate_gate, &update_candidate_mul_gate)
328//         .one_to_one(&update_gate, &invert_update_gate)
329//         .one_to_one(&candidate_hidden_add_gate, &hidden_invert_mul_gate)
330//         .one_to_one(&invert_update_gate, &hidden_invert_mul_gate)
331//         .one_to_one(&hidden_invert_mul_gate, &candidate_hidden_add_gate)
332//         .one_to_one(&update_candidate_mul_gate, &candidate_hidden_add_gate)
333//         .one_to_many(&candidate_hidden_add_gate, &output_weights)
334//         .many_to_one(&output_weights, &output)
335//         .build();
336
337//     self.node_cache = Some(graph.into_iter().collect());
338//     self
339// }
340
341// pub fn lstm(mut self, input_size: usize, output_size: usize, output: Op<f32>) -> GraphBuilder<f32> {
342//     self.with_values(NodeType::Input, (0..input_size).map(Op::var).collect());
343//     self.with_values(NodeType::Output, vec![output]);
344
345//     let input = self.input(input_size);
346//     let output = self.output(output_size);
347
348//     let cell_state = self.aggregates(1);
349//     let hidden_state = self.aggregates(1);
350
351//     let forget_gate = self.aggregates(1);
352//     let input_gate = self.aggregates(1);
353//     let output_gate = self.aggregates(1);
354//     let candidate = self.aggregates(1);
355
356//     let input_to_forget_weights = self.edge(input_size);
357//     let input_to_input_weights = self.edge(input_size);
358//     let input_to_output_weights = self.edge(input_size);
359//     let input_to_candidate_weights = self.edge(input_size);
360
361//     let hidden_to_forget_weights = self.edge(1);
362//     let hidden_to_input_weights = self.edge(1);
363//     let hidden_to_output_weights = self.edge(1);
364//     let hidden_to_candidate_weights = self.edge(1);
365
366//     let final_weights = self.edge(output_size);
367
368//     let graph = GraphAggregate::new()
369//         .one_to_one(&input, &input_to_forget_weights)
370//         .one_to_one(&input, &input_to_input_weights)
371//         .one_to_one(&input, &input_to_output_weights)
372//         .one_to_one(&input, &input_to_candidate_weights)
373//         .one_to_one(&hidden_state, &hidden_to_forget_weights)
374//         .one_to_one(&hidden_state, &hidden_to_input_weights)
375//         .one_to_one(&hidden_state, &hidden_to_output_weights)
376//         .one_to_one(&hidden_state, &hidden_to_candidate_weights)
377//         .many_to_one(&input_to_forget_weights, &forget_gate)
378//         .many_to_one(&hidden_to_forget_weights, &forget_gate)
379//         .many_to_one(&input_to_input_weights, &input_gate)
380//         .many_to_one(&hidden_to_input_weights, &input_gate)
381//         .many_to_one(&input_to_output_weights, &output_gate)
382//         .many_to_one(&hidden_to_output_weights, &output_gate)
383//         .many_to_one(&input_to_candidate_weights, &candidate)
384//         .many_to_one(&hidden_to_candidate_weights, &candidate)
385//         .one_to_one(&forget_gate, &cell_state)
386//         .one_to_one(&input_gate, &candidate)
387//         .one_to_one(&candidate, &cell_state)
388//         .one_to_one(&cell_state, &hidden_state)
389//         .one_to_one(&output_gate, &hidden_state)
390//         .one_to_many(&hidden_state, &final_weights)
391//         .one_to_one(&final_weights, &output)
392//         .build();
393
394//     self.node_cache = Some(graph.into_iter().collect());
395//     self
396// }