1use std::collections::HashMap;
2
3use gen_core::{HashId, PATH_END_NODE_ID, PATH_START_NODE_ID, calculate_hash, traits::Capnp};
4use rusqlite::{Row, params};
5use serde::{Deserialize, Serialize};
6
7use crate::{db::GraphConnection, gen_models_capnp::node, sequence::Sequence, traits::*};
8
9#[derive(Clone, Debug, Eq, Deserialize, Hash, Serialize, PartialEq)]
10pub struct Node {
11 pub id: HashId,
12 pub sequence_hash: HashId,
13}
14
15impl<'a> Capnp<'a> for Node {
16 type Builder = node::Builder<'a>;
17 type Reader = node::Reader<'a>;
18
19 fn write_capnp(&self, builder: &mut Self::Builder) {
20 builder.set_id(&self.id.0).unwrap();
21 builder.set_sequence_hash(&self.sequence_hash.0).unwrap();
22 }
23
24 fn read_capnp(reader: Self::Reader) -> Self {
25 let id = reader
26 .get_id()
27 .unwrap()
28 .as_slice()
29 .unwrap()
30 .try_into()
31 .unwrap();
32 let sequence_hash = reader
33 .get_sequence_hash()
34 .unwrap()
35 .as_slice()
36 .unwrap()
37 .try_into()
38 .unwrap();
39
40 Node { id, sequence_hash }
41 }
42}
43
44impl Query for Node {
45 type Model = Node;
46
47 const TABLE_NAME: &'static str = "nodes";
48
49 fn process_row(row: &Row) -> Self::Model {
50 Node {
51 id: row.get(0).unwrap(),
52 sequence_hash: row.get(1).unwrap(),
53 }
54 }
55}
56
57impl Node {
58 pub fn create(conn: &GraphConnection, sequence_hash: &HashId, node_hash: &HashId) -> HashId {
59 let insert_statement = "INSERT INTO nodes (id, sequence_hash) VALUES (?1, ?2);";
60 let mut stmt = conn.prepare_cached(insert_statement).unwrap();
61 match stmt.execute(params![node_hash, sequence_hash]) {
62 Ok(_) => *node_hash,
63 Err(rusqlite::Error::SqliteFailure(err, _details)) => {
64 if err.code == rusqlite::ErrorCode::ConstraintViolation {
65 *node_hash
66 } else {
67 panic!("something bad happened querying the database")
68 }
69 }
70 Err(_) => {
71 panic!("something bad happened querying the database")
72 }
73 }
74 }
75
76 pub fn get_sequences_by_node_ids(
77 conn: &GraphConnection,
78 node_ids: &[HashId],
79 ) -> HashMap<HashId, Sequence> {
80 let nodes = Node::query_by_ids(conn, node_ids);
81 let sequence_hashes_by_node_id = nodes
82 .iter()
83 .map(|node| (node.id, node.sequence_hash))
84 .collect::<HashMap<HashId, HashId>>();
85 let sequences_by_hash: HashMap<HashId, Sequence> = HashMap::from_iter(
86 Sequence::query_by_ids(
87 conn,
88 &sequence_hashes_by_node_id
89 .values()
90 .cloned()
91 .collect::<Vec<_>>(),
92 )
93 .iter()
94 .map(|seq| (seq.hash, seq.clone())),
95 );
96 sequence_hashes_by_node_id
97 .clone()
98 .into_iter()
99 .map(|(node_id, sequence_hash)| {
100 (
101 node_id,
102 sequences_by_hash.get(&sequence_hash).unwrap().clone(),
103 )
104 })
105 .collect::<HashMap<HashId, Sequence>>()
106 }
107
108 pub fn get_start_node() -> Node {
109 Node {
110 id: PATH_START_NODE_ID,
111 sequence_hash: HashId(calculate_hash(
112 "start-node-yyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyy",
113 )),
114 }
115 }
116
117 pub fn get_end_node() -> Node {
118 Node {
119 id: PATH_END_NODE_ID,
120 sequence_hash: HashId(calculate_hash(
121 "end-node-zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz",
122 )),
123 }
124 }
125}
126#[cfg(test)]
127mod tests {
128 use capnp::message::TypedBuilder;
129
130 use super::*;
131
132 #[test]
133 fn test_capnp_serialization() {
134 let node = Node {
135 id: HashId::convert_str("1"),
136 sequence_hash: HashId::convert_str("test_sequence_hash"),
137 };
138
139 let mut message = TypedBuilder::<node::Owned>::new_default();
140 let mut root = message.init_root();
141 node.write_capnp(&mut root);
142
143 let deserialized = Node::read_capnp(root.into_reader());
144 assert_eq!(node, deserialized);
145 }
146}