Skip to main content

gen_models/
sample.rs

1use std::fmt::*;
2
3use gen_core::traits::Capnp;
4use gen_graph::GenGraph;
5use rusqlite::{Result as SQLResult, Row, params};
6use serde::{Deserialize, Serialize};
7
8use crate::{block_group::BlockGroup, db::GraphConnection, gen_models_capnp::sample, traits::*};
9
10#[derive(Debug, Deserialize, Serialize, PartialEq)]
11pub struct Sample {
12    pub name: String,
13}
14
15impl<'a> Capnp<'a> for Sample {
16    type Builder = sample::Builder<'a>;
17    type Reader = sample::Reader<'a>;
18
19    fn write_capnp(&self, builder: &mut Self::Builder) {
20        builder.set_name(&self.name);
21    }
22
23    fn read_capnp(reader: Self::Reader) -> Self {
24        let name = reader.get_name().unwrap().to_string().unwrap();
25        Sample { name }
26    }
27}
28
29impl Query for Sample {
30    type Model = Sample;
31
32    const PRIMARY_KEY: &'static str = "name";
33    const TABLE_NAME: &'static str = "samples";
34
35    fn process_row(row: &Row) -> Self::Model {
36        Sample {
37            name: row.get(0).unwrap(),
38        }
39    }
40}
41
42impl Sample {
43    pub fn create(conn: &GraphConnection, name: &str) -> SQLResult<Sample> {
44        let mut stmt = conn
45            .prepare("INSERT INTO samples (name) VALUES (?1) returning (name);")
46            .unwrap();
47        stmt.query_row((name,), |row| Ok(Sample { name: row.get(0)? }))
48    }
49
50    pub fn get_or_create(conn: &GraphConnection, name: &str) -> Sample {
51        match Sample::create(conn, name) {
52            Ok(sample) => sample,
53            Err(rusqlite::Error::SqliteFailure(err, _details)) => {
54                if err.code == rusqlite::ErrorCode::ConstraintViolation {
55                    Sample {
56                        name: name.to_string(),
57                    }
58                } else {
59                    panic!("something bad happened querying the database")
60                }
61            }
62            Err(_) => {
63                panic!("something bad happened.")
64            }
65        }
66    }
67
68    pub fn delete_by_name(conn: &GraphConnection, name: &str) {
69        let mut stmt = conn.prepare("delete from samples where name = ?1").unwrap();
70        stmt.execute([name]).unwrap();
71    }
72
73    pub fn get_graph<'a>(
74        conn: &GraphConnection,
75        collection: &str,
76        name: impl Into<Option<&'a str>>,
77    ) -> GenGraph {
78        let name = name.into();
79        let block_groups = Sample::get_block_groups(conn, collection, name);
80        let mut sample_graph = GenGraph::new();
81        for bg in block_groups {
82            let bg_graph = BlockGroup::get_graph(conn, &bg.id);
83            // Add nodes and edges from block group graph to sample graph
84            for node in bg_graph.nodes() {
85                sample_graph.add_node(node);
86            }
87            for (source, dest, edges) in bg_graph.all_edges() {
88                if let Some(existing_edges) = sample_graph.edge_weight_mut(source, dest) {
89                    existing_edges.extend(edges.clone());
90                } else {
91                    sample_graph.add_edge(source, dest, edges.clone());
92                }
93            }
94        }
95        sample_graph
96    }
97
98    pub fn get_or_create_child(
99        conn: &GraphConnection,
100        collection_name: &str,
101        sample_name: &str,
102        parent_sample: Option<&str>,
103    ) -> Sample {
104        if let Ok(new_sample) = Sample::create(conn, sample_name) {
105            let bgs = if let Some(parent) = parent_sample {
106                BlockGroup::query(
107                    conn,
108                    "select * from block_groups where collection_name = ?1 AND sample_name = ?2",
109                    params!(collection_name, parent),
110                )
111            } else {
112                BlockGroup::query(
113                    conn,
114                    "select * from block_groups where collection_name = ?1 AND sample_name is null;",
115                    params!(collection_name),
116                )
117            };
118            for bg in bgs.iter() {
119                BlockGroup::get_or_create_sample_block_group(
120                    conn,
121                    collection_name,
122                    &new_sample.name,
123                    &bg.name,
124                    parent_sample,
125                )
126                .expect("failed to get or create blockgroup clone.");
127            }
128            new_sample
129        } else {
130            Sample {
131                name: sample_name.to_string(),
132            }
133        }
134    }
135
136    pub fn get_block_groups(
137        conn: &GraphConnection,
138        collection_name: &str,
139        sample_name: Option<&str>,
140    ) -> Vec<BlockGroup> {
141        if let Some(sample) = sample_name {
142            BlockGroup::query(
143                conn,
144                "select * from block_groups where collection_name = ?1 AND sample_name = ?2;",
145                params![collection_name, sample],
146            )
147        } else {
148            BlockGroup::query(
149                conn,
150                "select * from block_groups where collection_name = ?1 AND sample_name IS NULL;",
151                params![collection_name],
152            )
153        }
154    }
155
156    pub fn get_all_names(conn: &GraphConnection) -> Vec<String> {
157        let samples = Sample::query(conn, "select * from samples;", rusqlite::params!());
158        samples.iter().map(|s| s.name.clone()).collect()
159    }
160
161    pub fn get_by_name(conn: &GraphConnection, name: &str) -> SQLResult<Sample> {
162        Sample::get(
163            conn,
164            "select * from samples where name = ?1;",
165            rusqlite::params!(name),
166        )
167    }
168}
169
170#[cfg(test)]
171mod tests {
172    use capnp::message::TypedBuilder;
173
174    use super::*;
175    use crate::test_helpers::get_connection;
176
177    #[test]
178    fn test_capnp_serialization() {
179        let sample = Sample {
180            name: "test_sample".to_string(),
181        };
182
183        let mut message = TypedBuilder::<sample::Owned>::new_default();
184        let mut root = message.init_root();
185        sample.write_capnp(&mut root);
186
187        let deserialized = Sample::read_capnp(root.into_reader());
188        assert_eq!(sample, deserialized);
189    }
190
191    #[test]
192    fn test_delete_by_name() {
193        let conn = &get_connection(None).unwrap();
194
195        let _ = Sample::create(conn, "sample1").unwrap();
196        let _ = Sample::create(conn, "sample2").unwrap();
197
198        assert!(Sample::get_by_name(conn, "sample1").is_ok());
199        assert!(Sample::get_by_name(conn, "sample2").is_ok());
200
201        Sample::delete_by_name(conn, "sample1");
202
203        assert!(Sample::get_by_name(conn, "sample1").is_err());
204        assert!(Sample::get_by_name(conn, "sample2").is_ok());
205    }
206}