Skip to main content

gen_models/
path_edge.rs

1use std::{collections::HashMap, rc::Rc};
2
3use gen_core::{HashId, calculate_hash, traits::Capnp};
4use itertools::Itertools;
5use rusqlite::{self, Row, params, types::Value};
6use serde::{Deserialize, Serialize};
7
8use crate::{
9    db::GraphConnection, edge::Edge, gen_models_capnp::path_edge as PathEdgeCapnp, traits::*,
10};
11
12#[derive(Clone, Debug, PartialEq, Deserialize, Serialize)]
13pub struct PathEdge {
14    pub id: HashId,
15    pub path_id: HashId,
16    pub edge_id: HashId,
17    pub index_in_path: i64,
18}
19
20impl<'a> Capnp<'a> for PathEdge {
21    type Builder = PathEdgeCapnp::Builder<'a>;
22    type Reader = PathEdgeCapnp::Reader<'a>;
23
24    fn write_capnp(&self, builder: &mut Self::Builder) {
25        builder.set_id(&self.id.0).unwrap();
26        builder.set_path_id(&self.path_id.0).unwrap();
27        builder.set_index_in_path(self.index_in_path);
28        builder.set_edge_id(&self.edge_id.0).unwrap();
29    }
30
31    fn read_capnp(reader: Self::Reader) -> Self {
32        let id = reader
33            .get_id()
34            .unwrap()
35            .as_slice()
36            .unwrap()
37            .try_into()
38            .unwrap();
39        let path_id = reader
40            .get_path_id()
41            .unwrap()
42            .as_slice()
43            .unwrap()
44            .try_into()
45            .unwrap();
46        let index_in_path = reader.get_index_in_path();
47        let edge_id = reader
48            .get_edge_id()
49            .unwrap()
50            .as_slice()
51            .unwrap()
52            .try_into()
53            .unwrap();
54
55        PathEdge {
56            id,
57            path_id,
58            index_in_path,
59            edge_id,
60        }
61    }
62}
63
64impl Query for PathEdge {
65    type Model = PathEdge;
66
67    const TABLE_NAME: &'static str = "path_edges";
68
69    fn process_row(row: &Row) -> Self::Model {
70        PathEdge {
71            id: row.get(0).unwrap(),
72            path_id: row.get(1).unwrap(),
73            edge_id: row.get(2).unwrap(),
74            index_in_path: row.get(3).unwrap(),
75        }
76    }
77}
78
79impl PathEdge {
80    pub fn create(
81        conn: &GraphConnection,
82        path_id: &HashId,
83        index_in_path: i64,
84        edge_id: HashId,
85    ) -> PathEdge {
86        let query =
87            "INSERT INTO path_edges (id, path_id, edge_id, index_in_path) VALUES (?1, ?2, ?3, ?4);";
88        let mut stmt = conn.prepare(query).unwrap();
89        let hash = HashId(calculate_hash(&format!(
90            "{path_id}:{edge_id}:{index_in_path}"
91        )));
92        match stmt.execute(params![hash, path_id, edge_id, index_in_path]) {
93            Ok(_) => {}
94            Err(rusqlite::Error::SqliteFailure(err, _details)) => {
95                if err.code != rusqlite::ErrorCode::ConstraintViolation {
96                    panic!("something bad happened querying the database")
97                }
98            }
99            Err(_) => {
100                panic!("something bad happened querying the database")
101            }
102        }
103        PathEdge {
104            id: hash,
105            path_id: *path_id,
106            index_in_path,
107            edge_id,
108        }
109    }
110
111    pub fn bulk_create(conn: &GraphConnection, path_id: &HashId, edge_ids: &[HashId]) {
112        let batch_size = max_rows_per_batch(conn, 4);
113
114        for (index1, chunk) in edge_ids.chunks(batch_size).enumerate() {
115            let mut rows_to_insert = vec![];
116            let mut params: Vec<Box<dyn rusqlite::ToSql>> = Vec::new();
117            for (index2, edge_id) in chunk.iter().enumerate() {
118                rows_to_insert.push("(?, ?, ?, ?)".to_string());
119                let index_in = index1 * 100000 + index2;
120                let hash = HashId(calculate_hash(&format!("{path_id}:{edge_id}:{index_in}")));
121                params.push(Box::new(hash));
122                params.push(Box::new(path_id));
123                params.push(Box::new(edge_id));
124                params.push(Box::new(index_in));
125            }
126
127            let sql = format!(
128                "INSERT OR IGNORE INTO path_edges (id, path_id, edge_id, index_in_path) VALUES {};",
129                rows_to_insert.join(", ")
130            );
131
132            let mut stmt = conn.prepare(&sql).unwrap();
133            stmt.execute(rusqlite::params_from_iter(params)).unwrap();
134        }
135    }
136
137    pub fn delete(conn: &GraphConnection, path_id: &HashId) {
138        let statement = "DELETE from path_edges WHERE path_id = ?1;";
139        conn.execute(statement, params![path_id]).unwrap();
140    }
141
142    pub fn edges_for_path(conn: &GraphConnection, path_id: &HashId) -> Vec<Edge> {
143        let path_edges = PathEdge::query(
144            conn,
145            "select * from path_edges where path_id = ?1 order by index_in_path ASC",
146            params![path_id],
147        );
148        let edge_ids = path_edges
149            .into_iter()
150            .map(|path_edge| path_edge.edge_id)
151            .collect::<Vec<HashId>>();
152        let edges = Edge::query_by_ids(conn, &edge_ids);
153        let edges_by_id = edges
154            .into_iter()
155            .map(|edge| (edge.id, edge))
156            .collect::<HashMap<HashId, Edge>>();
157        edge_ids
158            .into_iter()
159            .map(|edge_id| edges_by_id[&edge_id].clone())
160            .collect::<Vec<Edge>>()
161    }
162
163    pub fn edges_for_paths(
164        conn: &GraphConnection,
165        path_ids: Vec<HashId>,
166    ) -> HashMap<HashId, Vec<Edge>> {
167        let query_path_ids = path_ids
168            .iter()
169            .map(|path_id| Value::from(*path_id))
170            .collect::<Vec<Value>>();
171        let path_edges = PathEdge::query(
172            conn,
173            "select * from path_edges where path_id in rarray(?1) ORDER BY path_id, index_in_path",
174            params![Rc::new(query_path_ids)],
175        );
176        let edge_ids = path_edges
177            .iter()
178            .map(|path_edge| path_edge.edge_id)
179            .collect::<Vec<_>>();
180        let edges = Edge::query_by_ids(conn, &edge_ids);
181        let edges_by_id = edges
182            .into_iter()
183            .map(|edge| (edge.id, edge))
184            .collect::<HashMap<_, Edge>>();
185        let path_edges_by_path_id = path_edges
186            .iter()
187            .map(|path_edge| (path_edge.path_id, path_edge.edge_id))
188            .into_group_map();
189        path_edges_by_path_id
190            .into_iter()
191            .map(|(path_id, edge_ids)| {
192                (
193                    path_id,
194                    edge_ids
195                        .into_iter()
196                        .map(|edge_id| edges_by_id[&edge_id].clone())
197                        .collect::<Vec<Edge>>(),
198                )
199            })
200            .collect::<HashMap<HashId, Vec<Edge>>>()
201    }
202}
203
204#[cfg(test)]
205mod tests {
206    use capnp::message::TypedBuilder;
207
208    use super::*;
209    use crate::gen_models_capnp::path_edge;
210
211    #[test]
212    fn test_path_edge_capnp_serialization() {
213        let path_edge = PathEdge {
214            id: HashId::pad_str(100),
215            path_id: HashId::pad_str(200),
216            index_in_path: 5,
217            edge_id: HashId::pad_str(300),
218        };
219
220        let mut message = TypedBuilder::<path_edge::Owned>::new_default();
221        let mut root = message.init_root();
222        path_edge.write_capnp(&mut root);
223
224        let deserialized = PathEdge::read_capnp(root.into_reader());
225        assert_eq!(path_edge, deserialized);
226    }
227}