1use std::fs::File;
2use std::io::BufReader;
3use std::path::Path;
4
5use memmap2::Mmap;
6
7use crate::error::{Error, Result};
8use crate::header::FileHeader;
9use crate::matrix::RectangularBinaryMatrix;
10use crate::mer::MerDna;
11
12pub struct QueryMerFile {
30 mmap: Mmap,
31 data_offset: usize,
32 key_len_bytes: usize,
33 val_len_bytes: usize,
34 record_len: usize,
35 k: usize,
36 matrix: RectangularBinaryMatrix,
37 size_mask: u64,
38 num_records: usize,
39 header: FileHeader,
40}
41
42impl QueryMerFile {
43 pub fn open<P: AsRef<Path>>(path: P) -> Result<Self> {
45 let file = File::open(path.as_ref())?;
46
47 let mut buf_reader = BufReader::new(File::open(path.as_ref())?);
49 let header = FileHeader::read(&mut buf_reader)?;
50
51 if !header.is_binary() {
52 return Err(Error::UnsupportedFormat(format!(
53 "QueryMerFile requires binary/sorted format, got {:?}",
54 header.format()
55 )));
56 }
57
58 let key_len_bytes = header
59 .key_bytes()
60 .ok_or_else(|| Error::MissingField("key_len".to_string()))?;
61 let val_len_bytes = header
62 .data_val_len()
63 .ok_or_else(|| Error::MissingField("counter_len or val_len".to_string()))?;
64 let k = header
65 .k()
66 .ok_or_else(|| Error::MissingField("key_len".to_string()))?;
67 let size = header
68 .size()
69 .ok_or_else(|| Error::MissingField("size".to_string()))?;
70
71 let matrix = header.matrix(0)?;
72
73 let record_len = key_len_bytes + val_len_bytes;
74 let data_offset = header.offset();
75
76 let mmap = unsafe { Mmap::map(&file)? };
78 let file_data_len = mmap.len().saturating_sub(data_offset);
79 let num_records = if record_len > 0 {
80 file_data_len / record_len
81 } else {
82 0
83 };
84
85 Ok(Self {
86 mmap,
87 data_offset,
88 key_len_bytes,
89 val_len_bytes,
90 record_len,
91 k,
92 matrix,
93 size_mask: size - 1,
94 num_records,
95 header,
96 })
97 }
98
99 pub fn get(&self, mer: &MerDna) -> Option<u64> {
101 if mer.k() != self.k || self.num_records == 0 {
102 return None;
103 }
104
105 let hash_pos = self.matrix.times(mer.words()) & self.size_mask;
107
108 self.binary_search_record(mer, hash_pos)
111 }
112
113 pub fn query(&self, kmer_str: &str) -> Result<Option<u64>> {
115 let mer: MerDna = kmer_str.parse()?;
116 Ok(self.get(&mer))
117 }
118
119 pub fn header(&self) -> &FileHeader {
121 &self.header
122 }
123
124 pub fn k(&self) -> usize {
126 self.k
127 }
128
129 pub fn num_records(&self) -> usize {
131 self.num_records
132 }
133
134 fn read_key_at(&self, index: usize) -> Option<MerDna> {
136 if index >= self.num_records {
137 return None;
138 }
139 let offset = self.data_offset + index * self.record_len;
140 let end = offset + self.key_len_bytes;
141 if end > self.mmap.len() {
142 return None;
143 }
144 Some(MerDna::from_bytes(&self.mmap[offset..end], self.k))
145 }
146
147 fn read_val_at(&self, index: usize) -> Option<u64> {
149 if index >= self.num_records {
150 return None;
151 }
152 let offset = self.data_offset + index * self.record_len + self.key_len_bytes;
153 let end = offset + self.val_len_bytes;
154 if end > self.mmap.len() {
155 return None;
156 }
157 let mut count = 0u64;
158 for (i, &byte) in self.mmap[offset..end].iter().enumerate() {
159 count |= (byte as u64) << (i * 8);
160 }
161 Some(count)
162 }
163
164 fn binary_search_record(&self, mer: &MerDna, _hint_pos: u64) -> Option<u64> {
166 if self.num_records == 0 {
167 return None;
168 }
169
170 let mut lo = 0usize;
171 let mut hi = self.num_records;
172
173 while lo < hi {
174 let mid = lo + (hi - lo) / 2;
175 let key = self.read_key_at(mid)?;
176 match key.cmp(mer) {
177 std::cmp::Ordering::Equal => return self.read_val_at(mid),
178 std::cmp::Ordering::Less => lo = mid + 1,
179 std::cmp::Ordering::Greater => hi = mid,
180 }
181 }
182
183 None
184 }
185}
186
187#[cfg(test)]
188mod tests {
189 use super::*;
190 use std::io::Write;
191 use tempfile::NamedTempFile;
192
193 fn create_test_jf_file(k: usize, val_len: usize, records: &[(&str, u64)]) -> NamedTempFile {
195 let key_len_bits = k * 2;
196 let key_len_bytes = (key_len_bits + 7) / 8;
197
198 let mut sorted_records: Vec<(MerDna, u64)> = records
200 .iter()
201 .map(|(s, c)| (s.parse::<MerDna>().unwrap(), *c))
202 .collect();
203 sorted_records.sort_by(|a, b| a.0.cmp(&b.0));
204
205 let size = sorted_records.len().next_power_of_two().max(2);
206 let json = serde_json::json!({
207 "format": "binary/sorted",
208 "key_len": key_len_bits,
209 "val_len": val_len,
210 "counter_len": val_len,
211 "size": size,
212 "canonical": false,
213 "max_reprobe": 126
214 });
215 let json_str = serde_json::to_string(&json).unwrap();
216 let header_len = json_str.len();
217
218 let mut file = NamedTempFile::new().unwrap();
219
220 write!(file, "{:09}", header_len).unwrap();
222 file.write_all(json_str.as_bytes()).unwrap();
224
225 for (mer, count) in &sorted_records {
227 let words = mer.words();
229 let mut bytes_written = 0;
230 for &word in words {
231 for byte_idx in 0..8 {
232 if bytes_written >= key_len_bytes {
233 break;
234 }
235 file.write_all(&[(word >> (byte_idx * 8)) as u8]).unwrap();
236 bytes_written += 1;
237 }
238 }
239
240 for i in 0..val_len {
242 file.write_all(&[(count >> (i * 8)) as u8]).unwrap();
243 }
244 }
245
246 file.flush().unwrap();
247 file
248 }
249
250 #[test]
251 fn test_open_and_query_single() {
252 let file = create_test_jf_file(4, 4, &[("ACGT", 42)]);
253 let qf = QueryMerFile::open(file.path()).unwrap();
254 assert_eq!(qf.k(), 4);
255 assert_eq!(qf.num_records(), 1);
256
257 let mer: MerDna = "ACGT".parse().unwrap();
258 assert_eq!(qf.get(&mer), Some(42));
259 }
260
261 #[test]
262 fn test_query_not_found() {
263 let file = create_test_jf_file(4, 4, &[("ACGT", 42)]);
264 let qf = QueryMerFile::open(file.path()).unwrap();
265
266 let mer: MerDna = "TTTT".parse().unwrap();
267 assert_eq!(qf.get(&mer), None);
268 }
269
270 #[test]
271 fn test_query_multiple() {
272 let file = create_test_jf_file(
273 4,
274 4,
275 &[("AAAA", 10), ("ACGT", 42), ("CCCC", 7), ("TTTT", 100)],
276 );
277 let qf = QueryMerFile::open(file.path()).unwrap();
278
279 assert_eq!(qf.get(&"AAAA".parse::<MerDna>().unwrap()), Some(10));
280 assert_eq!(qf.get(&"ACGT".parse::<MerDna>().unwrap()), Some(42));
281 assert_eq!(qf.get(&"CCCC".parse::<MerDna>().unwrap()), Some(7));
282 assert_eq!(qf.get(&"TTTT".parse::<MerDna>().unwrap()), Some(100));
283 assert_eq!(qf.get(&"GGGG".parse::<MerDna>().unwrap()), None);
284 }
285
286 #[test]
287 fn test_query_string_convenience() {
288 let file = create_test_jf_file(4, 4, &[("ACGT", 42)]);
289 let qf = QueryMerFile::open(file.path()).unwrap();
290
291 assert_eq!(qf.query("ACGT").unwrap(), Some(42));
292 assert_eq!(qf.query("TTTT").unwrap(), None);
293 }
294
295 #[test]
296 fn test_query_wrong_k() {
297 let file = create_test_jf_file(4, 4, &[("ACGT", 42)]);
298 let qf = QueryMerFile::open(file.path()).unwrap();
299
300 let mer: MerDna = "ACGTACGT".parse().unwrap(); assert_eq!(qf.get(&mer), None);
303 }
304
305 #[test]
306 fn test_query_longer_kmer() {
307 let seq = "ACGTACGTACGTACGTACGTACGTA"; let file = create_test_jf_file(25, 4, &[(seq, 99)]);
309 let qf = QueryMerFile::open(file.path()).unwrap();
310
311 assert_eq!(qf.get(&seq.parse::<MerDna>().unwrap()), Some(99));
312 }
313
314 #[test]
315 fn test_header_access() {
316 let file = create_test_jf_file(4, 4, &[("ACGT", 42)]);
317 let qf = QueryMerFile::open(file.path()).unwrap();
318
319 assert!(qf.header().is_binary());
320 assert_eq!(qf.header().k(), Some(4));
321 }
322
323 #[test]
324 fn test_empty_database() {
325 let file = create_test_jf_file(4, 4, &[]);
326 let qf = QueryMerFile::open(file.path()).unwrap();
327 assert_eq!(qf.num_records(), 0);
328 assert_eq!(qf.get(&"ACGT".parse::<MerDna>().unwrap()), None);
329 }
330}