Skip to main content

gen_models/
sample.rs

1use std::{collections::HashSet, rc::Rc};
2
3use gen_core::traits::Capnp;
4use gen_graph::GenGraph;
5use rusqlite::{Result as SQLResult, Row, params, types::Value as SQLValue};
6use serde::{Deserialize, Serialize};
7
8use crate::{
9    block_group::BlockGroup, db::GraphConnection, errors::SampleError, gen_models_capnp::sample,
10    sample_lineage::SampleLineage, traits::Query,
11};
12
13#[derive(Debug, Deserialize, Serialize, PartialEq)]
14pub struct Sample {
15    pub name: String,
16}
17
18impl<'a> Capnp<'a> for Sample {
19    type Builder = sample::Builder<'a>;
20    type Reader = sample::Reader<'a>;
21
22    fn write_capnp(&self, builder: &mut Self::Builder) {
23        builder.set_name(&self.name);
24    }
25
26    fn read_capnp(reader: Self::Reader) -> Self {
27        let name = reader.get_name().unwrap().to_string().unwrap();
28        Sample { name }
29    }
30}
31
32impl Query for Sample {
33    type Model = Sample;
34
35    const PRIMARY_KEY: &'static str = "name";
36    const TABLE_NAME: &'static str = "samples";
37
38    fn process_row(row: &Row) -> Self::Model {
39        Sample {
40            name: row.get(0).unwrap(),
41        }
42    }
43}
44
45impl Sample {
46    pub const DEFAULT_NAME: &str = "reference";
47
48    pub fn get_parent_names(conn: &GraphConnection, sample_name: &str) -> Vec<String> {
49        SampleLineage::get_parents(conn, sample_name)
50    }
51
52    pub fn create(conn: &GraphConnection, name: &str) -> SQLResult<Sample> {
53        let mut stmt = conn
54            .prepare("INSERT INTO samples (name) VALUES (?1) returning (name);")
55            .unwrap();
56        stmt.query_row((name,), |row| Ok(Sample { name: row.get(0)? }))
57    }
58
59    pub fn get_or_create(conn: &GraphConnection, name: &str) -> Sample {
60        match Sample::create(conn, name) {
61            Ok(sample) => sample,
62            Err(rusqlite::Error::SqliteFailure(err, _details)) => {
63                if err.code == rusqlite::ErrorCode::ConstraintViolation {
64                    Sample {
65                        name: name.to_string(),
66                    }
67                } else {
68                    panic!("something bad happened querying the database")
69                }
70            }
71            Err(_) => {
72                panic!("something bad happened.")
73            }
74        }
75    }
76
77    pub fn delete_by_name(conn: &GraphConnection, name: &str) {
78        let mut stmt = conn.prepare("delete from samples where name = ?1").unwrap();
79        stmt.execute([name]).unwrap();
80    }
81
82    pub fn get_graph(conn: &GraphConnection, collection: &str, name: &str) -> GenGraph {
83        let block_groups = Sample::get_block_groups(conn, collection, name);
84        let mut sample_graph = GenGraph::new();
85        for bg in block_groups {
86            let bg_graph = BlockGroup::get_graph(conn, &bg.id);
87            // Add nodes and edges from block group graph to sample graph
88            for node in bg_graph.nodes() {
89                sample_graph.add_node(node);
90            }
91            for (source, dest, edges) in bg_graph.all_edges() {
92                if let Some(existing_edges) = sample_graph.edge_weight_mut(source, dest) {
93                    existing_edges.extend(edges.clone());
94                } else {
95                    sample_graph.add_edge(source, dest, edges.clone());
96                }
97            }
98        }
99        sample_graph
100    }
101
102    pub fn get_all_sequences(
103        conn: &GraphConnection,
104        collection_name: &str,
105        sample_name: &str,
106        prune: bool,
107    ) -> HashSet<String> {
108        Sample::get_block_groups(conn, collection_name, sample_name)
109            .into_iter()
110            .flat_map(|block_group| BlockGroup::get_all_sequences(conn, &block_group.id, prune))
111            .collect()
112    }
113
114    pub fn get_or_create_child(
115        conn: &GraphConnection,
116        collection_name: &str,
117        sample_name: &str,
118        parent_samples: Vec<String>,
119    ) -> Result<Sample, SampleError> {
120        match Sample::create(conn, sample_name) {
121            Ok(new_sample) => {
122                if !parent_samples.is_empty() {
123                    let parent_block_groups = BlockGroup::query(
124                        conn,
125                        "select * from block_groups
126                         where collection_name = ?1 AND sample_name IN rarray(?2)
127                         ORDER BY name, sample_name, created_on, id",
128                        params![
129                            collection_name,
130                            Rc::new(
131                                parent_samples
132                                    .iter()
133                                    .cloned()
134                                    .map(SQLValue::from)
135                                    .collect::<Vec<_>>()
136                            ),
137                        ],
138                    );
139                    let group_names = parent_block_groups
140                        .into_iter()
141                        .map(|parent_block_group| parent_block_group.name)
142                        .collect::<HashSet<_>>();
143
144                    for group_name in group_names {
145                        BlockGroup::get_or_create_sample_block_groups(
146                            conn,
147                            collection_name,
148                            &new_sample.name,
149                            &group_name,
150                            parent_samples.clone(),
151                        )
152                        .map_err(SampleError::from)?;
153                    }
154
155                    for parent_sample in parent_samples {
156                        SampleLineage::create(conn, &parent_sample, &new_sample.name)
157                            .map_err(SampleError::from)?;
158                    }
159                }
160
161                Ok(new_sample)
162            }
163            Err(rusqlite::Error::SqliteFailure(err, _details)) => {
164                if err.code == rusqlite::ErrorCode::ConstraintViolation {
165                    Ok(Sample {
166                        name: sample_name.to_string(),
167                    })
168                } else {
169                    Err(SampleError::SqliteError(rusqlite::Error::SqliteFailure(
170                        err, _details,
171                    )))
172                }
173            }
174            Err(err) => Err(SampleError::SqliteError(err)),
175        }
176    }
177
178    pub fn get_block_groups(
179        conn: &GraphConnection,
180        collection_name: &str,
181        sample_name: &str,
182    ) -> Vec<BlockGroup> {
183        BlockGroup::query(
184            conn,
185            "select * from block_groups where collection_name = ?1 AND sample_name = ?2;",
186            params![collection_name, sample_name],
187        )
188    }
189
190    pub fn get_all_names(conn: &GraphConnection) -> Vec<String> {
191        let samples = Sample::query(conn, "select * from samples;", rusqlite::params!());
192        samples.iter().map(|s| s.name.clone()).collect()
193    }
194
195    pub fn get_by_name(conn: &GraphConnection, name: &str) -> SQLResult<Sample> {
196        Sample::get(
197            conn,
198            "select * from samples where name = ?1;",
199            rusqlite::params!(name),
200        )
201    }
202
203    pub fn search_name(conn: &GraphConnection, name: &str) -> Vec<Sample> {
204        Sample::query(
205            conn,
206            "select * from samples
207             where instr(lower(name), lower(?1)) > 0
208             order by name;",
209            rusqlite::params!(name),
210        )
211    }
212}
213
214#[cfg(test)]
215mod tests {
216    use capnp::message::TypedBuilder;
217
218    use super::*;
219    use crate::{
220        collection::Collection,
221        errors::SampleError,
222        test_helpers::{create_bg, get_connection},
223    };
224
225    #[test]
226    fn test_capnp_serialization() {
227        let sample = Sample {
228            name: "test_sample".to_string(),
229        };
230
231        let mut message = TypedBuilder::<sample::Owned>::new_default();
232        let mut root = message.init_root();
233        sample.write_capnp(&mut root);
234
235        let deserialized = Sample::read_capnp(root.into_reader());
236        assert_eq!(sample, deserialized);
237    }
238
239    #[test]
240    fn test_delete_by_name() {
241        let conn = &get_connection(None).unwrap();
242
243        let _ = Sample::create(conn, "sample1").unwrap();
244        let _ = Sample::create(conn, "sample2").unwrap();
245
246        assert!(Sample::get_by_name(conn, "sample1").is_ok());
247        assert!(Sample::get_by_name(conn, "sample2").is_ok());
248
249        Sample::delete_by_name(conn, "sample1");
250
251        assert!(Sample::get_by_name(conn, "sample1").is_err());
252        assert!(Sample::get_by_name(conn, "sample2").is_ok());
253    }
254
255    #[test]
256    fn test_search_name_returns_partial_matches() {
257        let conn = &get_connection(None).unwrap();
258
259        for sample in ["alpha", "BarFooBaz", "foo", "QuxFood", "zzz"] {
260            Sample::create(conn, sample).unwrap();
261        }
262
263        let matches = Sample::search_name(conn, "FoO")
264            .into_iter()
265            .map(|sample| sample.name)
266            .collect::<Vec<_>>();
267
268        assert_eq!(matches, vec!["BarFooBaz", "QuxFood", "foo"]);
269    }
270
271    #[test]
272    fn test_get_or_create_child_does_not_add_lineage_for_existing_sample() {
273        let conn = &get_connection(None).unwrap();
274        Sample::get_or_create(conn, "parent");
275        Sample::get_or_create(conn, "child");
276
277        Sample::get_or_create_child(conn, "test", "child", vec!["parent".to_string()]).unwrap();
278
279        assert!(SampleLineage::get_parents(conn, "child").is_empty());
280    }
281
282    #[test]
283    fn test_get_or_create_child_returns_sample_error_for_invalid_lineage() {
284        let conn = &get_connection(None).unwrap();
285
286        let err = Sample::get_or_create_child(conn, "test", "child", vec!["child".to_string()])
287            .unwrap_err();
288
289        assert!(matches!(
290            err,
291            SampleError::SqliteError(rusqlite::Error::SqliteFailure(code, _))
292                if code.code == rusqlite::ErrorCode::ConstraintViolation
293        ));
294    }
295
296    #[test]
297    fn test_get_or_create_child_multiple_parents() {
298        let conn = &get_connection(None).unwrap();
299        Collection::create(conn, "test");
300
301        create_bg(conn, "test", "parent_a", "chr1");
302        create_bg(conn, "test", "parent_a", "chr2");
303        create_bg(conn, "test", "parent_b", "chr2");
304        create_bg(conn, "test", "parent_c", "chr3");
305
306        let child = Sample::get_or_create_child(
307            conn,
308            "test",
309            "child",
310            vec![
311                "parent_a".to_string(),
312                "parent_b".to_string(),
313                "parent_c".to_string(),
314            ],
315        )
316        .unwrap();
317
318        let mut block_group_names = Sample::get_block_groups(conn, "test", &child.name)
319            .into_iter()
320            .map(|block_group| block_group.name)
321            .collect::<Vec<_>>();
322        block_group_names.sort();
323        assert_eq!(block_group_names, vec!["chr1", "chr2", "chr2", "chr3"]);
324        assert_eq!(
325            SampleLineage::get_parents(conn, &child.name),
326            vec![
327                "parent_a".to_string(),
328                "parent_b".to_string(),
329                "parent_c".to_string(),
330            ]
331        );
332    }
333}