1use std::collections::HashMap;
7use std::path::Path;
8
9use memmap2::Mmap;
10use noodles_fasta::fai;
11
12use crate::fasta::SidecarCache;
13use refget_model::SequenceMetadata;
14
15use crate::fasta::{DigestCache, FastaSequenceSummary, index_digests, read_fai_index};
16use crate::{SequenceStore, StoreError, StoreResult};
17
18pub struct MmapSequenceStore {
27 digest_index: HashMap<String, usize>,
29 records: Vec<MmapRecordInfo>,
31 mmaps: Vec<Mmap>,
33}
34
35struct MmapRecordInfo {
36 mmap_idx: usize,
37 metadata: SequenceMetadata,
38 fai_offset: u64,
40 fai_line_bases: u64,
42 fai_line_width: u64,
44}
45
46impl MmapSequenceStore {
47 pub fn new() -> Self {
49 Self { digest_index: HashMap::new(), records: Vec::new(), mmaps: Vec::new() }
50 }
51
52 pub fn mark_circular(&mut self, circular_names: &[String]) {
54 for record in &mut self.records {
55 if circular_names.iter().any(|n| n.as_str() == record.metadata.aliases[0].value) {
56 record.metadata.circular = true;
57 }
58 }
59 }
60
61 pub fn add_fasta<P: AsRef<Path>>(&mut self, path: P) -> StoreResult<Vec<FastaSequenceSummary>> {
66 let path = path.as_ref();
67
68 let cache = DigestCache::load_if_fresh(path).ok_or_else(|| {
70 StoreError::Fasta(format!(
71 "Disk mode requires a fresh .refget.json cache for {}. \
72 Run `refget-tools cache {}` first.",
73 path.display(),
74 path.display()
75 ))
76 })?;
77
78 let index = read_fai_index(path)?;
80 let fai_records: &[fai::Record] = index.as_ref();
81
82 if cache.sequences.len() != fai_records.len() {
83 return Err(StoreError::Fasta(format!(
84 "Digest cache has {} entries but FAI index has {} for {}",
85 cache.sequences.len(),
86 fai_records.len(),
87 path.display()
88 )));
89 }
90
91 let file = std::fs::File::open(path)?;
96 let mmap = unsafe { Mmap::map(&file) }
97 .map_err(|e| StoreError::Fasta(format!("Failed to mmap {}: {e}", path.display())))?;
98 let mmap_idx = self.mmaps.len();
99 self.mmaps.push(mmap);
100
101 let mut summaries = Vec::new();
102
103 for (cached, fai_rec) in cache.sequences.iter().zip(fai_records.iter()) {
104 if fai_rec.length() != cached.length {
106 return Err(StoreError::Fasta(format!(
107 "FAI/cache length mismatch for {}: FAI says {}, cache says {}",
108 cached.name,
109 fai_rec.length(),
110 cached.length
111 )));
112 }
113
114 let metadata = cached.to_metadata();
115 summaries.push(cached.to_summary());
116
117 let record_idx = self.records.len();
118 index_digests(&mut self.digest_index, cached, record_idx);
119
120 self.records.push(MmapRecordInfo {
121 mmap_idx,
122 metadata,
123 fai_offset: fai_rec.offset(),
124 fai_line_bases: fai_rec.line_bases(),
125 fai_line_width: fai_rec.line_width(),
126 });
127 }
128
129 Ok(summaries)
130 }
131
132 fn extract_bases(&self, info: &MmapRecordInfo, start: u64, end: u64) -> Vec<u8> {
136 let mmap = &self.mmaps[info.mmap_idx];
137 let len = (end - start) as usize;
138 let mut result = Vec::with_capacity(len);
139 let mut pos = start;
140
141 while pos < end {
142 let line_idx = pos / info.fai_line_bases;
143 let col = pos % info.fai_line_bases;
144 let line_start = info.fai_offset + line_idx * info.fai_line_width;
145 let remaining_on_line = info.fai_line_bases - col;
147 let to_read = remaining_on_line.min(end - pos) as usize;
148 let byte_start = (line_start + col) as usize;
149
150 for &b in &mmap[byte_start..byte_start + to_read] {
151 result.push(b.to_ascii_uppercase());
152 }
153 pos += to_read as u64;
154 }
155
156 result
157 }
158}
159
160impl Default for MmapSequenceStore {
161 fn default() -> Self {
162 Self::new()
163 }
164}
165
166impl SequenceStore for MmapSequenceStore {
167 fn get_sequence(
168 &self,
169 digest: &str,
170 start: Option<u64>,
171 end: Option<u64>,
172 ) -> StoreResult<Option<Vec<u8>>> {
173 let Some(&record_idx) = self.digest_index.get(digest) else {
174 return Ok(None);
175 };
176 let info = &self.records[record_idx];
177 let length = info.metadata.length;
178
179 let start = start.unwrap_or(0);
180 let end = end.unwrap_or(length).min(length);
181
182 if start >= length {
183 return Ok(Some(vec![]));
184 }
185
186 Ok(Some(self.extract_bases(info, start, end)))
187 }
188
189 fn get_metadata(&self, digest: &str) -> StoreResult<Option<SequenceMetadata>> {
190 Ok(self.digest_index.get(digest).map(|&idx| self.records[idx].metadata.clone()))
191 }
192
193 fn get_length(&self, digest: &str) -> StoreResult<Option<u64>> {
194 Ok(self.digest_index.get(digest).map(|&idx| self.records[idx].metadata.length))
195 }
196}
197
198#[cfg(test)]
199mod tests {
200 use super::*;
201 use std::io::Write;
202 use tempfile::TempDir;
203
204 fn write_test_fasta_with_cache(dir: &TempDir) -> std::path::PathBuf {
206 let fasta_path = dir.path().join("test.fa");
207 let mut f = std::fs::File::create(&fasta_path).unwrap();
208 writeln!(f, ">seq1").unwrap();
210 writeln!(f, "ACGT").unwrap();
211 writeln!(f, "NNNN").unwrap();
212 writeln!(f, ">seq2").unwrap();
214 writeln!(f, "TTTT").unwrap();
215
216 let fai_path = dir.path().join("test.fa.fai");
218 let mut fai = std::fs::File::create(&fai_path).unwrap();
219 writeln!(fai, "seq1\t8\t6\t4\t5").unwrap();
220 writeln!(fai, "seq2\t4\t22\t4\t5").unwrap();
221
222 let cache = DigestCache::from_fasta(&fasta_path).unwrap();
224 cache.write(&fasta_path).unwrap();
225
226 fasta_path
227 }
228
229 #[test]
230 fn test_mmap_load_and_retrieve() {
231 let dir = TempDir::new().unwrap();
232 let fasta_path = write_test_fasta_with_cache(&dir);
233
234 let mut store = MmapSequenceStore::new();
235 let summaries = store.add_fasta(&fasta_path).unwrap();
236 assert_eq!(summaries.len(), 2);
237
238 let seq = store.get_sequence(&summaries[0].sha512t24u, None, None).unwrap().unwrap();
239 assert_eq!(seq, b"ACGTNNNN");
240
241 let seq = store.get_sequence(&summaries[1].sha512t24u, None, None).unwrap().unwrap();
242 assert_eq!(seq, b"TTTT");
243 }
244
245 #[test]
246 fn test_mmap_subsequence() {
247 let dir = TempDir::new().unwrap();
248 let fasta_path = write_test_fasta_with_cache(&dir);
249
250 let mut store = MmapSequenceStore::new();
251 let summaries = store.add_fasta(&fasta_path).unwrap();
252
253 let seq = store.get_sequence(&summaries[0].sha512t24u, Some(1), Some(3)).unwrap().unwrap();
255 assert_eq!(seq, b"CG");
256
257 let seq = store.get_sequence(&summaries[0].sha512t24u, Some(2), Some(6)).unwrap().unwrap();
259 assert_eq!(seq, b"GTNN");
260 }
261
262 #[test]
263 fn test_mmap_metadata() {
264 let dir = TempDir::new().unwrap();
265 let fasta_path = write_test_fasta_with_cache(&dir);
266
267 let mut store = MmapSequenceStore::new();
268 let summaries = store.add_fasta(&fasta_path).unwrap();
269
270 let meta = store.get_metadata(&summaries[0].sha512t24u).unwrap().unwrap();
271 assert_eq!(meta.length, 8);
272 assert!(meta.sha512t24u.starts_with("SQ."));
273 }
274
275 #[test]
276 fn test_mmap_get_length() {
277 let dir = TempDir::new().unwrap();
278 let fasta_path = write_test_fasta_with_cache(&dir);
279
280 let mut store = MmapSequenceStore::new();
281 let summaries = store.add_fasta(&fasta_path).unwrap();
282
283 assert_eq!(store.get_length(&summaries[0].sha512t24u).unwrap(), Some(8));
284 assert_eq!(store.get_length(&summaries[1].sha512t24u).unwrap(), Some(4));
285 assert_eq!(store.get_length("nonexistent").unwrap(), None);
286 }
287
288 #[test]
289 fn test_mmap_not_found() {
290 let store = MmapSequenceStore::new();
291 assert!(store.get_sequence("missing", None, None).unwrap().is_none());
292 assert!(store.get_metadata("missing").unwrap().is_none());
293 }
294
295 #[test]
296 fn test_mmap_start_beyond_length() {
297 let dir = TempDir::new().unwrap();
298 let fasta_path = write_test_fasta_with_cache(&dir);
299
300 let mut store = MmapSequenceStore::new();
301 let summaries = store.add_fasta(&fasta_path).unwrap();
302
303 let seq = store.get_sequence(&summaries[0].sha512t24u, Some(100), None).unwrap().unwrap();
304 assert!(seq.is_empty());
305 }
306
307 #[test]
308 fn test_mmap_requires_cache() {
309 let dir = TempDir::new().unwrap();
310 let fasta_path = dir.path().join("nocache.fa");
311 let mut f = std::fs::File::create(&fasta_path).unwrap();
312 writeln!(f, ">seq1\nACGT").unwrap();
313 let mut fai = std::fs::File::create(dir.path().join("nocache.fa.fai")).unwrap();
314 writeln!(fai, "seq1\t4\t6\t4\t5").unwrap();
315
316 let mut store = MmapSequenceStore::new();
317 let err = store.add_fasta(&fasta_path).unwrap_err();
318 let msg = format!("{err}");
319 assert!(msg.contains("requires a fresh .refget.json"), "Unexpected: {msg}");
320 }
321
322 #[test]
323 fn test_mmap_multiple_fastas() {
324 let dir = TempDir::new().unwrap();
325
326 let fa1 = dir.path().join("a.fa");
327 let mut f = std::fs::File::create(&fa1).unwrap();
328 writeln!(f, ">s1\nAAAA").unwrap();
329 let mut fai = std::fs::File::create(dir.path().join("a.fa.fai")).unwrap();
330 writeln!(fai, "s1\t4\t4\t4\t5").unwrap();
331 DigestCache::from_fasta(&fa1).unwrap().write(&fa1).unwrap();
332
333 let fa2 = dir.path().join("b.fa");
334 let mut f = std::fs::File::create(&fa2).unwrap();
335 writeln!(f, ">s2\nCCCC").unwrap();
336 let mut fai = std::fs::File::create(dir.path().join("b.fa.fai")).unwrap();
337 writeln!(fai, "s2\t4\t4\t4\t5").unwrap();
338 DigestCache::from_fasta(&fa2).unwrap().write(&fa2).unwrap();
339
340 let mut store = MmapSequenceStore::new();
341 let s1 = store.add_fasta(&fa1).unwrap();
342 let s2 = store.add_fasta(&fa2).unwrap();
343
344 let seq1 = store.get_sequence(&s1[0].sha512t24u, None, None).unwrap().unwrap();
345 assert_eq!(seq1, b"AAAA");
346 let seq2 = store.get_sequence(&s2[0].sha512t24u, None, None).unwrap().unwrap();
347 assert_eq!(seq2, b"CCCC");
348 }
349
350 #[test]
351 fn test_mmap_lowercase_uppercased() {
352 let dir = TempDir::new().unwrap();
353 let fasta_path = dir.path().join("lower.fa");
354 let mut f = std::fs::File::create(&fasta_path).unwrap();
355 writeln!(f, ">seq1\nacgt").unwrap();
356 let mut fai = std::fs::File::create(dir.path().join("lower.fa.fai")).unwrap();
357 writeln!(fai, "seq1\t4\t6\t4\t5").unwrap();
358
359 DigestCache::from_fasta(&fasta_path).unwrap().write(&fasta_path).unwrap();
360
361 let mut store = MmapSequenceStore::new();
362 let summaries = store.add_fasta(&fasta_path).unwrap();
363
364 let seq = store.get_sequence(&summaries[0].sha512t24u, None, None).unwrap().unwrap();
365 assert_eq!(seq, b"ACGT");
366 }
367
368 #[test]
369 fn test_mmap_matches_memory_store() {
370 let dir = TempDir::new().unwrap();
371 let fasta_path = write_test_fasta_with_cache(&dir);
372
373 let (mem_store, mem_summaries) =
375 crate::fasta::FastaSequenceStore::from_fasta(&fasta_path).unwrap();
376 let mut mmap_store = MmapSequenceStore::new();
377 let mmap_summaries = mmap_store.add_fasta(&fasta_path).unwrap();
378
379 assert_eq!(mem_summaries.len(), mmap_summaries.len());
381 for (a, b) in mem_summaries.iter().zip(mmap_summaries.iter()) {
382 assert_eq!(a.name, b.name);
383 assert_eq!(a.length, b.length);
384 assert_eq!(a.md5, b.md5);
385 assert_eq!(a.sha512t24u, b.sha512t24u);
386 }
387
388 for s in &mem_summaries {
390 let mem_seq = mem_store.get_sequence(&s.sha512t24u, None, None).unwrap().unwrap();
391 let mmap_seq = mmap_store.get_sequence(&s.sha512t24u, None, None).unwrap().unwrap();
392 assert_eq!(mem_seq, mmap_seq, "Mismatch for {}", s.name);
393 }
394 }
395}