Skip to main content

argenus/
classifier.rs

1
2use anyhow::{Context, Result};
3use rustc_hash::FxHashMap;
4use std::fs::File;
5use std::io::{BufRead, BufReader, BufWriter, Read, Seek, SeekFrom, Write};
6use std::path::Path;
7use std::process::Command;
8
9use crate::snp::{self, SnpStatus};
10
11const FDB_MAGIC: &[u8; 8] = b"FLANKDB\0";
12
13#[derive(Debug, Clone)]
14pub struct ArgPosition {
15
16    pub arg_name: String,
17
18    pub contig_name: String,
19
20    pub contig_seq: String,
21
22    pub contig_len: usize,
23
24    pub arg_start: usize,
25
26    pub arg_end: usize,
27
28    pub strand: char,
29}
30
31#[derive(Debug, Clone)]
32pub struct GenusResult {
33
34    pub arg_name: String,
35
36    pub contig_name: String,
37
38    pub genus: Option<String>,
39
40    pub confidence: f64,
41
42    pub specificity: f64,
43
44    pub upstream_len: usize,
45
46    pub downstream_len: usize,
47
48    pub top_matches: Vec<(String, f64)>,
49
50    pub snp_status: SnpStatus,
51}
52
53impl Default for GenusResult {
54    fn default() -> Self {
55        Self {
56            arg_name: String::new(),
57            contig_name: String::new(),
58            genus: None,
59            confidence: 0.0,
60            specificity: 0.0,
61            upstream_len: 0,
62            downstream_len: 0,
63            top_matches: vec![],
64            snp_status: SnpStatus::NotApplicable,
65        }
66    }
67}
68
69#[derive(Debug, Clone)]
70pub struct FlankingRecord {
71
72    pub contig: String,
73
74    pub genus: String,
75
76    pub upstream: String,
77
78    pub downstream: String,
79}
80
81#[derive(Debug, Clone)]
82struct FdbIndexEntry {
83    offset: u64,
84    compressed_len: u32,
85    record_count: u32,
86}
87
88pub struct FlankingDatabase {
89    file: File,
90    index: FxHashMap<String, FdbIndexEntry>,
91
92    gene_name_to_key: FxHashMap<String, String>,
93}
94
95impl FlankingDatabase {
96
97    pub fn open<P: AsRef<Path>>(path: P) -> Result<Self> {
98        let mut file = File::open(path.as_ref())
99            .with_context(|| format!("Failed to open fdb: {}", path.as_ref().display()))?;
100
101        let mut magic = [0u8; 8];
102        file.read_exact(&mut magic)?;
103        if &magic != FDB_MAGIC {
104            anyhow::bail!("Invalid fdb magic");
105        }
106
107        let mut buf4 = [0u8; 4];
108        let mut buf8 = [0u8; 8];
109
110        file.read_exact(&mut buf4)?;
111        let _version = u32::from_le_bytes(buf4);
112
113        file.read_exact(&mut buf4)?;
114        let gene_count = u32::from_le_bytes(buf4);
115
116        file.read_exact(&mut buf8)?;
117        let index_offset = u64::from_le_bytes(buf8);
118
119        file.seek(SeekFrom::Start(index_offset))?;
120        let mut index = FxHashMap::default();
121
122        for _ in 0..gene_count {
123            let mut buf2 = [0u8; 2];
124            file.read_exact(&mut buf2)?;
125            let name_len = u16::from_le_bytes(buf2) as usize;
126
127            let mut name_buf = vec![0u8; name_len];
128            file.read_exact(&mut name_buf)?;
129            let gene = String::from_utf8(name_buf)?;
130
131            file.read_exact(&mut buf8)?;
132            let offset = u64::from_le_bytes(buf8);
133
134            file.read_exact(&mut buf4)?;
135            let compressed_len = u32::from_le_bytes(buf4);
136
137            file.read_exact(&mut buf4)?;
138            let record_count = u32::from_le_bytes(buf4);
139
140            index.insert(gene, FdbIndexEntry {
141                offset,
142                compressed_len,
143                record_count,
144            });
145        }
146
147        let mut gene_name_to_key = FxHashMap::default();
148        for full_key in index.keys() {
149
150            let gene_name = full_key.split('|').next().unwrap_or(full_key);
151
152            if !gene_name_to_key.contains_key(gene_name) {
153                gene_name_to_key.insert(gene_name.to_string(), full_key.clone());
154            }
155        }
156
157        Ok(Self { file, index, gene_name_to_key })
158    }
159
160    pub fn has_gene(&self, gene: &str) -> bool {
161
162        if self.index.contains_key(gene) {
163            return true;
164        }
165
166        self.gene_name_to_key.contains_key(gene)
167    }
168
169    fn resolve_gene_key(&self, gene: &str) -> Option<&String> {
170        if self.index.contains_key(gene) {
171
172            None
173        } else {
174
175            self.gene_name_to_key.get(gene)
176        }
177    }
178
179    pub fn get_gene_records(&mut self, gene: &str) -> Result<Vec<FlankingRecord>> {
180
181        let lookup_key = if self.index.contains_key(gene) {
182            gene.to_string()
183        } else if let Some(full_key) = self.gene_name_to_key.get(gene) {
184            full_key.clone()
185        } else {
186            anyhow::bail!("Gene not found: {}", gene);
187        };
188
189        let entry = self.index.get(&lookup_key)
190            .ok_or_else(|| anyhow::anyhow!("Gene not found in index: {}", lookup_key))?
191            .clone();
192
193        self.file.seek(SeekFrom::Start(entry.offset))?;
194        let mut compressed = vec![0u8; entry.compressed_len as usize];
195        self.file.read_exact(&mut compressed)?;
196
197        let decompressed = zstd::decode_all(&compressed[..])?;
198        let content = String::from_utf8(decompressed)?;
199
200        let mut records = Vec::with_capacity(entry.record_count as usize);
201        let mut lines = content.lines();
202
203        let _header = lines.next();
204
205        for line in lines {
206            if line.is_empty() {
207                continue;
208            }
209            let fields: Vec<&str> = line.split('\t').collect();
210
211            if fields.len() < 7 {
212                continue;
213            }
214
215            records.push(FlankingRecord {
216                contig: fields[1].to_string(),
217                genus: fields[2].to_string(),
218                upstream: fields[5].to_string(),
219                downstream: fields[6].to_string(),
220            });
221        }
222
223        Ok(records)
224    }
225
226    pub fn get_genus_distribution(&mut self, gene: &str) -> Result<FxHashMap<String, usize>> {
227        let records = self.get_gene_records(gene)?;
228        let mut dist: FxHashMap<String, usize> = FxHashMap::default();
229
230        for rec in records {
231            *dist.entry(rec.genus).or_default() += 1;
232        }
233
234        Ok(dist)
235    }
236}
237
238pub struct GenusClassifier {
239    db: FlankingDatabase,
240    minimap2_path: String,
241    min_identity: f64,
242    min_align_len: usize,
243    max_flanking: usize,
244}
245
246impl GenusClassifier {
247
248    pub fn new<P: AsRef<Path>>(
249        db_path: P,
250        minimap2_path: &str,
251        min_identity: f64,
252        min_align_len: usize,
253        max_flanking: usize,
254    ) -> Result<Self> {
255        let db = FlankingDatabase::open(db_path)?;
256        Ok(Self {
257            db,
258            minimap2_path: minimap2_path.to_string(),
259            min_identity,
260            min_align_len,
261            max_flanking,
262        })
263    }
264
265    pub fn classify_batch(&mut self, positions: &[ArgPosition], threads: usize) -> Result<Vec<GenusResult>> {
266        let mut results = Vec::with_capacity(positions.len());
267
268        for pos in positions {
269            let result = self.classify_single(pos, threads)?;
270            results.push(result);
271        }
272
273        Ok(results)
274    }
275
276    pub fn classify_single(&mut self, pos: &ArgPosition, threads: usize) -> Result<GenusResult> {
277
278        let (upstream, downstream) = self.extract_flanking_regions(pos);
279
280        let upstream_len = upstream.len();
281        let downstream_len = downstream.len();
282
283        let snp_status = snp::verify_snp(
284            &pos.contig_seq,
285            &pos.arg_name,
286            0,
287            pos.arg_end - pos.arg_start,
288            pos.arg_start,
289            pos.arg_end,
290            pos.strand,
291        );
292
293        if upstream_len < 50 && downstream_len < 50 {
294            return Ok(GenusResult {
295                arg_name: pos.arg_name.clone(),
296                contig_name: pos.contig_name.clone(),
297                genus: None,
298                confidence: 0.0,
299                specificity: 0.0,
300                upstream_len,
301                downstream_len,
302                top_matches: vec![],
303                snp_status,
304            });
305        }
306
307        if !self.db.has_gene(&pos.arg_name) {
308            return Ok(GenusResult {
309                arg_name: pos.arg_name.clone(),
310                contig_name: pos.contig_name.clone(),
311                genus: None,
312                confidence: 0.0,
313                specificity: 0.0,
314                upstream_len,
315                downstream_len,
316                top_matches: vec![("gene_not_in_db".to_string(), 0.0)],
317                snp_status,
318            });
319        }
320
321        let ref_records = self.db.get_gene_records(&pos.arg_name)?;
322        if ref_records.is_empty() {
323            return Ok(GenusResult {
324                arg_name: pos.arg_name.clone(),
325                contig_name: pos.contig_name.clone(),
326                genus: None,
327                confidence: 0.0,
328                specificity: 0.0,
329                upstream_len,
330                downstream_len,
331                top_matches: vec![("no_ref_records".to_string(), 0.0)],
332                snp_status,
333            });
334        }
335
336        let temp_dir = std::env::temp_dir();
337        let pid = std::process::id();
338        let query_path = temp_dir.join(format!("argenus_query_{}.fas", pid));
339        let ref_path = temp_dir.join(format!("argenus_ref_{}.fas", pid));
340        let paf_path = temp_dir.join(format!("argenus_align_{}.paf", pid));
341
342        {
343            let mut query_file = BufWriter::new(File::create(&query_path)?);
344            if !upstream.is_empty() {
345                writeln!(query_file, ">upstream")?;
346                writeln!(query_file, "{}", upstream)?;
347            }
348            if !downstream.is_empty() {
349                writeln!(query_file, ">downstream")?;
350                writeln!(query_file, "{}", downstream)?;
351            }
352        }
353
354        {
355            let mut ref_file = BufWriter::new(File::create(&ref_path)?);
356            for (i, rec) in ref_records.iter().enumerate() {
357                if !rec.upstream.is_empty() {
358                    writeln!(ref_file, ">{}|{}|up_{}", rec.genus, rec.contig, i)?;
359                    writeln!(ref_file, "{}", rec.upstream)?;
360                }
361                if !rec.downstream.is_empty() {
362                    writeln!(ref_file, ">{}|{}|down_{}", rec.genus, rec.contig, i)?;
363                    writeln!(ref_file, "{}", rec.downstream)?;
364                }
365            }
366        }
367
368        let output = Command::new(&self.minimap2_path)
369            .args(["-x", "sr", "-t", &threads.to_string(), "-c", "--secondary=yes", "-N", "100", "-k", "15", "-w", "5"])
370            .arg(&ref_path)
371            .arg(&query_path)
372            .arg("-o").arg(&paf_path)
373            .stderr(std::process::Stdio::null())
374            .output()
375            .context("Failed to run minimap2")?;
376
377        if !output.status.success() {
378
379            let _ = std::fs::remove_file(&query_path);
380            let _ = std::fs::remove_file(&ref_path);
381            let _ = std::fs::remove_file(&paf_path);
382
383            return Ok(GenusResult {
384                arg_name: pos.arg_name.clone(),
385                contig_name: pos.contig_name.clone(),
386                genus: None,
387                confidence: 0.0,
388                specificity: 0.0,
389                upstream_len,
390                downstream_len,
391                top_matches: vec![("minimap2_failed".to_string(), 0.0)],
392                snp_status,
393            });
394        }
395
396        let genus_scores = self.calculate_genus_scores(&paf_path)?;
397
398        let _ = std::fs::remove_file(&query_path);
399        let _ = std::fs::remove_file(&ref_path);
400        let _ = std::fs::remove_file(&paf_path);
401
402        let genus_dist = self.db.get_genus_distribution(&pos.arg_name)?;
403        let total_in_db: usize = genus_dist.values().sum();
404
405        let mut sorted_scores: Vec<(String, f64)> = genus_scores.into_iter().collect();
406        sorted_scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
407
408        let (genus, confidence, specificity) = if let Some((best_genus, best_score)) = sorted_scores.first() {
409            let genus_count = genus_dist.get(best_genus).copied().unwrap_or(0);
410            let specificity = if total_in_db > 0 {
411                (genus_count as f64 / total_in_db as f64) * 100.0
412            } else {
413                0.0
414            };
415
416            (Some(best_genus.clone()), *best_score, specificity)
417        } else {
418            (None, 0.0, 0.0)
419        };
420
421        let top_matches: Vec<(String, f64)> = sorted_scores.into_iter().take(5).collect();
422
423        Ok(GenusResult {
424            arg_name: pos.arg_name.clone(),
425            contig_name: pos.contig_name.clone(),
426            genus,
427            confidence,
428            specificity,
429            upstream_len,
430            downstream_len,
431            top_matches,
432            snp_status,
433        })
434    }
435
436    fn extract_flanking_regions(&self, pos: &ArgPosition) -> (String, String) {
437        let seq = &pos.contig_seq;
438
439        let upstream_end = pos.arg_start;
440        let upstream_start = upstream_end.saturating_sub(self.max_flanking);
441        let upstream = if upstream_end > upstream_start {
442            seq[upstream_start..upstream_end].to_string()
443        } else {
444            String::new()
445        };
446
447        let downstream_start = pos.arg_end;
448        let downstream_end = (downstream_start + self.max_flanking).min(seq.len());
449        let downstream = if downstream_end > downstream_start {
450            seq[downstream_start..downstream_end].to_string()
451        } else {
452            String::new()
453        };
454
455        if pos.strand == '-' {
456            (reverse_complement(&downstream), reverse_complement(&upstream))
457        } else {
458            (upstream, downstream)
459        }
460    }
461
462    fn calculate_genus_scores(&self, paf_path: &Path) -> Result<FxHashMap<String, f64>> {
463        let file = File::open(paf_path)?;
464        let reader = BufReader::new(file);
465
466        let mut genus_matches: FxHashMap<String, Vec<f64>> = FxHashMap::default();
467        let min_identity_pct = self.min_identity * 100.0;
468
469        for line in reader.lines() {
470            let line = line?;
471            let fields: Vec<&str> = line.split('\t').collect();
472            if fields.len() < 12 {
473                continue;
474            }
475
476            let block_len: usize = fields[10].parse().unwrap_or(0);
477            let matches: usize = fields[9].parse().unwrap_or(0);
478
479            if block_len < self.min_align_len {
480                continue;
481            }
482
483            let identity = if block_len > 0 {
484                (matches as f64 / block_len as f64) * 100.0
485            } else {
486                0.0
487            };
488
489            if identity < min_identity_pct {
490                continue;
491            }
492
493            let target_name = fields[5];
494            if let Some(genus) = target_name.split('|').next() {
495                genus_matches.entry(genus.to_string()).or_default().push(identity);
496            }
497        }
498
499        let mut genus_scores: FxHashMap<String, f64> = FxHashMap::default();
500        for (genus, scores) in genus_matches {
501            if scores.is_empty() {
502                continue;
503            }
504
505            let avg_identity = scores.iter().sum::<f64>() / scores.len() as f64;
506            let count_bonus = (scores.len() as f64).ln().max(1.0);
507            genus_scores.insert(genus, avg_identity * count_bonus / count_bonus.max(1.0));
508        }
509
510        Ok(genus_scores)
511    }
512
513}
514
515fn reverse_complement(seq: &str) -> String {
516    seq.chars()
517        .rev()
518        .map(|c| match c.to_ascii_uppercase() {
519            'A' => 'T',
520            'T' => 'A',
521            'G' => 'C',
522            'C' => 'G',
523            _ => 'N',
524        })
525        .collect()
526}
527
528#[cfg(test)]
529mod tests {
530    use super::*;
531
532    #[test]
533    fn test_reverse_complement() {
534        assert_eq!(reverse_complement("ATGC"), "GCAT");
535        assert_eq!(reverse_complement("AAAA"), "TTTT");
536        assert_eq!(reverse_complement(""), "");
537    }
538
539    #[test]
540    fn test_genus_result_default() {
541        let result = GenusResult::default();
542        assert!(result.genus.is_none());
543        assert_eq!(result.confidence, 0.0);
544    }
545}