radiate_gp/collections/graphs/
transaction.rs

1use super::{Direction, Graph, GraphNode};
2use crate::{Arity, NodeType, node::Node};
3use radiate::{Valid, random_provider};
4use std::{collections::HashSet, fmt::Debug, ops::Deref};
5
6/// Represents a reversible change to the graph
7#[derive(Debug, Clone)]
8pub enum MutationStep {
9    AddNode(usize),
10    AddEdge(usize, usize),
11    RemoveEdge(usize, usize),
12    DirectionChange {
13        index: usize,
14        previous_direction: Direction,
15    },
16}
17
18#[derive(Clone)]
19pub enum ReplayStep<T> {
20    AddNode(usize, Option<GraphNode<T>>),
21    AddEdge(usize, usize),
22    RemoveEdge(usize, usize),
23    DirectionChange(usize, Direction),
24}
25
26pub enum TransactionResult<T> {
27    Valid(Vec<MutationStep>),
28    Invalid(Vec<MutationStep>, Vec<ReplayStep<T>>),
29}
30
31#[derive(Debug, Clone, Copy, PartialEq, Eq)]
32pub enum InsertStep {
33    Detach(usize, usize),
34    Connect(usize, usize),
35    Invalid,
36}
37
38/// Tracks changes and provides rollback capability
39pub struct GraphTransaction<'a, T> {
40    graph: &'a mut Graph<T>,
41    steps: Vec<MutationStep>,
42    effects: HashSet<usize>,
43}
44
45impl<'a, T> GraphTransaction<'a, T> {
46    pub fn new(graph: &'a mut Graph<T>) -> Self {
47        GraphTransaction {
48            graph,
49            steps: Vec::new(),
50            effects: HashSet::new(),
51        }
52    }
53
54    pub fn commit(self) -> TransactionResult<T> {
55        self.commit_with(None)
56    }
57
58    pub fn commit_with(
59        mut self,
60        validator: Option<&dyn Fn(&Graph<T>) -> bool>,
61    ) -> TransactionResult<T> {
62        self.set_cycles();
63        let result_steps = self.steps.iter().map(|step| (*step).clone()).collect();
64
65        if let Some(validator) = validator {
66            if validator(self.graph) && self.is_valid() {
67                return TransactionResult::Valid(result_steps);
68            } else {
69                let replay_steps = self.rollback();
70                return TransactionResult::Invalid(result_steps, replay_steps);
71            }
72        }
73
74        if self.is_valid() {
75            TransactionResult::Valid(result_steps)
76        } else {
77            let replay_steps = self.rollback();
78            TransactionResult::Invalid(result_steps, replay_steps)
79        }
80    }
81
82    pub fn len(&self) -> usize {
83        self.graph.len()
84    }
85
86    pub fn add_node(&mut self, node: GraphNode<T>) -> usize {
87        let index = self.graph.len();
88        self.steps.push(MutationStep::AddNode(index));
89        self.graph.push(node);
90        self.effects.insert(index);
91        index
92    }
93
94    pub fn attach(&mut self, from: usize, to: usize) {
95        self.steps.push(MutationStep::AddEdge(from, to));
96        self.graph.attach(from, to);
97        self.effects.insert(from);
98        self.effects.insert(to);
99    }
100
101    pub fn detach(&mut self, from: usize, to: usize) {
102        self.steps.push(MutationStep::RemoveEdge(from, to));
103        self.graph.detach(from, to);
104        self.effects.insert(from);
105        self.effects.insert(to);
106    }
107
108    pub fn change_direction(&mut self, index: usize, direction: Direction) {
109        if let Some(node) = self.graph.get_mut(index) {
110            if node.direction() == direction {
111                return;
112            }
113
114            self.steps.push(MutationStep::DirectionChange {
115                index,
116                previous_direction: node.direction(),
117            });
118            node.set_direction(direction);
119        }
120    }
121
122    pub fn rollback(self) -> Vec<ReplayStep<T>> {
123        let mut replay_steps = Vec::new();
124        for step in self.steps.into_iter().rev() {
125            match step {
126                MutationStep::AddNode(_) => {
127                    let added_node = self.graph.pop();
128                    replay_steps.push(ReplayStep::AddNode(self.graph.len(), added_node));
129                }
130                MutationStep::AddEdge(from, to) => {
131                    self.graph.detach(from, to);
132                    replay_steps.push(ReplayStep::AddEdge(from, to));
133                }
134                MutationStep::RemoveEdge(from, to) => {
135                    self.graph.attach(from, to);
136                    replay_steps.push(ReplayStep::RemoveEdge(from, to));
137                }
138                MutationStep::DirectionChange {
139                    index,
140                    previous_direction,
141                    ..
142                } => {
143                    if let Some(node) = self.graph.get_mut(index) {
144                        let prev_dir = node.direction();
145                        node.set_direction(previous_direction);
146                        replay_steps.push(ReplayStep::DirectionChange(index, prev_dir));
147                    }
148                }
149            }
150        }
151
152        replay_steps.reverse();
153        replay_steps
154    }
155
156    pub fn replay(&mut self, steps: Vec<ReplayStep<T>>) {
157        for step in steps {
158            match step {
159                ReplayStep::AddNode(_, node) => {
160                    if let Some(node) = node {
161                        self.add_node(node);
162                    }
163                }
164                ReplayStep::AddEdge(from, to) => {
165                    self.attach(from, to);
166                }
167                ReplayStep::RemoveEdge(from, to) => {
168                    self.detach(from, to);
169                }
170                ReplayStep::DirectionChange(index, direction) => {
171                    self.change_direction(index, direction);
172                }
173            }
174        }
175    }
176
177    pub fn is_valid(&self) -> bool {
178        self.graph.is_valid()
179    }
180
181    pub fn set_cycles(&mut self) {
182        let effects = self.effects.clone();
183
184        for idx in effects {
185            let node_cycles = self.graph.get_cycles(idx);
186
187            if node_cycles.is_empty() {
188                self.change_direction(idx, Direction::Forward);
189            } else {
190                for cycle_idx in node_cycles {
191                    self.change_direction(cycle_idx, Direction::Backward);
192                }
193            }
194        }
195    }
196
197    pub fn get_insertion_steps(
198        &self,
199        source_idx: usize,
200        target_idx: usize,
201        new_node_idx: usize,
202    ) -> Vec<InsertStep> {
203        let target_node = self.graph.get(target_idx).unwrap();
204        let source_node = self.graph.get(source_idx).unwrap();
205        let new_node = self.graph.get(new_node_idx).unwrap();
206
207        let mut steps = Vec::new();
208
209        let source_is_edge = source_node.node_type() == NodeType::Edge;
210        let target_is_edge = target_node.node_type() == NodeType::Edge;
211        let new_node_arity = new_node.arity();
212
213        if new_node_arity == Arity::Zero && !target_node.is_locked() {
214            steps.push(InsertStep::Connect(new_node_idx, target_idx));
215            return steps;
216        }
217
218        if source_is_edge {
219            let source_outgoing_idxes = source_node.outgoing().iter().collect::<Vec<&usize>>();
220            let source_outgoing = *random_provider::choose(&source_outgoing_idxes);
221
222            if source_outgoing == &new_node_idx {
223                steps.push(InsertStep::Connect(source_idx, new_node_idx));
224            } else {
225                steps.push(InsertStep::Connect(source_idx, new_node_idx));
226                steps.push(InsertStep::Connect(new_node_idx, *source_outgoing));
227                steps.push(InsertStep::Detach(source_idx, *source_outgoing));
228            }
229        } else if target_is_edge || target_node.is_locked() {
230            let target_incoming_idxes = target_node.incoming().iter().collect::<Vec<&usize>>();
231            let target_incoming = *random_provider::choose(&target_incoming_idxes);
232
233            if target_incoming == &new_node_idx {
234                steps.push(InsertStep::Connect(*target_incoming, new_node_idx));
235            } else {
236                steps.push(InsertStep::Connect(*target_incoming, new_node_idx));
237                steps.push(InsertStep::Connect(new_node_idx, target_idx));
238                steps.push(InsertStep::Detach(*target_incoming, target_idx));
239            }
240        } else {
241            steps.push(InsertStep::Connect(source_idx, new_node_idx));
242            steps.push(InsertStep::Connect(new_node_idx, target_idx));
243        }
244
245        steps
246    }
247
248    /// The below functions are used to get random nodes from the graph. These are useful for
249    /// creating connections between nodes. Neither of these functions will return an edge node.
250    /// This is because edge nodes are not valid source or target nodes for connections as they
251    /// they only allow one incoming and one outgoing connection, thus they can't be used to create
252    /// new connections. Instread, edge nodes are used to represent the weights of the connections
253    ///
254    /// Get a random node that can be used as a source node for a connection.
255    /// A source node can be either an input or a vertex node.
256    #[inline]
257    pub fn random_source_node(&self) -> Option<&GraphNode<T>> {
258        self.random_node_of_type(vec![NodeType::Input, NodeType::Vertex, NodeType::Edge])
259    }
260    /// Get a random node that can be used as a target node for a connection.
261    /// A target node can be either an output or a vertex node.
262    #[inline]
263    pub fn random_target_node(&self) -> Option<&GraphNode<T>> {
264        self.random_node_of_type(vec![NodeType::Output, NodeType::Vertex, NodeType::Edge])
265    }
266    /// Helper functions to get a random node of the specified type. If no nodes of the specified
267    /// type are found, the function will try to get a random node of a different type.
268    /// If no nodes are found, the function will panic.
269    #[inline]
270    fn random_node_of_type(&self, node_types: Vec<NodeType>) -> Option<&GraphNode<T>> {
271        if node_types.is_empty() {
272            return None;
273        }
274
275        let gene_node_type_index = random_provider::random_range(0..node_types.len());
276        let gene_node_type = node_types.get(gene_node_type_index).unwrap();
277
278        let genes = match gene_node_type {
279            NodeType::Input => self
280                .iter()
281                .filter(|node| node.node_type() == NodeType::Input)
282                .collect::<Vec<&GraphNode<T>>>(),
283            NodeType::Output => self
284                .iter()
285                .filter(|node| node.node_type() == NodeType::Output)
286                .collect::<Vec<&GraphNode<T>>>(),
287            NodeType::Vertex => self
288                .iter()
289                .filter(|node| node.node_type() == NodeType::Vertex)
290                .collect::<Vec<&GraphNode<T>>>(),
291            NodeType::Edge => self
292                .iter()
293                .filter(|node| node.node_type() == NodeType::Edge)
294                .collect::<Vec<&GraphNode<T>>>(),
295            _ => vec![],
296        };
297
298        if genes.is_empty() {
299            return self.random_node_of_type(
300                node_types
301                    .iter()
302                    .filter(|nt| *nt != gene_node_type)
303                    .cloned()
304                    .collect(),
305            );
306        }
307
308        let index = random_provider::random_range(0..genes.len());
309        genes.get(index).map(|x| *x)
310    }
311}
312
313impl<'a, T> Deref for GraphTransaction<'a, T> {
314    type Target = Graph<T>;
315
316    fn deref(&self) -> &Self::Target {
317        self.graph
318    }
319}
320
321// /// Check if connecting the source node to the target node would create a cycle.
322// ///
323// /// # Arguments
324// /// - source: The index of the source node.
325// /// - target: The index of the target node.
326// ///
327// #[inline]
328// pub fn would_create_cycle(&self, source: usize, target: usize) -> bool {
329//     let mut seen = HashSet::new();
330//     let mut visited = self
331//         .get(target)
332//         .map(|node| node.outgoing().iter().collect())
333//         .unwrap_or(Vec::new());
334
335//     while !visited.is_empty() {
336//         let node_index = visited.pop().unwrap();
337
338//         seen.insert(*node_index);
339
340//         if *node_index == source {
341//             return true;
342//         }
343
344//         let node_edges = self
345//             .get(*node_index)
346//             .map(|node| {
347//                 node.outgoing()
348//                     .iter()
349//                     .filter(|edge_index| !seen.contains(edge_index))
350//                     .collect()
351//             })
352//             .unwrap_or(Vec::new());
353
354//         for edge_index in node_edges {
355//             visited.push(edge_index);
356//         }
357//     }
358
359//     false
360// }