1use std::{collections::HashMap, fs, str, sync};
2
3use cached::proc_macro::cached;
4use gen_core::{HashId, traits::Capnp};
5use noodles::{
6 bgzf::{self, gzi},
7 core::Region,
8 fasta::{self, fai, io::indexed_reader::Builder as IndexBuilder},
9};
10use rusqlite::{Row, params};
11use serde::{Deserialize, Serialize};
12use sha2::{Digest, Sha256};
13
14use crate::{db::GraphConnection, gen_models_capnp::sequence, traits::*};
15
16#[derive(Clone, Debug, Eq, Hash, PartialEq, Deserialize, Serialize)]
17pub struct Sequence {
18 pub hash: HashId,
19 pub sequence_type: String,
20 sequence: String,
21 pub name: String,
23 pub file_path: String,
24 pub length: i64,
25 pub external_sequence: bool,
28}
29
30impl<'a> Capnp<'a> for Sequence {
31 type Builder = sequence::Builder<'a>;
32 type Reader = sequence::Reader<'a>;
33
34 fn write_capnp(&self, builder: &mut Self::Builder) {
35 builder.set_hash(&self.hash.0).unwrap();
36 builder.set_sequence_type(&self.sequence_type);
37 builder.set_sequence(&self.sequence);
38 builder.set_name(&self.name);
39 builder.set_file_path(&self.file_path);
40 builder.set_length(self.length);
41 builder.set_external_sequence(self.external_sequence);
42 }
43
44 fn read_capnp(reader: Self::Reader) -> Self {
45 let hash = reader
46 .get_hash()
47 .unwrap()
48 .as_slice()
49 .unwrap()
50 .try_into()
51 .unwrap();
52 let sequence_type = reader.get_sequence_type().unwrap().to_string().unwrap();
53 let sequence = reader.get_sequence().unwrap().to_string().unwrap();
54 let name = reader.get_name().unwrap().to_string().unwrap();
55 let file_path = reader.get_file_path().unwrap().to_string().unwrap();
56 let length = reader.get_length();
57 let external_sequence = reader.get_external_sequence();
58
59 Sequence {
60 hash,
61 sequence_type,
62 sequence,
63 name,
64 file_path,
65 length,
66 external_sequence,
67 }
68 }
69}
70
71#[derive(Default, Debug)]
72pub struct NewSequence<'a> {
73 sequence_type: Option<&'a str>,
74 sequence: Option<&'a str>,
75 name: Option<&'a str>,
76 file_path: Option<&'a str>,
77 length: Option<i64>,
78 shallow: bool,
79}
80
81impl<'a> From<&'a Sequence> for NewSequence<'a> {
82 fn from(value: &'a Sequence) -> NewSequence<'a> {
83 NewSequence::new()
84 .sequence_type(&value.sequence_type)
85 .sequence(&value.sequence)
86 .name(&value.name)
87 .file_path(&value.file_path)
88 .length(value.length)
89 }
90}
91
92impl<'a> NewSequence<'a> {
93 pub fn new() -> NewSequence<'static> {
94 NewSequence {
95 shallow: false,
96 ..NewSequence::default()
97 }
98 }
99
100 pub fn shallow(mut self, setting: bool) -> Self {
101 self.shallow = setting;
102 self
103 }
104
105 pub fn sequence_type(mut self, seq_type: &'a str) -> Self {
106 self.sequence_type = Some(seq_type);
107 self
108 }
109
110 pub fn sequence(mut self, sequence: &'a str) -> Self {
111 self.sequence = Some(sequence);
112 self.length = Some(sequence.len() as i64);
113 self
114 }
115
116 pub fn name(mut self, name: &'a str) -> Self {
117 self.name = Some(name);
118 self
119 }
120
121 pub fn file_path(mut self, path: &'a str) -> Self {
122 if !path.is_empty() {
123 self.file_path = Some(path);
124 self.shallow = true;
125 }
126 self
127 }
128
129 pub fn length(mut self, length: i64) -> Self {
130 self.length = Some(length);
131 self
132 }
133
134 pub fn hash(&self) -> HashId {
135 let mut hasher = Sha256::new();
136 hasher.update(self.sequence_type.expect("Sequence type must be defined."));
137 hasher.update(";");
138 if let Some(v) = self.sequence {
139 hasher.update(v);
140 } else {
141 hasher.update("");
142 }
143 hasher.update(";");
144 if let Some(v) = self.name {
145 hasher.update(v);
146 } else {
147 hasher.update("");
148 }
149 hasher.update(";");
150 if let Some(v) = self.file_path {
151 hasher.update(v);
152 } else {
153 hasher.update("");
154 }
155 hasher.update(";");
156
157 HashId(hasher.finalize().into())
158 }
159
160 pub fn build(self) -> Sequence {
161 let file_path = self.file_path.unwrap_or("").to_string();
162 let external_sequence = !file_path.is_empty();
163 Sequence {
164 hash: self.hash(),
165 sequence_type: self.sequence_type.unwrap().to_string(),
166 sequence: self.sequence.unwrap_or("").to_string(),
167 name: self.name.unwrap_or("").to_string(),
168 file_path,
169 length: self.length.unwrap(),
170 external_sequence,
171 }
172 }
173
174 pub fn save(self, conn: &GraphConnection) -> Sequence {
175 let mut length = 0;
176 if self.sequence.is_none() && self.file_path.is_none() {
177 panic!("Sequence or file_path must be set.");
178 }
179 if self.file_path.is_some() && self.name.is_none() {
180 panic!("A filepath must have an accompanying sequence name");
181 }
182 if self.length.is_none() {
183 if let Some(v) = self.sequence {
184 length = v.len() as i64;
185 } else {
186 panic!("Sequence length must be specified.");
188 }
189 }
190 let hash = self.hash();
191 match conn.query_row(
192 "SELECT hash from sequences where hash = ?1;",
193 [hash],
194 |row| row.get::<_, HashId>(0),
195 ) {
196 Ok(_) => {}
197 Err(rusqlite::Error::QueryReturnedNoRows) => {
198 let mut stmt = conn.prepare("INSERT INTO sequences (hash, sequence_type, sequence, name, file_path, length) VALUES (?1, ?2, ?3, ?4, ?5, ?6);").unwrap();
199 stmt.execute(params![
200 hash,
201 self.sequence_type.unwrap().to_string(),
202 if self.shallow {
203 ""
204 } else {
205 self.sequence.unwrap()
206 },
207 self.name.unwrap_or(""),
208 self.file_path.unwrap_or(""),
209 self.length.unwrap_or(length)
210 ])
211 .unwrap();
212 }
213 Err(_e) => {
214 panic!("something bad happened querying the database")
215 }
216 };
217 Sequence {
218 hash,
219 sequence_type: self.sequence_type.unwrap().to_string(),
220 sequence: self.sequence.unwrap_or("").to_string(),
221 name: self.name.unwrap_or("").to_string(),
222 file_path: self.file_path.unwrap_or("").to_string(),
223 length: self.length.unwrap_or(length),
224 external_sequence: !self.file_path.unwrap_or("").is_empty(),
225 }
226 }
227}
228
229#[cached(key = "String", convert = r#"{ format!("{}", path) }"#)]
230fn fasta_index(path: &str) -> Option<fai::Index> {
231 let index_path = format!("{path}.fai");
232 if fs::metadata(&index_path).is_ok() {
233 return Some(fai::fs::read(&index_path).unwrap());
234 }
235 None
236}
237
238#[cached(key = "String", convert = r#"{ format!("{}", path) }"#)]
239fn fasta_gzi_index(path: &str) -> Option<gzi::Index> {
240 let index_path = format!("{path}.gzi");
241 if fs::metadata(&index_path).is_ok() {
242 return Some(gzi::fs::read(&index_path).unwrap());
243 }
244 None
245}
246
247pub fn cached_sequence(file_path: &str, name: &str, start: usize, end: usize) -> Option<String> {
248 static SEQUENCE_CACHE: sync::LazyLock<sync::RwLock<HashMap<String, Option<String>>>> =
249 sync::LazyLock::new(|| sync::RwLock::new(HashMap::new()));
250 let key = format!("{file_path}-{name}");
251
252 {
253 let cache = SEQUENCE_CACHE.read().unwrap();
254 if let Some(cached_sequence) = cache.get(&key) {
255 if let Some(sequence) = cached_sequence {
256 return Some(sequence[start..end].to_string());
257 }
258 return None;
259 }
260 }
261
262 let mut cache = SEQUENCE_CACHE.write().unwrap();
263
264 let mut sequence: Option<String> = None;
265 let region = name.parse::<Region>().unwrap();
266 if let Some(index) = fasta_index(file_path) {
267 let builder = IndexBuilder::default().set_index(index);
268 if let Some(gzi_index) = fasta_gzi_index(file_path) {
269 let bgzf_reader = bgzf::io::indexed_reader::Builder::default()
270 .set_index(gzi_index)
271 .build_from_path(file_path)
272 .unwrap();
273 let mut reader = builder.build_from_reader(bgzf_reader).unwrap();
274 sequence = Some(
275 str::from_utf8(reader.query(®ion).unwrap().sequence().as_ref())
276 .unwrap()
277 .to_string(),
278 )
279 } else {
280 let mut reader = builder.build_from_path(file_path).unwrap();
281 sequence = Some(
282 str::from_utf8(reader.query(®ion).unwrap().sequence().as_ref())
283 .unwrap()
284 .to_string(),
285 );
286 }
287 } else {
288 let mut reader = fasta::io::reader::Builder
289 .build_from_path(file_path)
290 .unwrap();
291 for result in reader.records() {
292 let record = result.unwrap();
293 if String::from_utf8(record.name().to_vec()).unwrap() == name {
294 sequence = Some(
295 str::from_utf8(record.sequence().as_ref())
296 .unwrap()
297 .to_string(),
298 );
299 break;
300 }
301 }
302 }
303 cache.clear();
306 cache.insert(key.clone(), sequence);
307 if let Some(seq) = &cache[&key] {
309 return Some(seq[start..end].to_string());
310 }
311 None
312}
313
314impl Sequence {
315 #[allow(clippy::new_ret_no_self)]
316 pub fn new() -> NewSequence<'static> {
317 NewSequence::new()
318 }
319
320 pub fn get_sequence(
321 &self,
322 start: impl Into<Option<i64>>,
323 end: impl Into<Option<i64>>,
324 ) -> String {
325 let start: Option<i64> = start.into();
328 let end: Option<i64> = end.into();
329 let start = start.unwrap_or(0) as usize;
330 let end = end.unwrap_or(self.length) as usize;
331 if self.external_sequence {
332 if let Some(sequence) = cached_sequence(&self.file_path, &self.name, start, end) {
333 return sequence;
334 } else {
335 panic!(
336 "{name} not found in fasta file {file_path}",
337 name = self.name,
338 file_path = self.file_path
339 );
340 }
341 }
342 if start == 0 && end as i64 == self.length {
343 return self.sequence.clone();
344 }
345 self.sequence[start..end].to_string()
346 }
347
348 pub fn delete_by_hash(conn: &GraphConnection, hash: &HashId) {
349 let mut stmt = conn
350 .prepare("delete from sequences where hash = ?1;")
351 .unwrap();
352 stmt.execute(params![hash]).unwrap();
353 }
354
355 pub fn query_by_blockgroup(conn: &GraphConnection, block_group_id: &HashId) -> Vec<Sequence> {
356 Sequence::query(
357 conn,
358 "select sequences.* from block_group_edges bge left join edges on bge.edge_id = edges.id left join nodes on (edges.source_node_id = nodes.id or edges.target_node_id = nodes.id) left join sequences on (nodes.sequence_hash = sequences.hash) where bge.block_group_id = ?1;",
359 params![block_group_id],
360 )
361 }
362}
363
364impl Query for Sequence {
365 type Model = Sequence;
366
367 const PRIMARY_KEY: &'static str = "hash";
368 const TABLE_NAME: &'static str = "sequences";
369
370 fn process_row(row: &Row) -> Self::Model {
371 let file_path: String = row.get(4).unwrap();
372 let mut external_sequence = false;
373 if !file_path.is_empty() {
374 external_sequence = true;
375 }
376 let hash: HashId = row.get(0).unwrap();
377 let sequence = row.get(2).unwrap();
378 Sequence {
379 hash,
380 sequence_type: row.get(1).unwrap(),
381 sequence,
382 name: row.get(3).unwrap(),
383 file_path,
384 length: row.get(5).unwrap(),
385 external_sequence,
386 }
387 }
388}
389
390#[cfg(test)]
391mod tests {
392 #[allow(unused_imports)]
394 use std::time;
395 use std::{fs::OpenOptions, io::Write};
396
397 use rand::{self, Rng};
398
399 use super::*;
400 use crate::test_helpers::get_connection;
401
402 #[test]
403 fn test_builder() {
404 let sequence = Sequence::new()
405 .sequence_type("DNA")
406 .sequence("ATCG")
407 .build();
408 assert_eq!(sequence.length, 4);
409 assert_eq!(sequence.sequence, "ATCG");
410 }
411
412 #[test]
413 fn test_builder_with_from_disk() {
414 let sequence = Sequence::new()
415 .sequence_type("DNA")
416 .name("chr1")
417 .file_path("/foo/bar")
418 .length(50)
419 .build();
420 assert_eq!(sequence.length, 50);
421 assert_eq!(sequence.sequence, "");
422 }
423
424 #[test]
425 fn test_create_sequence_in_db() {
426 let conn = &get_connection(None).unwrap();
427 let sequence = Sequence::new()
428 .sequence_type("DNA")
429 .sequence("AACCTT")
430 .save(conn);
431 assert_eq!(&sequence.sequence, "AACCTT");
432 assert_eq!(sequence.sequence_type, "DNA");
433 assert!(!sequence.external_sequence);
434 }
435
436 #[test]
437 fn test_delete_sequence_by_hash() {
438 let conn = &get_connection(None).unwrap();
439 let before_count = Sequence::all(conn).len();
440 let sequence = Sequence::new()
441 .sequence_type("DNA")
442 .sequence("AACCTT")
443 .save(conn);
444 let sequence2 = Sequence::new()
445 .sequence_type("DNA")
446 .sequence("AACCTTAA")
447 .save(conn);
448
449 let sequences = Sequence::all(conn);
450 assert_eq!(sequences.len(), before_count + 2);
451
452 Sequence::delete_by_hash(conn, &sequence.hash);
453
454 let sequences = Sequence::all(conn);
455 assert_eq!(sequences.len(), before_count + 1);
456 assert!(sequences.iter().any(|s| s.hash == sequence2.hash));
457 }
458
459 #[test]
460 fn test_create_sequence_on_disk() {
461 let conn = &get_connection(None).unwrap();
462 let sequence = Sequence::new()
463 .sequence_type("DNA")
464 .name("chr1")
465 .file_path("/some/path.fa")
466 .length(10)
467 .save(conn);
468 assert_eq!(sequence.sequence_type, "DNA");
469 assert_eq!(&sequence.sequence, "");
470 assert_eq!(sequence.name, "chr1");
471 assert_eq!(sequence.file_path, "/some/path.fa");
472 assert_eq!(sequence.length, 10);
473 assert!(sequence.external_sequence);
474 }
475
476 #[test]
477 fn test_get_sequence() {
478 let conn = &get_connection(None).unwrap();
479 let sequence = Sequence::new()
480 .sequence_type("DNA")
481 .sequence("ATCGATCGATCGATCGATCGGGAACACACAGAGA")
482 .save(conn);
483 assert_eq!(
484 sequence.get_sequence(None, None),
485 "ATCGATCGATCGATCGATCGGGAACACACAGAGA"
486 );
487 assert_eq!(sequence.get_sequence(0, 5), "ATCGA");
488 assert_eq!(sequence.get_sequence(10, 15), "CGATC");
489 assert_eq!(
490 sequence.get_sequence(3, None),
491 "GATCGATCGATCGATCGGGAACACACAGAGA"
492 );
493 assert_eq!(sequence.get_sequence(None, 5), "ATCGA");
494 }
495
496 #[test]
497 fn test_get_sequence_from_disk() {
498 let conn = &get_connection(None).unwrap();
499 let temp_dir = tempfile::tempdir().unwrap();
500 let temp_file_path = temp_dir.path().join("simple.fa");
501 fs::write(
502 &temp_file_path,
503 ">m123\nATCGATCGATCGATCGATCGGGAACACACAGAGA\n",
504 )
505 .unwrap();
506 let seq = Sequence::new()
507 .sequence_type("DNA")
508 .name("m123")
509 .file_path(temp_file_path.to_str().unwrap())
510 .length(34)
511 .save(conn);
512 assert_eq!(
513 seq.get_sequence(None, None),
514 "ATCGATCGATCGATCGATCGGGAACACACAGAGA"
515 );
516 assert_eq!(seq.get_sequence(0, 5), "ATCGA");
517 assert_eq!(seq.get_sequence(10, 15), "CGATC");
518 assert_eq!(seq.get_sequence(3, None), "GATCGATCGATCGATCGGGAACACACAGAGA");
519 assert_eq!(seq.get_sequence(None, 5), "ATCGA");
520 }
521
522 #[test]
523 fn test_cached_sequence_performance() {
525 let conn = &get_connection(None).unwrap();
526 let temp_dir = tempfile::tempdir().unwrap();
527 let temp_file_path = temp_dir.path().join("large.fa");
528 let mut file = OpenOptions::new()
529 .append(true)
530 .create(true)
531 .open(&temp_file_path)
532 .unwrap();
533 writeln!(file, ">chr22").unwrap();
534 for _ in 1..3_000_000 {
535 writeln!(
536 file,
537 "ATCGATCGATCGATCGATCGGGAACACACAGAGAATCGATCGATCGATCGATCGGGAACACACAGAGA"
538 )
539 .unwrap();
540 }
541 let index_path = temp_dir.path().join("large.fa.fai");
543 fs::write(&index_path, "chr22 203999932 7 68 69\n").unwrap();
544 let sequence = Sequence::new()
545 .sequence_type("DNA")
546 .file_path(temp_file_path.to_str().unwrap())
547 .name("chr22")
548 .length(203_999_932)
549 .save(conn);
550 let s = time::Instant::now();
551 for _ in 1..1_000_000 {
552 let start = rand::rng().random_range(1..200_000_000);
553
554 sequence.get_sequence(start, start + 20);
555 }
556 let elapsed = s.elapsed().as_secs();
557 assert!(
558 elapsed < 5,
559 "Cached sequence benchmark failed: {elapsed}s elapsed"
560 );
561 }
562
563 #[test]
564 fn test_capnp_serialization() {
565 use capnp::message::TypedBuilder;
566
567 let sequence = Sequence {
568 hash: HashId::convert_str("test_hash"),
569 sequence_type: "DNA".to_string(),
570 sequence: "ATCG".to_string(),
571 name: "test_seq".to_string(),
572 file_path: "/path/to/file".to_string(),
573 length: 4,
574 external_sequence: false,
575 };
576
577 let mut message = TypedBuilder::<sequence::Owned>::new_default();
578 let mut root = message.init_root();
579 sequence.write_capnp(&mut root);
580
581 let deserialized = Sequence::read_capnp(root.into_reader());
582 assert_eq!(sequence, deserialized);
583 }
584}