Skip to main content

gen_models/
accession.rs

1use std::collections::HashSet;
2
3use gen_core::{HashId, Strand, calculate_hash, traits::Capnp};
4use itertools::Itertools;
5use rusqlite::{Result as SQLResult, Row, params};
6use serde::{Deserialize, Serialize};
7
8use crate::{
9    block_group_edge::AugmentedEdgeData,
10    db::GraphConnection,
11    gen_models_capnp::{accession, accession_edge, accession_path},
12    traits::*,
13};
14
15#[derive(Deserialize, Serialize, Debug, Eq, PartialEq)]
16pub struct Accession {
17    pub id: HashId,
18    pub name: String,
19    pub path_id: HashId,
20    pub parent_accession_id: Option<HashId>,
21}
22
23impl<'a> Capnp<'a> for Accession {
24    type Builder = accession::Builder<'a>;
25    type Reader = accession::Reader<'a>;
26
27    fn write_capnp(&self, builder: &mut Self::Builder) {
28        builder.set_id(&self.id.0).unwrap();
29        builder.set_name(&self.name);
30        builder.set_path_id(&self.path_id.0).unwrap();
31        match &self.parent_accession_id {
32            None => {
33                builder.reborrow().get_parent_accession_id().set_none(());
34            }
35            Some(n) => {
36                builder
37                    .reborrow()
38                    .get_parent_accession_id()
39                    .set_some(&n.0)
40                    .unwrap();
41            }
42        }
43    }
44
45    fn read_capnp(reader: Self::Reader) -> Self {
46        let id = reader
47            .get_id()
48            .unwrap()
49            .as_slice()
50            .unwrap()
51            .try_into()
52            .unwrap();
53        let name = reader.get_name().unwrap().to_string().unwrap();
54        let path_id = reader
55            .get_path_id()
56            .unwrap()
57            .as_slice()
58            .unwrap()
59            .try_into()
60            .unwrap();
61        let parent_accession_id: Option<HashId> =
62            match reader.get_parent_accession_id().which().unwrap() {
63                accession::parent_accession_id::None(()) => None,
64                accession::parent_accession_id::Some(n) => {
65                    Some(n.unwrap().as_slice().unwrap().try_into().unwrap())
66                }
67            };
68
69        Accession {
70            id,
71            name,
72            path_id,
73            parent_accession_id,
74        }
75    }
76}
77
78#[derive(Deserialize, Serialize, Debug, PartialEq, Eq, Hash)]
79pub struct AccessionEdge {
80    pub id: HashId,
81    pub source_node_id: HashId,
82    pub source_coordinate: i64,
83    pub source_strand: Strand,
84    pub target_node_id: HashId,
85    pub target_coordinate: i64,
86    pub target_strand: Strand,
87    pub chromosome_index: i64,
88}
89
90impl<'a> Capnp<'a> for AccessionEdge {
91    type Builder = accession_edge::Builder<'a>;
92    type Reader = accession_edge::Reader<'a>;
93
94    fn write_capnp(&self, builder: &mut Self::Builder) {
95        builder.set_id(&self.id.0).unwrap();
96        builder.set_source_node_id(&self.source_node_id.0).unwrap();
97        builder.set_source_coordinate(self.source_coordinate);
98        builder.set_source_strand(self.source_strand.into());
99        builder.set_target_node_id(&self.target_node_id.0).unwrap();
100        builder.set_target_coordinate(self.target_coordinate);
101        builder.set_target_strand(self.target_strand.into());
102        builder.set_chromosome_index(self.chromosome_index);
103    }
104
105    fn read_capnp(reader: Self::Reader) -> Self {
106        let id = reader
107            .get_id()
108            .unwrap()
109            .as_slice()
110            .unwrap()
111            .try_into()
112            .unwrap();
113        let source_node_id = reader
114            .get_source_node_id()
115            .unwrap()
116            .as_slice()
117            .unwrap()
118            .try_into()
119            .unwrap();
120        let source_coordinate = reader.get_source_coordinate();
121        let source_strand = reader.get_source_strand().unwrap().into();
122        let target_node_id = reader
123            .get_target_node_id()
124            .unwrap()
125            .as_slice()
126            .unwrap()
127            .try_into()
128            .unwrap();
129        let target_coordinate = reader.get_target_coordinate();
130        let target_strand = reader.get_target_strand().unwrap().into();
131        let chromosome_index = reader.get_chromosome_index();
132
133        AccessionEdge {
134            id,
135            source_node_id,
136            source_coordinate,
137            source_strand,
138            target_node_id,
139            target_coordinate,
140            target_strand,
141            chromosome_index,
142        }
143    }
144}
145
146#[derive(Deserialize, Serialize, Debug, PartialEq)]
147pub struct AccessionPath {
148    pub id: HashId,
149    pub accession_id: HashId,
150    pub index_in_path: i64,
151    pub edge_id: HashId,
152}
153
154impl<'a> Capnp<'a> for AccessionPath {
155    type Builder = accession_path::Builder<'a>;
156    type Reader = accession_path::Reader<'a>;
157
158    fn write_capnp(&self, builder: &mut Self::Builder) {
159        builder.set_id(&self.id.0).unwrap();
160        builder.set_accession_id(&self.accession_id.0).unwrap();
161        builder.set_index_in_path(self.index_in_path);
162        builder.set_edge_id(&self.edge_id.0).unwrap();
163    }
164
165    fn read_capnp(reader: Self::Reader) -> Self {
166        let id = reader
167            .get_id()
168            .unwrap()
169            .as_slice()
170            .unwrap()
171            .try_into()
172            .unwrap();
173        let accession_id = reader
174            .get_accession_id()
175            .unwrap()
176            .as_slice()
177            .unwrap()
178            .try_into()
179            .unwrap();
180        let index_in_path = reader.get_index_in_path();
181        let edge_id = reader
182            .get_edge_id()
183            .unwrap()
184            .as_slice()
185            .unwrap()
186            .try_into()
187            .unwrap();
188
189        AccessionPath {
190            id,
191            accession_id,
192            index_in_path,
193            edge_id,
194        }
195    }
196}
197
198#[derive(Clone, Debug, Eq, Hash, PartialEq)]
199pub struct AccessionEdgeData {
200    pub source_node_id: HashId,
201    pub source_coordinate: i64,
202    pub source_strand: Strand,
203    pub target_node_id: HashId,
204    pub target_coordinate: i64,
205    pub target_strand: Strand,
206    pub chromosome_index: i64,
207}
208
209impl AccessionEdgeData {
210    pub fn id_hash(&self) -> HashId {
211        HashId(calculate_hash(&format!(
212            "{}:{}:{}:{}:{}:{}:{}",
213            self.source_node_id,
214            self.source_coordinate,
215            self.source_strand,
216            self.target_node_id,
217            self.target_coordinate,
218            self.target_strand,
219            self.chromosome_index
220        )))
221    }
222}
223
224impl From<&AccessionEdge> for AccessionEdgeData {
225    fn from(item: &AccessionEdge) -> Self {
226        AccessionEdgeData {
227            source_node_id: item.source_node_id,
228            source_coordinate: item.source_coordinate,
229            source_strand: item.source_strand,
230            target_node_id: item.target_node_id,
231            target_coordinate: item.target_coordinate,
232            target_strand: item.target_strand,
233            chromosome_index: item.chromosome_index,
234        }
235    }
236}
237
238impl From<&AugmentedEdgeData> for AccessionEdgeData {
239    fn from(item: &AugmentedEdgeData) -> Self {
240        AccessionEdgeData {
241            source_node_id: item.edge_data.source_node_id,
242            source_coordinate: item.edge_data.source_coordinate,
243            source_strand: item.edge_data.source_strand,
244            target_node_id: item.edge_data.target_node_id,
245            target_coordinate: item.edge_data.target_coordinate,
246            target_strand: item.edge_data.target_strand,
247            chromosome_index: item.chromosome_index,
248        }
249    }
250}
251
252impl Accession {
253    pub fn create(
254        conn: &GraphConnection,
255        name: &str,
256        path_id: &HashId,
257        parent_accession_id: Option<&HashId>,
258    ) -> SQLResult<Accession> {
259        let hash = HashId(calculate_hash(&format!(
260            "{path_id}:{parent_accession_id:?}:{name}"
261        )));
262        let query = "INSERT INTO accessions (id, name, path_id, parent_accession_id) VALUES (?1, ?2, ?3, ?4);";
263        let mut stmt = conn.prepare(query).unwrap();
264
265        stmt.execute((hash, name, path_id, parent_accession_id))?;
266        Ok(Accession {
267            id: hash,
268            name: name.to_string(),
269            path_id: *path_id,
270            parent_accession_id: parent_accession_id.copied(),
271        })
272    }
273
274    pub fn get_or_create(
275        conn: &GraphConnection,
276        name: &str,
277        path_id: &HashId,
278        parent_accession_id: Option<&HashId>,
279    ) -> Accession {
280        match Accession::create(conn, name, path_id, parent_accession_id) {
281            Ok(accession) => accession,
282            Err(rusqlite::Error::SqliteFailure(err, _details)) => {
283                if err.code == rusqlite::ErrorCode::ConstraintViolation {
284                    let existing_id: HashId;
285                    if let Some(id) = parent_accession_id {
286                        existing_id = conn.query_row("select id from accessions where name = ?1 and path_id = ?2 and parent_accession_id = ?3;", params![name.to_string(), path_id, id], |row| row.get(0)).unwrap();
287                    } else {
288                        existing_id = conn.query_row("select id from accessions where name = ?1 and path_id = ?2 and parent_accession_id is null;", params![name.to_string(), path_id], |row| row.get(0)).unwrap();
289                    }
290                    Accession {
291                        id: existing_id,
292                        name: name.to_string(),
293                        path_id: *path_id,
294                        parent_accession_id: parent_accession_id.copied(),
295                    }
296                } else {
297                    panic!("something bad happened querying the database")
298                }
299            }
300            Err(_) => {
301                panic!("something bad happened.")
302            }
303        }
304    }
305
306    pub fn get_edges_by_id(conn: &GraphConnection, accession_id: &HashId) -> Vec<AccessionEdge> {
307        let query = "\
308            select ae.* \
309            from accession_edges ae \
310            join accession_paths ap on ap.edge_id = ae.id \
311            where ap.accession_id = ?1 \
312            order by ap.index_in_path;";
313        AccessionEdge::query(conn, query, params![accession_id])
314    }
315}
316
317impl Query for Accession {
318    type Model = Accession;
319
320    const TABLE_NAME: &'static str = "accessions";
321
322    fn process_row(row: &Row) -> Self::Model {
323        Accession {
324            id: row.get(0).unwrap(),
325            name: row.get(1).unwrap(),
326            path_id: row.get(2).unwrap(),
327            parent_accession_id: row.get(3).unwrap(),
328        }
329    }
330}
331
332impl AccessionEdge {
333    pub fn create(conn: &GraphConnection, edge: AccessionEdgeData) -> AccessionEdge {
334        let hash = HashId(calculate_hash(&format!(
335            "{}:{}:{}:{}:{}:{}:{}",
336            edge.source_node_id,
337            edge.source_coordinate,
338            edge.source_strand,
339            edge.target_node_id,
340            edge.target_coordinate,
341            edge.target_strand,
342            edge.chromosome_index
343        )));
344        // TODO: handle get-or-create
345        let insert_statement = "INSERT INTO accession_edges (id, source_node_id, source_coordinate, source_strand, target_node_id, target_coordinate, target_strand, chromosome_index) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8);";
346        let mut stmt = conn.prepare(insert_statement).unwrap();
347        match stmt.execute(params![
348            hash,
349            edge.source_node_id,
350            edge.source_coordinate,
351            edge.source_strand,
352            edge.target_node_id,
353            edge.target_coordinate,
354            edge.target_strand,
355            edge.chromosome_index
356        ]) {
357            Ok(_) => {}
358            Err(rusqlite::Error::SqliteFailure(_err, _details)) => {}
359            Err(_) => {
360                panic!("something bad happened querying the database")
361            }
362        };
363        AccessionEdge {
364            id: hash,
365            source_node_id: edge.source_node_id,
366            source_coordinate: edge.source_coordinate,
367            source_strand: edge.source_strand,
368            target_node_id: edge.target_node_id,
369            target_coordinate: edge.target_coordinate,
370            target_strand: edge.target_strand,
371            chromosome_index: edge.chromosome_index,
372        }
373    }
374
375    pub fn bulk_create(conn: &GraphConnection, edges: &[AccessionEdgeData]) -> Vec<HashId> {
376        let edge_ids = edges.iter().map(|edge| edge.id_hash()).collect::<Vec<_>>();
377        let query = AccessionEdge::query_by_ids(conn, &edge_ids);
378        let existing_edges = query.iter().map(|edge| &edge.id).collect::<HashSet<_>>();
379
380        let mut edges_to_insert = HashSet::new();
381        for (index, edge) in edge_ids.iter().enumerate() {
382            if !existing_edges.contains(edge) {
383                edges_to_insert.insert(&edges[index]);
384            }
385        }
386
387        let batch_size = max_rows_per_batch(conn, 8);
388
389        for chunk in &edges_to_insert.iter().chunks(batch_size) {
390            let mut rows = vec![];
391            let mut params: Vec<Box<dyn rusqlite::ToSql>> = Vec::new();
392            for edge in chunk {
393                params.push(Box::new(edge.id_hash()));
394                params.push(Box::new(edge.source_node_id));
395                params.push(Box::new(edge.source_coordinate));
396                params.push(Box::new(edge.source_strand));
397                params.push(Box::new(edge.target_node_id));
398                params.push(Box::new(edge.target_coordinate));
399                params.push(Box::new(edge.target_strand));
400                params.push(Box::new(edge.chromosome_index));
401                rows.push("(?, ?, ?, ?, ?, ?, ?, ?)");
402            }
403            let sql = format!(
404                "INSERT INTO accession_edges (id, source_node_id, source_coordinate, source_strand, target_node_id, target_coordinate, target_strand, chromosome_index) VALUES {};",
405                rows.join(",")
406            );
407            conn.execute(&sql, rusqlite::params_from_iter(params))
408                .unwrap();
409        }
410        edge_ids
411    }
412
413    pub fn bulk_delete(conn: &GraphConnection, edges: &[AccessionEdgeData]) {
414        let ids = edges.iter().map(|e| e.id_hash()).collect::<Vec<_>>();
415        AccessionEdge::delete_by_ids(conn, &ids);
416    }
417
418    pub fn to_data(edge: AccessionEdge) -> AccessionEdgeData {
419        AccessionEdgeData {
420            source_node_id: edge.source_node_id,
421            source_coordinate: edge.source_coordinate,
422            source_strand: edge.source_strand,
423            target_node_id: edge.target_node_id,
424            target_coordinate: edge.target_coordinate,
425            target_strand: edge.target_strand,
426            chromosome_index: edge.chromosome_index,
427        }
428    }
429}
430
431impl Query for AccessionEdge {
432    type Model = AccessionEdge;
433
434    const TABLE_NAME: &'static str = "accession_edges";
435
436    fn process_row(row: &Row) -> Self::Model {
437        AccessionEdge {
438            id: row.get(0).unwrap(),
439            source_node_id: row.get(1).unwrap(),
440            source_coordinate: row.get(2).unwrap(),
441            source_strand: row.get(3).unwrap(),
442            target_node_id: row.get(4).unwrap(),
443            target_coordinate: row.get(5).unwrap(),
444            target_strand: row.get(6).unwrap(),
445            chromosome_index: row.get(7).unwrap(),
446        }
447    }
448}
449
450impl AccessionPath {
451    pub fn create(conn: &GraphConnection, accession_id: &HashId, edge_ids: &[HashId]) {
452        let batch_size = max_rows_per_batch(conn, 4);
453
454        for (index1, chunk) in edge_ids.chunks(batch_size).enumerate() {
455            let mut rows_to_insert = vec![];
456            let mut params: Vec<Box<dyn rusqlite::ToSql>> = Vec::new();
457            for (index2, edge_id) in chunk.iter().enumerate() {
458                rows_to_insert.push("(?, ?, ?, ?)".to_string());
459                let index_of = index1 * 100000 + index2;
460                let hash = HashId(calculate_hash(&format!(
461                    "{accession_id}:{edge_ids:?}:{index_of}",
462                )));
463                params.push(Box::new(hash));
464                params.push(Box::new(accession_id));
465                params.push(Box::new(edge_id));
466                params.push(Box::new(index_of));
467            }
468
469            let sql = format!(
470                "INSERT OR IGNORE INTO accession_paths (id, accession_id, edge_id, index_in_path) VALUES {};",
471                rows_to_insert.join(", ")
472            );
473
474            let mut stmt = conn.prepare(&sql).unwrap();
475            stmt.execute(rusqlite::params_from_iter(params)).unwrap();
476        }
477    }
478}
479
480impl Query for AccessionPath {
481    type Model = AccessionPath;
482
483    const TABLE_NAME: &'static str = "accession_paths";
484
485    fn process_row(row: &Row) -> AccessionPath {
486        AccessionPath {
487            id: row.get(0).unwrap(),
488            accession_id: row.get(1).unwrap(),
489            index_in_path: row.get(2).unwrap(),
490            edge_id: row.get(3).unwrap(),
491        }
492    }
493}
494
495#[cfg(test)]
496mod tests {
497    use capnp::message::TypedBuilder;
498
499    use super::*;
500    use crate::test_helpers::{get_connection, setup_block_group};
501
502    #[test]
503    fn test_accession_capnp_serialization() {
504        let accession = Accession {
505            id: "0000000000000000000000000000000000000000000000000000000000000200"
506                .try_into()
507                .unwrap(),
508            name: "test_accession".to_string(),
509            path_id: "0000000000000000000000000000000000000000000000000000000000000150"
510                .try_into()
511                .unwrap(),
512            parent_accession_id: Some(
513                "0000000000000000000000000000000000000000000000000000000000000100"
514                    .try_into()
515                    .unwrap(),
516            ),
517        };
518
519        let mut message = TypedBuilder::<accession::Owned>::new_default();
520        let mut root = message.init_root();
521        accession.write_capnp(&mut root);
522
523        let deserialized = Accession::read_capnp(root.into_reader());
524        assert_eq!(accession, deserialized);
525    }
526
527    #[test]
528    fn test_accession_capnp_serialization_no_parent() {
529        let accession = Accession {
530            id: "0000000000000000000000000000000000000000000000000000000000000201"
531                .try_into()
532                .unwrap(),
533            name: "test_accession_2".to_string(),
534            path_id: "0000000000000000000000000000000000000000000000000000000000000151"
535                .try_into()
536                .unwrap(),
537            parent_accession_id: None,
538        };
539
540        let mut message = TypedBuilder::<accession::Owned>::new_default();
541        let mut root = message.init_root();
542        accession.write_capnp(&mut root);
543
544        let deserialized = Accession::read_capnp(root.into_reader());
545        assert_eq!(accession, deserialized);
546    }
547
548    #[test]
549    fn test_accession_edge_capnp_serialization() {
550        let accession_edge = AccessionEdge {
551            id: "0000000000000000000000000000030000000000000000000000000000000000"
552                .try_into()
553                .unwrap(),
554            source_node_id: HashId::convert_str("10"),
555            source_coordinate: 100,
556            source_strand: Strand::Forward,
557            target_node_id: HashId::convert_str("20"),
558            target_coordinate: 200,
559            target_strand: Strand::Reverse,
560            chromosome_index: 1,
561        };
562
563        let mut message = TypedBuilder::<accession_edge::Owned>::new_default();
564        let mut root = message.init_root();
565        accession_edge.write_capnp(&mut root);
566
567        let deserialized = AccessionEdge::read_capnp(root.into_reader());
568        assert_eq!(accession_edge, deserialized);
569    }
570
571    #[test]
572    fn test_accession_create_query() {
573        let conn = &get_connection(None).unwrap();
574        let (_bg, path) = setup_block_group(conn);
575        let accession = Accession::create(conn, "test", &path.id, None).unwrap();
576        let _accession_2 = Accession::create(conn, "test2", &path.id, None).unwrap();
577        assert_eq!(
578            Accession::query(
579                conn,
580                "select * from accessions where name = ?1",
581                params!["test"],
582            ),
583            vec![Accession {
584                id: accession.id,
585                name: "test".to_string(),
586                path_id: path.id,
587                parent_accession_id: None,
588            }]
589        )
590    }
591}