Skip to main content

gen_models/
session_operations.rs

1use std::{
2    fs,
3    path::{Path as FilePath, PathBuf},
4    str,
5};
6
7use gen_core::{HashId, errors::ConfigError, traits::Capnp};
8use rusqlite::session;
9use serde::{Deserialize, Serialize};
10use sha2::{Digest, Sha256};
11
12use crate::{
13    accession::{Accession, AccessionEdge},
14    block_group::BlockGroup,
15    changesets::{DatabaseChangeset, process_changesetiter, write_changeset},
16    collection::Collection,
17    db::{DbContext, GraphConnection},
18    edge::Edge,
19    errors::OperationError,
20    file_types::FileTypes,
21    files::GenDatabase,
22    gen_models_capnp::dependency_models,
23    metadata::{self, get_db_uuid},
24    node::Node,
25    operations::{FileAddition, Operation, OperationInfo, OperationState, OperationSummary},
26    path::Path,
27    sample::Sample,
28    sequence::Sequence,
29};
30
31pub fn start_operation(conn: &GraphConnection) -> session::Session<'_> {
32    let mut session = session::Session::new(conn).unwrap();
33    attach_session(&mut session);
34    session
35}
36
37#[allow(clippy::too_many_arguments)]
38pub fn end_operation(
39    context: &DbContext,
40    session: &mut session::Session,
41    operation_info: &OperationInfo,
42    summary_str: &str,
43    force_hash: impl Into<Option<HashId>>,
44) -> Result<Operation, OperationError> {
45    let conn = context.graph().conn();
46    let operation_conn = context.operations().conn();
47    let db_uuid = metadata::get_db_uuid(conn);
48    // determine if this operation has already happened
49    let mut output = Vec::new();
50    session.changeset_strm(&mut output).unwrap();
51
52    let (changeset_models, dependencies) = process_changesetiter(conn, &output);
53
54    let hash = if let Some(hash) = force_hash.into() {
55        hash
56    } else {
57        if output.is_empty() {
58            return Err(OperationError::NoChanges);
59        }
60        let mut hasher = Sha256::new();
61        hasher.update(&db_uuid[..]);
62        hasher.update(&output[..]);
63        HashId(hasher.finalize().into())
64    };
65
66    operation_conn
67        .execute("SAVEPOINT new_operation;", [])
68        .unwrap();
69
70    match Operation::create(operation_conn, &operation_info.description, &hash) {
71        Ok(operation) => {
72            let gen_dir = match context.workspace().find_gen_dir() {
73                Some(dir) => dir,
74                None => {
75                    return Err(OperationError::ConfigError(
76                        ConfigError::GenDirectoryNotFound,
77                    ));
78                }
79            };
80            let assets_dir = FilePath::new(&gen_dir).join("assets");
81            fs::create_dir_all(&assets_dir).map_err(|_| OperationError::IOError)?;
82
83            for op_file in operation_info.files.iter() {
84                let fa = match FileAddition::get_or_create(
85                    context.workspace(),
86                    operation_conn,
87                    &op_file.file_path,
88                    op_file.file_type,
89                    None,
90                ) {
91                    Ok(fa) => fa,
92                    Err(err) => return Err(OperationError::SQLError(format!("{err}"))),
93                };
94                Operation::add_file(operation_conn, &operation.hash, &fa.id)
95                    .map_err(|err| OperationError::SQLError(format!("{err}")))?;
96                if fa.file_type != FileTypes::Changeset && fa.file_type != FileTypes::None {
97                    let asset_destination_path = assets_dir.join(fa.hashed_filename());
98                    if !asset_destination_path.exists() {
99                        let source_path = if FilePath::new(&op_file.file_path).is_absolute() {
100                            PathBuf::from(&op_file.file_path)
101                        } else {
102                            context
103                                .workspace()
104                                .repo_root()
105                                .map_err(OperationError::ConfigError)?
106                                .join(&op_file.file_path)
107                        };
108                        match fs::copy(source_path, asset_destination_path) {
109                            Ok(result) => result,
110                            Err(_) => return Err(OperationError::IOError),
111                        };
112                    }
113                }
114            }
115            Operation::add_database(operation_conn, &operation.hash, &db_uuid)
116                .map_err(|err| OperationError::SQLError(format!("{err}")))?;
117            OperationSummary::create(operation_conn, &operation.hash, summary_str);
118            let db_uuid = get_db_uuid(conn);
119            let gen_db = GenDatabase::get_by_uuid(operation_conn, &db_uuid).unwrap();
120            write_changeset(
121                context.workspace(),
122                &operation,
123                DatabaseChangeset {
124                    db_path: gen_db.path,
125                    changes: changeset_models,
126                },
127                &dependencies,
128            );
129            OperationState::set_operation(operation_conn, &operation.hash);
130            operation_conn
131                .execute("RELEASE SAVEPOINT new_operation;", [])
132                .unwrap();
133            Ok(operation)
134        }
135        Err(rusqlite::Error::SqliteFailure(err, details)) => {
136            operation_conn
137                .execute("ROLLBACK TRANSACTION TO SAVEPOINT new_operation;", [])
138                .unwrap();
139            if err.code == rusqlite::ErrorCode::ConstraintViolation {
140                Err(OperationError::OperationExists)
141            } else {
142                panic!("something bad happened querying the database {details:?}");
143            }
144        }
145        Err(e) => {
146            operation_conn
147                .execute("ROLLBACK TRANSACTION TO SAVEPOINT new_operation;", [])
148                .unwrap();
149            panic!("something bad happened querying the database {e:?}");
150        }
151    }
152}
153
154pub fn attach_session(session: &mut session::Session) {
155    for table in [
156        "collections",
157        "samples",
158        "sequences",
159        "block_groups",
160        "paths",
161        "nodes",
162        "edges",
163        "path_edges",
164        "block_group_edges",
165        "accessions",
166        "accession_edges",
167        "accession_paths",
168        "annotation_groups",
169        "annotations",
170        "annotation_group_samples",
171    ] {
172        session.attach(Some(table)).unwrap();
173    }
174}
175
176#[derive(Default, Deserialize, Serialize, Debug, PartialEq)]
177pub struct DependencyModels {
178    pub collections: Vec<Collection>,
179    pub samples: Vec<Sample>,
180    pub sequences: Vec<Sequence>,
181    pub block_group: Vec<BlockGroup>,
182    pub nodes: Vec<Node>,
183    pub edges: Vec<Edge>,
184    pub paths: Vec<Path>,
185    pub accessions: Vec<Accession>,
186    pub accession_edges: Vec<AccessionEdge>,
187}
188
189impl<'a> Capnp<'a> for DependencyModels {
190    type Builder = dependency_models::Builder<'a>;
191    type Reader = dependency_models::Reader<'a>;
192
193    fn write_capnp(&self, builder: &mut Self::Builder) {
194        // Write collections
195        let mut collections_builder = builder
196            .reborrow()
197            .init_collections(self.collections.len() as u32);
198        for (i, collection) in self.collections.iter().enumerate() {
199            let mut collection_builder = collections_builder.reborrow().get(i as u32);
200            collection.write_capnp(&mut collection_builder);
201        }
202
203        // Write samples
204        let mut samples_builder = builder.reborrow().init_samples(self.samples.len() as u32);
205        for (i, sample) in self.samples.iter().enumerate() {
206            let mut sample_builder = samples_builder.reborrow().get(i as u32);
207            sample.write_capnp(&mut sample_builder);
208        }
209
210        // Write sequences
211        let mut sequences_builder = builder
212            .reborrow()
213            .init_sequences(self.sequences.len() as u32);
214        for (i, sequence) in self.sequences.iter().enumerate() {
215            let mut sequence_builder = sequences_builder.reborrow().get(i as u32);
216            sequence.write_capnp(&mut sequence_builder);
217        }
218
219        // Write block groups (note: field name is blockGroup in capnp schema)
220        let mut block_group_builder = builder
221            .reborrow()
222            .init_block_group(self.block_group.len() as u32);
223        for (i, block_group) in self.block_group.iter().enumerate() {
224            let mut bg_builder = block_group_builder.reborrow().get(i as u32);
225            block_group.write_capnp(&mut bg_builder);
226        }
227
228        // Write nodes
229        let mut nodes_builder = builder.reborrow().init_nodes(self.nodes.len() as u32);
230        for (i, node) in self.nodes.iter().enumerate() {
231            let mut node_builder = nodes_builder.reborrow().get(i as u32);
232            node.write_capnp(&mut node_builder);
233        }
234
235        // Write edges
236        let mut edges_builder = builder.reborrow().init_edges(self.edges.len() as u32);
237        for (i, edge) in self.edges.iter().enumerate() {
238            let mut edge_builder = edges_builder.reborrow().get(i as u32);
239            edge.write_capnp(&mut edge_builder);
240        }
241
242        // Write paths
243        let mut paths_builder = builder.reborrow().init_paths(self.paths.len() as u32);
244        for (i, path) in self.paths.iter().enumerate() {
245            let mut path_builder = paths_builder.reborrow().get(i as u32);
246            path.write_capnp(&mut path_builder);
247        }
248
249        // Write accessions
250        let mut accessions_builder = builder
251            .reborrow()
252            .init_accessions(self.accessions.len() as u32);
253        for (i, accession) in self.accessions.iter().enumerate() {
254            let mut accession_builder = accessions_builder.reborrow().get(i as u32);
255            accession.write_capnp(&mut accession_builder);
256        }
257
258        // Write accession edges (note: field name is accessionEdges in capnp schema)
259        let mut accession_edges_builder = builder
260            .reborrow()
261            .init_accession_edges(self.accession_edges.len() as u32);
262        for (i, accession_edge) in self.accession_edges.iter().enumerate() {
263            let mut accession_edge_builder = accession_edges_builder.reborrow().get(i as u32);
264            accession_edge.write_capnp(&mut accession_edge_builder);
265        }
266    }
267
268    fn read_capnp(reader: Self::Reader) -> Self {
269        // Read collections
270        let collections_reader = reader.get_collections().unwrap();
271        let mut collections = Vec::new();
272        for collection_reader in collections_reader.iter() {
273            collections.push(Collection::read_capnp(collection_reader));
274        }
275
276        // Read samples
277        let samples_reader = reader.get_samples().unwrap();
278        let mut samples = Vec::new();
279        for sample_reader in samples_reader.iter() {
280            samples.push(Sample::read_capnp(sample_reader));
281        }
282
283        // Read sequences
284        let sequences_reader = reader.get_sequences().unwrap();
285        let mut sequences = Vec::new();
286        for sequence_reader in sequences_reader.iter() {
287            sequences.push(Sequence::read_capnp(sequence_reader));
288        }
289
290        // Read block groups
291        let block_group_reader = reader.get_block_group().unwrap();
292        let mut block_group = Vec::new();
293        for bg_reader in block_group_reader.iter() {
294            block_group.push(BlockGroup::read_capnp(bg_reader));
295        }
296
297        // Read nodes
298        let nodes_reader = reader.get_nodes().unwrap();
299        let mut nodes = Vec::new();
300        for node_reader in nodes_reader.iter() {
301            nodes.push(Node::read_capnp(node_reader));
302        }
303
304        // Read edges
305        let edges_reader = reader.get_edges().unwrap();
306        let mut edges = Vec::new();
307        for edge_reader in edges_reader.iter() {
308            edges.push(Edge::read_capnp(edge_reader));
309        }
310
311        // Read paths
312        let paths_reader = reader.get_paths().unwrap();
313        let mut paths = Vec::new();
314        for path_reader in paths_reader.iter() {
315            paths.push(Path::read_capnp(path_reader));
316        }
317
318        // Read accessions
319        let accessions_reader = reader.get_accessions().unwrap();
320        let mut accessions = Vec::new();
321        for accession_reader in accessions_reader.iter() {
322            accessions.push(Accession::read_capnp(accession_reader));
323        }
324
325        // Read accession edges
326        let accession_edges_reader = reader.get_accession_edges().unwrap();
327        let mut accession_edges = Vec::new();
328        for accession_edge_reader in accession_edges_reader.iter() {
329            accession_edges.push(AccessionEdge::read_capnp(accession_edge_reader));
330        }
331
332        DependencyModels {
333            collections,
334            samples,
335            sequences,
336            block_group,
337            nodes,
338            edges,
339            paths,
340            accessions,
341            accession_edges,
342        }
343    }
344}
345
346#[cfg(test)]
347mod tests {
348    use capnp::message::TypedBuilder;
349    use gen_core::Strand;
350
351    use super::*;
352    use crate::sequence::NewSequence;
353
354    #[test]
355    fn test_dependency_models_capnp_serialization() {
356        let dependency_models = DependencyModels {
357            collections: vec![Collection {
358                name: "test_collection".to_string(),
359            }],
360            samples: vec![Sample {
361                name: "test_sample".to_string(),
362            }],
363            sequences: vec![
364                NewSequence::new()
365                    .sequence_type("DNA")
366                    .sequence("ATCG")
367                    .name("test_seq")
368                    .build(),
369            ],
370            block_group: vec![BlockGroup {
371                id: HashId::pad_str(1),
372                collection_name: "test_collection".to_string(),
373                sample_name: Some("test_sample".to_string()),
374                name: "test_bg".to_string(),
375                created_on: 0,
376            }],
377            nodes: vec![Node {
378                id: HashId::convert_str("node_hash"),
379                sequence_hash: HashId::convert_str("test_hash"),
380            }],
381            edges: vec![Edge {
382                id: HashId::pad_str(1),
383                source_node_id: HashId::convert_str("1"),
384                source_coordinate: 0,
385                source_strand: Strand::Forward,
386                target_node_id: HashId::convert_str("2"),
387                target_coordinate: 0,
388                target_strand: Strand::Forward,
389            }],
390            paths: vec![Path {
391                id: HashId::pad_str(1),
392                block_group_id: HashId::pad_str(1),
393                name: "test_path".to_string(),
394                created_on: 0,
395            }],
396            accessions: vec![Accession {
397                id: HashId::pad_str(1),
398                name: "test_accession".to_string(),
399                path_id: HashId::pad_str(1),
400                parent_accession_id: None,
401            }],
402            accession_edges: vec![AccessionEdge {
403                id: HashId::pad_str(1),
404                source_node_id: HashId::convert_str("1"),
405                source_coordinate: 0,
406                source_strand: Strand::Forward,
407                target_node_id: HashId::convert_str("2"),
408                target_coordinate: 0,
409                target_strand: Strand::Forward,
410                chromosome_index: 0,
411            }],
412        };
413
414        let mut message = TypedBuilder::<dependency_models::Owned>::new_default();
415        let mut root = message.init_root();
416        dependency_models.write_capnp(&mut root);
417
418        let deserialized = DependencyModels::read_capnp(root.into_reader());
419
420        assert_eq!(dependency_models, deserialized);
421    }
422}