radiate_gp/collections/graphs/builder.rs
1use super::aggregate::GraphAggregate;
2use crate::{
3 Arity, Factory, NodeStore,
4 collections::{Graph, GraphNode, NodeType},
5};
6
7impl<T: Clone + Default> Graph<T> {
8 /// Creates a directed graph with the given input and output sizes.
9 /// The values are used to initialize the nodes in the graph with the given values.
10 ///
11 /// # Example
12 /// ```
13 /// use radiate::*;
14 /// use radiate_gp::*;
15 ///
16 /// let values = vec![
17 /// (NodeType::Input, vec![Op::var(0), Op::var(1), Op::var(2)]),
18 /// (NodeType::Output, vec![Op::sigmoid()]),
19 /// ];
20 ///
21 /// let graph = Graph::directed(3, 3, values);
22 ///
23 /// assert_eq!(graph.len(), 6);
24 /// ```
25 ///
26 /// The graph will have 6 nodes, 3 input nodes and 3 output nodes where each input node is
27 /// connected to each output node. Such as:
28 /// ``` text
29 /// [0, 1, 2] -> [3, 4, 5]
30 /// ```
31 ///
32 /// # Arguments
33 /// * `input_size` - The number of input nodes.
34 /// * `output_size` - The number of output nodes.
35 /// * `values` - The values to initialize the nodes with.
36 ///
37 /// # Returns
38 /// A new directed graph.
39 pub fn directed(
40 input_size: usize,
41 output_size: usize,
42 values: impl Into<NodeStore<T>>,
43 ) -> Graph<T> {
44 let builder = NodeBuilder::new(values);
45
46 let input_nodes = builder.input(input_size);
47 let output_nodes = builder.output(output_size);
48
49 GraphAggregate::new()
50 .all_to_all(&input_nodes, &output_nodes)
51 .build()
52 }
53
54 /// Creates a recurrent graph with the given input and output sizes.
55 /// The values are used to initialize the nodes in the graph with the given values.
56 /// The graph will have a recurrent connection from each hidden vertex to itself.
57 /// The graph will have a one-to-one connection from each input node to each hidden vertex.
58 /// The graph will have an all-to-all connection from each hidden vertex to each output node.
59 ///
60 /// # Example
61 /// ```
62 /// use radiate::*;
63 /// use radiate_gp::*;
64 ///
65 /// let values = vec![
66 /// (NodeType::Input, vec![Op::var(0), Op::var(1), Op::var(2)]),
67 /// (NodeType::Vertex, vec![Op::linear()]),
68 /// (NodeType::Output, vec![Op::sigmoid()]),
69 /// ];
70 ///
71 /// let graph = Graph::recurrent(3, 3, values);
72 ///
73 /// assert_eq!(graph.len(), 12);
74 /// ```
75 ///
76 /// The graph will have 12 nodes, 3 input nodes, 3 hidden nodes with recurrent connections to themselves,
77 /// and 3 output nodes. Such as:
78 /// ``` text
79 /// [0, 1, 2] -> [3, 4, 5]
80 /// [3, 4, 5] -> [6, 7, 8]
81 /// [6, 7, 8] -> [3, 4, 5]
82 /// [3, 4, 5] -> [9, 10, 11]
83 /// ```
84 ///
85 /// # Arguments
86 /// * `input_size` - The number of input nodes.
87 /// * `output_size` - The number of output nodes.
88 /// * `values` - The values to initialize the nodes with.
89 ///
90 /// # Returns
91 /// A new recurrent graph.
92 pub fn recurrent(
93 input_size: usize,
94 output_size: usize,
95 values: impl Into<NodeStore<T>>,
96 ) -> Graph<T> {
97 let builder = NodeBuilder::new(values);
98
99 let input = builder.input(input_size);
100 let aggregate = builder.vertecies(input_size);
101 let link = builder.vertecies(input_size);
102 let output = builder.output(output_size);
103
104 GraphAggregate::new()
105 .one_to_one(&input, &aggregate)
106 .one_to_self(&aggregate, &link)
107 .all_to_all(&aggregate, &output)
108 .build()
109 }
110
111 /// Creates a weighted directed graph with the given input and output sizes.
112 ///
113 /// This will result in the same graph as `Graph::directed` but with an additional edge
114 /// connecting each input node to each output node.
115 ///
116 /// # Arguments
117 /// * `input_size` - The number of input nodes.
118 /// * `output_size` - The number of output nodes.
119 ///
120 /// # Returns
121 /// A new weighted directed graph.
122 pub fn weighted_directed(
123 input_size: usize,
124 output_size: usize,
125 values: impl Into<NodeStore<T>>,
126 ) -> Graph<T> {
127 let builder = NodeBuilder::new(values);
128
129 let input = builder.input(input_size);
130 let output = builder.output(output_size);
131 let weights = builder.edge(input_size * output_size);
132
133 GraphAggregate::new()
134 .one_to_many(&input, &weights)
135 .many_to_one(&weights, &output)
136 .build()
137 }
138
139 /// Creates a weighted recurrent graph with the given input and output sizes.
140 /// This will result in the same graph as `Graph::recurrent` but with an additional edge
141 /// connecting each hidden vertex to each output node.
142 ///
143 /// # Arguments
144 /// * `input_size` - The number of input nodes.
145 /// * `output_size` - The number of output nodes.
146 ///
147 /// # Returns
148 /// A new weighted recurrent graph.
149 pub fn weighted_recurrent(
150 input_size: usize,
151 output_size: usize,
152 values: impl Into<NodeStore<T>>,
153 ) -> Graph<T> {
154 let builder = NodeBuilder::new(values);
155
156 let input = builder.input(input_size);
157 let aggregate = builder.vertecies(input_size);
158 let link = builder.vertecies(input_size);
159 let output = builder.output(output_size);
160 let weights = builder.edge(input_size * input_size);
161
162 GraphAggregate::new()
163 .one_to_one(&input, &aggregate)
164 .one_to_self(&aggregate, &link)
165 .one_to_many(&link, &weights)
166 .many_to_one(&weights, &output)
167 .build()
168 }
169}
170
171pub struct NodeBuilder<T> {
172 store: NodeStore<T>,
173}
174
175impl<T: Clone + Default> NodeBuilder<T> {
176 pub fn new(store: impl Into<NodeStore<T>>) -> Self {
177 NodeBuilder {
178 store: store.into(),
179 }
180 }
181
182 pub fn input(&self, size: usize) -> Vec<GraphNode<T>> {
183 self.new_nodes(NodeType::Input, size, Arity::Zero)
184 }
185
186 pub fn output(&self, size: usize) -> Vec<GraphNode<T>> {
187 self.new_nodes(NodeType::Output, size, Arity::Any)
188 }
189
190 pub fn edge(&self, size: usize) -> Vec<GraphNode<T>> {
191 self.new_nodes(NodeType::Edge, size, Arity::Exact(1))
192 }
193
194 pub fn vertecies(&self, size: usize) -> Vec<GraphNode<T>> {
195 (0..size)
196 .map(|idx| {
197 self.store
198 .new_instance((idx, NodeType::Vertex, |arity| match arity {
199 Arity::Any => true,
200 _ => false,
201 }))
202 })
203 .collect()
204 }
205
206 fn new_nodes(
207 &self,
208 node_type: NodeType,
209 size: usize,
210 fallback_arity: Arity,
211 ) -> Vec<GraphNode<T>> {
212 if self.store.contains_type(node_type) {
213 (0..size)
214 .map(|idx| self.store.new_instance((idx, node_type)))
215 .collect()
216 } else {
217 (0..size)
218 .map(|idx| {
219 self.store
220 .new_instance((idx, node_type, |arity| arity == fallback_arity))
221 })
222 .collect()
223 }
224 }
225}
226
227#[cfg(test)]
228mod tests {
229 use crate::{Node, Op};
230
231 use super::*;
232
233 #[test]
234 fn test_graph_builder() {
235 let graph = Graph::directed(3, 3, Op::sigmoid());
236
237 assert_eq!(graph.len(), 6);
238
239 for node in graph.iter() {
240 if node.node_type() == NodeType::Input {
241 assert_eq!(node.arity(), Arity::Zero);
242 assert_eq!(node.incoming().iter().count(), 0);
243 assert_eq!(node.outgoing().iter().count(), 3);
244 } else if node.node_type() == NodeType::Output {
245 assert_eq!(node.arity(), Arity::Any);
246 assert_eq!(node.incoming().iter().count(), 3);
247 assert_eq!(node.outgoing().iter().count(), 0);
248 assert_eq!(node.value(), &Op::sigmoid());
249 }
250 }
251 }
252
253 #[test]
254 fn test_graph_builder_recurrent() {
255 let graph = Graph::recurrent(3, 3, Op::sigmoid());
256
257 assert_eq!(graph.len(), 12);
258
259 for node in graph.iter() {
260 if node.node_type() == NodeType::Input {
261 assert_eq!(node.arity(), Arity::Zero);
262 assert_eq!(node.incoming().iter().count(), 0);
263 assert_eq!(node.outgoing().iter().count(), 1);
264 } else if node.node_type() == NodeType::Vertex {
265 assert_eq!(node.arity(), Arity::Any);
266 assert_eq!(node.is_recurrent(), true);
267 assert_eq!(node.value(), &Op::sigmoid());
268 } else if node.node_type() == NodeType::Output {
269 assert_eq!(node.arity(), Arity::Any);
270 assert_eq!(node.incoming().iter().count(), 3);
271 assert_eq!(node.outgoing().iter().count(), 0);
272 assert_eq!(node.value(), &Op::sigmoid());
273 }
274 }
275 }
276}
277
278// pub fn gru(
279// mut self,
280// input_size: usize,
281// output_size: usize,
282// memory_size: usize,
283// output: Op<f32>,
284// ) -> GraphBuilder<f32> {
285// self.with_values(NodeType::Input, (0..input_size).map(Op::var).collect());
286// self.with_values(NodeType::Output, vec![output]);
287
288// let input = self.input(input_size);
289// let output = self.output(output_size);
290
291// let output_weights = self.edge(memory_size * output_size);
292
293// let reset_gate = self.aggregates(memory_size);
294// let update_gate = self.aggregates(memory_size);
295// let candidate_gate = self.aggregates(memory_size);
296
297// let input_to_reset_weights = self.edge(input_size * memory_size);
298// let input_to_update_weights = self.edge(input_size * memory_size);
299// let input_to_candidate_weights = self.edge(input_size * memory_size);
300
301// let hidden_to_reset_weights = self.edge(memory_size * memory_size);
302// let hidden_to_update_weights = self.edge(memory_size * memory_size);
303// let hidden_to_candidate_weights = self.edge(memory_size * memory_size);
304
305// let hidden_reset_gate = self.aggregates(memory_size);
306// let update_candidate_mul_gate = self.aggregates(memory_size);
307// let invert_update_gate = self.aggregates(memory_size);
308// let hidden_invert_mul_gate = self.aggregates(memory_size);
309// let candidate_hidden_add_gate = self.aggregates(memory_size);
310
311// let graph = GraphAggregate::new()
312// .one_to_many(&input, &input_to_reset_weights)
313// .one_to_many(&input, &input_to_update_weights)
314// .one_to_many(&input, &input_to_candidate_weights)
315// .one_to_many(&candidate_hidden_add_gate, &hidden_to_reset_weights)
316// .one_to_many(&candidate_hidden_add_gate, &hidden_to_update_weights)
317// .many_to_one(&input_to_reset_weights, &reset_gate)
318// .many_to_one(&hidden_to_reset_weights, &reset_gate)
319// .many_to_one(&input_to_update_weights, &update_gate)
320// .many_to_one(&hidden_to_update_weights, &update_gate)
321// .one_to_one(&reset_gate, &hidden_reset_gate)
322// .one_to_one(&candidate_hidden_add_gate, &hidden_reset_gate)
323// .one_to_many(&hidden_reset_gate, &hidden_to_candidate_weights)
324// .many_to_one(&input_to_candidate_weights, &candidate_gate)
325// .many_to_one(&hidden_to_candidate_weights, &candidate_gate)
326// .one_to_one(&update_gate, &update_candidate_mul_gate)
327// .one_to_one(&candidate_gate, &update_candidate_mul_gate)
328// .one_to_one(&update_gate, &invert_update_gate)
329// .one_to_one(&candidate_hidden_add_gate, &hidden_invert_mul_gate)
330// .one_to_one(&invert_update_gate, &hidden_invert_mul_gate)
331// .one_to_one(&hidden_invert_mul_gate, &candidate_hidden_add_gate)
332// .one_to_one(&update_candidate_mul_gate, &candidate_hidden_add_gate)
333// .one_to_many(&candidate_hidden_add_gate, &output_weights)
334// .many_to_one(&output_weights, &output)
335// .build();
336
337// self.node_cache = Some(graph.into_iter().collect());
338// self
339// }
340
341// pub fn lstm(mut self, input_size: usize, output_size: usize, output: Op<f32>) -> GraphBuilder<f32> {
342// self.with_values(NodeType::Input, (0..input_size).map(Op::var).collect());
343// self.with_values(NodeType::Output, vec![output]);
344
345// let input = self.input(input_size);
346// let output = self.output(output_size);
347
348// let cell_state = self.aggregates(1);
349// let hidden_state = self.aggregates(1);
350
351// let forget_gate = self.aggregates(1);
352// let input_gate = self.aggregates(1);
353// let output_gate = self.aggregates(1);
354// let candidate = self.aggregates(1);
355
356// let input_to_forget_weights = self.edge(input_size);
357// let input_to_input_weights = self.edge(input_size);
358// let input_to_output_weights = self.edge(input_size);
359// let input_to_candidate_weights = self.edge(input_size);
360
361// let hidden_to_forget_weights = self.edge(1);
362// let hidden_to_input_weights = self.edge(1);
363// let hidden_to_output_weights = self.edge(1);
364// let hidden_to_candidate_weights = self.edge(1);
365
366// let final_weights = self.edge(output_size);
367
368// let graph = GraphAggregate::new()
369// .one_to_one(&input, &input_to_forget_weights)
370// .one_to_one(&input, &input_to_input_weights)
371// .one_to_one(&input, &input_to_output_weights)
372// .one_to_one(&input, &input_to_candidate_weights)
373// .one_to_one(&hidden_state, &hidden_to_forget_weights)
374// .one_to_one(&hidden_state, &hidden_to_input_weights)
375// .one_to_one(&hidden_state, &hidden_to_output_weights)
376// .one_to_one(&hidden_state, &hidden_to_candidate_weights)
377// .many_to_one(&input_to_forget_weights, &forget_gate)
378// .many_to_one(&hidden_to_forget_weights, &forget_gate)
379// .many_to_one(&input_to_input_weights, &input_gate)
380// .many_to_one(&hidden_to_input_weights, &input_gate)
381// .many_to_one(&input_to_output_weights, &output_gate)
382// .many_to_one(&hidden_to_output_weights, &output_gate)
383// .many_to_one(&input_to_candidate_weights, &candidate)
384// .many_to_one(&hidden_to_candidate_weights, &candidate)
385// .one_to_one(&forget_gate, &cell_state)
386// .one_to_one(&input_gate, &candidate)
387// .one_to_one(&candidate, &cell_state)
388// .one_to_one(&cell_state, &hidden_state)
389// .one_to_one(&output_gate, &hidden_state)
390// .one_to_many(&hidden_state, &final_weights)
391// .one_to_one(&final_weights, &output)
392// .build();
393
394// self.node_cache = Some(graph.into_iter().collect());
395// self
396// }