1use std::collections::HashSet;
2
3use gen_core::{HashId, Strand, calculate_hash, traits::Capnp};
4use itertools::Itertools;
5use rusqlite::{Result as SQLResult, Row, params};
6use serde::{Deserialize, Serialize};
7
8use crate::{
9 block_group_edge::AugmentedEdgeData,
10 db::GraphConnection,
11 gen_models_capnp::{accession, accession_edge, accession_path},
12 traits::*,
13};
14
15#[derive(Deserialize, Serialize, Debug, Eq, PartialEq)]
16pub struct Accession {
17 pub id: HashId,
18 pub name: String,
19 pub path_id: HashId,
20 pub parent_accession_id: Option<HashId>,
21}
22
23impl<'a> Capnp<'a> for Accession {
24 type Builder = accession::Builder<'a>;
25 type Reader = accession::Reader<'a>;
26
27 fn write_capnp(&self, builder: &mut Self::Builder) {
28 builder.set_id(&self.id.0).unwrap();
29 builder.set_name(&self.name);
30 builder.set_path_id(&self.path_id.0).unwrap();
31 match &self.parent_accession_id {
32 None => {
33 builder.reborrow().get_parent_accession_id().set_none(());
34 }
35 Some(n) => {
36 builder
37 .reborrow()
38 .get_parent_accession_id()
39 .set_some(&n.0)
40 .unwrap();
41 }
42 }
43 }
44
45 fn read_capnp(reader: Self::Reader) -> Self {
46 let id = reader
47 .get_id()
48 .unwrap()
49 .as_slice()
50 .unwrap()
51 .try_into()
52 .unwrap();
53 let name = reader.get_name().unwrap().to_string().unwrap();
54 let path_id = reader
55 .get_path_id()
56 .unwrap()
57 .as_slice()
58 .unwrap()
59 .try_into()
60 .unwrap();
61 let parent_accession_id: Option<HashId> =
62 match reader.get_parent_accession_id().which().unwrap() {
63 accession::parent_accession_id::None(()) => None,
64 accession::parent_accession_id::Some(n) => {
65 Some(n.unwrap().as_slice().unwrap().try_into().unwrap())
66 }
67 };
68
69 Accession {
70 id,
71 name,
72 path_id,
73 parent_accession_id,
74 }
75 }
76}
77
78#[derive(Deserialize, Serialize, Debug, PartialEq, Eq, Hash)]
79pub struct AccessionEdge {
80 pub id: HashId,
81 pub source_node_id: HashId,
82 pub source_coordinate: i64,
83 pub source_strand: Strand,
84 pub target_node_id: HashId,
85 pub target_coordinate: i64,
86 pub target_strand: Strand,
87 pub chromosome_index: i64,
88}
89
90impl<'a> Capnp<'a> for AccessionEdge {
91 type Builder = accession_edge::Builder<'a>;
92 type Reader = accession_edge::Reader<'a>;
93
94 fn write_capnp(&self, builder: &mut Self::Builder) {
95 builder.set_id(&self.id.0).unwrap();
96 builder.set_source_node_id(&self.source_node_id.0).unwrap();
97 builder.set_source_coordinate(self.source_coordinate);
98 builder.set_source_strand(self.source_strand.into());
99 builder.set_target_node_id(&self.target_node_id.0).unwrap();
100 builder.set_target_coordinate(self.target_coordinate);
101 builder.set_target_strand(self.target_strand.into());
102 builder.set_chromosome_index(self.chromosome_index);
103 }
104
105 fn read_capnp(reader: Self::Reader) -> Self {
106 let id = reader
107 .get_id()
108 .unwrap()
109 .as_slice()
110 .unwrap()
111 .try_into()
112 .unwrap();
113 let source_node_id = reader
114 .get_source_node_id()
115 .unwrap()
116 .as_slice()
117 .unwrap()
118 .try_into()
119 .unwrap();
120 let source_coordinate = reader.get_source_coordinate();
121 let source_strand = reader.get_source_strand().unwrap().into();
122 let target_node_id = reader
123 .get_target_node_id()
124 .unwrap()
125 .as_slice()
126 .unwrap()
127 .try_into()
128 .unwrap();
129 let target_coordinate = reader.get_target_coordinate();
130 let target_strand = reader.get_target_strand().unwrap().into();
131 let chromosome_index = reader.get_chromosome_index();
132
133 AccessionEdge {
134 id,
135 source_node_id,
136 source_coordinate,
137 source_strand,
138 target_node_id,
139 target_coordinate,
140 target_strand,
141 chromosome_index,
142 }
143 }
144}
145
146#[derive(Deserialize, Serialize, Debug, PartialEq)]
147pub struct AccessionPath {
148 pub id: HashId,
149 pub accession_id: HashId,
150 pub index_in_path: i64,
151 pub edge_id: HashId,
152}
153
154impl<'a> Capnp<'a> for AccessionPath {
155 type Builder = accession_path::Builder<'a>;
156 type Reader = accession_path::Reader<'a>;
157
158 fn write_capnp(&self, builder: &mut Self::Builder) {
159 builder.set_id(&self.id.0).unwrap();
160 builder.set_accession_id(&self.accession_id.0).unwrap();
161 builder.set_index_in_path(self.index_in_path);
162 builder.set_edge_id(&self.edge_id.0).unwrap();
163 }
164
165 fn read_capnp(reader: Self::Reader) -> Self {
166 let id = reader
167 .get_id()
168 .unwrap()
169 .as_slice()
170 .unwrap()
171 .try_into()
172 .unwrap();
173 let accession_id = reader
174 .get_accession_id()
175 .unwrap()
176 .as_slice()
177 .unwrap()
178 .try_into()
179 .unwrap();
180 let index_in_path = reader.get_index_in_path();
181 let edge_id = reader
182 .get_edge_id()
183 .unwrap()
184 .as_slice()
185 .unwrap()
186 .try_into()
187 .unwrap();
188
189 AccessionPath {
190 id,
191 accession_id,
192 index_in_path,
193 edge_id,
194 }
195 }
196}
197
198#[derive(Clone, Debug, Eq, Hash, PartialEq)]
199pub struct AccessionEdgeData {
200 pub source_node_id: HashId,
201 pub source_coordinate: i64,
202 pub source_strand: Strand,
203 pub target_node_id: HashId,
204 pub target_coordinate: i64,
205 pub target_strand: Strand,
206 pub chromosome_index: i64,
207}
208
209impl AccessionEdgeData {
210 pub fn id_hash(&self) -> HashId {
211 HashId(calculate_hash(&format!(
212 "{}:{}:{}:{}:{}:{}:{}",
213 self.source_node_id,
214 self.source_coordinate,
215 self.source_strand,
216 self.target_node_id,
217 self.target_coordinate,
218 self.target_strand,
219 self.chromosome_index
220 )))
221 }
222}
223
224impl From<&AccessionEdge> for AccessionEdgeData {
225 fn from(item: &AccessionEdge) -> Self {
226 AccessionEdgeData {
227 source_node_id: item.source_node_id,
228 source_coordinate: item.source_coordinate,
229 source_strand: item.source_strand,
230 target_node_id: item.target_node_id,
231 target_coordinate: item.target_coordinate,
232 target_strand: item.target_strand,
233 chromosome_index: item.chromosome_index,
234 }
235 }
236}
237
238impl From<&AugmentedEdgeData> for AccessionEdgeData {
239 fn from(item: &AugmentedEdgeData) -> Self {
240 AccessionEdgeData {
241 source_node_id: item.edge_data.source_node_id,
242 source_coordinate: item.edge_data.source_coordinate,
243 source_strand: item.edge_data.source_strand,
244 target_node_id: item.edge_data.target_node_id,
245 target_coordinate: item.edge_data.target_coordinate,
246 target_strand: item.edge_data.target_strand,
247 chromosome_index: item.chromosome_index,
248 }
249 }
250}
251
252impl Accession {
253 pub fn create(
254 conn: &GraphConnection,
255 name: &str,
256 path_id: &HashId,
257 parent_accession_id: Option<&HashId>,
258 ) -> SQLResult<Accession> {
259 let hash = HashId(calculate_hash(&format!(
260 "{path_id}:{parent_accession_id:?}:{name}"
261 )));
262 let query = "INSERT INTO accessions (id, name, path_id, parent_accession_id) VALUES (?1, ?2, ?3, ?4);";
263 let mut stmt = conn.prepare(query).unwrap();
264
265 stmt.execute((hash, name, path_id, parent_accession_id))?;
266 Ok(Accession {
267 id: hash,
268 name: name.to_string(),
269 path_id: *path_id,
270 parent_accession_id: parent_accession_id.copied(),
271 })
272 }
273
274 pub fn get_or_create(
275 conn: &GraphConnection,
276 name: &str,
277 path_id: &HashId,
278 parent_accession_id: Option<&HashId>,
279 ) -> Accession {
280 match Accession::create(conn, name, path_id, parent_accession_id) {
281 Ok(accession) => accession,
282 Err(rusqlite::Error::SqliteFailure(err, _details)) => {
283 if err.code == rusqlite::ErrorCode::ConstraintViolation {
284 let existing_id: HashId;
285 if let Some(id) = parent_accession_id {
286 existing_id = conn.query_row("select id from accessions where name = ?1 and path_id = ?2 and parent_accession_id = ?3;", params![name.to_string(), path_id, id], |row| row.get(0)).unwrap();
287 } else {
288 existing_id = conn.query_row("select id from accessions where name = ?1 and path_id = ?2 and parent_accession_id is null;", params![name.to_string(), path_id], |row| row.get(0)).unwrap();
289 }
290 Accession {
291 id: existing_id,
292 name: name.to_string(),
293 path_id: *path_id,
294 parent_accession_id: parent_accession_id.copied(),
295 }
296 } else {
297 panic!("something bad happened querying the database")
298 }
299 }
300 Err(_) => {
301 panic!("something bad happened.")
302 }
303 }
304 }
305
306 pub fn get_edges_by_id(conn: &GraphConnection, accession_id: &HashId) -> Vec<AccessionEdge> {
307 let query = "\
308 select ae.* \
309 from accession_edges ae \
310 join accession_paths ap on ap.edge_id = ae.id \
311 where ap.accession_id = ?1 \
312 order by ap.index_in_path;";
313 AccessionEdge::query(conn, query, params![accession_id])
314 }
315}
316
317impl Query for Accession {
318 type Model = Accession;
319
320 const TABLE_NAME: &'static str = "accessions";
321
322 fn process_row(row: &Row) -> Self::Model {
323 Accession {
324 id: row.get(0).unwrap(),
325 name: row.get(1).unwrap(),
326 path_id: row.get(2).unwrap(),
327 parent_accession_id: row.get(3).unwrap(),
328 }
329 }
330}
331
332impl AccessionEdge {
333 pub fn create(conn: &GraphConnection, edge: AccessionEdgeData) -> AccessionEdge {
334 let hash = HashId(calculate_hash(&format!(
335 "{}:{}:{}:{}:{}:{}:{}",
336 edge.source_node_id,
337 edge.source_coordinate,
338 edge.source_strand,
339 edge.target_node_id,
340 edge.target_coordinate,
341 edge.target_strand,
342 edge.chromosome_index
343 )));
344 let insert_statement = "INSERT INTO accession_edges (id, source_node_id, source_coordinate, source_strand, target_node_id, target_coordinate, target_strand, chromosome_index) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8);";
346 let mut stmt = conn.prepare(insert_statement).unwrap();
347 match stmt.execute(params![
348 hash,
349 edge.source_node_id,
350 edge.source_coordinate,
351 edge.source_strand,
352 edge.target_node_id,
353 edge.target_coordinate,
354 edge.target_strand,
355 edge.chromosome_index
356 ]) {
357 Ok(_) => {}
358 Err(rusqlite::Error::SqliteFailure(_err, _details)) => {}
359 Err(_) => {
360 panic!("something bad happened querying the database")
361 }
362 };
363 AccessionEdge {
364 id: hash,
365 source_node_id: edge.source_node_id,
366 source_coordinate: edge.source_coordinate,
367 source_strand: edge.source_strand,
368 target_node_id: edge.target_node_id,
369 target_coordinate: edge.target_coordinate,
370 target_strand: edge.target_strand,
371 chromosome_index: edge.chromosome_index,
372 }
373 }
374
375 pub fn bulk_create(conn: &GraphConnection, edges: &[AccessionEdgeData]) -> Vec<HashId> {
376 let edge_ids = edges.iter().map(|edge| edge.id_hash()).collect::<Vec<_>>();
377 let query = AccessionEdge::query_by_ids(conn, &edge_ids);
378 let existing_edges = query.iter().map(|edge| &edge.id).collect::<HashSet<_>>();
379
380 let mut edges_to_insert = HashSet::new();
381 for (index, edge) in edge_ids.iter().enumerate() {
382 if !existing_edges.contains(edge) {
383 edges_to_insert.insert(&edges[index]);
384 }
385 }
386
387 let batch_size = max_rows_per_batch(conn, 8);
388
389 for chunk in &edges_to_insert.iter().chunks(batch_size) {
390 let mut rows = vec![];
391 let mut params: Vec<Box<dyn rusqlite::ToSql>> = Vec::new();
392 for edge in chunk {
393 params.push(Box::new(edge.id_hash()));
394 params.push(Box::new(edge.source_node_id));
395 params.push(Box::new(edge.source_coordinate));
396 params.push(Box::new(edge.source_strand));
397 params.push(Box::new(edge.target_node_id));
398 params.push(Box::new(edge.target_coordinate));
399 params.push(Box::new(edge.target_strand));
400 params.push(Box::new(edge.chromosome_index));
401 rows.push("(?, ?, ?, ?, ?, ?, ?, ?)");
402 }
403 let sql = format!(
404 "INSERT INTO accession_edges (id, source_node_id, source_coordinate, source_strand, target_node_id, target_coordinate, target_strand, chromosome_index) VALUES {};",
405 rows.join(",")
406 );
407 conn.execute(&sql, rusqlite::params_from_iter(params))
408 .unwrap();
409 }
410 edge_ids
411 }
412
413 pub fn bulk_delete(conn: &GraphConnection, edges: &[AccessionEdgeData]) {
414 let ids = edges.iter().map(|e| e.id_hash()).collect::<Vec<_>>();
415 AccessionEdge::delete_by_ids(conn, &ids);
416 }
417
418 pub fn to_data(edge: AccessionEdge) -> AccessionEdgeData {
419 AccessionEdgeData {
420 source_node_id: edge.source_node_id,
421 source_coordinate: edge.source_coordinate,
422 source_strand: edge.source_strand,
423 target_node_id: edge.target_node_id,
424 target_coordinate: edge.target_coordinate,
425 target_strand: edge.target_strand,
426 chromosome_index: edge.chromosome_index,
427 }
428 }
429}
430
431impl Query for AccessionEdge {
432 type Model = AccessionEdge;
433
434 const TABLE_NAME: &'static str = "accession_edges";
435
436 fn process_row(row: &Row) -> Self::Model {
437 AccessionEdge {
438 id: row.get(0).unwrap(),
439 source_node_id: row.get(1).unwrap(),
440 source_coordinate: row.get(2).unwrap(),
441 source_strand: row.get(3).unwrap(),
442 target_node_id: row.get(4).unwrap(),
443 target_coordinate: row.get(5).unwrap(),
444 target_strand: row.get(6).unwrap(),
445 chromosome_index: row.get(7).unwrap(),
446 }
447 }
448}
449
450impl AccessionPath {
451 pub fn create(conn: &GraphConnection, accession_id: &HashId, edge_ids: &[HashId]) {
452 let batch_size = max_rows_per_batch(conn, 4);
453
454 for (index1, chunk) in edge_ids.chunks(batch_size).enumerate() {
455 let mut rows_to_insert = vec![];
456 let mut params: Vec<Box<dyn rusqlite::ToSql>> = Vec::new();
457 for (index2, edge_id) in chunk.iter().enumerate() {
458 rows_to_insert.push("(?, ?, ?, ?)".to_string());
459 let index_of = index1 * 100000 + index2;
460 let hash = HashId(calculate_hash(&format!(
461 "{accession_id}:{edge_ids:?}:{index_of}",
462 )));
463 params.push(Box::new(hash));
464 params.push(Box::new(accession_id));
465 params.push(Box::new(edge_id));
466 params.push(Box::new(index_of));
467 }
468
469 let sql = format!(
470 "INSERT OR IGNORE INTO accession_paths (id, accession_id, edge_id, index_in_path) VALUES {};",
471 rows_to_insert.join(", ")
472 );
473
474 let mut stmt = conn.prepare(&sql).unwrap();
475 stmt.execute(rusqlite::params_from_iter(params)).unwrap();
476 }
477 }
478}
479
480impl Query for AccessionPath {
481 type Model = AccessionPath;
482
483 const TABLE_NAME: &'static str = "accession_paths";
484
485 fn process_row(row: &Row) -> AccessionPath {
486 AccessionPath {
487 id: row.get(0).unwrap(),
488 accession_id: row.get(1).unwrap(),
489 index_in_path: row.get(2).unwrap(),
490 edge_id: row.get(3).unwrap(),
491 }
492 }
493}
494
495#[cfg(test)]
496mod tests {
497 use capnp::message::TypedBuilder;
498
499 use super::*;
500 use crate::test_helpers::{get_connection, setup_block_group};
501
502 #[test]
503 fn test_accession_capnp_serialization() {
504 let accession = Accession {
505 id: "0000000000000000000000000000000000000000000000000000000000000200"
506 .try_into()
507 .unwrap(),
508 name: "test_accession".to_string(),
509 path_id: "0000000000000000000000000000000000000000000000000000000000000150"
510 .try_into()
511 .unwrap(),
512 parent_accession_id: Some(
513 "0000000000000000000000000000000000000000000000000000000000000100"
514 .try_into()
515 .unwrap(),
516 ),
517 };
518
519 let mut message = TypedBuilder::<accession::Owned>::new_default();
520 let mut root = message.init_root();
521 accession.write_capnp(&mut root);
522
523 let deserialized = Accession::read_capnp(root.into_reader());
524 assert_eq!(accession, deserialized);
525 }
526
527 #[test]
528 fn test_accession_capnp_serialization_no_parent() {
529 let accession = Accession {
530 id: "0000000000000000000000000000000000000000000000000000000000000201"
531 .try_into()
532 .unwrap(),
533 name: "test_accession_2".to_string(),
534 path_id: "0000000000000000000000000000000000000000000000000000000000000151"
535 .try_into()
536 .unwrap(),
537 parent_accession_id: None,
538 };
539
540 let mut message = TypedBuilder::<accession::Owned>::new_default();
541 let mut root = message.init_root();
542 accession.write_capnp(&mut root);
543
544 let deserialized = Accession::read_capnp(root.into_reader());
545 assert_eq!(accession, deserialized);
546 }
547
548 #[test]
549 fn test_accession_edge_capnp_serialization() {
550 let accession_edge = AccessionEdge {
551 id: "0000000000000000000000000000030000000000000000000000000000000000"
552 .try_into()
553 .unwrap(),
554 source_node_id: HashId::convert_str("10"),
555 source_coordinate: 100,
556 source_strand: Strand::Forward,
557 target_node_id: HashId::convert_str("20"),
558 target_coordinate: 200,
559 target_strand: Strand::Reverse,
560 chromosome_index: 1,
561 };
562
563 let mut message = TypedBuilder::<accession_edge::Owned>::new_default();
564 let mut root = message.init_root();
565 accession_edge.write_capnp(&mut root);
566
567 let deserialized = AccessionEdge::read_capnp(root.into_reader());
568 assert_eq!(accession_edge, deserialized);
569 }
570
571 #[test]
572 fn test_accession_create_query() {
573 let conn = &get_connection(None).unwrap();
574 let (_bg, path) = setup_block_group(conn);
575 let accession = Accession::create(conn, "test", &path.id, None).unwrap();
576 let _accession_2 = Accession::create(conn, "test2", &path.id, None).unwrap();
577 assert_eq!(
578 Accession::query(
579 conn,
580 "select * from accessions where name = ?1",
581 params!["test"],
582 ),
583 vec![Accession {
584 id: accession.id,
585 name: "test".to_string(),
586 path_id: path.id,
587 parent_accession_id: None,
588 }]
589 )
590 }
591}