1use std::fmt::*;
2
3use gen_core::traits::Capnp;
4use gen_graph::GenGraph;
5use rusqlite::{Result as SQLResult, Row, params};
6use serde::{Deserialize, Serialize};
7
8use crate::{block_group::BlockGroup, db::GraphConnection, gen_models_capnp::sample, traits::*};
9
10#[derive(Debug, Deserialize, Serialize, PartialEq)]
11pub struct Sample {
12 pub name: String,
13}
14
15impl<'a> Capnp<'a> for Sample {
16 type Builder = sample::Builder<'a>;
17 type Reader = sample::Reader<'a>;
18
19 fn write_capnp(&self, builder: &mut Self::Builder) {
20 builder.set_name(&self.name);
21 }
22
23 fn read_capnp(reader: Self::Reader) -> Self {
24 let name = reader.get_name().unwrap().to_string().unwrap();
25 Sample { name }
26 }
27}
28
29impl Query for Sample {
30 type Model = Sample;
31
32 const PRIMARY_KEY: &'static str = "name";
33 const TABLE_NAME: &'static str = "samples";
34
35 fn process_row(row: &Row) -> Self::Model {
36 Sample {
37 name: row.get(0).unwrap(),
38 }
39 }
40}
41
42impl Sample {
43 pub fn create(conn: &GraphConnection, name: &str) -> SQLResult<Sample> {
44 let mut stmt = conn
45 .prepare("INSERT INTO samples (name) VALUES (?1) returning (name);")
46 .unwrap();
47 stmt.query_row((name,), |row| Ok(Sample { name: row.get(0)? }))
48 }
49
50 pub fn get_or_create(conn: &GraphConnection, name: &str) -> Sample {
51 match Sample::create(conn, name) {
52 Ok(sample) => sample,
53 Err(rusqlite::Error::SqliteFailure(err, _details)) => {
54 if err.code == rusqlite::ErrorCode::ConstraintViolation {
55 Sample {
56 name: name.to_string(),
57 }
58 } else {
59 panic!("something bad happened querying the database")
60 }
61 }
62 Err(_) => {
63 panic!("something bad happened.")
64 }
65 }
66 }
67
68 pub fn delete_by_name(conn: &GraphConnection, name: &str) {
69 let mut stmt = conn.prepare("delete from samples where name = ?1").unwrap();
70 stmt.execute([name]).unwrap();
71 }
72
73 pub fn get_graph<'a>(
74 conn: &GraphConnection,
75 collection: &str,
76 name: impl Into<Option<&'a str>>,
77 ) -> GenGraph {
78 let name = name.into();
79 let block_groups = Sample::get_block_groups(conn, collection, name);
80 let mut sample_graph = GenGraph::new();
81 for bg in block_groups {
82 let bg_graph = BlockGroup::get_graph(conn, &bg.id);
83 for node in bg_graph.nodes() {
85 sample_graph.add_node(node);
86 }
87 for (source, dest, edges) in bg_graph.all_edges() {
88 if let Some(existing_edges) = sample_graph.edge_weight_mut(source, dest) {
89 existing_edges.extend(edges.clone());
90 } else {
91 sample_graph.add_edge(source, dest, edges.clone());
92 }
93 }
94 }
95 sample_graph
96 }
97
98 pub fn get_or_create_child(
99 conn: &GraphConnection,
100 collection_name: &str,
101 sample_name: &str,
102 parent_sample: Option<&str>,
103 ) -> Sample {
104 if let Ok(new_sample) = Sample::create(conn, sample_name) {
105 let bgs = if let Some(parent) = parent_sample {
106 BlockGroup::query(
107 conn,
108 "select * from block_groups where collection_name = ?1 AND sample_name = ?2",
109 params!(collection_name, parent),
110 )
111 } else {
112 BlockGroup::query(
113 conn,
114 "select * from block_groups where collection_name = ?1 AND sample_name is null;",
115 params!(collection_name),
116 )
117 };
118 for bg in bgs.iter() {
119 BlockGroup::get_or_create_sample_block_group(
120 conn,
121 collection_name,
122 &new_sample.name,
123 &bg.name,
124 parent_sample,
125 )
126 .expect("failed to get or create blockgroup clone.");
127 }
128 new_sample
129 } else {
130 Sample {
131 name: sample_name.to_string(),
132 }
133 }
134 }
135
136 pub fn get_block_groups(
137 conn: &GraphConnection,
138 collection_name: &str,
139 sample_name: Option<&str>,
140 ) -> Vec<BlockGroup> {
141 if let Some(sample) = sample_name {
142 BlockGroup::query(
143 conn,
144 "select * from block_groups where collection_name = ?1 AND sample_name = ?2;",
145 params![collection_name, sample],
146 )
147 } else {
148 BlockGroup::query(
149 conn,
150 "select * from block_groups where collection_name = ?1 AND sample_name IS NULL;",
151 params![collection_name],
152 )
153 }
154 }
155
156 pub fn get_all_names(conn: &GraphConnection) -> Vec<String> {
157 let samples = Sample::query(conn, "select * from samples;", rusqlite::params!());
158 samples.iter().map(|s| s.name.clone()).collect()
159 }
160
161 pub fn get_by_name(conn: &GraphConnection, name: &str) -> SQLResult<Sample> {
162 Sample::get(
163 conn,
164 "select * from samples where name = ?1;",
165 rusqlite::params!(name),
166 )
167 }
168}
169
170#[cfg(test)]
171mod tests {
172 use capnp::message::TypedBuilder;
173
174 use super::*;
175 use crate::test_helpers::get_connection;
176
177 #[test]
178 fn test_capnp_serialization() {
179 let sample = Sample {
180 name: "test_sample".to_string(),
181 };
182
183 let mut message = TypedBuilder::<sample::Owned>::new_default();
184 let mut root = message.init_root();
185 sample.write_capnp(&mut root);
186
187 let deserialized = Sample::read_capnp(root.into_reader());
188 assert_eq!(sample, deserialized);
189 }
190
191 #[test]
192 fn test_delete_by_name() {
193 let conn = &get_connection(None).unwrap();
194
195 let _ = Sample::create(conn, "sample1").unwrap();
196 let _ = Sample::create(conn, "sample2").unwrap();
197
198 assert!(Sample::get_by_name(conn, "sample1").is_ok());
199 assert!(Sample::get_by_name(conn, "sample2").is_ok());
200
201 Sample::delete_by_name(conn, "sample1");
202
203 assert!(Sample::get_by_name(conn, "sample1").is_err());
204 assert!(Sample::get_by_name(conn, "sample2").is_ok());
205 }
206}