1use std::{collections::HashSet, rc::Rc};
2
3use gen_core::traits::Capnp;
4use gen_graph::GenGraph;
5use rusqlite::{Result as SQLResult, Row, params, types::Value as SQLValue};
6use serde::{Deserialize, Serialize};
7
8use crate::{
9 block_group::BlockGroup, db::GraphConnection, errors::SampleError, gen_models_capnp::sample,
10 sample_lineage::SampleLineage, traits::Query,
11};
12
13#[derive(Debug, Deserialize, Serialize, PartialEq)]
14pub struct Sample {
15 pub name: String,
16}
17
18impl<'a> Capnp<'a> for Sample {
19 type Builder = sample::Builder<'a>;
20 type Reader = sample::Reader<'a>;
21
22 fn write_capnp(&self, builder: &mut Self::Builder) {
23 builder.set_name(&self.name);
24 }
25
26 fn read_capnp(reader: Self::Reader) -> Self {
27 let name = reader.get_name().unwrap().to_string().unwrap();
28 Sample { name }
29 }
30}
31
32impl Query for Sample {
33 type Model = Sample;
34
35 const PRIMARY_KEY: &'static str = "name";
36 const TABLE_NAME: &'static str = "samples";
37
38 fn process_row(row: &Row) -> Self::Model {
39 Sample {
40 name: row.get(0).unwrap(),
41 }
42 }
43}
44
45impl Sample {
46 pub const DEFAULT_NAME: &str = "reference";
47
48 pub fn get_parent_names(conn: &GraphConnection, sample_name: &str) -> Vec<String> {
49 SampleLineage::get_parents(conn, sample_name)
50 }
51
52 pub fn create(conn: &GraphConnection, name: &str) -> SQLResult<Sample> {
53 let mut stmt = conn
54 .prepare("INSERT INTO samples (name) VALUES (?1) returning (name);")
55 .unwrap();
56 stmt.query_row((name,), |row| Ok(Sample { name: row.get(0)? }))
57 }
58
59 pub fn get_or_create(conn: &GraphConnection, name: &str) -> Sample {
60 match Sample::create(conn, name) {
61 Ok(sample) => sample,
62 Err(rusqlite::Error::SqliteFailure(err, _details)) => {
63 if err.code == rusqlite::ErrorCode::ConstraintViolation {
64 Sample {
65 name: name.to_string(),
66 }
67 } else {
68 panic!("something bad happened querying the database")
69 }
70 }
71 Err(_) => {
72 panic!("something bad happened.")
73 }
74 }
75 }
76
77 pub fn delete_by_name(conn: &GraphConnection, name: &str) {
78 let mut stmt = conn.prepare("delete from samples where name = ?1").unwrap();
79 stmt.execute([name]).unwrap();
80 }
81
82 pub fn get_graph(conn: &GraphConnection, collection: &str, name: &str) -> GenGraph {
83 let block_groups = Sample::get_block_groups(conn, collection, name);
84 let mut sample_graph = GenGraph::new();
85 for bg in block_groups {
86 let bg_graph = BlockGroup::get_graph(conn, &bg.id);
87 for node in bg_graph.nodes() {
89 sample_graph.add_node(node);
90 }
91 for (source, dest, edges) in bg_graph.all_edges() {
92 if let Some(existing_edges) = sample_graph.edge_weight_mut(source, dest) {
93 existing_edges.extend(edges.clone());
94 } else {
95 sample_graph.add_edge(source, dest, edges.clone());
96 }
97 }
98 }
99 sample_graph
100 }
101
102 pub fn get_all_sequences(
103 conn: &GraphConnection,
104 collection_name: &str,
105 sample_name: &str,
106 prune: bool,
107 ) -> HashSet<String> {
108 Sample::get_block_groups(conn, collection_name, sample_name)
109 .into_iter()
110 .flat_map(|block_group| BlockGroup::get_all_sequences(conn, &block_group.id, prune))
111 .collect()
112 }
113
114 pub fn get_or_create_child(
115 conn: &GraphConnection,
116 collection_name: &str,
117 sample_name: &str,
118 parent_samples: Vec<String>,
119 ) -> Result<Sample, SampleError> {
120 match Sample::create(conn, sample_name) {
121 Ok(new_sample) => {
122 if !parent_samples.is_empty() {
123 let parent_block_groups = BlockGroup::query(
124 conn,
125 "select * from block_groups
126 where collection_name = ?1 AND sample_name IN rarray(?2)
127 ORDER BY name, sample_name, created_on, id",
128 params![
129 collection_name,
130 Rc::new(
131 parent_samples
132 .iter()
133 .cloned()
134 .map(SQLValue::from)
135 .collect::<Vec<_>>()
136 ),
137 ],
138 );
139 let group_names = parent_block_groups
140 .into_iter()
141 .map(|parent_block_group| parent_block_group.name)
142 .collect::<HashSet<_>>();
143
144 for group_name in group_names {
145 BlockGroup::get_or_create_sample_block_groups(
146 conn,
147 collection_name,
148 &new_sample.name,
149 &group_name,
150 parent_samples.clone(),
151 )
152 .map_err(SampleError::from)?;
153 }
154
155 for parent_sample in parent_samples {
156 SampleLineage::create(conn, &parent_sample, &new_sample.name)
157 .map_err(SampleError::from)?;
158 }
159 }
160
161 Ok(new_sample)
162 }
163 Err(rusqlite::Error::SqliteFailure(err, _details)) => {
164 if err.code == rusqlite::ErrorCode::ConstraintViolation {
165 Ok(Sample {
166 name: sample_name.to_string(),
167 })
168 } else {
169 Err(SampleError::SqliteError(rusqlite::Error::SqliteFailure(
170 err, _details,
171 )))
172 }
173 }
174 Err(err) => Err(SampleError::SqliteError(err)),
175 }
176 }
177
178 pub fn get_block_groups(
179 conn: &GraphConnection,
180 collection_name: &str,
181 sample_name: &str,
182 ) -> Vec<BlockGroup> {
183 BlockGroup::query(
184 conn,
185 "select * from block_groups where collection_name = ?1 AND sample_name = ?2;",
186 params![collection_name, sample_name],
187 )
188 }
189
190 pub fn get_all_names(conn: &GraphConnection) -> Vec<String> {
191 let samples = Sample::query(conn, "select * from samples;", rusqlite::params!());
192 samples.iter().map(|s| s.name.clone()).collect()
193 }
194
195 pub fn get_by_name(conn: &GraphConnection, name: &str) -> SQLResult<Sample> {
196 Sample::get(
197 conn,
198 "select * from samples where name = ?1;",
199 rusqlite::params!(name),
200 )
201 }
202
203 pub fn search_name(conn: &GraphConnection, name: &str) -> Vec<Sample> {
204 Sample::query(
205 conn,
206 "select * from samples
207 where instr(lower(name), lower(?1)) > 0
208 order by name;",
209 rusqlite::params!(name),
210 )
211 }
212}
213
214#[cfg(test)]
215mod tests {
216 use capnp::message::TypedBuilder;
217
218 use super::*;
219 use crate::{
220 collection::Collection,
221 errors::SampleError,
222 test_helpers::{create_bg, get_connection},
223 };
224
225 #[test]
226 fn test_capnp_serialization() {
227 let sample = Sample {
228 name: "test_sample".to_string(),
229 };
230
231 let mut message = TypedBuilder::<sample::Owned>::new_default();
232 let mut root = message.init_root();
233 sample.write_capnp(&mut root);
234
235 let deserialized = Sample::read_capnp(root.into_reader());
236 assert_eq!(sample, deserialized);
237 }
238
239 #[test]
240 fn test_delete_by_name() {
241 let conn = &get_connection(None).unwrap();
242
243 let _ = Sample::create(conn, "sample1").unwrap();
244 let _ = Sample::create(conn, "sample2").unwrap();
245
246 assert!(Sample::get_by_name(conn, "sample1").is_ok());
247 assert!(Sample::get_by_name(conn, "sample2").is_ok());
248
249 Sample::delete_by_name(conn, "sample1");
250
251 assert!(Sample::get_by_name(conn, "sample1").is_err());
252 assert!(Sample::get_by_name(conn, "sample2").is_ok());
253 }
254
255 #[test]
256 fn test_search_name_returns_partial_matches() {
257 let conn = &get_connection(None).unwrap();
258
259 for sample in ["alpha", "BarFooBaz", "foo", "QuxFood", "zzz"] {
260 Sample::create(conn, sample).unwrap();
261 }
262
263 let matches = Sample::search_name(conn, "FoO")
264 .into_iter()
265 .map(|sample| sample.name)
266 .collect::<Vec<_>>();
267
268 assert_eq!(matches, vec!["BarFooBaz", "QuxFood", "foo"]);
269 }
270
271 #[test]
272 fn test_get_or_create_child_does_not_add_lineage_for_existing_sample() {
273 let conn = &get_connection(None).unwrap();
274 Sample::get_or_create(conn, "parent");
275 Sample::get_or_create(conn, "child");
276
277 Sample::get_or_create_child(conn, "test", "child", vec!["parent".to_string()]).unwrap();
278
279 assert!(SampleLineage::get_parents(conn, "child").is_empty());
280 }
281
282 #[test]
283 fn test_get_or_create_child_returns_sample_error_for_invalid_lineage() {
284 let conn = &get_connection(None).unwrap();
285
286 let err = Sample::get_or_create_child(conn, "test", "child", vec!["child".to_string()])
287 .unwrap_err();
288
289 assert!(matches!(
290 err,
291 SampleError::SqliteError(rusqlite::Error::SqliteFailure(code, _))
292 if code.code == rusqlite::ErrorCode::ConstraintViolation
293 ));
294 }
295
296 #[test]
297 fn test_get_or_create_child_multiple_parents() {
298 let conn = &get_connection(None).unwrap();
299 Collection::create(conn, "test");
300
301 create_bg(conn, "test", "parent_a", "chr1");
302 create_bg(conn, "test", "parent_a", "chr2");
303 create_bg(conn, "test", "parent_b", "chr2");
304 create_bg(conn, "test", "parent_c", "chr3");
305
306 let child = Sample::get_or_create_child(
307 conn,
308 "test",
309 "child",
310 vec![
311 "parent_a".to_string(),
312 "parent_b".to_string(),
313 "parent_c".to_string(),
314 ],
315 )
316 .unwrap();
317
318 let mut block_group_names = Sample::get_block_groups(conn, "test", &child.name)
319 .into_iter()
320 .map(|block_group| block_group.name)
321 .collect::<Vec<_>>();
322 block_group_names.sort();
323 assert_eq!(block_group_names, vec!["chr1", "chr2", "chr2", "chr3"]);
324 assert_eq!(
325 SampleLineage::get_parents(conn, &child.name),
326 vec![
327 "parent_a".to_string(),
328 "parent_b".to_string(),
329 "parent_c".to_string(),
330 ]
331 );
332 }
333}