radiate_gp/collections/graphs/
transaction.rs1use super::{Direction, Graph, GraphNode};
2use crate::{Arity, NodeType, node::Node};
3use radiate::{Valid, random_provider};
4use std::{collections::HashSet, fmt::Debug, ops::Deref};
5
6#[derive(Debug, Clone)]
8pub enum MutationStep {
9 AddNode(usize),
10 AddEdge(usize, usize),
11 RemoveEdge(usize, usize),
12 DirectionChange {
13 index: usize,
14 previous_direction: Direction,
15 },
16}
17
18#[derive(Clone)]
19pub enum ReplayStep<T> {
20 AddNode(usize, Option<GraphNode<T>>),
21 AddEdge(usize, usize),
22 RemoveEdge(usize, usize),
23 DirectionChange(usize, Direction),
24}
25
26pub enum TransactionResult<T> {
27 Valid(Vec<MutationStep>),
28 Invalid(Vec<MutationStep>, Vec<ReplayStep<T>>),
29}
30
31#[derive(Debug, Clone, Copy, PartialEq, Eq)]
32pub enum InsertStep {
33 Detach(usize, usize),
34 Connect(usize, usize),
35 Invalid,
36}
37
38pub struct GraphTransaction<'a, T> {
40 graph: &'a mut Graph<T>,
41 steps: Vec<MutationStep>,
42 effects: HashSet<usize>,
43}
44
45impl<'a, T> GraphTransaction<'a, T> {
46 pub fn new(graph: &'a mut Graph<T>) -> Self {
47 GraphTransaction {
48 graph,
49 steps: Vec::new(),
50 effects: HashSet::new(),
51 }
52 }
53
54 pub fn commit(self) -> TransactionResult<T> {
55 self.commit_with(None)
56 }
57
58 pub fn commit_with(
59 mut self,
60 validator: Option<&dyn Fn(&Graph<T>) -> bool>,
61 ) -> TransactionResult<T> {
62 self.set_cycles();
63 let result_steps = self.steps.iter().map(|step| (*step).clone()).collect();
64
65 if let Some(validator) = validator {
66 if validator(self.graph) && self.is_valid() {
67 return TransactionResult::Valid(result_steps);
68 } else {
69 let replay_steps = self.rollback();
70 return TransactionResult::Invalid(result_steps, replay_steps);
71 }
72 }
73
74 if self.is_valid() {
75 TransactionResult::Valid(result_steps)
76 } else {
77 let replay_steps = self.rollback();
78 TransactionResult::Invalid(result_steps, replay_steps)
79 }
80 }
81
82 pub fn len(&self) -> usize {
83 self.graph.len()
84 }
85
86 pub fn add_node(&mut self, node: GraphNode<T>) -> usize {
87 let index = self.graph.len();
88 self.steps.push(MutationStep::AddNode(index));
89 self.graph.push(node);
90 self.effects.insert(index);
91 index
92 }
93
94 pub fn attach(&mut self, from: usize, to: usize) {
95 self.steps.push(MutationStep::AddEdge(from, to));
96 self.graph.attach(from, to);
97 self.effects.insert(from);
98 self.effects.insert(to);
99 }
100
101 pub fn detach(&mut self, from: usize, to: usize) {
102 self.steps.push(MutationStep::RemoveEdge(from, to));
103 self.graph.detach(from, to);
104 self.effects.insert(from);
105 self.effects.insert(to);
106 }
107
108 pub fn change_direction(&mut self, index: usize, direction: Direction) {
109 if let Some(node) = self.graph.get_mut(index) {
110 if node.direction() == direction {
111 return;
112 }
113
114 self.steps.push(MutationStep::DirectionChange {
115 index,
116 previous_direction: node.direction(),
117 });
118 node.set_direction(direction);
119 }
120 }
121
122 pub fn rollback(self) -> Vec<ReplayStep<T>> {
123 let mut replay_steps = Vec::new();
124 for step in self.steps.into_iter().rev() {
125 match step {
126 MutationStep::AddNode(_) => {
127 let added_node = self.graph.pop();
128 replay_steps.push(ReplayStep::AddNode(self.graph.len(), added_node));
129 }
130 MutationStep::AddEdge(from, to) => {
131 self.graph.detach(from, to);
132 replay_steps.push(ReplayStep::AddEdge(from, to));
133 }
134 MutationStep::RemoveEdge(from, to) => {
135 self.graph.attach(from, to);
136 replay_steps.push(ReplayStep::RemoveEdge(from, to));
137 }
138 MutationStep::DirectionChange {
139 index,
140 previous_direction,
141 ..
142 } => {
143 if let Some(node) = self.graph.get_mut(index) {
144 let prev_dir = node.direction();
145 node.set_direction(previous_direction);
146 replay_steps.push(ReplayStep::DirectionChange(index, prev_dir));
147 }
148 }
149 }
150 }
151
152 replay_steps.reverse();
153 replay_steps
154 }
155
156 pub fn replay(&mut self, steps: Vec<ReplayStep<T>>) {
157 for step in steps {
158 match step {
159 ReplayStep::AddNode(_, node) => {
160 if let Some(node) = node {
161 self.add_node(node);
162 }
163 }
164 ReplayStep::AddEdge(from, to) => {
165 self.attach(from, to);
166 }
167 ReplayStep::RemoveEdge(from, to) => {
168 self.detach(from, to);
169 }
170 ReplayStep::DirectionChange(index, direction) => {
171 self.change_direction(index, direction);
172 }
173 }
174 }
175 }
176
177 pub fn is_valid(&self) -> bool {
178 self.graph.is_valid()
179 }
180
181 pub fn set_cycles(&mut self) {
182 let effects = self.effects.clone();
183
184 for idx in effects {
185 let node_cycles = self.graph.get_cycles(idx);
186
187 if node_cycles.is_empty() {
188 self.change_direction(idx, Direction::Forward);
189 } else {
190 for cycle_idx in node_cycles {
191 self.change_direction(cycle_idx, Direction::Backward);
192 }
193 }
194 }
195 }
196
197 pub fn get_insertion_steps(
198 &self,
199 source_idx: usize,
200 target_idx: usize,
201 new_node_idx: usize,
202 ) -> Vec<InsertStep> {
203 let target_node = self.graph.get(target_idx).unwrap();
204 let source_node = self.graph.get(source_idx).unwrap();
205 let new_node = self.graph.get(new_node_idx).unwrap();
206
207 let mut steps = Vec::new();
208
209 let source_is_edge = source_node.node_type() == NodeType::Edge;
210 let target_is_edge = target_node.node_type() == NodeType::Edge;
211 let new_node_arity = new_node.arity();
212
213 if new_node_arity == Arity::Zero && !target_node.is_locked() {
214 steps.push(InsertStep::Connect(new_node_idx, target_idx));
215 return steps;
216 }
217
218 if source_is_edge {
219 let source_outgoing_idxes = source_node.outgoing().iter().collect::<Vec<&usize>>();
220 let source_outgoing = *random_provider::choose(&source_outgoing_idxes);
221
222 if source_outgoing == &new_node_idx {
223 steps.push(InsertStep::Connect(source_idx, new_node_idx));
224 } else {
225 steps.push(InsertStep::Connect(source_idx, new_node_idx));
226 steps.push(InsertStep::Connect(new_node_idx, *source_outgoing));
227 steps.push(InsertStep::Detach(source_idx, *source_outgoing));
228 }
229 } else if target_is_edge || target_node.is_locked() {
230 let target_incoming_idxes = target_node.incoming().iter().collect::<Vec<&usize>>();
231 let target_incoming = *random_provider::choose(&target_incoming_idxes);
232
233 if target_incoming == &new_node_idx {
234 steps.push(InsertStep::Connect(*target_incoming, new_node_idx));
235 } else {
236 steps.push(InsertStep::Connect(*target_incoming, new_node_idx));
237 steps.push(InsertStep::Connect(new_node_idx, target_idx));
238 steps.push(InsertStep::Detach(*target_incoming, target_idx));
239 }
240 } else {
241 steps.push(InsertStep::Connect(source_idx, new_node_idx));
242 steps.push(InsertStep::Connect(new_node_idx, target_idx));
243 }
244
245 steps
246 }
247
248 #[inline]
257 pub fn random_source_node(&self) -> Option<&GraphNode<T>> {
258 self.random_node_of_type(vec![NodeType::Input, NodeType::Vertex, NodeType::Edge])
259 }
260 #[inline]
263 pub fn random_target_node(&self) -> Option<&GraphNode<T>> {
264 self.random_node_of_type(vec![NodeType::Output, NodeType::Vertex, NodeType::Edge])
265 }
266 #[inline]
270 fn random_node_of_type(&self, node_types: Vec<NodeType>) -> Option<&GraphNode<T>> {
271 if node_types.is_empty() {
272 return None;
273 }
274
275 let gene_node_type_index = random_provider::random_range(0..node_types.len());
276 let gene_node_type = node_types.get(gene_node_type_index).unwrap();
277
278 let genes = match gene_node_type {
279 NodeType::Input => self
280 .iter()
281 .filter(|node| node.node_type() == NodeType::Input)
282 .collect::<Vec<&GraphNode<T>>>(),
283 NodeType::Output => self
284 .iter()
285 .filter(|node| node.node_type() == NodeType::Output)
286 .collect::<Vec<&GraphNode<T>>>(),
287 NodeType::Vertex => self
288 .iter()
289 .filter(|node| node.node_type() == NodeType::Vertex)
290 .collect::<Vec<&GraphNode<T>>>(),
291 NodeType::Edge => self
292 .iter()
293 .filter(|node| node.node_type() == NodeType::Edge)
294 .collect::<Vec<&GraphNode<T>>>(),
295 _ => vec![],
296 };
297
298 if genes.is_empty() {
299 return self.random_node_of_type(
300 node_types
301 .iter()
302 .filter(|nt| *nt != gene_node_type)
303 .cloned()
304 .collect(),
305 );
306 }
307
308 let index = random_provider::random_range(0..genes.len());
309 genes.get(index).map(|x| *x)
310 }
311}
312
313impl<'a, T> Deref for GraphTransaction<'a, T> {
314 type Target = Graph<T>;
315
316 fn deref(&self) -> &Self::Target {
317 self.graph
318 }
319}
320
321