Skip to main content

gen_models/
annotations.rs

1use std::{
2    collections::HashMap,
3    fs,
4    path::{Path, PathBuf},
5    rc::Rc,
6};
7
8use anyhow::anyhow;
9use gen_core::{HashId, calculate_hash, config::Workspace, traits::Capnp};
10use noodles::core::Region;
11use rusqlite::{Row, params, types::Value};
12use serde::{Deserialize, Serialize};
13use thiserror::Error;
14
15use crate::{
16    block_group::{BlockGroup, PathCache},
17    changesets::{ChangesetModels, DatabaseChangeset, write_changeset},
18    db::{DbContext, GraphConnection, OperationsConnection},
19    errors::{FileAdditionError, OperationError},
20    file_types::FileTypes,
21    files::GenDatabase,
22    gen_models_capnp::{annotation, annotation_group, annotation_group_sample},
23    metadata,
24    operations::{FileAddition, Operation, OperationInfo, OperationSummary},
25    sample::Sample,
26    session_operations::{DependencyModels, end_operation, start_operation},
27    traits::Query,
28};
29
30#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
31pub struct AnnotationGroup {
32    pub name: String,
33}
34
35impl Query for AnnotationGroup {
36    type Model = AnnotationGroup;
37
38    const PRIMARY_KEY: &'static str = "name";
39    const TABLE_NAME: &'static str = "annotation_groups";
40
41    fn process_row(row: &Row) -> Self::Model {
42        AnnotationGroup {
43            name: row.get(0).unwrap(),
44        }
45    }
46}
47
48impl AnnotationGroup {
49    pub fn create(conn: &GraphConnection, name: &str) -> rusqlite::Result<AnnotationGroup> {
50        let mut stmt = conn
51            .prepare("INSERT INTO annotation_groups (name) VALUES (?1) returning (name);")
52            .unwrap();
53        stmt.query_row((name,), |row| Ok(AnnotationGroup { name: row.get(0)? }))
54    }
55
56    pub fn get_or_create(
57        conn: &GraphConnection,
58        name: &str,
59    ) -> Result<AnnotationGroup, AnnotationGroupError> {
60        match AnnotationGroup::create(conn, name) {
61            Ok(group) => Ok(group),
62            Err(rusqlite::Error::SqliteFailure(err, _details))
63                if err.code == rusqlite::ErrorCode::ConstraintViolation =>
64            {
65                AnnotationGroup::get_by_id(conn, &name.to_string())
66                    .ok_or_else(|| rusqlite::Error::QueryReturnedNoRows.into())
67            }
68            Err(err) => Err(err.into()),
69        }
70    }
71
72    pub fn query_by_sample(conn: &GraphConnection, sample_name: &str) -> Vec<AnnotationGroup> {
73        let query = "\
74            select ag.* \
75            from annotation_groups ag \
76            join annotation_group_samples s \
77                on ag.name = s.annotation_group \
78            where s.sample_name = ?1 \
79            order by ag.name;";
80        AnnotationGroup::query(conn, query, params![sample_name])
81    }
82}
83
84impl<'a> Capnp<'a> for AnnotationGroup {
85    type Builder = annotation_group::Builder<'a>;
86    type Reader = annotation_group::Reader<'a>;
87
88    fn write_capnp(&self, builder: &mut Self::Builder) {
89        builder.set_name(&self.name);
90    }
91
92    fn read_capnp(reader: Self::Reader) -> Self {
93        AnnotationGroup {
94            name: reader.get_name().unwrap().to_string().unwrap(),
95        }
96    }
97}
98
99#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
100pub struct Annotation {
101    pub id: HashId,
102    pub name: String,
103    pub group: String,
104    pub accession_id: HashId,
105}
106
107impl<'a> Capnp<'a> for Annotation {
108    type Builder = annotation::Builder<'a>;
109    type Reader = annotation::Reader<'a>;
110
111    fn write_capnp(&self, builder: &mut Self::Builder) {
112        builder.set_id(&self.id.0).unwrap();
113        builder.set_name(&self.name);
114        builder.set_annotation_group(&self.group);
115        builder.set_accession_id(&self.accession_id.0).unwrap();
116    }
117
118    fn read_capnp(reader: Self::Reader) -> Self {
119        let id = reader
120            .get_id()
121            .unwrap()
122            .as_slice()
123            .unwrap()
124            .try_into()
125            .unwrap();
126        let name = reader.get_name().unwrap().to_string().unwrap();
127        let group = reader.get_annotation_group().unwrap().to_string().unwrap();
128        let accession_id = reader
129            .get_accession_id()
130            .unwrap()
131            .as_slice()
132            .unwrap()
133            .try_into()
134            .unwrap();
135
136        Annotation {
137            id,
138            name,
139            group,
140            accession_id,
141        }
142    }
143}
144
145impl Query for Annotation {
146    type Model = Annotation;
147
148    const TABLE_NAME: &'static str = "annotations";
149
150    fn process_row(row: &Row) -> Self::Model {
151        Annotation {
152            id: row.get(0).unwrap(),
153            name: row.get(1).unwrap(),
154            group: row.get(2).unwrap(),
155            accession_id: row.get(3).unwrap(),
156        }
157    }
158}
159
160#[derive(Debug, Error)]
161pub enum AnnotationError {
162    #[error("Database error: {0}")]
163    DatabaseError(#[from] rusqlite::Error),
164    #[error("Annotation group error: {0}")]
165    AnnotationGroupError(#[from] AnnotationGroupError),
166}
167
168impl Annotation {
169    pub fn generate_id(name: &str, group: &str, accession_id: &HashId) -> HashId {
170        HashId(calculate_hash(&format!("{name}:{group}:{accession_id}",)))
171    }
172
173    pub fn create(
174        conn: &GraphConnection,
175        name: &str,
176        group: &str,
177        accession_id: &HashId,
178    ) -> Result<Annotation, AnnotationError> {
179        let id = Annotation::generate_id(name, group, accession_id);
180        let query = "INSERT INTO annotations (id, name, annotation_group, accession_id) VALUES (?1, ?2, ?3, ?4);";
181        let mut stmt = conn.prepare(query)?;
182        stmt.execute(params![id, name, group, accession_id])?;
183        Ok(Annotation {
184            id,
185            name: name.to_string(),
186            group: group.to_string(),
187            accession_id: *accession_id,
188        })
189    }
190
191    pub fn get_or_create(
192        conn: &GraphConnection,
193        name: &str,
194        group: &str,
195        accession_id: &HashId,
196    ) -> Result<Annotation, AnnotationError> {
197        AnnotationGroup::get_or_create(conn, group)?;
198        match Annotation::create(conn, name, group, accession_id) {
199            Ok(annotation) => Ok(annotation),
200            Err(AnnotationError::DatabaseError(rusqlite::Error::SqliteFailure(err, _details)))
201                if err.code == rusqlite::ErrorCode::ConstraintViolation =>
202            {
203                let id = Annotation::generate_id(name, group, accession_id);
204                Ok(Annotation {
205                    id,
206                    name: name.to_string(),
207                    group: group.to_string(),
208                    accession_id: *accession_id,
209                })
210            }
211            Err(err) => Err(err),
212        }
213    }
214
215    pub fn create_with_samples(
216        conn: &GraphConnection,
217        name: &str,
218        group: &str,
219        accession_id: &HashId,
220        sample_names: &[&str],
221    ) -> Result<Annotation, AnnotationError> {
222        let annotation = Annotation::get_or_create(conn, name, group, accession_id)?;
223        annotation.add_samples(conn, sample_names)?;
224        Ok(annotation)
225    }
226
227    pub fn add_samples(
228        &self,
229        conn: &GraphConnection,
230        sample_names: &[&str],
231    ) -> Result<(), AnnotationError> {
232        if sample_names.is_empty() {
233            return Ok(());
234        }
235        AnnotationGroup::get_or_create(conn, &self.group)?;
236        let query = "INSERT OR IGNORE INTO annotation_group_samples (annotation_group, sample_name) VALUES (?1, ?2);";
237        let mut stmt = conn.prepare(query)?;
238        for sample_name in sample_names {
239            stmt.execute(params![self.group, sample_name])?;
240        }
241        Ok(())
242    }
243
244    pub fn get_samples(
245        conn: &GraphConnection,
246        annotation_group: &str,
247    ) -> Result<Vec<String>, AnnotationError> {
248        let query = "SELECT sample_name FROM annotation_group_samples WHERE annotation_group = ?1;";
249        let mut stmt = conn.prepare(query)?;
250        let rows = stmt.query_map(params![annotation_group], |row| row.get(0))?;
251        let mut samples = Vec::new();
252        for row in rows {
253            samples.push(row?);
254        }
255        Ok(samples)
256    }
257
258    pub fn query_by_sample(
259        conn: &GraphConnection,
260        sample_name: &str,
261    ) -> Result<Vec<Annotation>, AnnotationError> {
262        let query = "select a.* from annotations a left join annotation_group_samples s on (a.annotation_group = s.annotation_group) where s.sample_name = ?1";
263        Ok(Annotation::query(conn, query, params![sample_name]))
264    }
265
266    pub fn query_by_group(
267        conn: &GraphConnection,
268        group: &str,
269    ) -> Result<Vec<Annotation>, AnnotationError> {
270        let query = "select * from annotations where annotation_group = ?1";
271        Ok(Annotation::query(conn, query, params![group]))
272    }
273}
274
275#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
276pub struct AnnotationGroupSample {
277    pub annotation_group: String,
278    pub sample_name: String,
279}
280
281impl<'a> Capnp<'a> for AnnotationGroupSample {
282    type Builder = annotation_group_sample::Builder<'a>;
283    type Reader = annotation_group_sample::Reader<'a>;
284
285    fn write_capnp(&self, builder: &mut Self::Builder) {
286        builder.set_annotation_group(&self.annotation_group);
287        builder.set_sample_name(&self.sample_name);
288    }
289
290    fn read_capnp(reader: Self::Reader) -> Self {
291        let annotation_group = reader.get_annotation_group().unwrap().to_string().unwrap();
292        let sample_name = reader.get_sample_name().unwrap().to_string().unwrap();
293        AnnotationGroupSample {
294            annotation_group,
295            sample_name,
296        }
297    }
298}
299
300impl AnnotationGroupSample {
301    pub fn create(
302        conn: &GraphConnection,
303        annotation_group: &str,
304        sample_name: &str,
305    ) -> Result<(), AnnotationError> {
306        AnnotationGroup::get_or_create(conn, annotation_group)?;
307        let query = "INSERT OR IGNORE INTO annotation_group_samples (annotation_group, sample_name) VALUES (?1, ?2);";
308        let mut stmt = conn.prepare(query)?;
309        stmt.execute(params![annotation_group, sample_name])?;
310        Ok(())
311    }
312
313    pub fn delete(
314        conn: &GraphConnection,
315        annotation_group: &str,
316        sample_name: &str,
317    ) -> Result<(), AnnotationError> {
318        let query = "DELETE FROM annotation_group_samples WHERE annotation_group = ?1 AND sample_name = ?2;";
319        let mut stmt = conn.prepare(query)?;
320        stmt.execute(params![annotation_group, sample_name])?;
321        Ok(())
322    }
323}
324
325#[derive(Debug, Error)]
326pub enum AnnotationGroupError {
327    #[error("Database error: {0}")]
328    DatabaseError(#[from] rusqlite::Error),
329}
330
331#[derive(Clone, Debug, Eq, Hash, PartialEq, Deserialize, Serialize)]
332pub struct AnnotationFileInfo {
333    pub file_addition: FileAddition,
334    pub index_file_addition: Option<FileAddition>,
335    pub name: Option<String>,
336}
337
338#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
339pub struct AnnotationFileAdditionInput {
340    pub file_path: String,
341    pub file_type: FileTypes,
342    pub checksum_override: Option<HashId>,
343    pub name: Option<String>,
344    pub index_file_path: Option<String>,
345}
346
347#[derive(Debug, Error)]
348pub enum AnnotationFileError {
349    #[error("Database error: {0}")]
350    DatabaseError(#[from] rusqlite::Error),
351    #[error("File addition error: {0}")]
352    FileAdditionError(#[from] FileAdditionError),
353    #[error("Index file must be Tabix, got: {0:?}")]
354    InvalidIndexFileType(FileTypes),
355    #[error("Unsupported annotation file type: {0}")]
356    UnsupportedFileType(String),
357}
358
359pub fn parse_annotation_file_type(value: &str) -> Result<FileTypes, AnnotationFileError> {
360    match value.trim().to_ascii_lowercase().as_str() {
361        "gff3" | "gff" => Ok(FileTypes::Gff3),
362        "bed" => Ok(FileTypes::Bed),
363        "genbank" | "gb" => Ok(FileTypes::GenBank),
364        other => Err(AnnotationFileError::UnsupportedFileType(other.to_string())),
365    }
366}
367
368pub fn annotation_file_extension(path: &str) -> Option<String> {
369    let path = Path::new(path);
370    let mut ext = path
371        .extension()
372        .and_then(|ext| ext.to_str())
373        .map(|ext| ext.to_ascii_lowercase());
374    if matches!(ext.as_deref(), Some("gz") | Some("bgz")) {
375        ext = path
376            .file_stem()
377            .and_then(|stem| stem.to_str())
378            .and_then(|stem| Path::new(stem).extension().and_then(|ext| ext.to_str()))
379            .map(|ext| ext.to_ascii_lowercase());
380    }
381    ext
382}
383
384pub fn annotation_index_file_path(
385    workspace: &Workspace,
386    path: &str,
387    explicit_index_path: Option<&str>,
388) -> Option<String> {
389    if let Some(index_path) = explicit_index_path {
390        return Some(index_path.to_string());
391    }
392
393    let mut candidates = vec![format!("{path}.tbi")];
394    let path_buf = PathBuf::from(path);
395    if let Some(extension) = path_buf.extension().and_then(|ext| ext.to_str()) {
396        let mut extension_candidate = path_buf.clone();
397        extension_candidate.set_extension(format!("{extension}.tbi"));
398        let extension_candidate = extension_candidate.to_string_lossy().to_string();
399        if !candidates
400            .iter()
401            .any(|candidate| candidate == &extension_candidate)
402        {
403            candidates.push(extension_candidate);
404        }
405    }
406
407    for candidate in candidates {
408        let exists = if Path::new(&candidate).is_absolute() {
409            Path::new(&candidate).exists()
410        } else {
411            workspace
412                .repo_root()
413                .ok()
414                .is_some_and(|repo_root| repo_root.join(&candidate).exists())
415        };
416        if exists {
417            return Some(candidate);
418        }
419    }
420
421    None
422}
423
424pub fn add_annotation(
425    context: &DbContext,
426    collection: &str,
427    name: &str,
428    group: Option<&str>,
429    sample: Option<&str>,
430    region: &str,
431) -> Result<Operation, Box<dyn std::error::Error>> {
432    let graph_conn = context.graph().conn();
433    let operation_conn = context.operations().conn();
434    let parsed_region = region.parse::<Region>()?;
435    let interval = parsed_region.interval();
436    let start = interval
437        .start()
438        .ok_or_else(|| anyhow!("Region missing start"))?
439        .get() as i64;
440    let end = interval
441        .end()
442        .ok_or_else(|| anyhow!("Region missing end"))?
443        .get() as i64;
444
445    let block_groups = Sample::get_block_groups(graph_conn, collection, sample);
446    let block_group = block_groups
447        .iter()
448        .find(|bg| bg.name == parsed_region.name())
449        .ok_or_else(|| {
450            let sample_label = match sample {
451                Some(name) => format!("sample {name}"),
452                None => "default sample".to_string(),
453            };
454            anyhow!(
455                "Graph {} not found for {sample_label}",
456                parsed_region.name()
457            )
458        })?;
459    let path = BlockGroup::get_current_path(graph_conn, &block_group.id);
460    let path_length = path.length(graph_conn);
461    if start < 0 || end < 0 || start > end || end > path_length {
462        return Err(anyhow!("Region {region} is outside the path bounds (0-{path_length})").into());
463    }
464
465    let mut session = start_operation(graph_conn);
466    graph_conn.execute("BEGIN TRANSACTION", [])?;
467    operation_conn.execute("BEGIN TRANSACTION", [])?;
468
469    let mut cache = PathCache::new(graph_conn);
470    let _ = PathCache::lookup(&mut cache, &block_group.id, path.name.clone());
471    let accession = BlockGroup::add_accession(graph_conn, &path, name, start, end, &mut cache);
472
473    let annotation_group = group.unwrap_or("default");
474    let annotation = Annotation::get_or_create(graph_conn, name, annotation_group, &accession.id)?;
475    if let Some(sample_name) = sample {
476        AnnotationGroupSample::create(graph_conn, &annotation.group, sample_name)?;
477    }
478
479    let operation = end_operation(
480        context,
481        &mut session,
482        &OperationInfo {
483            files: vec![],
484            description: format!("add annotation {name}"),
485        },
486        &format!("add annotation {name}"),
487        None,
488    )?;
489
490    graph_conn.execute("END TRANSACTION", [])?;
491    operation_conn.execute("END TRANSACTION", [])?;
492
493    Ok(operation)
494}
495
496pub fn add_annotation_file(
497    context: &DbContext,
498    path: &str,
499    format: Option<&str>,
500    index: Option<&str>,
501    name: Option<&str>,
502    message: Option<&str>,
503) -> Result<Operation, Box<dyn std::error::Error>> {
504    let workspace = context.workspace();
505    let operation_conn = context.operations().conn();
506    let graph_conn = context.graph().conn();
507    let db_uuid = metadata::get_db_uuid(graph_conn);
508
509    let file_type = match format {
510        Some(format) => parse_annotation_file_type(format)?,
511        None => {
512            let ext = annotation_file_extension(path).ok_or_else(|| {
513                anyhow!(
514                    "Unable to detect annotation file format from the file extension. Use --format to specify it explicitly."
515                )
516            })?;
517            parse_annotation_file_type(&ext)?
518        }
519    };
520    let file_addition =
521        FileAddition::get_or_create(workspace, operation_conn, path, file_type, None)?;
522    let index_file_addition = annotation_index_file_path(workspace, path, index)
523        .map(|index_path| {
524            FileAddition::get_or_create(
525                workspace,
526                operation_conn,
527                &index_path,
528                FileTypes::Tabix,
529                None,
530            )
531        })
532        .transpose()?;
533    let name_value = name.unwrap_or_default();
534    let index_file_addition_id = index_file_addition
535        .as_ref()
536        .map(|index_file| index_file.id.to_string())
537        .unwrap_or_default();
538    let operation_hash = HashId(calculate_hash(&format!(
539        "{file_addition_id}:{name_value}:{index_file_addition_id}",
540        file_addition_id = file_addition.id
541    )));
542    let operation = match Operation::create(operation_conn, "annotation-file", &operation_hash) {
543        Ok(operation) => operation,
544        Err(rusqlite::Error::SqliteFailure(err, _details))
545            if err.code == rusqlite::ErrorCode::ConstraintViolation =>
546        {
547            return Err(OperationError::NoChanges.into());
548        }
549        Err(err) => return Err(err.into()),
550    };
551    AnnotationFile::link_to_operation(
552        operation_conn,
553        &operation.hash,
554        &file_addition.id,
555        index_file_addition
556            .as_ref()
557            .map(|index_file| &index_file.id),
558        name,
559    )?;
560    Operation::add_database(operation_conn, &operation.hash, &db_uuid)?;
561    let summary = message
562        .map(str::to_string)
563        .unwrap_or_else(|| format!("Add annotation file {path}"));
564    OperationSummary::create(operation_conn, &operation.hash, &summary);
565
566    let gen_db = GenDatabase::get_by_uuid(operation_conn, &db_uuid)?;
567    write_changeset(
568        workspace,
569        &operation,
570        DatabaseChangeset {
571            db_path: gen_db.path,
572            changes: ChangesetModels::default(),
573        },
574        &DependencyModels::default(),
575    );
576
577    if file_type != FileTypes::Changeset && file_type != FileTypes::None {
578        let gen_dir = workspace
579            .find_gen_dir()
580            .ok_or_else(|| anyhow!("No .gen directory found. Please run 'gen init' first."))?;
581        let assets_dir = gen_dir.join("assets");
582        fs::create_dir_all(&assets_dir)?;
583        let asset_path = assets_dir.join(file_addition.hashed_filename());
584        if !asset_path.exists() {
585            let source_path = if Path::new(path).is_absolute() {
586                PathBuf::from(path)
587            } else {
588                workspace.repo_root()?.join(path)
589            };
590            fs::copy(source_path, asset_path)?;
591        }
592        if let Some(index_file_addition) = index_file_addition {
593            let index_asset_path = assets_dir.join(index_file_addition.clone().hashed_filename());
594            if !index_asset_path.exists() {
595                let index_source_path = if Path::new(&index_file_addition.file_path).is_absolute() {
596                    PathBuf::from(&index_file_addition.file_path)
597                } else {
598                    workspace.repo_root()?.join(&index_file_addition.file_path)
599                };
600                fs::copy(index_source_path, index_asset_path)?;
601            }
602        }
603    }
604
605    Ok(operation)
606}
607
608#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
609pub struct AnnotationFile {
610    pub id: i64,
611    pub operation_hash: HashId,
612    pub file_addition_id: HashId,
613    pub index_file_addition_id: Option<HashId>,
614    pub name: Option<String>,
615}
616
617impl Query for AnnotationFile {
618    type Model = AnnotationFile;
619
620    const TABLE_NAME: &'static str = "annotation_files";
621
622    fn process_row(row: &Row) -> Self::Model {
623        AnnotationFile {
624            id: row.get(0).unwrap(),
625            operation_hash: row.get(1).unwrap(),
626            file_addition_id: row.get(2).unwrap(),
627            index_file_addition_id: row.get(3).unwrap(),
628            name: row.get(4).unwrap(),
629        }
630    }
631}
632
633impl AnnotationFile {
634    pub fn load_index(
635        conn: &OperationsConnection,
636        file_addition_id: Option<&HashId>,
637    ) -> Result<Option<FileAddition>, AnnotationFileError> {
638        let Some(file_addition_id) = file_addition_id else {
639            return Ok(None);
640        };
641        let index_file_addition = FileAddition::get_by_id(conn, file_addition_id).ok_or(
642            AnnotationFileError::DatabaseError(rusqlite::Error::QueryReturnedNoRows),
643        )?;
644        if index_file_addition.file_type != FileTypes::Tabix {
645            return Err(AnnotationFileError::InvalidIndexFileType(
646                index_file_addition.file_type,
647            ));
648        }
649        Ok(Some(index_file_addition))
650    }
651
652    pub fn link_to_operation(
653        conn: &OperationsConnection,
654        operation_hash: &HashId,
655        file_addition_id: &HashId,
656        index_file_addition_id: Option<&HashId>,
657        name: Option<&str>,
658    ) -> Result<(), AnnotationFileError> {
659        AnnotationFile::load_index(conn, index_file_addition_id)?;
660        let query = "INSERT INTO annotation_files (operation_hash, file_addition_id, index_file_addition_id, name) VALUES (?1, ?2, ?3, ?4);";
661        let mut stmt = conn.prepare(query)?;
662        stmt.execute(params![
663            operation_hash,
664            file_addition_id,
665            index_file_addition_id,
666            name
667        ])?;
668        Ok(())
669    }
670
671    pub fn add_to_operation(
672        workspace: &Workspace,
673        conn: &OperationsConnection,
674        operation_hash: &HashId,
675        input: &AnnotationFileAdditionInput,
676    ) -> Result<FileAddition, AnnotationFileError> {
677        let file_addition = FileAddition::get_or_create(
678            workspace,
679            conn,
680            &input.file_path,
681            input.file_type,
682            input.checksum_override,
683        )?;
684        let index_file_addition = input
685            .index_file_path
686            .as_deref()
687            .map(|path| FileAddition::get_or_create(workspace, conn, path, FileTypes::Tabix, None))
688            .transpose()?;
689        AnnotationFile::link_to_operation(
690            conn,
691            operation_hash,
692            &file_addition.id,
693            index_file_addition.as_ref().map(|index| &index.id),
694            input.name.as_deref(),
695        )?;
696        Ok(file_addition)
697    }
698
699    pub fn get_files_for_operation(
700        conn: &OperationsConnection,
701        operation_hash: &HashId,
702    ) -> Vec<AnnotationFileInfo> {
703        let query = "select fa.*, af.index_file_addition_id, af.name from file_additions fa join annotation_files af on (fa.id = af.file_addition_id) where af.operation_hash = ?1";
704        let mut stmt = conn.prepare(query).unwrap();
705        let rows = stmt
706            .query_map(params![operation_hash], |row| {
707                Ok((
708                    FileAddition::process_row(row),
709                    row.get::<_, Option<HashId>>(4)?,
710                    row.get::<_, Option<String>>(5)?,
711                ))
712            })
713            .unwrap();
714        rows.map(|row| {
715            let (file_addition, index_file_addition_id, name) = row.unwrap();
716            AnnotationFileInfo {
717                file_addition,
718                index_file_addition: AnnotationFile::load_index(
719                    conn,
720                    index_file_addition_id.as_ref(),
721                )
722                .unwrap(),
723                name,
724            }
725        })
726        .collect()
727    }
728
729    pub fn get_all_files(conn: &OperationsConnection) -> Vec<AnnotationFileInfo> {
730        let query = "select fa.*, af.index_file_addition_id, af.name from file_additions fa join annotation_files af on (fa.id = af.file_addition_id)";
731        let mut stmt = conn.prepare(query).unwrap();
732        let rows = stmt
733            .query_map([], |row| {
734                Ok((
735                    FileAddition::process_row(row),
736                    row.get::<_, Option<HashId>>(4)?,
737                    row.get::<_, Option<String>>(5)?,
738                ))
739            })
740            .unwrap();
741        let mut entries: Vec<AnnotationFileInfo> = rows
742            .map(|row| {
743                let (file_addition, index_file_addition_id, name) = row.unwrap();
744                AnnotationFileInfo {
745                    file_addition,
746                    index_file_addition: AnnotationFile::load_index(
747                        conn,
748                        index_file_addition_id.as_ref(),
749                    )
750                    .unwrap(),
751                    name,
752                }
753            })
754            .collect();
755        entries.sort_by(|a, b| {
756            let a_name = std::path::Path::new(&a.file_addition.file_path)
757                .file_name()
758                .map(|name| name.to_string_lossy().to_string())
759                .unwrap_or_else(|| a.file_addition.file_path.clone());
760            let b_name = std::path::Path::new(&b.file_addition.file_path)
761                .file_name()
762                .map(|name| name.to_string_lossy().to_string())
763                .unwrap_or_else(|| b.file_addition.file_path.clone());
764            a_name
765                .cmp(&b_name)
766                .then_with(|| a.file_addition.file_path.cmp(&b.file_addition.file_path))
767        });
768        entries
769    }
770
771    pub fn query_by_operations(
772        conn: &OperationsConnection,
773        operations: &[HashId],
774    ) -> Result<HashMap<HashId, Vec<FileAddition>>, AnnotationFileError> {
775        let query = "select fa.*, af.operation_hash from file_additions fa left join annotation_files af on (fa.id = af.file_addition_id) where af.operation_hash in rarray(?1)";
776        let mut stmt = conn.prepare(query)?;
777        let rows = stmt.query_map(
778            params![Rc::new(
779                operations
780                    .iter()
781                    .map(|h| Value::from(*h))
782                    .collect::<Vec<Value>>()
783            )],
784            |row| Ok((FileAddition::process_row(row), row.get::<_, HashId>(4)?)),
785        )?;
786        rows.into_iter()
787            .try_fold(HashMap::new(), |mut acc: HashMap<_, Vec<_>>, row| {
788                let (item, hash) = row?;
789                acc.entry(hash).or_default().push(item);
790                Ok(acc)
791            })
792            .map_err(AnnotationFileError::DatabaseError)
793    }
794}
795
796#[cfg(test)]
797mod tests {
798    use std::fs;
799
800    use gen_core::HashId;
801
802    use super::*;
803    use crate::{
804        block_group::{BlockGroup, PathCache},
805        errors::OperationError,
806        files::GenDatabase,
807        metadata,
808        sample::Sample,
809        test_helpers::{get_connection, setup_block_group, setup_gen},
810    };
811
812    #[test]
813    fn create_annotation_with_samples() {
814        let conn = get_connection(None).unwrap();
815        let (block_group_id, path) = setup_block_group(&conn);
816
817        let _ = Sample::create(&conn, "sample-1").unwrap();
818        let _ = Sample::create(&conn, "sample-2").unwrap();
819
820        let mut cache = PathCache::new(&conn);
821        let _ = PathCache::lookup(&mut cache, &block_group_id, path.name.clone());
822        let accession = BlockGroup::add_accession(&conn, &path, "ann-accession", 0, 5, &mut cache);
823
824        let annotation =
825            Annotation::get_or_create(&conn, "gene-a", "project-tracks", &accession.id).unwrap();
826        annotation
827            .add_samples(&conn, &["sample-1", "sample-2"])
828            .unwrap();
829
830        let mut samples = Annotation::get_samples(&conn, &annotation.group).unwrap();
831        samples.sort();
832        assert_eq!(
833            samples,
834            vec!["sample-1".to_string(), "sample-2".to_string()]
835        );
836
837        let by_sample = Annotation::query_by_sample(&conn, "sample-1").unwrap();
838        assert_eq!(by_sample.len(), 1);
839        assert_eq!(by_sample[0], annotation);
840
841        let by_group = Annotation::query_by_group(&conn, "project-tracks").unwrap();
842        assert_eq!(by_group, vec![annotation]);
843    }
844
845    #[test]
846    fn add_annotation_file_to_operation() {
847        let context = setup_gen();
848        let op_conn = context.operations().conn();
849        let workspace = context.workspace();
850        let repo_root = workspace.repo_root().unwrap();
851        let annotation_path = repo_root.join("fixtures").join("annotation.gff3");
852        fs::create_dir_all(annotation_path.parent().unwrap()).unwrap();
853        fs::write(&annotation_path, "##gff-version 3\n").unwrap();
854
855        let op_hash = HashId::random_str();
856        let _ = crate::operations::Operation::create(op_conn, "annotation-file", &op_hash)
857            .expect("should create operation");
858
859        let file_addition = AnnotationFile::add_to_operation(
860            workspace,
861            op_conn,
862            &op_hash,
863            &AnnotationFileAdditionInput {
864                file_path: annotation_path.to_string_lossy().to_string(),
865                file_type: FileTypes::Gff3,
866                checksum_override: None,
867                name: Some("fixtures-annotation".to_string()),
868                index_file_path: None,
869            },
870        )
871        .unwrap();
872
873        let files = AnnotationFile::get_files_for_operation(op_conn, &op_hash);
874        assert_eq!(files.len(), 1);
875        assert_eq!(files[0].file_addition, file_addition);
876        assert!(files[0].index_file_addition.is_none());
877    }
878
879    #[test]
880    fn parse_annotation_file_type_values() {
881        assert_eq!(parse_annotation_file_type("gff3").unwrap(), FileTypes::Gff3);
882        assert_eq!(parse_annotation_file_type("GFF").unwrap(), FileTypes::Gff3);
883        assert_eq!(parse_annotation_file_type("bed").unwrap(), FileTypes::Bed);
884        assert_eq!(
885            parse_annotation_file_type("GenBank").unwrap(),
886            FileTypes::GenBank
887        );
888        assert_eq!(
889            parse_annotation_file_type("gb").unwrap(),
890            FileTypes::GenBank
891        );
892        let err = parse_annotation_file_type("bam").unwrap_err();
893        assert!(matches!(err, AnnotationFileError::UnsupportedFileType(_)));
894    }
895
896    #[test]
897    fn add_annotation_creates_annotation() {
898        let context = setup_gen();
899        let graph_conn = context.graph().conn();
900        let operation_conn = context.operations().conn();
901        let db_uuid = metadata::get_db_uuid(graph_conn);
902        let _ = GenDatabase::create(operation_conn, &db_uuid, "test-db", "test-db-path").unwrap();
903        let _ = setup_block_group(graph_conn);
904
905        let operation = add_annotation(
906            &context,
907            "test",
908            "gene-a",
909            Some("track-1"),
910            None,
911            "chr1:1-5",
912        )
913        .unwrap();
914        assert_eq!(operation.change_type, "add annotation gene-a");
915
916        let annotations = Annotation::query_by_group(graph_conn, "track-1").unwrap();
917        assert_eq!(annotations.len(), 1);
918        assert_eq!(annotations[0].name, "gene-a");
919    }
920
921    #[test]
922    fn add_annotation_file_creates_operation() {
923        let context = setup_gen();
924        let graph_conn = context.graph().conn();
925        let operation_conn = context.operations().conn();
926        let db_uuid = metadata::get_db_uuid(graph_conn);
927        let _ = GenDatabase::create(operation_conn, &db_uuid, "test-db", "test-db-path").unwrap();
928
929        let repo_root = context.workspace().repo_root().unwrap();
930        let annotation_path = repo_root.join("fixtures").join("annotation.gff3");
931        fs::create_dir_all(annotation_path.parent().unwrap()).unwrap();
932        fs::write(&annotation_path, "##gff-version 3\n").unwrap();
933        let annotation_path_str = annotation_path.to_string_lossy().to_string();
934
935        let operation = add_annotation_file(
936            &context,
937            &annotation_path_str,
938            None,
939            None,
940            Some("track-1"),
941            None,
942        )
943        .unwrap();
944        assert_eq!(operation.change_type, "annotation-file");
945
946        let files = AnnotationFile::get_files_for_operation(operation_conn, &operation.hash);
947        assert_eq!(files.len(), 1);
948        assert_eq!(files[0].name.as_deref(), Some("track-1"));
949
950        let err = add_annotation_file(
951            &context,
952            &annotation_path_str,
953            None,
954            None,
955            Some("track-1"),
956            None,
957        )
958        .unwrap_err();
959        let op_err = err
960            .downcast_ref::<OperationError>()
961            .expect("should be an OperationError");
962        assert_eq!(*op_err, OperationError::NoChanges);
963    }
964}