1use std::{collections::HashMap, rc::Rc};
2
3use gen_core::{HashId, traits::Capnp};
4use rusqlite::{Result as SQLResult, Row, params, types::Value};
5use serde::{Deserialize, Serialize};
6use thiserror::Error;
7
8use crate::{db::OperationsConnection, gen_models_capnp::gen_database, traits::*};
9
10#[derive(Debug, Error)]
11pub enum GenDatabaseError {
12 #[error("Database error: {0}")]
13 DatabaseError(#[from] rusqlite::Error),
14}
15
16#[derive(Clone, Debug, Eq, Hash, Serialize, Deserialize, PartialEq)]
17pub struct GenDatabase {
18 pub db_uuid: String,
19 pub name: String,
20 pub path: String,
21}
22
23impl<'a> Capnp<'a> for GenDatabase {
24 type Builder = gen_database::Builder<'a>;
25 type Reader = gen_database::Reader<'a>;
26
27 fn write_capnp(&self, builder: &mut Self::Builder) {
28 builder.set_db_uuid(&self.db_uuid);
29 builder.set_name(&self.name);
30 builder.set_path(&self.path);
31 }
32
33 fn read_capnp(reader: Self::Reader) -> Self {
34 let db_uuid = reader.get_db_uuid().unwrap().to_string().unwrap();
35 let name = reader.get_name().unwrap().to_string().unwrap();
36 let path = reader.get_path().unwrap().to_string().unwrap();
37
38 GenDatabase {
39 db_uuid,
40 name,
41 path,
42 }
43 }
44}
45
46impl Query for GenDatabase {
47 type Model = GenDatabase;
48
49 const PRIMARY_KEY: &'static str = "db_uuid";
50 const TABLE_NAME: &'static str = "gen_databases";
51
52 fn process_row(row: &Row) -> Self::Model {
53 GenDatabase {
54 db_uuid: row.get(0).unwrap(),
55 name: row.get(1).unwrap(),
56 path: row.get(2).unwrap(),
57 }
58 }
59}
60
61impl GenDatabase {
62 pub fn create(
63 conn: &OperationsConnection,
64 db_uuid: &str,
65 name: &str,
66 path: &str,
67 ) -> SQLResult<GenDatabase> {
68 let query = "INSERT INTO gen_databases (db_uuid, name, path) VALUES (?1, ?2, ?3);";
69 let mut stmt = conn.prepare(query)?;
70 stmt.execute(params![db_uuid, name, path])?;
71 Ok(GenDatabase {
72 db_uuid: db_uuid.to_string(),
73 name: name.to_string(),
74 path: path.to_string(),
75 })
76 }
77
78 pub fn delete_by_uuid(conn: &OperationsConnection, db_uuid: &str) -> SQLResult<GenDatabase> {
79 GenDatabase::get(
80 conn,
81 "DELETE FROM gen_databases WHERE db_uuid = ?1",
82 params![db_uuid],
83 )
84 }
85
86 pub fn get_by_uuid(conn: &OperationsConnection, db_uuid: &str) -> SQLResult<GenDatabase> {
87 GenDatabase::get(
88 conn,
89 "SELECT * FROM gen_databases WHERE db_uuid = ?1",
90 params![db_uuid],
91 )
92 }
93
94 pub fn get_by_path(conn: &OperationsConnection, path: &str) -> SQLResult<GenDatabase> {
95 GenDatabase::get(
96 conn,
97 "SELECT * FROM gen_databases WHERE path = ?1",
98 params![path],
99 )
100 }
101
102 pub fn get_or_create(
103 conn: &OperationsConnection,
104 db_uuid: &str,
105 name: &str,
106 path: &str,
107 ) -> SQLResult<GenDatabase> {
108 match GenDatabase::create(conn, db_uuid, name, path) {
109 Ok(new) => Ok(new),
110 Err(rusqlite::Error::SqliteFailure(err, _details)) => {
111 if err.code == rusqlite::ErrorCode::ConstraintViolation {
112 match GenDatabase::get(
113 conn,
114 "select * from gen_databases where db_uuid = ?1 AND name = ?2 AND path = ?3",
115 params![db_uuid, name, path],
116 ) {
117 Ok(result) => Ok(result),
118 Err(e) => Err(e),
119 }
120 } else {
121 panic!("something bad happened querying the database")
122 }
123 }
124 Err(_) => {
125 panic!("something bad happened.")
126 }
127 }
128 }
129
130 pub fn query_by_operations(
131 conn: &OperationsConnection,
132 operations: &[HashId],
133 ) -> Result<HashMap<HashId, Vec<GenDatabase>>, GenDatabaseError> {
134 let query = "select gd.*, od.operation_hash from gen_databases gd left join operation_databases od on (gd.db_uuid = od.database_uuid) where od.operation_hash in rarray(?1)";
135 let mut stmt = conn.prepare(query).unwrap();
136 let rows = stmt
137 .query_map(
138 params![Rc::new(
139 operations
140 .iter()
141 .map(|h| Value::from(*h))
142 .collect::<Vec<Value>>()
143 )],
144 |row| Ok((GenDatabase::process_row(row), row.get::<_, HashId>(3)?)),
145 )
146 .unwrap();
147 rows.into_iter()
148 .try_fold(HashMap::new(), |mut acc: HashMap<_, Vec<_>>, row| {
149 let (item, hash) = row?;
150 acc.entry(hash).or_default().push(item);
151 Ok(acc)
152 })
153 }
154}
155
156#[cfg(test)]
157mod tests {
158 use capnp::message::TypedBuilder;
159
160 use super::*;
161 use crate::test_helpers::get_operation_connection;
162
163 #[test]
164 fn test_gen_database_capnp_serialization() {
165 let gen_database = GenDatabase {
166 db_uuid: "test-uuid-123".to_string(),
167 name: "test_database".to_string(),
168 path: "/path/to/test.db".to_string(),
169 };
170
171 let mut message = TypedBuilder::<gen_database::Owned>::new_default();
172 let mut root = message.init_root();
173 gen_database.write_capnp(&mut root);
174
175 let deserialized = GenDatabase::read_capnp(root.into_reader());
176 assert_eq!(gen_database, deserialized);
177 }
178
179 #[test]
180 fn test_create_gen_database() {
181 let conn = get_operation_connection(None).unwrap();
182
183 let db = GenDatabase::create(&conn, "test-uuid-123", "test_db", "path/to/db.db").unwrap();
184
185 assert_eq!(db.db_uuid, "test-uuid-123");
186 assert_eq!(db.name, "test_db");
187 assert_eq!(db.path, "path/to/db.db");
188 }
189
190 #[test]
191 fn test_get_by_uuid() {
192 let conn = get_operation_connection(None).unwrap();
193
194 let created_db =
195 GenDatabase::create(&conn, "test-uuid-456", "test_db2", "path/to/db2.db").unwrap();
196
197 let retrieved_db = GenDatabase::get_by_uuid(&conn, &created_db.db_uuid).unwrap();
198
199 assert_eq!(retrieved_db, created_db);
200 }
201
202 #[test]
203 fn test_get_by_path() {
204 let conn = get_operation_connection(None).unwrap();
205
206 let created_db =
207 GenDatabase::create(&conn, "test-uuid-789", "test_db3", "path/to/db3.db").unwrap();
208
209 let retrieved_db = GenDatabase::get_by_path(&conn, "path/to/db3.db").unwrap();
210
211 assert_eq!(retrieved_db, created_db);
212 }
213
214 #[test]
215 fn test_get_or_create_existing() {
216 let conn = get_operation_connection(None).unwrap();
217
218 let created_db = GenDatabase::create(
219 &conn,
220 "test-uuid-existing",
221 "existing_db",
222 "path/to/existing.db",
223 )
224 .unwrap();
225
226 let retrieved_db = GenDatabase::get_or_create(
228 &conn,
229 "test-uuid-existing",
230 "existing_db",
231 "path/to/existing.db",
232 )
233 .unwrap();
234
235 assert_eq!(retrieved_db, created_db); }
237
238 #[test]
239 fn test_get_or_create_conflict() {
240 let conn = get_operation_connection(None).unwrap();
241
242 let _ = GenDatabase::create(
244 &conn,
245 "test-uuid-existing",
246 "existing_db",
247 "path/to/existing.db",
248 )
249 .unwrap();
250
251 let retrieved_db = GenDatabase::get_or_create(
253 &conn,
254 "test-uuid-existing",
255 "something_else",
256 "path/to/something_else.db",
257 );
258 assert!(retrieved_db.is_err())
259 }
260
261 #[test]
262 fn test_get_or_create_new() {
263 let conn = get_operation_connection(None).unwrap();
264
265 let new_db =
267 GenDatabase::get_or_create(&conn, "test-uuid-new", "new_db", "path/to/new.db").unwrap();
268
269 assert_eq!(new_db.db_uuid, "test-uuid-new");
270 assert_eq!(new_db.name, "new_db");
271 assert_eq!(new_db.path, "path/to/new.db");
272 }
273
274 #[test]
275 fn test_get_by_uuid_not_found() {
276 let conn = get_operation_connection(None).unwrap();
277
278 let result = GenDatabase::get_by_uuid(&conn, "non-existing-uuid");
279
280 assert!(result.is_err());
281 }
282
283 #[test]
284 fn test_get_by_path_not_found() {
285 let conn = get_operation_connection(None).unwrap();
286
287 let result = GenDatabase::get_by_path(&conn, "non/existing/path.db");
288
289 assert!(result.is_err());
290 }
291}