Skip to main content

gen_models/
annotations.rs

1use std::{
2    collections::HashMap,
3    fs,
4    path::{Path, PathBuf},
5    rc::Rc,
6};
7
8use anyhow::anyhow;
9use gen_core::{HashId, calculate_hash, config::Workspace, region::Region, traits::Capnp};
10use rusqlite::{Row, params, types::Value};
11use serde::{Deserialize, Serialize};
12use thiserror::Error;
13
14use crate::{
15    block_group::{BlockGroup, PathCache},
16    changesets::{ChangesetModels, DatabaseChangeset, write_changeset},
17    db::{DbContext, GraphConnection, OperationsConnection},
18    errors::{FileAdditionError, OperationError},
19    file_types::FileTypes,
20    files::GenDatabase,
21    gen_models_capnp::{annotation, annotation_group, annotation_group_sample},
22    metadata,
23    operations::{FileAddition, Operation, OperationInfo, OperationSummary},
24    sample::Sample,
25    session_operations::{DependencyModels, end_operation, start_operation},
26    traits::Query,
27};
28
29#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
30pub struct AnnotationGroup {
31    pub name: String,
32}
33
34impl Query for AnnotationGroup {
35    type Model = AnnotationGroup;
36
37    const PRIMARY_KEY: &'static str = "name";
38    const TABLE_NAME: &'static str = "annotation_groups";
39
40    fn process_row(row: &Row) -> Self::Model {
41        AnnotationGroup {
42            name: row.get(0).unwrap(),
43        }
44    }
45}
46
47impl AnnotationGroup {
48    pub fn create(conn: &GraphConnection, name: &str) -> rusqlite::Result<AnnotationGroup> {
49        let mut stmt = conn
50            .prepare("INSERT INTO annotation_groups (name) VALUES (?1) returning (name);")
51            .unwrap();
52        stmt.query_row((name,), |row| Ok(AnnotationGroup { name: row.get(0)? }))
53    }
54
55    pub fn get_or_create(
56        conn: &GraphConnection,
57        name: &str,
58    ) -> Result<AnnotationGroup, AnnotationGroupError> {
59        match AnnotationGroup::create(conn, name) {
60            Ok(group) => Ok(group),
61            Err(rusqlite::Error::SqliteFailure(err, _details))
62                if err.code == rusqlite::ErrorCode::ConstraintViolation =>
63            {
64                AnnotationGroup::get_by_id(conn, &name.to_string())
65                    .ok_or_else(|| rusqlite::Error::QueryReturnedNoRows.into())
66            }
67            Err(err) => Err(err.into()),
68        }
69    }
70
71    pub fn query_by_sample(conn: &GraphConnection, sample_name: &str) -> Vec<AnnotationGroup> {
72        let query = "\
73            select ag.* \
74            from annotation_groups ag \
75            join annotation_group_samples s \
76                on ag.name = s.annotation_group \
77            where s.sample_name = ?1 \
78            order by ag.name;";
79        AnnotationGroup::query(conn, query, params![sample_name])
80    }
81}
82
83impl<'a> Capnp<'a> for AnnotationGroup {
84    type Builder = annotation_group::Builder<'a>;
85    type Reader = annotation_group::Reader<'a>;
86
87    fn write_capnp(&self, builder: &mut Self::Builder) {
88        builder.set_name(&self.name);
89    }
90
91    fn read_capnp(reader: Self::Reader) -> Self {
92        AnnotationGroup {
93            name: reader.get_name().unwrap().to_string().unwrap(),
94        }
95    }
96}
97
98#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
99pub struct Annotation {
100    pub id: HashId,
101    pub name: String,
102    pub group: String,
103    pub accession_id: HashId,
104}
105
106impl<'a> Capnp<'a> for Annotation {
107    type Builder = annotation::Builder<'a>;
108    type Reader = annotation::Reader<'a>;
109
110    fn write_capnp(&self, builder: &mut Self::Builder) {
111        builder.set_id(&self.id.0).unwrap();
112        builder.set_name(&self.name);
113        builder.set_annotation_group(&self.group);
114        builder.set_accession_id(&self.accession_id.0).unwrap();
115    }
116
117    fn read_capnp(reader: Self::Reader) -> Self {
118        let id = reader
119            .get_id()
120            .unwrap()
121            .as_slice()
122            .unwrap()
123            .try_into()
124            .unwrap();
125        let name = reader.get_name().unwrap().to_string().unwrap();
126        let group = reader.get_annotation_group().unwrap().to_string().unwrap();
127        let accession_id = reader
128            .get_accession_id()
129            .unwrap()
130            .as_slice()
131            .unwrap()
132            .try_into()
133            .unwrap();
134
135        Annotation {
136            id,
137            name,
138            group,
139            accession_id,
140        }
141    }
142}
143
144impl Query for Annotation {
145    type Model = Annotation;
146
147    const TABLE_NAME: &'static str = "annotations";
148
149    fn process_row(row: &Row) -> Self::Model {
150        Annotation {
151            id: row.get(0).unwrap(),
152            name: row.get(1).unwrap(),
153            group: row.get(2).unwrap(),
154            accession_id: row.get(3).unwrap(),
155        }
156    }
157}
158
159#[derive(Debug, Error)]
160pub enum AnnotationError {
161    #[error("Database error: {0}")]
162    DatabaseError(#[from] rusqlite::Error),
163    #[error("Annotation group error: {0}")]
164    AnnotationGroupError(#[from] AnnotationGroupError),
165}
166
167impl Annotation {
168    pub fn generate_id(name: &str, group: &str, accession_id: &HashId) -> HashId {
169        HashId(calculate_hash(&format!("{name}:{group}:{accession_id}",)))
170    }
171
172    pub fn create(
173        conn: &GraphConnection,
174        name: &str,
175        group: &str,
176        accession_id: &HashId,
177    ) -> Result<Annotation, AnnotationError> {
178        let id = Annotation::generate_id(name, group, accession_id);
179        let query = "INSERT INTO annotations (id, name, annotation_group, accession_id) VALUES (?1, ?2, ?3, ?4);";
180        let mut stmt = conn.prepare(query)?;
181        stmt.execute(params![id, name, group, accession_id])?;
182        Ok(Annotation {
183            id,
184            name: name.to_string(),
185            group: group.to_string(),
186            accession_id: *accession_id,
187        })
188    }
189
190    pub fn get_or_create(
191        conn: &GraphConnection,
192        name: &str,
193        group: &str,
194        accession_id: &HashId,
195    ) -> Result<Annotation, AnnotationError> {
196        AnnotationGroup::get_or_create(conn, group)?;
197        match Annotation::create(conn, name, group, accession_id) {
198            Ok(annotation) => Ok(annotation),
199            Err(AnnotationError::DatabaseError(rusqlite::Error::SqliteFailure(err, _details)))
200                if err.code == rusqlite::ErrorCode::ConstraintViolation =>
201            {
202                let id = Annotation::generate_id(name, group, accession_id);
203                Ok(Annotation {
204                    id,
205                    name: name.to_string(),
206                    group: group.to_string(),
207                    accession_id: *accession_id,
208                })
209            }
210            Err(err) => Err(err),
211        }
212    }
213
214    pub fn create_with_samples(
215        conn: &GraphConnection,
216        name: &str,
217        group: &str,
218        accession_id: &HashId,
219        sample_names: &[&str],
220    ) -> Result<Annotation, AnnotationError> {
221        let annotation = Annotation::get_or_create(conn, name, group, accession_id)?;
222        annotation.add_samples(conn, sample_names)?;
223        Ok(annotation)
224    }
225
226    pub fn add_samples(
227        &self,
228        conn: &GraphConnection,
229        sample_names: &[&str],
230    ) -> Result<(), AnnotationError> {
231        if sample_names.is_empty() {
232            return Ok(());
233        }
234        AnnotationGroup::get_or_create(conn, &self.group)?;
235        let query = "INSERT OR IGNORE INTO annotation_group_samples (annotation_group, sample_name) VALUES (?1, ?2);";
236        let mut stmt = conn.prepare(query)?;
237        for sample_name in sample_names {
238            stmt.execute(params![self.group, sample_name])?;
239        }
240        Ok(())
241    }
242
243    pub fn get_samples(
244        conn: &GraphConnection,
245        annotation_group: &str,
246    ) -> Result<Vec<String>, AnnotationError> {
247        let query = "SELECT sample_name FROM annotation_group_samples WHERE annotation_group = ?1;";
248        let mut stmt = conn.prepare(query)?;
249        let rows = stmt.query_map(params![annotation_group], |row| row.get(0))?;
250        let mut samples = Vec::new();
251        for row in rows {
252            samples.push(row?);
253        }
254        Ok(samples)
255    }
256
257    pub fn query_by_sample(
258        conn: &GraphConnection,
259        sample_name: &str,
260    ) -> Result<Vec<Annotation>, AnnotationError> {
261        let query = "select a.* from annotations a left join annotation_group_samples s on (a.annotation_group = s.annotation_group) where s.sample_name = ?1";
262        Ok(Annotation::query(conn, query, params![sample_name]))
263    }
264
265    pub fn query_by_group(
266        conn: &GraphConnection,
267        group: &str,
268    ) -> Result<Vec<Annotation>, AnnotationError> {
269        let query = "select * from annotations where annotation_group = ?1";
270        Ok(Annotation::query(conn, query, params![group]))
271    }
272}
273
274#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
275pub struct AnnotationGroupSample {
276    pub annotation_group: String,
277    pub sample_name: String,
278}
279
280impl<'a> Capnp<'a> for AnnotationGroupSample {
281    type Builder = annotation_group_sample::Builder<'a>;
282    type Reader = annotation_group_sample::Reader<'a>;
283
284    fn write_capnp(&self, builder: &mut Self::Builder) {
285        builder.set_annotation_group(&self.annotation_group);
286        builder.set_sample_name(&self.sample_name);
287    }
288
289    fn read_capnp(reader: Self::Reader) -> Self {
290        let annotation_group = reader.get_annotation_group().unwrap().to_string().unwrap();
291        let sample_name = reader.get_sample_name().unwrap().to_string().unwrap();
292        AnnotationGroupSample {
293            annotation_group,
294            sample_name,
295        }
296    }
297}
298
299impl AnnotationGroupSample {
300    pub fn create(
301        conn: &GraphConnection,
302        annotation_group: &str,
303        sample_name: &str,
304    ) -> Result<(), AnnotationError> {
305        AnnotationGroup::get_or_create(conn, annotation_group)?;
306        let query = "INSERT OR IGNORE INTO annotation_group_samples (annotation_group, sample_name) VALUES (?1, ?2);";
307        let mut stmt = conn.prepare(query)?;
308        stmt.execute(params![annotation_group, sample_name])?;
309        Ok(())
310    }
311
312    pub fn delete(
313        conn: &GraphConnection,
314        annotation_group: &str,
315        sample_name: &str,
316    ) -> Result<(), AnnotationError> {
317        let query = "DELETE FROM annotation_group_samples WHERE annotation_group = ?1 AND sample_name = ?2;";
318        let mut stmt = conn.prepare(query)?;
319        stmt.execute(params![annotation_group, sample_name])?;
320        Ok(())
321    }
322}
323
324#[derive(Debug, Error)]
325pub enum AnnotationGroupError {
326    #[error("Database error: {0}")]
327    DatabaseError(#[from] rusqlite::Error),
328}
329
330#[derive(Clone, Debug, Eq, Hash, PartialEq, Deserialize, Serialize)]
331pub struct AnnotationFileInfo {
332    pub file_addition: FileAddition,
333    pub index_file_addition: Option<FileAddition>,
334    pub name: Option<String>,
335}
336
337#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
338pub struct AnnotationFileAdditionInput {
339    pub file_path: String,
340    pub file_type: FileTypes,
341    pub checksum_override: Option<HashId>,
342    pub name: Option<String>,
343    pub index_file_path: Option<String>,
344}
345
346#[derive(Debug, Error)]
347pub enum AnnotationFileError {
348    #[error("Database error: {0}")]
349    DatabaseError(#[from] rusqlite::Error),
350    #[error("File addition error: {0}")]
351    FileAdditionError(#[from] FileAdditionError),
352    #[error("Index file must be Tabix, got: {0:?}")]
353    InvalidIndexFileType(FileTypes),
354    #[error("Unsupported annotation file type: {0}")]
355    UnsupportedFileType(String),
356}
357
358pub fn parse_annotation_file_type(value: &str) -> Result<FileTypes, AnnotationFileError> {
359    match value.trim().to_ascii_lowercase().as_str() {
360        "gff3" | "gff" => Ok(FileTypes::Gff3),
361        "bed" => Ok(FileTypes::Bed),
362        "genbank" | "gb" => Ok(FileTypes::GenBank),
363        other => Err(AnnotationFileError::UnsupportedFileType(other.to_string())),
364    }
365}
366
367pub fn annotation_file_extension(path: &str) -> Option<String> {
368    let path = Path::new(path);
369    let mut ext = path
370        .extension()
371        .and_then(|ext| ext.to_str())
372        .map(|ext| ext.to_ascii_lowercase());
373    if matches!(ext.as_deref(), Some("gz") | Some("bgz")) {
374        ext = path
375            .file_stem()
376            .and_then(|stem| stem.to_str())
377            .and_then(|stem| Path::new(stem).extension().and_then(|ext| ext.to_str()))
378            .map(|ext| ext.to_ascii_lowercase());
379    }
380    ext
381}
382
383pub fn annotation_index_file_path(
384    workspace: &Workspace,
385    path: &str,
386    explicit_index_path: Option<&str>,
387) -> Option<String> {
388    if let Some(index_path) = explicit_index_path {
389        return Some(index_path.to_string());
390    }
391
392    let mut candidates = vec![format!("{path}.tbi")];
393    let path_buf = PathBuf::from(path);
394    if let Some(extension) = path_buf.extension().and_then(|ext| ext.to_str()) {
395        let mut extension_candidate = path_buf.clone();
396        extension_candidate.set_extension(format!("{extension}.tbi"));
397        let extension_candidate = extension_candidate.to_string_lossy().to_string();
398        if !candidates
399            .iter()
400            .any(|candidate| candidate == &extension_candidate)
401        {
402            candidates.push(extension_candidate);
403        }
404    }
405
406    for candidate in candidates {
407        let exists = if Path::new(&candidate).is_absolute() {
408            Path::new(&candidate).exists()
409        } else {
410            workspace
411                .repo_root()
412                .ok()
413                .is_some_and(|repo_root| repo_root.join(&candidate).exists())
414        };
415        if exists {
416            return Some(candidate);
417        }
418    }
419
420    None
421}
422
423pub fn add_annotation(
424    context: &DbContext,
425    collection: &str,
426    name: &str,
427    group: Option<&str>,
428    sample: &str,
429    region: &str,
430) -> Result<Operation, Box<dyn std::error::Error>> {
431    let graph_conn = context.graph().conn();
432    let operation_conn = context.operations().conn();
433    let parsed_region = Region::parse(region)?;
434    let start = parsed_region.start;
435    let end = parsed_region.end;
436
437    let block_groups = Sample::get_block_groups(graph_conn, collection, sample);
438    let block_group = block_groups
439        .iter()
440        .find(|bg| bg.name == parsed_region.name)
441        .ok_or_else(|| anyhow!("Graph {} not found for sample {sample}", parsed_region.name))?;
442    let path = BlockGroup::get_current_path(graph_conn, &block_group.id);
443    let path_length = path.length(graph_conn);
444    if start < 0 || end < 0 || start > end || end > path_length {
445        return Err(anyhow!("Region {region} is outside the path bounds (0-{path_length})").into());
446    }
447
448    let mut session = start_operation(graph_conn);
449    graph_conn.execute("BEGIN TRANSACTION", [])?;
450    operation_conn.execute("BEGIN TRANSACTION", [])?;
451
452    let mut cache = PathCache::new(graph_conn);
453    let _ = PathCache::lookup(&mut cache, &block_group.id, path.name.clone());
454    let accession = BlockGroup::add_accession(graph_conn, &path, name, start, end, &mut cache);
455
456    let annotation_group = group.unwrap_or("default");
457    let annotation = Annotation::get_or_create(graph_conn, name, annotation_group, &accession.id)?;
458    AnnotationGroupSample::create(graph_conn, &annotation.group, sample)?;
459
460    let operation = end_operation(
461        context,
462        &mut session,
463        &OperationInfo {
464            files: vec![],
465            description: format!("add annotation {name}"),
466        },
467        &format!("add annotation {name}"),
468        None,
469    )?;
470
471    graph_conn.execute("END TRANSACTION", [])?;
472    operation_conn.execute("END TRANSACTION", [])?;
473
474    Ok(operation)
475}
476
477pub fn add_annotation_file(
478    context: &DbContext,
479    path: &str,
480    format: Option<&str>,
481    index: Option<&str>,
482    name: Option<&str>,
483    message: Option<&str>,
484) -> Result<Operation, Box<dyn std::error::Error>> {
485    let workspace = context.workspace();
486    let operation_conn = context.operations().conn();
487    let graph_conn = context.graph().conn();
488    let db_uuid = metadata::get_db_uuid(graph_conn);
489
490    let file_type = match format {
491        Some(format) => parse_annotation_file_type(format)?,
492        None => {
493            let ext = annotation_file_extension(path).ok_or_else(|| {
494                anyhow!(
495                    "Unable to detect annotation file format from the file extension. Use --format to specify it explicitly."
496                )
497            })?;
498            parse_annotation_file_type(&ext)?
499        }
500    };
501    let file_addition =
502        FileAddition::get_or_create(workspace, operation_conn, path, file_type, None)?;
503    let index_file_addition = annotation_index_file_path(workspace, path, index)
504        .map(|index_path| {
505            FileAddition::get_or_create(
506                workspace,
507                operation_conn,
508                &index_path,
509                FileTypes::Tabix,
510                None,
511            )
512        })
513        .transpose()?;
514    let name_value = name.unwrap_or_default();
515    let index_file_addition_id = index_file_addition
516        .as_ref()
517        .map(|index_file| index_file.id.to_string())
518        .unwrap_or_default();
519    let operation_hash = HashId(calculate_hash(&format!(
520        "{file_addition_id}:{name_value}:{index_file_addition_id}",
521        file_addition_id = file_addition.id
522    )));
523    let operation = match Operation::create(operation_conn, "annotation-file", &operation_hash) {
524        Ok(operation) => operation,
525        Err(rusqlite::Error::SqliteFailure(err, _details))
526            if err.code == rusqlite::ErrorCode::ConstraintViolation =>
527        {
528            return Err(OperationError::NoChanges.into());
529        }
530        Err(err) => return Err(err.into()),
531    };
532    AnnotationFile::link_to_operation(
533        operation_conn,
534        &operation.hash,
535        &file_addition.id,
536        index_file_addition
537            .as_ref()
538            .map(|index_file| &index_file.id),
539        name,
540    )?;
541    Operation::add_database(operation_conn, &operation.hash, &db_uuid)?;
542    let summary = message
543        .map(str::to_string)
544        .unwrap_or_else(|| format!("Add annotation file {path}"));
545    OperationSummary::create(operation_conn, &operation.hash, &summary);
546
547    let gen_db = GenDatabase::get_by_uuid(operation_conn, &db_uuid)?;
548    write_changeset(
549        workspace,
550        &operation,
551        DatabaseChangeset {
552            db_path: gen_db.path,
553            changes: ChangesetModels::default(),
554        },
555        &DependencyModels::default(),
556    );
557
558    if file_type != FileTypes::Changeset && file_type != FileTypes::None {
559        let gen_dir = workspace
560            .find_gen_dir()
561            .ok_or_else(|| anyhow!("No .gen directory found. Please run 'gen init' first."))?;
562        let assets_dir = gen_dir.join("assets");
563        fs::create_dir_all(&assets_dir)?;
564        let asset_path = assets_dir.join(file_addition.hashed_filename());
565        if !asset_path.exists() {
566            let source_path = if Path::new(path).is_absolute() {
567                PathBuf::from(path)
568            } else {
569                workspace.repo_root()?.join(path)
570            };
571            fs::copy(source_path, asset_path)?;
572        }
573        if let Some(index_file_addition) = index_file_addition {
574            let index_asset_path = assets_dir.join(index_file_addition.clone().hashed_filename());
575            if !index_asset_path.exists() {
576                let index_source_path = if Path::new(&index_file_addition.file_path).is_absolute() {
577                    PathBuf::from(&index_file_addition.file_path)
578                } else {
579                    workspace.repo_root()?.join(&index_file_addition.file_path)
580                };
581                fs::copy(index_source_path, index_asset_path)?;
582            }
583        }
584    }
585
586    Ok(operation)
587}
588
589#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
590pub struct AnnotationFile {
591    pub id: i64,
592    pub operation_hash: HashId,
593    pub file_addition_id: HashId,
594    pub index_file_addition_id: Option<HashId>,
595    pub name: Option<String>,
596}
597
598impl Query for AnnotationFile {
599    type Model = AnnotationFile;
600
601    const TABLE_NAME: &'static str = "annotation_files";
602
603    fn process_row(row: &Row) -> Self::Model {
604        AnnotationFile {
605            id: row.get(0).unwrap(),
606            operation_hash: row.get(1).unwrap(),
607            file_addition_id: row.get(2).unwrap(),
608            index_file_addition_id: row.get(3).unwrap(),
609            name: row.get(4).unwrap(),
610        }
611    }
612}
613
614impl AnnotationFile {
615    pub fn load_index(
616        conn: &OperationsConnection,
617        file_addition_id: Option<&HashId>,
618    ) -> Result<Option<FileAddition>, AnnotationFileError> {
619        let Some(file_addition_id) = file_addition_id else {
620            return Ok(None);
621        };
622        let index_file_addition = FileAddition::get_by_id(conn, file_addition_id).ok_or(
623            AnnotationFileError::DatabaseError(rusqlite::Error::QueryReturnedNoRows),
624        )?;
625        if index_file_addition.file_type != FileTypes::Tabix {
626            return Err(AnnotationFileError::InvalidIndexFileType(
627                index_file_addition.file_type,
628            ));
629        }
630        Ok(Some(index_file_addition))
631    }
632
633    pub fn link_to_operation(
634        conn: &OperationsConnection,
635        operation_hash: &HashId,
636        file_addition_id: &HashId,
637        index_file_addition_id: Option<&HashId>,
638        name: Option<&str>,
639    ) -> Result<(), AnnotationFileError> {
640        AnnotationFile::load_index(conn, index_file_addition_id)?;
641        let query = "INSERT INTO annotation_files (operation_hash, file_addition_id, index_file_addition_id, name) VALUES (?1, ?2, ?3, ?4);";
642        let mut stmt = conn.prepare(query)?;
643        stmt.execute(params![
644            operation_hash,
645            file_addition_id,
646            index_file_addition_id,
647            name
648        ])?;
649        Ok(())
650    }
651
652    pub fn add_to_operation(
653        workspace: &Workspace,
654        conn: &OperationsConnection,
655        operation_hash: &HashId,
656        input: &AnnotationFileAdditionInput,
657    ) -> Result<FileAddition, AnnotationFileError> {
658        let file_addition = FileAddition::get_or_create(
659            workspace,
660            conn,
661            &input.file_path,
662            input.file_type,
663            input.checksum_override,
664        )?;
665        let index_file_addition = input
666            .index_file_path
667            .as_deref()
668            .map(|path| FileAddition::get_or_create(workspace, conn, path, FileTypes::Tabix, None))
669            .transpose()?;
670        AnnotationFile::link_to_operation(
671            conn,
672            operation_hash,
673            &file_addition.id,
674            index_file_addition.as_ref().map(|index| &index.id),
675            input.name.as_deref(),
676        )?;
677        Ok(file_addition)
678    }
679
680    pub fn get_files_for_operation(
681        conn: &OperationsConnection,
682        operation_hash: &HashId,
683    ) -> Vec<AnnotationFileInfo> {
684        let query = "select fa.*, af.index_file_addition_id, af.name from file_additions fa join annotation_files af on (fa.id = af.file_addition_id) where af.operation_hash = ?1";
685        let mut stmt = conn.prepare(query).unwrap();
686        let rows = stmt
687            .query_map(params![operation_hash], |row| {
688                Ok((
689                    FileAddition::process_row(row),
690                    row.get::<_, Option<HashId>>(4)?,
691                    row.get::<_, Option<String>>(5)?,
692                ))
693            })
694            .unwrap();
695        rows.map(|row| {
696            let (file_addition, index_file_addition_id, name) = row.unwrap();
697            AnnotationFileInfo {
698                file_addition,
699                index_file_addition: AnnotationFile::load_index(
700                    conn,
701                    index_file_addition_id.as_ref(),
702                )
703                .unwrap(),
704                name,
705            }
706        })
707        .collect()
708    }
709
710    pub fn get_all_files(conn: &OperationsConnection) -> Vec<AnnotationFileInfo> {
711        let query = "select fa.*, af.index_file_addition_id, af.name from file_additions fa join annotation_files af on (fa.id = af.file_addition_id)";
712        let mut stmt = conn.prepare(query).unwrap();
713        let rows = stmt
714            .query_map([], |row| {
715                Ok((
716                    FileAddition::process_row(row),
717                    row.get::<_, Option<HashId>>(4)?,
718                    row.get::<_, Option<String>>(5)?,
719                ))
720            })
721            .unwrap();
722        let mut entries: Vec<AnnotationFileInfo> = rows
723            .map(|row| {
724                let (file_addition, index_file_addition_id, name) = row.unwrap();
725                AnnotationFileInfo {
726                    file_addition,
727                    index_file_addition: AnnotationFile::load_index(
728                        conn,
729                        index_file_addition_id.as_ref(),
730                    )
731                    .unwrap(),
732                    name,
733                }
734            })
735            .collect();
736        entries.sort_by(|a, b| {
737            let a_name = std::path::Path::new(&a.file_addition.file_path)
738                .file_name()
739                .map(|name| name.to_string_lossy().to_string())
740                .unwrap_or_else(|| a.file_addition.file_path.clone());
741            let b_name = std::path::Path::new(&b.file_addition.file_path)
742                .file_name()
743                .map(|name| name.to_string_lossy().to_string())
744                .unwrap_or_else(|| b.file_addition.file_path.clone());
745            a_name
746                .cmp(&b_name)
747                .then_with(|| a.file_addition.file_path.cmp(&b.file_addition.file_path))
748        });
749        entries
750    }
751
752    pub fn query_by_operations(
753        conn: &OperationsConnection,
754        operations: &[HashId],
755    ) -> Result<HashMap<HashId, Vec<FileAddition>>, AnnotationFileError> {
756        let query = "select fa.*, af.operation_hash from file_additions fa left join annotation_files af on (fa.id = af.file_addition_id) where af.operation_hash in rarray(?1)";
757        let mut stmt = conn.prepare(query)?;
758        let rows = stmt.query_map(
759            params![Rc::new(
760                operations
761                    .iter()
762                    .map(|h| Value::from(*h))
763                    .collect::<Vec<Value>>()
764            )],
765            |row| Ok((FileAddition::process_row(row), row.get::<_, HashId>(4)?)),
766        )?;
767        rows.into_iter()
768            .try_fold(HashMap::new(), |mut acc: HashMap<_, Vec<_>>, row| {
769                let (item, hash) = row?;
770                acc.entry(hash).or_default().push(item);
771                Ok(acc)
772            })
773            .map_err(AnnotationFileError::DatabaseError)
774    }
775}
776
777#[cfg(test)]
778mod tests {
779    use std::fs;
780
781    use gen_core::HashId;
782
783    use super::*;
784    use crate::{
785        block_group::{BlockGroup, PathCache},
786        errors::OperationError,
787        files::GenDatabase,
788        metadata,
789        sample::Sample,
790        test_helpers::{get_connection, setup_block_group, setup_gen},
791    };
792
793    #[test]
794    fn create_annotation_with_samples() {
795        let conn = get_connection(None).unwrap();
796        let (block_group_id, path) = setup_block_group(&conn);
797
798        let _ = Sample::create(&conn, "sample-1").unwrap();
799        let _ = Sample::create(&conn, "sample-2").unwrap();
800
801        let mut cache = PathCache::new(&conn);
802        let _ = PathCache::lookup(&mut cache, &block_group_id, path.name.clone());
803        let accession = BlockGroup::add_accession(&conn, &path, "ann-accession", 0, 5, &mut cache);
804
805        let annotation =
806            Annotation::get_or_create(&conn, "gene-a", "project-tracks", &accession.id).unwrap();
807        annotation
808            .add_samples(&conn, &["sample-1", "sample-2"])
809            .unwrap();
810
811        let mut samples = Annotation::get_samples(&conn, &annotation.group).unwrap();
812        samples.sort();
813        assert_eq!(
814            samples,
815            vec!["sample-1".to_string(), "sample-2".to_string()]
816        );
817
818        let by_sample = Annotation::query_by_sample(&conn, "sample-1").unwrap();
819        assert_eq!(by_sample.len(), 1);
820        assert_eq!(by_sample[0], annotation);
821
822        let by_group = Annotation::query_by_group(&conn, "project-tracks").unwrap();
823        assert_eq!(by_group, vec![annotation]);
824    }
825
826    #[test]
827    fn add_annotation_file_to_operation() {
828        let context = setup_gen();
829        let op_conn = context.operations().conn();
830        let workspace = context.workspace();
831        let repo_root = workspace.repo_root().unwrap();
832        let annotation_path = repo_root.join("fixtures").join("annotation.gff3");
833        fs::create_dir_all(annotation_path.parent().unwrap()).unwrap();
834        fs::write(&annotation_path, "##gff-version 3\n").unwrap();
835
836        let op_hash = HashId::random_str();
837        let _ = crate::operations::Operation::create(op_conn, "annotation-file", &op_hash)
838            .expect("should create operation");
839
840        let file_addition = AnnotationFile::add_to_operation(
841            workspace,
842            op_conn,
843            &op_hash,
844            &AnnotationFileAdditionInput {
845                file_path: annotation_path.to_string_lossy().to_string(),
846                file_type: FileTypes::Gff3,
847                checksum_override: None,
848                name: Some("fixtures-annotation".to_string()),
849                index_file_path: None,
850            },
851        )
852        .unwrap();
853
854        let files = AnnotationFile::get_files_for_operation(op_conn, &op_hash);
855        assert_eq!(files.len(), 1);
856        assert_eq!(files[0].file_addition, file_addition);
857        assert!(files[0].index_file_addition.is_none());
858    }
859
860    #[test]
861    fn parse_annotation_file_type_values() {
862        assert_eq!(parse_annotation_file_type("gff3").unwrap(), FileTypes::Gff3);
863        assert_eq!(parse_annotation_file_type("GFF").unwrap(), FileTypes::Gff3);
864        assert_eq!(parse_annotation_file_type("bed").unwrap(), FileTypes::Bed);
865        assert_eq!(
866            parse_annotation_file_type("GenBank").unwrap(),
867            FileTypes::GenBank
868        );
869        assert_eq!(
870            parse_annotation_file_type("gb").unwrap(),
871            FileTypes::GenBank
872        );
873        let err = parse_annotation_file_type("bam").unwrap_err();
874        assert!(matches!(err, AnnotationFileError::UnsupportedFileType(_)));
875    }
876
877    #[test]
878    fn add_annotation_creates_annotation() {
879        let context = setup_gen();
880        let graph_conn = context.graph().conn();
881        let operation_conn = context.operations().conn();
882        let db_uuid = metadata::get_db_uuid(graph_conn);
883        let _ = GenDatabase::create(operation_conn, &db_uuid, "test-db", "test-db-path").unwrap();
884        let _ = setup_block_group(graph_conn);
885
886        let operation = add_annotation(
887            &context,
888            "test",
889            "gene-a",
890            Some("track-1"),
891            "test",
892            "chr1:1-5",
893        )
894        .unwrap();
895        assert_eq!(operation.change_type, "add annotation gene-a");
896
897        let annotations = Annotation::query_by_group(graph_conn, "track-1").unwrap();
898        assert_eq!(annotations.len(), 1);
899        assert_eq!(annotations[0].name, "gene-a");
900    }
901
902    #[test]
903    fn add_annotation_file_creates_operation() {
904        let context = setup_gen();
905        let graph_conn = context.graph().conn();
906        let operation_conn = context.operations().conn();
907        let db_uuid = metadata::get_db_uuid(graph_conn);
908        let _ = GenDatabase::create(operation_conn, &db_uuid, "test-db", "test-db-path").unwrap();
909
910        let repo_root = context.workspace().repo_root().unwrap();
911        let annotation_path = repo_root.join("fixtures").join("annotation.gff3");
912        fs::create_dir_all(annotation_path.parent().unwrap()).unwrap();
913        fs::write(&annotation_path, "##gff-version 3\n").unwrap();
914        let annotation_path_str = annotation_path.to_string_lossy().to_string();
915
916        let operation = add_annotation_file(
917            &context,
918            &annotation_path_str,
919            None,
920            None,
921            Some("track-1"),
922            None,
923        )
924        .unwrap();
925        assert_eq!(operation.change_type, "annotation-file");
926
927        let files = AnnotationFile::get_files_for_operation(operation_conn, &operation.hash);
928        assert_eq!(files.len(), 1);
929        assert_eq!(files[0].name.as_deref(), Some("track-1"));
930
931        let err = add_annotation_file(
932            &context,
933            &annotation_path_str,
934            None,
935            None,
936            Some("track-1"),
937            None,
938        )
939        .unwrap_err();
940        let op_err = err
941            .downcast_ref::<OperationError>()
942            .expect("should be an OperationError");
943        assert_eq!(*op_err, OperationError::NoChanges);
944    }
945}