Skip to main content

gen_diff/
graph.rs

1use std::collections::{HashMap, HashSet};
2
3use gen_core::{
4    HashId,
5    Strand::{self, Forward},
6    is_end_node, is_start_node, is_terminal,
7};
8use gen_graph::{GenGraph, GraphEdge, GraphNode};
9use gen_models::{
10    block_group_edge::BlockGroupEdge, changesets::ChangesetModels, edge::Edge, node::Node,
11    sequence::Sequence, session_operations::DependencyModels,
12};
13use itertools::Itertools;
14use petgraph::{Direction, graphmap::DiGraphMap};
15
16#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash, Ord, PartialOrd)]
17pub struct DiffGraphNode {
18    pub node: GraphNode,
19    pub is_new: bool,
20}
21
22#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash, Ord, PartialOrd)]
23pub struct DiffGraphEdge {
24    pub edge: GraphEdge,
25    pub is_new: bool,
26}
27
28pub type DiffGenGraph = DiGraphMap<DiffGraphNode, Vec<DiffGraphEdge>>;
29
30pub struct DiffGenGraphRef<'a>(pub &'a DiffGenGraph);
31
32impl<'a> From<&'a DiffGenGraph> for DiffGenGraphRef<'a> {
33    fn from(graph: &'a DiffGenGraph) -> Self {
34        Self(graph)
35    }
36}
37
38impl<'a> From<DiffGenGraphRef<'a>> for GenGraph {
39    fn from(val: DiffGenGraphRef<'a>) -> Self {
40        let mut graph = GenGraph::new();
41        for node in val.0.nodes() {
42            graph.add_node(node.node);
43        }
44        for (src, dest, edges) in val.0.all_edges() {
45            let mapped_edges = edges.iter().map(|edge| edge.edge).collect::<Vec<_>>();
46            graph.add_edge(src.node, dest.node, mapped_edges);
47        }
48        graph
49    }
50}
51
52impl From<DiffGraphNode> for GraphNode {
53    fn from(node: DiffGraphNode) -> Self {
54        node.node
55    }
56}
57
58impl From<DiffGraphEdge> for GraphEdge {
59    fn from(edge: DiffGraphEdge) -> Self {
60        edge.edge
61    }
62}
63
64pub fn get_diff_graph(
65    changes: &ChangesetModels,
66    dependencies: &DependencyModels,
67) -> HashMap<HashId, DiffGenGraph> {
68    let start_node = Node::get_start_node();
69    let end_node = Node::get_end_node();
70    let mut bges_by_bg: HashMap<HashId, Vec<&BlockGroupEdge>> = HashMap::new();
71    // the bool is marking whether or not the Edge/Node is new
72    let mut edges_by_id: HashMap<HashId, (&Edge, bool)> = HashMap::new();
73    let mut nodes_by_id: HashMap<HashId, (&Node, bool)> = HashMap::new();
74    nodes_by_id.insert(start_node.id, (&start_node, false));
75    nodes_by_id.insert(end_node.id, (&end_node, false));
76    let mut sequences_by_hash: HashMap<HashId, &Sequence> = HashMap::new();
77    let mut block_graphs: HashMap<HashId, DiffGenGraph> = HashMap::new();
78
79    for bge in changes.block_group_edges.iter() {
80        bges_by_bg
81            .entry(bge.block_group_id)
82            .and_modify(|l| l.push(bge))
83            .or_insert_with(|| vec![bge]);
84    }
85    for edge in dependencies.edges.iter() {
86        edges_by_id.insert(edge.id, (edge, false));
87    }
88    for edge in changes.edges.iter() {
89        edges_by_id.insert(edge.id, (edge, true));
90    }
91    for node in dependencies.nodes.iter() {
92        nodes_by_id.insert(node.id, (node, false));
93    }
94    for node in changes.nodes.iter() {
95        nodes_by_id.insert(node.id, (node, true));
96    }
97    for seq in dependencies
98        .sequences
99        .iter()
100        .chain(changes.sequences.iter())
101    {
102        sequences_by_hash.insert(seq.hash, seq);
103    }
104
105    for (bg_id, bg_edges) in bges_by_bg.iter() {
106        let mut graph: DiGraphMap<HashId, Vec<(i64, i64, bool)>> = DiGraphMap::new();
107        let mut block_graph = DiffGenGraph::new();
108        block_graph.add_node(DiffGraphNode {
109            node: GraphNode {
110                block_id: -1,
111                node_id: start_node.id,
112                sequence_start: 0,
113                sequence_end: 0,
114            },
115            is_new: false,
116        });
117        block_graph.add_node(DiffGraphNode {
118            node: GraphNode {
119                block_id: -1,
120                node_id: end_node.id,
121                sequence_start: 0,
122                sequence_end: 0,
123            },
124            is_new: false,
125        });
126        for bg_edge in bg_edges {
127            let (edge, edge_is_new) = *edges_by_id.get(&bg_edge.edge_id).unwrap();
128            if let Some(weights) = graph.edge_weight_mut(edge.source_node_id, edge.target_node_id) {
129                weights.push((edge.source_coordinate, edge.target_coordinate, edge_is_new));
130            } else {
131                graph.add_edge(
132                    edge.source_node_id,
133                    edge.target_node_id,
134                    vec![(edge.source_coordinate, edge.target_coordinate, edge_is_new)],
135                );
136            }
137        }
138
139        for node in graph.nodes() {
140            if is_terminal(node) {
141                continue;
142            }
143            let in_ports = graph
144                .edges_directed(node, Direction::Incoming)
145                .flat_map(|(_src, _dest, weights)| {
146                    weights.iter().map(|(_, tp, _)| *tp).collect::<Vec<_>>()
147                })
148                .collect::<Vec<_>>();
149            let out_ports = graph
150                .edges_directed(node, Direction::Outgoing)
151                .flat_map(|(_src, _dest, weights)| {
152                    weights.iter().map(|(fp, _tp, _)| *fp).collect::<Vec<_>>()
153                })
154                .collect::<Vec<_>>();
155
156            let (node_obj, node_is_new) = *nodes_by_id.get(&node).unwrap();
157            let sequence = *sequences_by_hash.get(&node_obj.sequence_hash).unwrap();
158            let s_len = sequence.length;
159            let mut block_starts: HashSet<i64> = HashSet::from_iter(in_ports.iter().copied());
160            block_starts.insert(0);
161            for x in out_ports.iter() {
162                if *x < s_len - 1 {
163                    block_starts.insert(*x);
164                }
165            }
166            let mut block_ends: HashSet<i64> = HashSet::from_iter(out_ports.iter().copied());
167            block_ends.insert(s_len);
168            for x in in_ports.iter() {
169                if *x > 0 {
170                    block_ends.insert(*x);
171                }
172            }
173
174            let block_starts = block_starts.into_iter().sorted().collect::<Vec<_>>();
175            let block_ends = block_ends.into_iter().sorted().collect::<Vec<_>>();
176
177            let mut blocks = vec![];
178            for (i, j) in block_starts.iter().zip(block_ends.iter()) {
179                let node = DiffGraphNode {
180                    node: GraphNode {
181                        block_id: -1,
182                        node_id: node,
183                        sequence_start: *i,
184                        sequence_end: *j,
185                    },
186                    is_new: node_is_new,
187                };
188                block_graph.add_node(node);
189                blocks.push(node);
190            }
191
192            for (i, j) in blocks.iter().tuple_windows() {
193                block_graph.add_edge(
194                    *i,
195                    *j,
196                    vec![DiffGraphEdge {
197                        edge: GraphEdge {
198                            edge_id: HashId::pad_str(1),
199                            source_strand: Strand::Forward,
200                            target_strand: Forward,
201                            chromosome_index: 0,
202                            phased: 0,
203                            created_on: 0,
204                        },
205                        is_new: node_is_new,
206                    }],
207                );
208            }
209        }
210
211        for (src, dest, weights) in graph.all_edges() {
212            for (fp, tp, edge_is_new) in weights {
213                if !(is_end_node(src) && is_start_node(dest)) {
214                    let source_block = block_graph
215                        .nodes()
216                        .find(|node| node.node.node_id == src && node.node.sequence_end == *fp)
217                        .unwrap();
218                    let dest_block = block_graph
219                        .nodes()
220                        .find(|node| node.node.node_id == dest && node.node.sequence_start == *tp)
221                        .unwrap();
222                    block_graph.add_edge(
223                        source_block,
224                        dest_block,
225                        vec![DiffGraphEdge {
226                            edge: GraphEdge {
227                                edge_id: HashId::pad_str(1),
228                                source_strand: Strand::Forward,
229                                target_strand: Forward,
230                                chromosome_index: 0,
231                                phased: 0,
232                                created_on: 0,
233                            },
234                            is_new: *edge_is_new,
235                        }],
236                    );
237                }
238            }
239        }
240
241        block_graphs.insert(*bg_id, block_graph);
242    }
243
244    block_graphs
245}
246
247#[cfg(test)]
248mod tests {
249    use gen_core::{HashId, Strand};
250    use gen_graph::{GraphEdge, GraphNode};
251    use gen_models::{
252        block_group::BlockGroup,
253        block_group_edge::BlockGroupEdge,
254        changesets::ChangesetModels,
255        edge::Edge,
256        node::Node,
257        sequence::{NewSequence, Sequence},
258        session_operations::DependencyModels,
259    };
260
261    use super::*;
262
263    fn base_dependencies(start_node: &Node, end_node: &Node) -> DependencyModels {
264        let mut start_sequence = Sequence::new()
265            .sequence_type("DNA")
266            .sequence("")
267            .name("start")
268            .build();
269        start_sequence.hash = start_node.sequence_hash;
270        let mut end_sequence = Sequence::new()
271            .sequence_type("DNA")
272            .sequence("")
273            .name("end")
274            .build();
275        end_sequence.hash = end_node.sequence_hash;
276        DependencyModels {
277            collections: vec![],
278            samples: vec![],
279            sequences: vec![start_sequence, end_sequence],
280            block_group: vec![],
281            nodes: vec![start_node.clone(), end_node.clone()],
282            edges: vec![],
283            paths: vec![],
284            accessions: vec![],
285            accession_edges: vec![],
286        }
287    }
288
289    fn find_edge(
290        graph: &DiffGenGraph,
291        src: DiffGraphNode,
292        dest: DiffGraphNode,
293    ) -> Option<&Vec<DiffGraphEdge>> {
294        graph
295            .all_edges()
296            .find(|(edge_src, edge_dest, _)| *edge_src == src && *edge_dest == dest)
297            .map(|(_, _, edges)| edges)
298    }
299
300    #[test]
301    fn diff_graph_to_gen_graph_maps_nodes_and_edges() {
302        let node_a = DiffGraphNode {
303            node: GraphNode {
304                block_id: 1,
305                node_id: HashId::pad_str(1),
306                sequence_start: 0,
307                sequence_end: 5,
308            },
309            is_new: true,
310        };
311        let node_b = DiffGraphNode {
312            node: GraphNode {
313                block_id: 2,
314                node_id: HashId::pad_str(2),
315                sequence_start: 5,
316                sequence_end: 10,
317            },
318            is_new: false,
319        };
320        let mut diff_graph = DiffGenGraph::new();
321        diff_graph.add_node(node_a);
322        diff_graph.add_node(node_b);
323        let edge = DiffGraphEdge {
324            edge: GraphEdge {
325                edge_id: HashId::pad_str(9),
326                source_strand: Strand::Forward,
327                target_strand: Strand::Forward,
328                chromosome_index: 0,
329                phased: 0,
330                created_on: 0,
331            },
332            is_new: true,
333        };
334        diff_graph.add_edge(node_a, node_b, vec![edge]);
335
336        let graph: GenGraph = DiffGenGraphRef(&diff_graph).into();
337        assert_eq!(graph.nodes().count(), 2);
338        assert_eq!(graph.all_edges().count(), 1);
339        let weights = graph
340            .all_edges()
341            .next()
342            .map(|(_, _, edges)| edges.clone())
343            .expect("graph edge");
344        assert_eq!(weights, vec![edge.edge]);
345    }
346
347    #[test]
348    fn get_diff_graph_splits_blocks_and_marks_new() {
349        let start_node = Node::get_start_node();
350        let end_node = Node::get_end_node();
351        let block_group = BlockGroup {
352            id: HashId::pad_str(10),
353            collection_name: "collection".to_string(),
354            sample_name: Some("sample".to_string()),
355            name: "bg".to_string(),
356            created_on: 0,
357        };
358        let seq = NewSequence::new()
359            .sequence_type("dna")
360            .sequence("AAAAAAAAAA")
361            .name("seq")
362            .build();
363        let node = Node {
364            id: HashId::pad_str(11),
365            sequence_hash: seq.hash,
366        };
367        let old_edges = vec![
368            Edge {
369                id: HashId::convert_str("start-node"),
370                source_node_id: start_node.id,
371                source_coordinate: 0,
372                source_strand: Strand::Forward,
373                target_node_id: node.id,
374                target_coordinate: 0,
375                target_strand: Strand::Forward,
376            },
377            Edge {
378                id: HashId::convert_str("node-deletion"),
379                source_node_id: node.id,
380                source_coordinate: 3,
381                source_strand: Strand::Forward,
382                target_node_id: node.id,
383                target_coordinate: 5,
384                target_strand: Strand::Forward,
385            },
386            Edge {
387                id: HashId::convert_str("node-end"),
388                source_node_id: node.id,
389                source_coordinate: seq.length,
390                source_strand: Strand::Forward,
391                target_node_id: end_node.id,
392                target_coordinate: 0,
393                target_strand: Strand::Forward,
394            },
395        ];
396        let new_seq = NewSequence::new()
397            .sequence_type("dna")
398            .sequence("TTTT")
399            .name("new-seq")
400            .build();
401        let new_node = Node {
402            id: HashId::pad_str(12),
403            sequence_hash: new_seq.hash,
404        };
405        let edges = vec![
406            Edge {
407                id: HashId::convert_str("node-insertion-start"),
408                source_node_id: node.id,
409                source_coordinate: 5,
410                source_strand: Strand::Forward,
411                target_node_id: new_node.id,
412                target_coordinate: 0,
413                target_strand: Strand::Forward,
414            },
415            Edge {
416                id: HashId::convert_str("node-insertion-end"),
417                source_node_id: new_node.id,
418                source_coordinate: new_seq.length,
419                source_strand: Strand::Forward,
420                target_node_id: node.id,
421                target_coordinate: 8,
422                target_strand: Strand::Forward,
423            },
424        ];
425        let block_group_edges = edges
426            .iter()
427            .enumerate()
428            .map(|(index, edge)| BlockGroupEdge {
429                id: HashId::convert_str(&format!("bge-{index}")),
430                block_group_id: block_group.id,
431                edge_id: edge.id,
432                chromosome_index: 0,
433                phased: 0,
434                created_on: 0,
435            })
436            .collect::<Vec<_>>();
437        let changes = ChangesetModels {
438            collections: vec![],
439            samples: vec![],
440            sequences: vec![new_seq.clone()],
441            block_groups: vec![block_group.clone()],
442            nodes: vec![new_node.clone()],
443            edges,
444            block_group_edges,
445            paths: vec![],
446            path_edges: vec![],
447            accessions: vec![],
448            accession_edges: vec![],
449            accession_paths: vec![],
450            annotation_groups: vec![],
451            annotations: vec![],
452            annotation_group_samples: vec![],
453        };
454        let mut dependencies = base_dependencies(&start_node, &end_node);
455        dependencies.sequences.push(seq.clone());
456        dependencies.nodes.push(node.clone());
457        dependencies.edges.extend(old_edges);
458
459        let diff_graphs = get_diff_graph(&changes, &dependencies);
460        let graph = diff_graphs.get(&block_group.id).expect("block group graph");
461
462        let block_nodes = graph
463            .nodes()
464            .filter(|node_ref| node_ref.node.node_id == node.id)
465            .collect::<Vec<_>>();
466        assert_eq!(block_nodes.len(), 3);
467        assert!(block_nodes.iter().all(|node_ref| !node_ref.is_new));
468
469        let new_node_block = graph
470            .nodes()
471            .find(|node_ref| node_ref.node.node_id == new_node.id)
472            .expect("new node block");
473        assert!(new_node_block.is_new);
474
475        let block_node_insert_start = graph
476            .nodes()
477            .find(|node_ref| node_ref.node.node_id == node.id && node_ref.node.sequence_end == 5)
478            .expect("insert node block");
479
480        let internal_edges =
481            find_edge(graph, block_node_insert_start, new_node_block).expect("internal edge");
482        assert_eq!(internal_edges.len(), 1);
483        assert!(internal_edges[0].is_new);
484    }
485}