Skip to main content

gen_models/
session_operations.rs

1use std::{
2    fs,
3    path::{Path as FilePath, PathBuf},
4    str,
5};
6
7use gen_core::{HashId, errors::ConfigError, traits::Capnp};
8use rusqlite::session;
9use serde::{Deserialize, Serialize};
10use sha2::{Digest, Sha256};
11
12use crate::{
13    accession::{Accession, AccessionEdge},
14    block_group::BlockGroup,
15    changesets::{DatabaseChangeset, process_changesetiter, write_changeset},
16    collection::Collection,
17    db::{DbContext, GraphConnection},
18    edge::Edge,
19    errors::OperationError,
20    file_types::FileTypes,
21    files::GenDatabase,
22    gen_models_capnp::dependency_models,
23    metadata::{self, get_db_uuid},
24    node::Node,
25    operations::{FileAddition, Operation, OperationInfo, OperationState, OperationSummary},
26    path::Path,
27    sample::Sample,
28    sequence::Sequence,
29};
30
31pub fn start_operation(conn: &GraphConnection) -> session::Session<'_> {
32    let mut session = session::Session::new(conn).unwrap();
33    attach_session(&mut session);
34    session
35}
36
37#[allow(clippy::too_many_arguments)]
38pub fn end_operation(
39    context: &DbContext,
40    session: &mut session::Session,
41    operation_info: &OperationInfo,
42    summary_str: &str,
43    force_hash: impl Into<Option<HashId>>,
44) -> Result<Operation, OperationError> {
45    let conn = context.graph().conn();
46    let operation_conn = context.operations().conn();
47    let db_uuid = metadata::get_db_uuid(conn);
48    // determine if this operation has already happened
49    let mut output = Vec::new();
50    session.changeset_strm(&mut output).unwrap();
51
52    let (changeset_models, dependencies) = process_changesetiter(conn, &output);
53
54    let hash = if let Some(hash) = force_hash.into() {
55        hash
56    } else {
57        if output.is_empty() {
58            return Err(OperationError::NoChanges);
59        }
60        let mut hasher = Sha256::new();
61        hasher.update(&db_uuid[..]);
62        hasher.update(&output[..]);
63        HashId(hasher.finalize().into())
64    };
65
66    operation_conn
67        .execute("SAVEPOINT new_operation;", [])
68        .unwrap();
69
70    match Operation::create(operation_conn, &operation_info.description, &hash) {
71        Ok(operation) => {
72            let gen_dir = match context.workspace().find_gen_dir() {
73                Some(dir) => dir,
74                None => {
75                    return Err(OperationError::ConfigError(
76                        ConfigError::GenDirectoryNotFound,
77                    ));
78                }
79            };
80            let assets_dir = FilePath::new(&gen_dir).join("assets");
81            fs::create_dir_all(&assets_dir).map_err(|_| OperationError::IOError)?;
82
83            for op_file in operation_info.files.iter() {
84                let fa = match FileAddition::get_or_create(
85                    context.workspace(),
86                    operation_conn,
87                    &op_file.file_path,
88                    op_file.file_type,
89                    None,
90                ) {
91                    Ok(fa) => fa,
92                    Err(err) => return Err(OperationError::SQLError(format!("{err}"))),
93                };
94                Operation::add_file(operation_conn, &operation.hash, &fa.id)
95                    .map_err(|err| OperationError::SQLError(format!("{err}")))?;
96                if fa.file_type != FileTypes::Changeset && fa.file_type != FileTypes::None {
97                    let asset_destination_path = assets_dir.join(fa.hashed_filename());
98                    if !asset_destination_path.exists() {
99                        let source_path = if FilePath::new(&op_file.file_path).is_absolute() {
100                            PathBuf::from(&op_file.file_path)
101                        } else {
102                            context
103                                .workspace()
104                                .repo_root()
105                                .map_err(OperationError::ConfigError)?
106                                .join(&op_file.file_path)
107                        };
108                        match fs::copy(source_path, asset_destination_path) {
109                            Ok(result) => result,
110                            Err(_) => return Err(OperationError::IOError),
111                        };
112                    }
113                }
114            }
115            Operation::add_database(operation_conn, &operation.hash, &db_uuid)
116                .map_err(|err| OperationError::SQLError(format!("{err}")))?;
117            OperationSummary::create(operation_conn, &operation.hash, summary_str);
118            let db_uuid = get_db_uuid(conn);
119            let gen_db = GenDatabase::get_by_uuid(operation_conn, &db_uuid).unwrap();
120            write_changeset(
121                context.workspace(),
122                &operation,
123                DatabaseChangeset {
124                    db_path: gen_db.path,
125                    changes: changeset_models,
126                },
127                &dependencies,
128            );
129            OperationState::set_operation(operation_conn, &operation.hash);
130            operation_conn
131                .execute("RELEASE SAVEPOINT new_operation;", [])
132                .unwrap();
133            Ok(operation)
134        }
135        Err(rusqlite::Error::SqliteFailure(err, details)) => {
136            operation_conn
137                .execute("ROLLBACK TRANSACTION TO SAVEPOINT new_operation;", [])
138                .unwrap();
139            if err.code == rusqlite::ErrorCode::ConstraintViolation {
140                Err(OperationError::OperationExists)
141            } else {
142                panic!("something bad happened querying the database {details:?}");
143            }
144        }
145        Err(e) => {
146            operation_conn
147                .execute("ROLLBACK TRANSACTION TO SAVEPOINT new_operation;", [])
148                .unwrap();
149            panic!("something bad happened querying the database {e:?}");
150        }
151    }
152}
153
154pub fn attach_session(session: &mut session::Session) {
155    for table in [
156        "collections",
157        "samples",
158        "sequences",
159        "block_groups",
160        "paths",
161        "nodes",
162        "edges",
163        "path_edges",
164        "block_group_edges",
165        "accessions",
166        "accession_edges",
167        "accession_paths",
168        "annotation_groups",
169        "annotations",
170        "annotation_group_samples",
171        "sample_lineage",
172    ] {
173        session.attach(Some(table)).unwrap();
174    }
175}
176
177#[derive(Default, Deserialize, Serialize, Debug, PartialEq)]
178pub struct DependencyModels {
179    pub collections: Vec<Collection>,
180    pub samples: Vec<Sample>,
181    pub sequences: Vec<Sequence>,
182    pub block_group: Vec<BlockGroup>,
183    pub nodes: Vec<Node>,
184    pub edges: Vec<Edge>,
185    pub paths: Vec<Path>,
186    pub accessions: Vec<Accession>,
187    pub accession_edges: Vec<AccessionEdge>,
188}
189
190impl<'a> Capnp<'a> for DependencyModels {
191    type Builder = dependency_models::Builder<'a>;
192    type Reader = dependency_models::Reader<'a>;
193
194    fn write_capnp(&self, builder: &mut Self::Builder) {
195        // Write collections
196        let mut collections_builder = builder
197            .reborrow()
198            .init_collections(self.collections.len() as u32);
199        for (i, collection) in self.collections.iter().enumerate() {
200            let mut collection_builder = collections_builder.reborrow().get(i as u32);
201            collection.write_capnp(&mut collection_builder);
202        }
203
204        // Write samples
205        let mut samples_builder = builder.reborrow().init_samples(self.samples.len() as u32);
206        for (i, sample) in self.samples.iter().enumerate() {
207            let mut sample_builder = samples_builder.reborrow().get(i as u32);
208            sample.write_capnp(&mut sample_builder);
209        }
210
211        // Write sequences
212        let mut sequences_builder = builder
213            .reborrow()
214            .init_sequences(self.sequences.len() as u32);
215        for (i, sequence) in self.sequences.iter().enumerate() {
216            let mut sequence_builder = sequences_builder.reborrow().get(i as u32);
217            sequence.write_capnp(&mut sequence_builder);
218        }
219
220        // Write block groups (note: field name is blockGroup in capnp schema)
221        let mut block_group_builder = builder
222            .reborrow()
223            .init_block_group(self.block_group.len() as u32);
224        for (i, block_group) in self.block_group.iter().enumerate() {
225            let mut bg_builder = block_group_builder.reborrow().get(i as u32);
226            block_group.write_capnp(&mut bg_builder);
227        }
228
229        // Write nodes
230        let mut nodes_builder = builder.reborrow().init_nodes(self.nodes.len() as u32);
231        for (i, node) in self.nodes.iter().enumerate() {
232            let mut node_builder = nodes_builder.reborrow().get(i as u32);
233            node.write_capnp(&mut node_builder);
234        }
235
236        // Write edges
237        let mut edges_builder = builder.reborrow().init_edges(self.edges.len() as u32);
238        for (i, edge) in self.edges.iter().enumerate() {
239            let mut edge_builder = edges_builder.reborrow().get(i as u32);
240            edge.write_capnp(&mut edge_builder);
241        }
242
243        // Write paths
244        let mut paths_builder = builder.reborrow().init_paths(self.paths.len() as u32);
245        for (i, path) in self.paths.iter().enumerate() {
246            let mut path_builder = paths_builder.reborrow().get(i as u32);
247            path.write_capnp(&mut path_builder);
248        }
249
250        // Write accessions
251        let mut accessions_builder = builder
252            .reborrow()
253            .init_accessions(self.accessions.len() as u32);
254        for (i, accession) in self.accessions.iter().enumerate() {
255            let mut accession_builder = accessions_builder.reborrow().get(i as u32);
256            accession.write_capnp(&mut accession_builder);
257        }
258
259        // Write accession edges (note: field name is accessionEdges in capnp schema)
260        let mut accession_edges_builder = builder
261            .reborrow()
262            .init_accession_edges(self.accession_edges.len() as u32);
263        for (i, accession_edge) in self.accession_edges.iter().enumerate() {
264            let mut accession_edge_builder = accession_edges_builder.reborrow().get(i as u32);
265            accession_edge.write_capnp(&mut accession_edge_builder);
266        }
267    }
268
269    fn read_capnp(reader: Self::Reader) -> Self {
270        // Read collections
271        let collections_reader = reader.get_collections().unwrap();
272        let mut collections = Vec::new();
273        for collection_reader in collections_reader.iter() {
274            collections.push(Collection::read_capnp(collection_reader));
275        }
276
277        // Read samples
278        let samples_reader = reader.get_samples().unwrap();
279        let mut samples = Vec::new();
280        for sample_reader in samples_reader.iter() {
281            samples.push(Sample::read_capnp(sample_reader));
282        }
283
284        // Read sequences
285        let sequences_reader = reader.get_sequences().unwrap();
286        let mut sequences = Vec::new();
287        for sequence_reader in sequences_reader.iter() {
288            sequences.push(Sequence::read_capnp(sequence_reader));
289        }
290
291        // Read block groups
292        let block_group_reader = reader.get_block_group().unwrap();
293        let mut block_group = Vec::new();
294        for bg_reader in block_group_reader.iter() {
295            block_group.push(BlockGroup::read_capnp(bg_reader));
296        }
297
298        // Read nodes
299        let nodes_reader = reader.get_nodes().unwrap();
300        let mut nodes = Vec::new();
301        for node_reader in nodes_reader.iter() {
302            nodes.push(Node::read_capnp(node_reader));
303        }
304
305        // Read edges
306        let edges_reader = reader.get_edges().unwrap();
307        let mut edges = Vec::new();
308        for edge_reader in edges_reader.iter() {
309            edges.push(Edge::read_capnp(edge_reader));
310        }
311
312        // Read paths
313        let paths_reader = reader.get_paths().unwrap();
314        let mut paths = Vec::new();
315        for path_reader in paths_reader.iter() {
316            paths.push(Path::read_capnp(path_reader));
317        }
318
319        // Read accessions
320        let accessions_reader = reader.get_accessions().unwrap();
321        let mut accessions = Vec::new();
322        for accession_reader in accessions_reader.iter() {
323            accessions.push(Accession::read_capnp(accession_reader));
324        }
325
326        // Read accession edges
327        let accession_edges_reader = reader.get_accession_edges().unwrap();
328        let mut accession_edges = Vec::new();
329        for accession_edge_reader in accession_edges_reader.iter() {
330            accession_edges.push(AccessionEdge::read_capnp(accession_edge_reader));
331        }
332
333        DependencyModels {
334            collections,
335            samples,
336            sequences,
337            block_group,
338            nodes,
339            edges,
340            paths,
341            accessions,
342            accession_edges,
343        }
344    }
345}
346
347#[cfg(test)]
348mod tests {
349    use capnp::message::TypedBuilder;
350    use gen_core::Strand;
351
352    use super::*;
353    use crate::sequence::NewSequence;
354
355    #[test]
356    fn test_dependency_models_capnp_serialization() {
357        let dependency_models = DependencyModels {
358            collections: vec![Collection {
359                name: "test_collection".to_string(),
360            }],
361            samples: vec![Sample {
362                name: "test_sample".to_string(),
363            }],
364            sequences: vec![
365                NewSequence::new()
366                    .sequence_type("DNA")
367                    .sequence("ATCG")
368                    .name("test_seq")
369                    .build(),
370            ],
371            block_group: vec![BlockGroup {
372                id: HashId::pad_str(1),
373                collection_name: "test_collection".to_string(),
374                sample_name: "test_sample".to_string(),
375                name: "test_bg".to_string(),
376                created_on: 0,
377                parent_block_group_id: None,
378                is_default: false,
379            }],
380            nodes: vec![Node {
381                id: HashId::convert_str("node_hash"),
382                sequence_hash: HashId::convert_str("test_hash"),
383            }],
384            edges: vec![Edge {
385                id: HashId::pad_str(1),
386                source_node_id: HashId::convert_str("1"),
387                source_coordinate: 0,
388                source_strand: Strand::Forward,
389                target_node_id: HashId::convert_str("2"),
390                target_coordinate: 0,
391                target_strand: Strand::Forward,
392            }],
393            paths: vec![Path {
394                id: HashId::pad_str(1),
395                block_group_id: HashId::pad_str(1),
396                name: "test_path".to_string(),
397                created_on: 0,
398            }],
399            accessions: vec![Accession {
400                id: HashId::pad_str(1),
401                name: "test_accession".to_string(),
402                path_id: HashId::pad_str(1),
403                parent_accession_id: None,
404            }],
405            accession_edges: vec![AccessionEdge {
406                id: HashId::pad_str(1),
407                source_node_id: HashId::convert_str("1"),
408                source_coordinate: 0,
409                source_strand: Strand::Forward,
410                target_node_id: HashId::convert_str("2"),
411                target_coordinate: 0,
412                target_strand: Strand::Forward,
413                chromosome_index: 0,
414            }],
415        };
416
417        let mut message = TypedBuilder::<dependency_models::Owned>::new_default();
418        let mut root = message.init_root();
419        dependency_models.write_capnp(&mut root);
420
421        let deserialized = DependencyModels::read_capnp(root.into_reader());
422
423        assert_eq!(dependency_models, deserialized);
424    }
425}