Skip to main content

gen_models/
sequence.rs

1use std::{collections::HashMap, fs, str, sync};
2
3use cached::proc_macro::cached;
4use gen_core::{HashId, traits::Capnp};
5use noodles::{
6    bgzf::{self, gzi},
7    core::Region,
8    fasta::{self, fai, io::indexed_reader::Builder as IndexBuilder},
9};
10use rusqlite::{Row, params};
11use serde::{Deserialize, Serialize};
12use sha2::{Digest, Sha256};
13
14use crate::{db::GraphConnection, gen_models_capnp::sequence, traits::*};
15
16#[derive(Clone, Debug, Eq, Hash, PartialEq, Deserialize, Serialize)]
17pub struct Sequence {
18    pub hash: HashId,
19    pub sequence_type: String,
20    sequence: String,
21    // these 2 fields are only relevant when the sequence is stored externally
22    pub name: String,
23    pub file_path: String,
24    pub length: i64,
25    // indicates whether the sequence is stored externally, a quick flag instead of having to
26    // check sequence or file_path and do the logic in function calls.
27    pub external_sequence: bool,
28}
29
30impl<'a> Capnp<'a> for Sequence {
31    type Builder = sequence::Builder<'a>;
32    type Reader = sequence::Reader<'a>;
33
34    fn write_capnp(&self, builder: &mut Self::Builder) {
35        builder.set_hash(&self.hash.0).unwrap();
36        builder.set_sequence_type(&self.sequence_type);
37        builder.set_sequence(&self.sequence);
38        builder.set_name(&self.name);
39        builder.set_file_path(&self.file_path);
40        builder.set_length(self.length);
41        builder.set_external_sequence(self.external_sequence);
42    }
43
44    fn read_capnp(reader: Self::Reader) -> Self {
45        let hash = reader
46            .get_hash()
47            .unwrap()
48            .as_slice()
49            .unwrap()
50            .try_into()
51            .unwrap();
52        let sequence_type = reader.get_sequence_type().unwrap().to_string().unwrap();
53        let sequence = reader.get_sequence().unwrap().to_string().unwrap();
54        let name = reader.get_name().unwrap().to_string().unwrap();
55        let file_path = reader.get_file_path().unwrap().to_string().unwrap();
56        let length = reader.get_length();
57        let external_sequence = reader.get_external_sequence();
58
59        Sequence {
60            hash,
61            sequence_type,
62            sequence,
63            name,
64            file_path,
65            length,
66            external_sequence,
67        }
68    }
69}
70
71#[derive(Default, Debug)]
72pub struct NewSequence<'a> {
73    sequence_type: Option<&'a str>,
74    sequence: Option<&'a str>,
75    name: Option<&'a str>,
76    file_path: Option<&'a str>,
77    length: Option<i64>,
78    shallow: bool,
79}
80
81impl<'a> From<&'a Sequence> for NewSequence<'a> {
82    fn from(value: &'a Sequence) -> NewSequence<'a> {
83        NewSequence::new()
84            .sequence_type(&value.sequence_type)
85            .sequence(&value.sequence)
86            .name(&value.name)
87            .file_path(&value.file_path)
88            .length(value.length)
89    }
90}
91
92impl<'a> NewSequence<'a> {
93    pub fn new() -> NewSequence<'static> {
94        NewSequence {
95            shallow: false,
96            ..NewSequence::default()
97        }
98    }
99
100    pub fn shallow(mut self, setting: bool) -> Self {
101        self.shallow = setting;
102        self
103    }
104
105    pub fn sequence_type(mut self, seq_type: &'a str) -> Self {
106        self.sequence_type = Some(seq_type);
107        self
108    }
109
110    pub fn sequence(mut self, sequence: &'a str) -> Self {
111        self.sequence = Some(sequence);
112        self.length = Some(sequence.len() as i64);
113        self
114    }
115
116    pub fn name(mut self, name: &'a str) -> Self {
117        self.name = Some(name);
118        self
119    }
120
121    pub fn file_path(mut self, path: &'a str) -> Self {
122        if !path.is_empty() {
123            self.file_path = Some(path);
124            self.shallow = true;
125        }
126        self
127    }
128
129    pub fn length(mut self, length: i64) -> Self {
130        self.length = Some(length);
131        self
132    }
133
134    pub fn hash(&self) -> HashId {
135        let mut hasher = Sha256::new();
136        hasher.update(self.sequence_type.expect("Sequence type must be defined."));
137        hasher.update(";");
138        if let Some(v) = self.sequence {
139            hasher.update(v);
140        } else {
141            hasher.update("");
142        }
143        hasher.update(";");
144        if let Some(v) = self.name {
145            hasher.update(v);
146        } else {
147            hasher.update("");
148        }
149        hasher.update(";");
150        if let Some(v) = self.file_path {
151            hasher.update(v);
152        } else {
153            hasher.update("");
154        }
155        hasher.update(";");
156
157        HashId(hasher.finalize().into())
158    }
159
160    pub fn build(self) -> Sequence {
161        let file_path = self.file_path.unwrap_or("").to_string();
162        let external_sequence = !file_path.is_empty();
163        Sequence {
164            hash: self.hash(),
165            sequence_type: self.sequence_type.unwrap().to_string(),
166            sequence: self.sequence.unwrap_or("").to_string(),
167            name: self.name.unwrap_or("").to_string(),
168            file_path,
169            length: self.length.unwrap(),
170            external_sequence,
171        }
172    }
173
174    pub fn save(self, conn: &GraphConnection) -> Sequence {
175        let mut length = 0;
176        if self.sequence.is_none() && self.file_path.is_none() {
177            panic!("Sequence or file_path must be set.");
178        }
179        if self.file_path.is_some() && self.name.is_none() {
180            panic!("A filepath must have an accompanying sequence name");
181        }
182        if self.length.is_none() {
183            if let Some(v) = self.sequence {
184                length = v.len() as i64;
185            } else {
186                // TODO: if name/path specified, grab length automatically
187                panic!("Sequence length must be specified.");
188            }
189        }
190        let hash = self.hash();
191        match conn.query_row(
192            "SELECT hash from sequences where hash = ?1;",
193            [hash],
194            |row| row.get::<_, HashId>(0),
195        ) {
196            Ok(_) => {}
197            Err(rusqlite::Error::QueryReturnedNoRows) => {
198                let mut stmt = conn.prepare("INSERT INTO sequences (hash, sequence_type, sequence, name, file_path, length) VALUES (?1, ?2, ?3, ?4, ?5, ?6);").unwrap();
199                stmt.execute(params![
200                    hash,
201                    self.sequence_type.unwrap().to_string(),
202                    if self.shallow {
203                        ""
204                    } else {
205                        self.sequence.unwrap()
206                    },
207                    self.name.unwrap_or(""),
208                    self.file_path.unwrap_or(""),
209                    self.length.unwrap_or(length)
210                ])
211                .unwrap();
212            }
213            Err(_e) => {
214                panic!("something bad happened querying the database")
215            }
216        };
217        Sequence {
218            hash,
219            sequence_type: self.sequence_type.unwrap().to_string(),
220            sequence: self.sequence.unwrap_or("").to_string(),
221            name: self.name.unwrap_or("").to_string(),
222            file_path: self.file_path.unwrap_or("").to_string(),
223            length: self.length.unwrap_or(length),
224            external_sequence: !self.file_path.unwrap_or("").is_empty(),
225        }
226    }
227}
228
229#[cached(key = "String", convert = r#"{ format!("{}", path) }"#)]
230fn fasta_index(path: &str) -> Option<fai::Index> {
231    let index_path = format!("{path}.fai");
232    if fs::metadata(&index_path).is_ok() {
233        return Some(fai::fs::read(&index_path).unwrap());
234    }
235    None
236}
237
238#[cached(key = "String", convert = r#"{ format!("{}", path) }"#)]
239fn fasta_gzi_index(path: &str) -> Option<gzi::Index> {
240    let index_path = format!("{path}.gzi");
241    if fs::metadata(&index_path).is_ok() {
242        return Some(gzi::fs::read(&index_path).unwrap());
243    }
244    None
245}
246
247pub fn cached_sequence(file_path: &str, name: &str, start: usize, end: usize) -> Option<String> {
248    static SEQUENCE_CACHE: sync::LazyLock<sync::RwLock<HashMap<String, Option<String>>>> =
249        sync::LazyLock::new(|| sync::RwLock::new(HashMap::new()));
250    let key = format!("{file_path}-{name}");
251
252    {
253        let cache = SEQUENCE_CACHE.read().unwrap();
254        if let Some(cached_sequence) = cache.get(&key) {
255            if let Some(sequence) = cached_sequence {
256                return Some(sequence[start..end].to_string());
257            }
258            return None;
259        }
260    }
261
262    let mut cache = SEQUENCE_CACHE.write().unwrap();
263
264    let mut sequence: Option<String> = None;
265    let region = name.parse::<Region>().unwrap();
266    if let Some(index) = fasta_index(file_path) {
267        let builder = IndexBuilder::default().set_index(index);
268        if let Some(gzi_index) = fasta_gzi_index(file_path) {
269            let bgzf_reader = bgzf::io::indexed_reader::Builder::default()
270                .set_index(gzi_index)
271                .build_from_path(file_path)
272                .unwrap();
273            let mut reader = builder.build_from_reader(bgzf_reader).unwrap();
274            sequence = Some(
275                str::from_utf8(reader.query(&region).unwrap().sequence().as_ref())
276                    .unwrap()
277                    .to_string(),
278            )
279        } else {
280            let mut reader = builder.build_from_path(file_path).unwrap();
281            sequence = Some(
282                str::from_utf8(reader.query(&region).unwrap().sequence().as_ref())
283                    .unwrap()
284                    .to_string(),
285            );
286        }
287    } else {
288        let mut reader = fasta::io::reader::Builder
289            .build_from_path(file_path)
290            .unwrap();
291        for result in reader.records() {
292            let record = result.unwrap();
293            if String::from_utf8(record.name().to_vec()).unwrap() == name {
294                sequence = Some(
295                    str::from_utf8(record.sequence().as_ref())
296                        .unwrap()
297                        .to_string(),
298                );
299                break;
300            }
301        }
302    }
303    // this is a LRU cache setup, we just keep the last sequence we fetched so we don't end up loading
304    // plant genomes into memory.
305    cache.clear();
306    cache.insert(key.clone(), sequence);
307    // we do this to avoid a clone of potentially large data.
308    if let Some(seq) = &cache[&key] {
309        return Some(seq[start..end].to_string());
310    }
311    None
312}
313
314impl Sequence {
315    #[allow(clippy::new_ret_no_self)]
316    pub fn new() -> NewSequence<'static> {
317        NewSequence::new()
318    }
319
320    pub fn get_sequence(
321        &self,
322        start: impl Into<Option<i64>>,
323        end: impl Into<Option<i64>>,
324    ) -> String {
325        // todo: handle circles
326
327        let start: Option<i64> = start.into();
328        let end: Option<i64> = end.into();
329        let start = start.unwrap_or(0) as usize;
330        let end = end.unwrap_or(self.length) as usize;
331        if self.external_sequence {
332            if let Some(sequence) = cached_sequence(&self.file_path, &self.name, start, end) {
333                return sequence;
334            } else {
335                panic!(
336                    "{name} not found in fasta file {file_path}",
337                    name = self.name,
338                    file_path = self.file_path
339                );
340            }
341        }
342        if start == 0 && end as i64 == self.length {
343            return self.sequence.clone();
344        }
345        self.sequence[start..end].to_string()
346    }
347
348    pub fn delete_by_hash(conn: &GraphConnection, hash: &HashId) {
349        let mut stmt = conn
350            .prepare("delete from sequences where hash = ?1;")
351            .unwrap();
352        stmt.execute(params![hash]).unwrap();
353    }
354
355    pub fn query_by_blockgroup(conn: &GraphConnection, block_group_id: &HashId) -> Vec<Sequence> {
356        Sequence::query(
357            conn,
358            "select sequences.* from block_group_edges bge left join edges on bge.edge_id = edges.id left join nodes on (edges.source_node_id = nodes.id or edges.target_node_id = nodes.id) left join sequences on (nodes.sequence_hash = sequences.hash) where bge.block_group_id = ?1;",
359            params![block_group_id],
360        )
361    }
362}
363
364impl Query for Sequence {
365    type Model = Sequence;
366
367    const PRIMARY_KEY: &'static str = "hash";
368    const TABLE_NAME: &'static str = "sequences";
369
370    fn process_row(row: &Row) -> Self::Model {
371        let file_path: String = row.get(4).unwrap();
372        let mut external_sequence = false;
373        if !file_path.is_empty() {
374            external_sequence = true;
375        }
376        let hash: HashId = row.get(0).unwrap();
377        let sequence = row.get(2).unwrap();
378        Sequence {
379            hash,
380            sequence_type: row.get(1).unwrap(),
381            sequence,
382            name: row.get(3).unwrap(),
383            file_path,
384            length: row.get(5).unwrap(),
385            external_sequence,
386        }
387    }
388}
389
390#[cfg(test)]
391mod tests {
392    // Note this useful idiom: importing names from outer (for mod tests) scope.
393    #[allow(unused_imports)]
394    use std::time;
395    use std::{fs::OpenOptions, io::Write};
396
397    use rand::{self, Rng};
398
399    use super::*;
400    use crate::test_helpers::get_connection;
401
402    #[test]
403    fn test_builder() {
404        let sequence = Sequence::new()
405            .sequence_type("DNA")
406            .sequence("ATCG")
407            .build();
408        assert_eq!(sequence.length, 4);
409        assert_eq!(sequence.sequence, "ATCG");
410    }
411
412    #[test]
413    fn test_builder_with_from_disk() {
414        let sequence = Sequence::new()
415            .sequence_type("DNA")
416            .name("chr1")
417            .file_path("/foo/bar")
418            .length(50)
419            .build();
420        assert_eq!(sequence.length, 50);
421        assert_eq!(sequence.sequence, "");
422    }
423
424    #[test]
425    fn test_create_sequence_in_db() {
426        let conn = &get_connection(None).unwrap();
427        let sequence = Sequence::new()
428            .sequence_type("DNA")
429            .sequence("AACCTT")
430            .save(conn);
431        assert_eq!(&sequence.sequence, "AACCTT");
432        assert_eq!(sequence.sequence_type, "DNA");
433        assert!(!sequence.external_sequence);
434    }
435
436    #[test]
437    fn test_delete_sequence_by_hash() {
438        let conn = &get_connection(None).unwrap();
439        let before_count = Sequence::all(conn).len();
440        let sequence = Sequence::new()
441            .sequence_type("DNA")
442            .sequence("AACCTT")
443            .save(conn);
444        let sequence2 = Sequence::new()
445            .sequence_type("DNA")
446            .sequence("AACCTTAA")
447            .save(conn);
448
449        let sequences = Sequence::all(conn);
450        assert_eq!(sequences.len(), before_count + 2);
451
452        Sequence::delete_by_hash(conn, &sequence.hash);
453
454        let sequences = Sequence::all(conn);
455        assert_eq!(sequences.len(), before_count + 1);
456        assert!(sequences.iter().any(|s| s.hash == sequence2.hash));
457    }
458
459    #[test]
460    fn test_create_sequence_on_disk() {
461        let conn = &get_connection(None).unwrap();
462        let sequence = Sequence::new()
463            .sequence_type("DNA")
464            .name("chr1")
465            .file_path("/some/path.fa")
466            .length(10)
467            .save(conn);
468        assert_eq!(sequence.sequence_type, "DNA");
469        assert_eq!(&sequence.sequence, "");
470        assert_eq!(sequence.name, "chr1");
471        assert_eq!(sequence.file_path, "/some/path.fa");
472        assert_eq!(sequence.length, 10);
473        assert!(sequence.external_sequence);
474    }
475
476    #[test]
477    fn test_get_sequence() {
478        let conn = &get_connection(None).unwrap();
479        let sequence = Sequence::new()
480            .sequence_type("DNA")
481            .sequence("ATCGATCGATCGATCGATCGGGAACACACAGAGA")
482            .save(conn);
483        assert_eq!(
484            sequence.get_sequence(None, None),
485            "ATCGATCGATCGATCGATCGGGAACACACAGAGA"
486        );
487        assert_eq!(sequence.get_sequence(0, 5), "ATCGA");
488        assert_eq!(sequence.get_sequence(10, 15), "CGATC");
489        assert_eq!(
490            sequence.get_sequence(3, None),
491            "GATCGATCGATCGATCGGGAACACACAGAGA"
492        );
493        assert_eq!(sequence.get_sequence(None, 5), "ATCGA");
494    }
495
496    #[test]
497    fn test_get_sequence_from_disk() {
498        let conn = &get_connection(None).unwrap();
499        let temp_dir = tempfile::tempdir().unwrap();
500        let temp_file_path = temp_dir.path().join("simple.fa");
501        fs::write(
502            &temp_file_path,
503            ">m123\nATCGATCGATCGATCGATCGGGAACACACAGAGA\n",
504        )
505        .unwrap();
506        let seq = Sequence::new()
507            .sequence_type("DNA")
508            .name("m123")
509            .file_path(temp_file_path.to_str().unwrap())
510            .length(34)
511            .save(conn);
512        assert_eq!(
513            seq.get_sequence(None, None),
514            "ATCGATCGATCGATCGATCGGGAACACACAGAGA"
515        );
516        assert_eq!(seq.get_sequence(0, 5), "ATCGA");
517        assert_eq!(seq.get_sequence(10, 15), "CGATC");
518        assert_eq!(seq.get_sequence(3, None), "GATCGATCGATCGATCGGGAACACACAGAGA");
519        assert_eq!(seq.get_sequence(None, 5), "ATCGA");
520    }
521
522    #[test]
523    // #[cfg(feature = "benchmark")]
524    fn test_cached_sequence_performance() {
525        let conn = &get_connection(None).unwrap();
526        let temp_dir = tempfile::tempdir().unwrap();
527        let temp_file_path = temp_dir.path().join("large.fa");
528        let mut file = OpenOptions::new()
529            .append(true)
530            .create(true)
531            .open(&temp_file_path)
532            .unwrap();
533        writeln!(file, ">chr22").unwrap();
534        for _ in 1..3_000_000 {
535            writeln!(
536                file,
537                "ATCGATCGATCGATCGATCGGGAACACACAGAGAATCGATCGATCGATCGATCGGGAACACACAGAGA"
538            )
539            .unwrap();
540        }
541        // write index
542        let index_path = temp_dir.path().join("large.fa.fai");
543        fs::write(&index_path, "chr22	203999932	7	68	69\n").unwrap();
544        let sequence = Sequence::new()
545            .sequence_type("DNA")
546            .file_path(temp_file_path.to_str().unwrap())
547            .name("chr22")
548            .length(203_999_932)
549            .save(conn);
550        let s = time::Instant::now();
551        for _ in 1..1_000_000 {
552            let start = rand::rng().random_range(1..200_000_000);
553
554            sequence.get_sequence(start, start + 20);
555        }
556        let elapsed = s.elapsed().as_secs();
557        assert!(
558            elapsed < 5,
559            "Cached sequence benchmark failed: {elapsed}s elapsed"
560        );
561    }
562
563    #[test]
564    fn test_capnp_serialization() {
565        use capnp::message::TypedBuilder;
566
567        let sequence = Sequence {
568            hash: HashId::convert_str("test_hash"),
569            sequence_type: "DNA".to_string(),
570            sequence: "ATCG".to_string(),
571            name: "test_seq".to_string(),
572            file_path: "/path/to/file".to_string(),
573            length: 4,
574            external_sequence: false,
575        };
576
577        let mut message = TypedBuilder::<sequence::Owned>::new_default();
578        let mut root = message.init_root();
579        sequence.write_capnp(&mut root);
580
581        let deserialized = Sequence::read_capnp(root.into_reader());
582        assert_eq!(sequence, deserialized);
583    }
584}