1use std::{
2 fs,
3 path::{Path as FilePath, PathBuf},
4 str,
5};
6
7use gen_core::{HashId, errors::ConfigError, traits::Capnp};
8use rusqlite::session;
9use serde::{Deserialize, Serialize};
10use sha2::{Digest, Sha256};
11
12use crate::{
13 accession::{Accession, AccessionEdge},
14 block_group::BlockGroup,
15 changesets::{DatabaseChangeset, process_changesetiter, write_changeset},
16 collection::Collection,
17 db::{DbContext, GraphConnection},
18 edge::Edge,
19 errors::OperationError,
20 file_types::FileTypes,
21 files::GenDatabase,
22 gen_models_capnp::dependency_models,
23 metadata::{self, get_db_uuid},
24 node::Node,
25 operations::{FileAddition, Operation, OperationInfo, OperationState, OperationSummary},
26 path::Path,
27 sample::Sample,
28 sequence::Sequence,
29};
30
31pub fn start_operation(conn: &GraphConnection) -> session::Session<'_> {
32 let mut session = session::Session::new(conn).unwrap();
33 attach_session(&mut session);
34 session
35}
36
37#[allow(clippy::too_many_arguments)]
38pub fn end_operation(
39 context: &DbContext,
40 session: &mut session::Session,
41 operation_info: &OperationInfo,
42 summary_str: &str,
43 force_hash: impl Into<Option<HashId>>,
44) -> Result<Operation, OperationError> {
45 let conn = context.graph().conn();
46 let operation_conn = context.operations().conn();
47 let db_uuid = metadata::get_db_uuid(conn);
48 let mut output = Vec::new();
50 session.changeset_strm(&mut output).unwrap();
51
52 let (changeset_models, dependencies) = process_changesetiter(conn, &output);
53
54 let hash = if let Some(hash) = force_hash.into() {
55 hash
56 } else {
57 if output.is_empty() {
58 return Err(OperationError::NoChanges);
59 }
60 let mut hasher = Sha256::new();
61 hasher.update(&db_uuid[..]);
62 hasher.update(&output[..]);
63 HashId(hasher.finalize().into())
64 };
65
66 operation_conn
67 .execute("SAVEPOINT new_operation;", [])
68 .unwrap();
69
70 match Operation::create(operation_conn, &operation_info.description, &hash) {
71 Ok(operation) => {
72 let gen_dir = match context.workspace().find_gen_dir() {
73 Some(dir) => dir,
74 None => {
75 return Err(OperationError::ConfigError(
76 ConfigError::GenDirectoryNotFound,
77 ));
78 }
79 };
80 let assets_dir = FilePath::new(&gen_dir).join("assets");
81 fs::create_dir_all(&assets_dir).map_err(|_| OperationError::IOError)?;
82
83 for op_file in operation_info.files.iter() {
84 let fa = match FileAddition::get_or_create(
85 context.workspace(),
86 operation_conn,
87 &op_file.file_path,
88 op_file.file_type,
89 None,
90 ) {
91 Ok(fa) => fa,
92 Err(err) => return Err(OperationError::SQLError(format!("{err}"))),
93 };
94 Operation::add_file(operation_conn, &operation.hash, &fa.id)
95 .map_err(|err| OperationError::SQLError(format!("{err}")))?;
96 if fa.file_type != FileTypes::Changeset && fa.file_type != FileTypes::None {
97 let asset_destination_path = assets_dir.join(fa.hashed_filename());
98 if !asset_destination_path.exists() {
99 let source_path = if FilePath::new(&op_file.file_path).is_absolute() {
100 PathBuf::from(&op_file.file_path)
101 } else {
102 context
103 .workspace()
104 .repo_root()
105 .map_err(OperationError::ConfigError)?
106 .join(&op_file.file_path)
107 };
108 match fs::copy(source_path, asset_destination_path) {
109 Ok(result) => result,
110 Err(_) => return Err(OperationError::IOError),
111 };
112 }
113 }
114 }
115 Operation::add_database(operation_conn, &operation.hash, &db_uuid)
116 .map_err(|err| OperationError::SQLError(format!("{err}")))?;
117 OperationSummary::create(operation_conn, &operation.hash, summary_str);
118 let db_uuid = get_db_uuid(conn);
119 let gen_db = GenDatabase::get_by_uuid(operation_conn, &db_uuid).unwrap();
120 write_changeset(
121 context.workspace(),
122 &operation,
123 DatabaseChangeset {
124 db_path: gen_db.path,
125 changes: changeset_models,
126 },
127 &dependencies,
128 );
129 OperationState::set_operation(operation_conn, &operation.hash);
130 operation_conn
131 .execute("RELEASE SAVEPOINT new_operation;", [])
132 .unwrap();
133 Ok(operation)
134 }
135 Err(rusqlite::Error::SqliteFailure(err, details)) => {
136 operation_conn
137 .execute("ROLLBACK TRANSACTION TO SAVEPOINT new_operation;", [])
138 .unwrap();
139 if err.code == rusqlite::ErrorCode::ConstraintViolation {
140 Err(OperationError::OperationExists)
141 } else {
142 panic!("something bad happened querying the database {details:?}");
143 }
144 }
145 Err(e) => {
146 operation_conn
147 .execute("ROLLBACK TRANSACTION TO SAVEPOINT new_operation;", [])
148 .unwrap();
149 panic!("something bad happened querying the database {e:?}");
150 }
151 }
152}
153
154pub fn attach_session(session: &mut session::Session) {
155 for table in [
156 "collections",
157 "samples",
158 "sequences",
159 "block_groups",
160 "paths",
161 "nodes",
162 "edges",
163 "path_edges",
164 "block_group_edges",
165 "accessions",
166 "accession_edges",
167 "accession_paths",
168 "annotation_groups",
169 "annotations",
170 "annotation_group_samples",
171 "sample_lineage",
172 ] {
173 session.attach(Some(table)).unwrap();
174 }
175}
176
177#[derive(Default, Deserialize, Serialize, Debug, PartialEq)]
178pub struct DependencyModels {
179 pub collections: Vec<Collection>,
180 pub samples: Vec<Sample>,
181 pub sequences: Vec<Sequence>,
182 pub block_group: Vec<BlockGroup>,
183 pub nodes: Vec<Node>,
184 pub edges: Vec<Edge>,
185 pub paths: Vec<Path>,
186 pub accessions: Vec<Accession>,
187 pub accession_edges: Vec<AccessionEdge>,
188}
189
190impl<'a> Capnp<'a> for DependencyModels {
191 type Builder = dependency_models::Builder<'a>;
192 type Reader = dependency_models::Reader<'a>;
193
194 fn write_capnp(&self, builder: &mut Self::Builder) {
195 let mut collections_builder = builder
197 .reborrow()
198 .init_collections(self.collections.len() as u32);
199 for (i, collection) in self.collections.iter().enumerate() {
200 let mut collection_builder = collections_builder.reborrow().get(i as u32);
201 collection.write_capnp(&mut collection_builder);
202 }
203
204 let mut samples_builder = builder.reborrow().init_samples(self.samples.len() as u32);
206 for (i, sample) in self.samples.iter().enumerate() {
207 let mut sample_builder = samples_builder.reborrow().get(i as u32);
208 sample.write_capnp(&mut sample_builder);
209 }
210
211 let mut sequences_builder = builder
213 .reborrow()
214 .init_sequences(self.sequences.len() as u32);
215 for (i, sequence) in self.sequences.iter().enumerate() {
216 let mut sequence_builder = sequences_builder.reborrow().get(i as u32);
217 sequence.write_capnp(&mut sequence_builder);
218 }
219
220 let mut block_group_builder = builder
222 .reborrow()
223 .init_block_group(self.block_group.len() as u32);
224 for (i, block_group) in self.block_group.iter().enumerate() {
225 let mut bg_builder = block_group_builder.reborrow().get(i as u32);
226 block_group.write_capnp(&mut bg_builder);
227 }
228
229 let mut nodes_builder = builder.reborrow().init_nodes(self.nodes.len() as u32);
231 for (i, node) in self.nodes.iter().enumerate() {
232 let mut node_builder = nodes_builder.reborrow().get(i as u32);
233 node.write_capnp(&mut node_builder);
234 }
235
236 let mut edges_builder = builder.reborrow().init_edges(self.edges.len() as u32);
238 for (i, edge) in self.edges.iter().enumerate() {
239 let mut edge_builder = edges_builder.reborrow().get(i as u32);
240 edge.write_capnp(&mut edge_builder);
241 }
242
243 let mut paths_builder = builder.reborrow().init_paths(self.paths.len() as u32);
245 for (i, path) in self.paths.iter().enumerate() {
246 let mut path_builder = paths_builder.reborrow().get(i as u32);
247 path.write_capnp(&mut path_builder);
248 }
249
250 let mut accessions_builder = builder
252 .reborrow()
253 .init_accessions(self.accessions.len() as u32);
254 for (i, accession) in self.accessions.iter().enumerate() {
255 let mut accession_builder = accessions_builder.reborrow().get(i as u32);
256 accession.write_capnp(&mut accession_builder);
257 }
258
259 let mut accession_edges_builder = builder
261 .reborrow()
262 .init_accession_edges(self.accession_edges.len() as u32);
263 for (i, accession_edge) in self.accession_edges.iter().enumerate() {
264 let mut accession_edge_builder = accession_edges_builder.reborrow().get(i as u32);
265 accession_edge.write_capnp(&mut accession_edge_builder);
266 }
267 }
268
269 fn read_capnp(reader: Self::Reader) -> Self {
270 let collections_reader = reader.get_collections().unwrap();
272 let mut collections = Vec::new();
273 for collection_reader in collections_reader.iter() {
274 collections.push(Collection::read_capnp(collection_reader));
275 }
276
277 let samples_reader = reader.get_samples().unwrap();
279 let mut samples = Vec::new();
280 for sample_reader in samples_reader.iter() {
281 samples.push(Sample::read_capnp(sample_reader));
282 }
283
284 let sequences_reader = reader.get_sequences().unwrap();
286 let mut sequences = Vec::new();
287 for sequence_reader in sequences_reader.iter() {
288 sequences.push(Sequence::read_capnp(sequence_reader));
289 }
290
291 let block_group_reader = reader.get_block_group().unwrap();
293 let mut block_group = Vec::new();
294 for bg_reader in block_group_reader.iter() {
295 block_group.push(BlockGroup::read_capnp(bg_reader));
296 }
297
298 let nodes_reader = reader.get_nodes().unwrap();
300 let mut nodes = Vec::new();
301 for node_reader in nodes_reader.iter() {
302 nodes.push(Node::read_capnp(node_reader));
303 }
304
305 let edges_reader = reader.get_edges().unwrap();
307 let mut edges = Vec::new();
308 for edge_reader in edges_reader.iter() {
309 edges.push(Edge::read_capnp(edge_reader));
310 }
311
312 let paths_reader = reader.get_paths().unwrap();
314 let mut paths = Vec::new();
315 for path_reader in paths_reader.iter() {
316 paths.push(Path::read_capnp(path_reader));
317 }
318
319 let accessions_reader = reader.get_accessions().unwrap();
321 let mut accessions = Vec::new();
322 for accession_reader in accessions_reader.iter() {
323 accessions.push(Accession::read_capnp(accession_reader));
324 }
325
326 let accession_edges_reader = reader.get_accession_edges().unwrap();
328 let mut accession_edges = Vec::new();
329 for accession_edge_reader in accession_edges_reader.iter() {
330 accession_edges.push(AccessionEdge::read_capnp(accession_edge_reader));
331 }
332
333 DependencyModels {
334 collections,
335 samples,
336 sequences,
337 block_group,
338 nodes,
339 edges,
340 paths,
341 accessions,
342 accession_edges,
343 }
344 }
345}
346
347#[cfg(test)]
348mod tests {
349 use capnp::message::TypedBuilder;
350 use gen_core::Strand;
351
352 use super::*;
353 use crate::sequence::NewSequence;
354
355 #[test]
356 fn test_dependency_models_capnp_serialization() {
357 let dependency_models = DependencyModels {
358 collections: vec![Collection {
359 name: "test_collection".to_string(),
360 }],
361 samples: vec![Sample {
362 name: "test_sample".to_string(),
363 }],
364 sequences: vec![
365 NewSequence::new()
366 .sequence_type("DNA")
367 .sequence("ATCG")
368 .name("test_seq")
369 .build(),
370 ],
371 block_group: vec![BlockGroup {
372 id: HashId::pad_str(1),
373 collection_name: "test_collection".to_string(),
374 sample_name: "test_sample".to_string(),
375 name: "test_bg".to_string(),
376 created_on: 0,
377 parent_block_group_id: None,
378 is_default: false,
379 }],
380 nodes: vec![Node {
381 id: HashId::convert_str("node_hash"),
382 sequence_hash: HashId::convert_str("test_hash"),
383 }],
384 edges: vec![Edge {
385 id: HashId::pad_str(1),
386 source_node_id: HashId::convert_str("1"),
387 source_coordinate: 0,
388 source_strand: Strand::Forward,
389 target_node_id: HashId::convert_str("2"),
390 target_coordinate: 0,
391 target_strand: Strand::Forward,
392 }],
393 paths: vec![Path {
394 id: HashId::pad_str(1),
395 block_group_id: HashId::pad_str(1),
396 name: "test_path".to_string(),
397 created_on: 0,
398 }],
399 accessions: vec![Accession {
400 id: HashId::pad_str(1),
401 name: "test_accession".to_string(),
402 path_id: HashId::pad_str(1),
403 parent_accession_id: None,
404 }],
405 accession_edges: vec![AccessionEdge {
406 id: HashId::pad_str(1),
407 source_node_id: HashId::convert_str("1"),
408 source_coordinate: 0,
409 source_strand: Strand::Forward,
410 target_node_id: HashId::convert_str("2"),
411 target_coordinate: 0,
412 target_strand: Strand::Forward,
413 chromosome_index: 0,
414 }],
415 };
416
417 let mut message = TypedBuilder::<dependency_models::Owned>::new_default();
418 let mut root = message.init_root();
419 dependency_models.write_capnp(&mut root);
420
421 let deserialized = DependencyModels::read_capnp(root.into_reader());
422
423 assert_eq!(dependency_models, deserialized);
424 }
425}