Skip to main content

gen_models/
node.rs

1use std::collections::HashMap;
2
3use gen_core::{HashId, PATH_END_NODE_ID, PATH_START_NODE_ID, calculate_hash, traits::Capnp};
4use rusqlite::{Row, params};
5use serde::{Deserialize, Serialize};
6
7use crate::{db::GraphConnection, gen_models_capnp::node, sequence::Sequence, traits::*};
8
9#[derive(Clone, Debug, Eq, Deserialize, Hash, Serialize, PartialEq)]
10pub struct Node {
11    pub id: HashId,
12    pub sequence_hash: HashId,
13}
14
15impl<'a> Capnp<'a> for Node {
16    type Builder = node::Builder<'a>;
17    type Reader = node::Reader<'a>;
18
19    fn write_capnp(&self, builder: &mut Self::Builder) {
20        builder.set_id(&self.id.0).unwrap();
21        builder.set_sequence_hash(&self.sequence_hash.0).unwrap();
22    }
23
24    fn read_capnp(reader: Self::Reader) -> Self {
25        let id = reader
26            .get_id()
27            .unwrap()
28            .as_slice()
29            .unwrap()
30            .try_into()
31            .unwrap();
32        let sequence_hash = reader
33            .get_sequence_hash()
34            .unwrap()
35            .as_slice()
36            .unwrap()
37            .try_into()
38            .unwrap();
39
40        Node { id, sequence_hash }
41    }
42}
43
44impl Query for Node {
45    type Model = Node;
46
47    const TABLE_NAME: &'static str = "nodes";
48
49    fn process_row(row: &Row) -> Self::Model {
50        Node {
51            id: row.get(0).unwrap(),
52            sequence_hash: row.get(1).unwrap(),
53        }
54    }
55}
56
57impl Node {
58    pub fn create(conn: &GraphConnection, sequence_hash: &HashId, node_hash: &HashId) -> HashId {
59        let insert_statement = "INSERT INTO nodes (id, sequence_hash) VALUES (?1, ?2);";
60        let mut stmt = conn.prepare_cached(insert_statement).unwrap();
61        match stmt.execute(params![node_hash, sequence_hash]) {
62            Ok(_) => *node_hash,
63            Err(rusqlite::Error::SqliteFailure(err, _details)) => {
64                if err.code == rusqlite::ErrorCode::ConstraintViolation {
65                    *node_hash
66                } else {
67                    panic!("something bad happened querying the database")
68                }
69            }
70            Err(_) => {
71                panic!("something bad happened querying the database")
72            }
73        }
74    }
75
76    pub fn get_sequences_by_node_ids(
77        conn: &GraphConnection,
78        node_ids: &[HashId],
79    ) -> HashMap<HashId, Sequence> {
80        let nodes = Node::query_by_ids(conn, node_ids);
81        let sequence_hashes_by_node_id = nodes
82            .iter()
83            .map(|node| (node.id, node.sequence_hash))
84            .collect::<HashMap<HashId, HashId>>();
85        let sequences_by_hash: HashMap<HashId, Sequence> = HashMap::from_iter(
86            Sequence::query_by_ids(
87                conn,
88                &sequence_hashes_by_node_id
89                    .values()
90                    .cloned()
91                    .collect::<Vec<_>>(),
92            )
93            .iter()
94            .map(|seq| (seq.hash, seq.clone())),
95        );
96        sequence_hashes_by_node_id
97            .clone()
98            .into_iter()
99            .map(|(node_id, sequence_hash)| {
100                (
101                    node_id,
102                    sequences_by_hash.get(&sequence_hash).unwrap().clone(),
103                )
104            })
105            .collect::<HashMap<HashId, Sequence>>()
106    }
107
108    pub fn get_start_node() -> Node {
109        Node {
110            id: PATH_START_NODE_ID,
111            sequence_hash: HashId(calculate_hash(
112                "start-node-yyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyy",
113            )),
114        }
115    }
116
117    pub fn get_end_node() -> Node {
118        Node {
119            id: PATH_END_NODE_ID,
120            sequence_hash: HashId(calculate_hash(
121                "end-node-zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz",
122            )),
123        }
124    }
125}
126#[cfg(test)]
127mod tests {
128    use capnp::message::TypedBuilder;
129
130    use super::*;
131
132    #[test]
133    fn test_capnp_serialization() {
134        let node = Node {
135            id: HashId::convert_str("1"),
136            sequence_hash: HashId::convert_str("test_sequence_hash"),
137        };
138
139        let mut message = TypedBuilder::<node::Owned>::new_default();
140        let mut root = message.init_root();
141        node.write_capnp(&mut root);
142
143        let deserialized = Node::read_capnp(root.into_reader());
144        assert_eq!(node, deserialized);
145    }
146}