1use std::{collections::HashMap, rc::Rc};
2
3use gen_core::{HashId, calculate_hash, traits::Capnp};
4use itertools::Itertools;
5use rusqlite::{self, Row, params, types::Value};
6use serde::{Deserialize, Serialize};
7
8use crate::{
9 db::GraphConnection, edge::Edge, gen_models_capnp::path_edge as PathEdgeCapnp, traits::*,
10};
11
12#[derive(Clone, Debug, PartialEq, Deserialize, Serialize)]
13pub struct PathEdge {
14 pub id: HashId,
15 pub path_id: HashId,
16 pub edge_id: HashId,
17 pub index_in_path: i64,
18}
19
20impl<'a> Capnp<'a> for PathEdge {
21 type Builder = PathEdgeCapnp::Builder<'a>;
22 type Reader = PathEdgeCapnp::Reader<'a>;
23
24 fn write_capnp(&self, builder: &mut Self::Builder) {
25 builder.set_id(&self.id.0).unwrap();
26 builder.set_path_id(&self.path_id.0).unwrap();
27 builder.set_index_in_path(self.index_in_path);
28 builder.set_edge_id(&self.edge_id.0).unwrap();
29 }
30
31 fn read_capnp(reader: Self::Reader) -> Self {
32 let id = reader
33 .get_id()
34 .unwrap()
35 .as_slice()
36 .unwrap()
37 .try_into()
38 .unwrap();
39 let path_id = reader
40 .get_path_id()
41 .unwrap()
42 .as_slice()
43 .unwrap()
44 .try_into()
45 .unwrap();
46 let index_in_path = reader.get_index_in_path();
47 let edge_id = reader
48 .get_edge_id()
49 .unwrap()
50 .as_slice()
51 .unwrap()
52 .try_into()
53 .unwrap();
54
55 PathEdge {
56 id,
57 path_id,
58 index_in_path,
59 edge_id,
60 }
61 }
62}
63
64impl Query for PathEdge {
65 type Model = PathEdge;
66
67 const TABLE_NAME: &'static str = "path_edges";
68
69 fn process_row(row: &Row) -> Self::Model {
70 PathEdge {
71 id: row.get(0).unwrap(),
72 path_id: row.get(1).unwrap(),
73 edge_id: row.get(2).unwrap(),
74 index_in_path: row.get(3).unwrap(),
75 }
76 }
77}
78
79impl PathEdge {
80 pub fn create(
81 conn: &GraphConnection,
82 path_id: &HashId,
83 index_in_path: i64,
84 edge_id: HashId,
85 ) -> PathEdge {
86 let query =
87 "INSERT INTO path_edges (id, path_id, edge_id, index_in_path) VALUES (?1, ?2, ?3, ?4);";
88 let mut stmt = conn.prepare(query).unwrap();
89 let hash = HashId(calculate_hash(&format!(
90 "{path_id}:{edge_id}:{index_in_path}"
91 )));
92 match stmt.execute(params![hash, path_id, edge_id, index_in_path]) {
93 Ok(_) => {}
94 Err(rusqlite::Error::SqliteFailure(err, _details)) => {
95 if err.code != rusqlite::ErrorCode::ConstraintViolation {
96 panic!("something bad happened querying the database")
97 }
98 }
99 Err(_) => {
100 panic!("something bad happened querying the database")
101 }
102 }
103 PathEdge {
104 id: hash,
105 path_id: *path_id,
106 index_in_path,
107 edge_id,
108 }
109 }
110
111 pub fn bulk_create(conn: &GraphConnection, path_id: &HashId, edge_ids: &[HashId]) {
112 let batch_size = max_rows_per_batch(conn, 4);
113
114 for (index1, chunk) in edge_ids.chunks(batch_size).enumerate() {
115 let mut rows_to_insert = vec![];
116 let mut params: Vec<Box<dyn rusqlite::ToSql>> = Vec::new();
117 for (index2, edge_id) in chunk.iter().enumerate() {
118 rows_to_insert.push("(?, ?, ?, ?)".to_string());
119 let index_in = index1 * 100000 + index2;
120 let hash = HashId(calculate_hash(&format!("{path_id}:{edge_id}:{index_in}")));
121 params.push(Box::new(hash));
122 params.push(Box::new(path_id));
123 params.push(Box::new(edge_id));
124 params.push(Box::new(index_in));
125 }
126
127 let sql = format!(
128 "INSERT OR IGNORE INTO path_edges (id, path_id, edge_id, index_in_path) VALUES {};",
129 rows_to_insert.join(", ")
130 );
131
132 let mut stmt = conn.prepare(&sql).unwrap();
133 stmt.execute(rusqlite::params_from_iter(params)).unwrap();
134 }
135 }
136
137 pub fn delete(conn: &GraphConnection, path_id: &HashId) {
138 let statement = "DELETE from path_edges WHERE path_id = ?1;";
139 conn.execute(statement, params![path_id]).unwrap();
140 }
141
142 pub fn edges_for_path(conn: &GraphConnection, path_id: &HashId) -> Vec<Edge> {
143 let path_edges = PathEdge::query(
144 conn,
145 "select * from path_edges where path_id = ?1 order by index_in_path ASC",
146 params![path_id],
147 );
148 let edge_ids = path_edges
149 .into_iter()
150 .map(|path_edge| path_edge.edge_id)
151 .collect::<Vec<HashId>>();
152 let edges = Edge::query_by_ids(conn, &edge_ids);
153 let edges_by_id = edges
154 .into_iter()
155 .map(|edge| (edge.id, edge))
156 .collect::<HashMap<HashId, Edge>>();
157 edge_ids
158 .into_iter()
159 .map(|edge_id| edges_by_id[&edge_id].clone())
160 .collect::<Vec<Edge>>()
161 }
162
163 pub fn edges_for_paths(
164 conn: &GraphConnection,
165 path_ids: Vec<HashId>,
166 ) -> HashMap<HashId, Vec<Edge>> {
167 let query_path_ids = path_ids
168 .iter()
169 .map(|path_id| Value::from(*path_id))
170 .collect::<Vec<Value>>();
171 let path_edges = PathEdge::query(
172 conn,
173 "select * from path_edges where path_id in rarray(?1) ORDER BY path_id, index_in_path",
174 params![Rc::new(query_path_ids)],
175 );
176 let edge_ids = path_edges
177 .iter()
178 .map(|path_edge| path_edge.edge_id)
179 .collect::<Vec<_>>();
180 let edges = Edge::query_by_ids(conn, &edge_ids);
181 let edges_by_id = edges
182 .into_iter()
183 .map(|edge| (edge.id, edge))
184 .collect::<HashMap<_, Edge>>();
185 let path_edges_by_path_id = path_edges
186 .iter()
187 .map(|path_edge| (path_edge.path_id, path_edge.edge_id))
188 .into_group_map();
189 path_edges_by_path_id
190 .into_iter()
191 .map(|(path_id, edge_ids)| {
192 (
193 path_id,
194 edge_ids
195 .into_iter()
196 .map(|edge_id| edges_by_id[&edge_id].clone())
197 .collect::<Vec<Edge>>(),
198 )
199 })
200 .collect::<HashMap<HashId, Vec<Edge>>>()
201 }
202}
203
204#[cfg(test)]
205mod tests {
206 use capnp::message::TypedBuilder;
207
208 use super::*;
209 use crate::gen_models_capnp::path_edge;
210
211 #[test]
212 fn test_path_edge_capnp_serialization() {
213 let path_edge = PathEdge {
214 id: HashId::pad_str(100),
215 path_id: HashId::pad_str(200),
216 index_in_path: 5,
217 edge_id: HashId::pad_str(300),
218 };
219
220 let mut message = TypedBuilder::<path_edge::Owned>::new_default();
221 let mut root = message.init_root();
222 path_edge.write_capnp(&mut root);
223
224 let deserialized = PathEdge::read_capnp(root.into_reader());
225 assert_eq!(path_edge, deserialized);
226 }
227}