Skip to main content

gen_diff/
graph.rs

1use std::collections::{HashMap, HashSet};
2
3use gen_core::{
4    HashId,
5    Strand::{self, Forward},
6    is_end_node, is_start_node, is_terminal,
7};
8use gen_graph::{GenGraph, GraphEdge, GraphNode};
9use gen_models::{
10    block_group_edge::BlockGroupEdge, changesets::ChangesetModels, edge::Edge, node::Node,
11    sequence::Sequence, session_operations::DependencyModels,
12};
13use itertools::Itertools;
14use petgraph::{Direction, graphmap::DiGraphMap};
15
16#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash, Ord, PartialOrd)]
17pub struct DiffGraphNode {
18    pub node: GraphNode,
19    pub is_new: bool,
20}
21
22#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash, Ord, PartialOrd)]
23pub struct DiffGraphEdge {
24    pub edge: GraphEdge,
25    pub is_new: bool,
26}
27
28pub type DiffGenGraph = DiGraphMap<DiffGraphNode, Vec<DiffGraphEdge>>;
29
30pub struct DiffGenGraphRef<'a>(pub &'a DiffGenGraph);
31
32impl<'a> From<&'a DiffGenGraph> for DiffGenGraphRef<'a> {
33    fn from(graph: &'a DiffGenGraph) -> Self {
34        Self(graph)
35    }
36}
37
38impl<'a> From<DiffGenGraphRef<'a>> for GenGraph {
39    fn from(val: DiffGenGraphRef<'a>) -> Self {
40        let mut graph = GenGraph::new();
41        for node in val.0.nodes() {
42            graph.add_node(node.node);
43        }
44        for (src, dest, edges) in val.0.all_edges() {
45            let mapped_edges = edges.iter().map(|edge| edge.edge).collect::<Vec<_>>();
46            graph.add_edge(src.node, dest.node, mapped_edges);
47        }
48        graph
49    }
50}
51
52impl From<DiffGraphNode> for GraphNode {
53    fn from(node: DiffGraphNode) -> Self {
54        node.node
55    }
56}
57
58impl From<DiffGraphEdge> for GraphEdge {
59    fn from(edge: DiffGraphEdge) -> Self {
60        edge.edge
61    }
62}
63
64pub fn get_diff_graph(
65    changes: &ChangesetModels,
66    dependencies: &DependencyModels,
67) -> HashMap<HashId, DiffGenGraph> {
68    let start_node = Node::get_start_node();
69    let end_node = Node::get_end_node();
70    let mut bges_by_bg: HashMap<HashId, Vec<&BlockGroupEdge>> = HashMap::new();
71    // the bool is marking whether or not the Edge/Node is new
72    let mut edges_by_id: HashMap<HashId, (&Edge, bool)> = HashMap::new();
73    let mut nodes_by_id: HashMap<HashId, (&Node, bool)> = HashMap::new();
74    nodes_by_id.insert(start_node.id, (&start_node, false));
75    nodes_by_id.insert(end_node.id, (&end_node, false));
76    let mut sequences_by_hash: HashMap<HashId, &Sequence> = HashMap::new();
77    let mut block_graphs: HashMap<HashId, DiffGenGraph> = HashMap::new();
78
79    for bge in changes.block_group_edges.iter() {
80        bges_by_bg
81            .entry(bge.block_group_id)
82            .and_modify(|l| l.push(bge))
83            .or_insert_with(|| vec![bge]);
84    }
85    for edge in dependencies.edges.iter() {
86        edges_by_id.insert(edge.id, (edge, false));
87    }
88    for edge in changes.edges.iter() {
89        edges_by_id.insert(edge.id, (edge, true));
90    }
91    for node in dependencies.nodes.iter() {
92        nodes_by_id.insert(node.id, (node, false));
93    }
94    for node in changes.nodes.iter() {
95        nodes_by_id.insert(node.id, (node, true));
96    }
97    for seq in dependencies
98        .sequences
99        .iter()
100        .chain(changes.sequences.iter())
101    {
102        sequences_by_hash.insert(seq.hash, seq);
103    }
104
105    for (bg_id, bg_edges) in bges_by_bg.iter() {
106        let mut graph: DiGraphMap<HashId, Vec<(i64, i64, bool)>> = DiGraphMap::new();
107        let mut block_graph = DiffGenGraph::new();
108        block_graph.add_node(DiffGraphNode {
109            node: GraphNode {
110                node_id: start_node.id,
111                sequence_start: 0,
112                sequence_end: 0,
113            },
114            is_new: false,
115        });
116        block_graph.add_node(DiffGraphNode {
117            node: GraphNode {
118                node_id: end_node.id,
119                sequence_start: 0,
120                sequence_end: 0,
121            },
122            is_new: false,
123        });
124        for bg_edge in bg_edges {
125            let (edge, edge_is_new) = *edges_by_id.get(&bg_edge.edge_id).unwrap();
126            if let Some(weights) = graph.edge_weight_mut(edge.source_node_id, edge.target_node_id) {
127                weights.push((edge.source_coordinate, edge.target_coordinate, edge_is_new));
128            } else {
129                graph.add_edge(
130                    edge.source_node_id,
131                    edge.target_node_id,
132                    vec![(edge.source_coordinate, edge.target_coordinate, edge_is_new)],
133                );
134            }
135        }
136
137        for node in graph.nodes() {
138            if is_terminal(node) {
139                continue;
140            }
141            let in_ports = graph
142                .edges_directed(node, Direction::Incoming)
143                .flat_map(|(_src, _dest, weights)| {
144                    weights.iter().map(|(_, tp, _)| *tp).collect::<Vec<_>>()
145                })
146                .collect::<Vec<_>>();
147            let out_ports = graph
148                .edges_directed(node, Direction::Outgoing)
149                .flat_map(|(_src, _dest, weights)| {
150                    weights.iter().map(|(fp, _tp, _)| *fp).collect::<Vec<_>>()
151                })
152                .collect::<Vec<_>>();
153
154            let (node_obj, node_is_new) = *nodes_by_id.get(&node).unwrap();
155            let sequence = *sequences_by_hash.get(&node_obj.sequence_hash).unwrap();
156            let s_len = sequence.length;
157            let mut block_starts: HashSet<i64> = HashSet::from_iter(in_ports.iter().copied());
158            block_starts.insert(0);
159            for x in out_ports.iter() {
160                if *x < s_len - 1 {
161                    block_starts.insert(*x);
162                }
163            }
164            let mut block_ends: HashSet<i64> = HashSet::from_iter(out_ports.iter().copied());
165            block_ends.insert(s_len);
166            for x in in_ports.iter() {
167                if *x > 0 {
168                    block_ends.insert(*x);
169                }
170            }
171
172            let block_starts = block_starts.into_iter().sorted().collect::<Vec<_>>();
173            let block_ends = block_ends.into_iter().sorted().collect::<Vec<_>>();
174
175            let mut blocks = vec![];
176            for (i, j) in block_starts.iter().zip(block_ends.iter()) {
177                let node = DiffGraphNode {
178                    node: GraphNode {
179                        node_id: node,
180                        sequence_start: *i,
181                        sequence_end: *j,
182                    },
183                    is_new: node_is_new,
184                };
185                block_graph.add_node(node);
186                blocks.push(node);
187            }
188
189            for (i, j) in blocks.iter().tuple_windows() {
190                block_graph.add_edge(
191                    *i,
192                    *j,
193                    vec![DiffGraphEdge {
194                        edge: GraphEdge {
195                            edge_id: HashId::pad_str(1),
196                            source_strand: Strand::Forward,
197                            target_strand: Forward,
198                            chromosome_index: 0,
199                            phased: 0,
200                            created_on: 0,
201                        },
202                        is_new: node_is_new,
203                    }],
204                );
205            }
206        }
207
208        for (src, dest, weights) in graph.all_edges() {
209            for (fp, tp, edge_is_new) in weights {
210                if !(is_end_node(src) && is_start_node(dest)) {
211                    let source_block = block_graph
212                        .nodes()
213                        .find(|node| node.node.node_id == src && node.node.sequence_end == *fp)
214                        .unwrap();
215                    let dest_block = block_graph
216                        .nodes()
217                        .find(|node| node.node.node_id == dest && node.node.sequence_start == *tp)
218                        .unwrap();
219                    block_graph.add_edge(
220                        source_block,
221                        dest_block,
222                        vec![DiffGraphEdge {
223                            edge: GraphEdge {
224                                edge_id: HashId::pad_str(1),
225                                source_strand: Strand::Forward,
226                                target_strand: Forward,
227                                chromosome_index: 0,
228                                phased: 0,
229                                created_on: 0,
230                            },
231                            is_new: *edge_is_new,
232                        }],
233                    );
234                }
235            }
236        }
237
238        block_graphs.insert(*bg_id, block_graph);
239    }
240
241    block_graphs
242}
243
244#[cfg(test)]
245mod tests {
246    use gen_core::{HashId, Strand};
247    use gen_graph::{GraphEdge, GraphNode};
248    use gen_models::{
249        block_group::BlockGroup,
250        block_group_edge::BlockGroupEdge,
251        changesets::ChangesetModels,
252        edge::Edge,
253        node::Node,
254        sequence::{NewSequence, Sequence},
255        session_operations::DependencyModels,
256    };
257
258    use super::*;
259
260    fn base_dependencies(start_node: &Node, end_node: &Node) -> DependencyModels {
261        let mut start_sequence = Sequence::new()
262            .sequence_type("DNA")
263            .sequence("")
264            .name("start")
265            .build();
266        start_sequence.hash = start_node.sequence_hash;
267        let mut end_sequence = Sequence::new()
268            .sequence_type("DNA")
269            .sequence("")
270            .name("end")
271            .build();
272        end_sequence.hash = end_node.sequence_hash;
273        DependencyModels {
274            collections: vec![],
275            samples: vec![],
276            sequences: vec![start_sequence, end_sequence],
277            block_group: vec![],
278            nodes: vec![start_node.clone(), end_node.clone()],
279            edges: vec![],
280            paths: vec![],
281            accessions: vec![],
282            accession_edges: vec![],
283        }
284    }
285
286    fn find_edge(
287        graph: &DiffGenGraph,
288        src: DiffGraphNode,
289        dest: DiffGraphNode,
290    ) -> Option<&Vec<DiffGraphEdge>> {
291        graph
292            .all_edges()
293            .find(|(edge_src, edge_dest, _)| *edge_src == src && *edge_dest == dest)
294            .map(|(_, _, edges)| edges)
295    }
296
297    #[test]
298    fn diff_graph_to_gen_graph_maps_nodes_and_edges() {
299        let node_a = DiffGraphNode {
300            node: GraphNode {
301                node_id: HashId::pad_str(1),
302                sequence_start: 0,
303                sequence_end: 5,
304            },
305            is_new: true,
306        };
307        let node_b = DiffGraphNode {
308            node: GraphNode {
309                node_id: HashId::pad_str(2),
310                sequence_start: 5,
311                sequence_end: 10,
312            },
313            is_new: false,
314        };
315        let mut diff_graph = DiffGenGraph::new();
316        diff_graph.add_node(node_a);
317        diff_graph.add_node(node_b);
318        let edge = DiffGraphEdge {
319            edge: GraphEdge {
320                edge_id: HashId::pad_str(9),
321                source_strand: Strand::Forward,
322                target_strand: Strand::Forward,
323                chromosome_index: 0,
324                phased: 0,
325                created_on: 0,
326            },
327            is_new: true,
328        };
329        diff_graph.add_edge(node_a, node_b, vec![edge]);
330
331        let graph: GenGraph = DiffGenGraphRef(&diff_graph).into();
332        assert_eq!(graph.nodes().count(), 2);
333        assert_eq!(graph.all_edges().count(), 1);
334        let weights = graph
335            .all_edges()
336            .next()
337            .map(|(_, _, edges)| edges.clone())
338            .expect("graph edge");
339        assert_eq!(weights, vec![edge.edge]);
340    }
341
342    #[test]
343    fn get_diff_graph_splits_blocks_and_marks_new() {
344        let start_node = Node::get_start_node();
345        let end_node = Node::get_end_node();
346        let block_group = BlockGroup {
347            id: HashId::pad_str(10),
348            collection_name: "collection".to_string(),
349            sample_name: "sample".to_string(),
350            name: "bg".to_string(),
351            created_on: 0,
352            parent_block_group_id: None,
353            is_default: false,
354        };
355        let seq = NewSequence::new()
356            .sequence_type("dna")
357            .sequence("AAAAAAAAAA")
358            .name("seq")
359            .build();
360        let node = Node {
361            id: HashId::pad_str(11),
362            sequence_hash: seq.hash,
363        };
364        let old_edges = vec![
365            Edge {
366                id: HashId::convert_str("start-node"),
367                source_node_id: start_node.id,
368                source_coordinate: 0,
369                source_strand: Strand::Forward,
370                target_node_id: node.id,
371                target_coordinate: 0,
372                target_strand: Strand::Forward,
373            },
374            Edge {
375                id: HashId::convert_str("node-deletion"),
376                source_node_id: node.id,
377                source_coordinate: 3,
378                source_strand: Strand::Forward,
379                target_node_id: node.id,
380                target_coordinate: 5,
381                target_strand: Strand::Forward,
382            },
383            Edge {
384                id: HashId::convert_str("node-end"),
385                source_node_id: node.id,
386                source_coordinate: seq.length,
387                source_strand: Strand::Forward,
388                target_node_id: end_node.id,
389                target_coordinate: 0,
390                target_strand: Strand::Forward,
391            },
392        ];
393        let new_seq = NewSequence::new()
394            .sequence_type("dna")
395            .sequence("TTTT")
396            .name("new-seq")
397            .build();
398        let new_node = Node {
399            id: HashId::pad_str(12),
400            sequence_hash: new_seq.hash,
401        };
402        let edges = vec![
403            Edge {
404                id: HashId::convert_str("node-insertion-start"),
405                source_node_id: node.id,
406                source_coordinate: 5,
407                source_strand: Strand::Forward,
408                target_node_id: new_node.id,
409                target_coordinate: 0,
410                target_strand: Strand::Forward,
411            },
412            Edge {
413                id: HashId::convert_str("node-insertion-end"),
414                source_node_id: new_node.id,
415                source_coordinate: new_seq.length,
416                source_strand: Strand::Forward,
417                target_node_id: node.id,
418                target_coordinate: 8,
419                target_strand: Strand::Forward,
420            },
421        ];
422        let block_group_edges = edges
423            .iter()
424            .enumerate()
425            .map(|(index, edge)| BlockGroupEdge {
426                id: HashId::convert_str(&format!("bge-{index}")),
427                block_group_id: block_group.id,
428                edge_id: edge.id,
429                chromosome_index: 0,
430                phased: 0,
431                created_on: 0,
432            })
433            .collect::<Vec<_>>();
434        let changes = ChangesetModels {
435            collections: vec![],
436            samples: vec![],
437            sample_lineages: vec![],
438            sequences: vec![new_seq.clone()],
439            block_groups: vec![block_group.clone()],
440            nodes: vec![new_node.clone()],
441            edges,
442            block_group_edges,
443            paths: vec![],
444            path_edges: vec![],
445            accessions: vec![],
446            accession_edges: vec![],
447            accession_paths: vec![],
448            annotation_groups: vec![],
449            annotations: vec![],
450            annotation_group_samples: vec![],
451        };
452        let mut dependencies = base_dependencies(&start_node, &end_node);
453        dependencies.sequences.push(seq.clone());
454        dependencies.nodes.push(node.clone());
455        dependencies.edges.extend(old_edges);
456
457        let diff_graphs = get_diff_graph(&changes, &dependencies);
458        let graph = diff_graphs.get(&block_group.id).expect("block group graph");
459
460        let block_nodes = graph
461            .nodes()
462            .filter(|node_ref| node_ref.node.node_id == node.id)
463            .collect::<Vec<_>>();
464        assert_eq!(block_nodes.len(), 3);
465        assert!(block_nodes.iter().all(|node_ref| !node_ref.is_new));
466
467        let new_node_block = graph
468            .nodes()
469            .find(|node_ref| node_ref.node.node_id == new_node.id)
470            .expect("new node block");
471        assert!(new_node_block.is_new);
472
473        let block_node_insert_start = graph
474            .nodes()
475            .find(|node_ref| node_ref.node.node_id == node.id && node_ref.node.sequence_end == 5)
476            .expect("insert node block");
477
478        let internal_edges =
479            find_edge(graph, block_node_insert_start, new_node_block).expect("internal edge");
480        assert_eq!(internal_edges.len(), 1);
481        assert!(internal_edges[0].is_new);
482    }
483}