Skip to main content

ref_solver/cli/
score.rs

1//! Score command - compare two files directly using the scoring algorithm.
2//!
3//! This command compares a query file against a reference file without using
4//! the reference catalog. Useful for comparing arbitrary files.
5
6use std::path::{Path, PathBuf};
7
8use clap::Args;
9
10use crate::cli::OutputFormat;
11use crate::core::header::QueryHeader;
12use crate::core::reference::KnownReference;
13use crate::core::types::{Assembly, ReferenceSource};
14use crate::matching::engine::ScoringWeights;
15use crate::matching::scoring::MatchScore;
16use crate::parsing;
17
18/// Arguments for the score command
19#[derive(Args)]
20pub struct ScoreArgs {
21    /// Query file (the file you want to score)
22    /// Supports: BAM, SAM, CRAM, FASTA, FAI, VCF, .dict, TSV, CSV
23    #[arg(required = true)]
24    pub query: PathBuf,
25
26    /// Reference file (the file to compare against)
27    /// Supports: BAM, SAM, CRAM, FASTA, FAI, VCF, .dict, TSV, CSV
28    #[arg(required = true)]
29    pub reference: PathBuf,
30
31    /// Also compute the reverse comparison (reference as query, query as reference).
32    /// By default, scoring is asymmetric: it measures how well the query matches
33    /// the reference. With --symmetric, both directions are computed.
34    #[arg(long)]
35    pub symmetric: bool,
36
37    // === Scoring weight options ===
38    /// Weight for contig match score (0-100, default 70)
39    /// How well query contigs match reference contigs
40    #[arg(long, default_value = "70", value_parser = clap::value_parser!(u32).range(0..=100))]
41    pub weight_match: u32,
42
43    /// Weight for coverage score (0-100, default 20)
44    /// What fraction of reference contigs are covered by query
45    #[arg(long, default_value = "20", value_parser = clap::value_parser!(u32).range(0..=100))]
46    pub weight_coverage: u32,
47
48    /// Weight for order score (0-100, default 10)
49    /// Whether contigs appear in the same order
50    #[arg(long, default_value = "10", value_parser = clap::value_parser!(u32).range(0..=100))]
51    pub weight_order: u32,
52}
53
54/// Result of scoring in one direction
55struct ScoreResult {
56    query_path: PathBuf,
57    reference_path: PathBuf,
58    query_header: QueryHeader,
59    reference_header: QueryHeader,
60    score: MatchScore,
61}
62
63/// Execute the score command
64///
65/// # Errors
66///
67/// Returns an error if inputs cannot be parsed or comparison fails.
68#[allow(clippy::needless_pass_by_value)]
69pub fn run(args: ScoreArgs, format: OutputFormat, verbose: bool) -> anyhow::Result<()> {
70    // Build scoring weights from command line args
71    let weights = ScoringWeights {
72        contig_match: f64::from(args.weight_match) / 100.0,
73        coverage: f64::from(args.weight_coverage) / 100.0,
74        order: f64::from(args.weight_order) / 100.0,
75        conflict_penalty: 0.1,
76    };
77
78    if verbose {
79        eprintln!(
80            "Scoring weights: {:.0}% match, {:.0}% coverage, {:.0}% order",
81            weights.contig_match * 100.0,
82            weights.coverage * 100.0,
83            weights.order * 100.0,
84        );
85    }
86
87    // Parse query file
88    let query_header = parse_input(&args.query)?;
89    if verbose {
90        eprintln!(
91            "Query: {} contigs ({:.0}% have MD5)",
92            query_header.contigs.len(),
93            query_header.md5_coverage() * 100.0,
94        );
95    }
96
97    // Parse reference file
98    let reference_header = parse_input(&args.reference)?;
99    if verbose {
100        eprintln!(
101            "Reference: {} contigs ({:.0}% have MD5)",
102            reference_header.contigs.len(),
103            reference_header.md5_coverage() * 100.0,
104        );
105    }
106
107    // Compute forward score (query vs reference)
108    let forward_result = compute_score(
109        args.query.clone(),
110        args.reference.clone(),
111        &query_header,
112        &reference_header,
113        &weights,
114    );
115
116    // Compute reverse score if symmetric
117    let reverse_result = if args.symmetric {
118        Some(compute_score(
119            args.reference.clone(),
120            args.query.clone(),
121            &reference_header,
122            &query_header,
123            &weights,
124        ))
125    } else {
126        None
127    };
128
129    // Output results
130    match format {
131        OutputFormat::Text => {
132            print_text_result(&forward_result, &weights, "");
133            if let Some(ref reverse) = reverse_result {
134                println!("\n{}", "─".repeat(60));
135                print_text_result(reverse, &weights, " (reverse)");
136            }
137        }
138        OutputFormat::Json => {
139            print_json_results(&forward_result, reverse_result.as_ref(), &weights)?;
140        }
141        OutputFormat::Tsv => {
142            print_tsv_results(&forward_result, reverse_result.as_ref(), &weights);
143        }
144    }
145
146    Ok(())
147}
148
149fn compute_score(
150    query_path: PathBuf,
151    reference_path: PathBuf,
152    query_header: &QueryHeader,
153    reference_header: &QueryHeader,
154    weights: &ScoringWeights,
155) -> ScoreResult {
156    // Convert reference header to KnownReference for scoring
157    let reference = KnownReference::new(
158        "reference",
159        reference_path.display().to_string().as_str(),
160        Assembly::Other("unknown".to_string()),
161        ReferenceSource::Custom("file".to_string()),
162    )
163    .with_contigs(reference_header.contigs.clone());
164
165    let score = MatchScore::calculate_with_weights(query_header, &reference, weights);
166
167    ScoreResult {
168        query_path,
169        reference_path,
170        query_header: query_header.clone(),
171        reference_header: reference_header.clone(),
172        score,
173    }
174}
175
176fn parse_input(path: &Path) -> anyhow::Result<QueryHeader> {
177    let ext = path
178        .extension()
179        .and_then(|e| e.to_str())
180        .map(str::to_lowercase);
181
182    match ext.as_deref() {
183        Some("dict") => Ok(parsing::dict::parse_dict_file(path)?),
184        Some("fai") => Ok(parsing::fai::parse_fai_file(path)?),
185        Some("fa" | "fasta" | "fna") => Ok(parsing::fasta::parse_fasta_file(path)?),
186        Some("vcf" | "vcf.gz") => Ok(parsing::vcf::parse_vcf_file(path)?),
187        Some("tsv") => Ok(parsing::tsv::parse_tsv_file(path, '\t')?),
188        Some("csv") => Ok(parsing::tsv::parse_tsv_file(path, ',')?),
189        // Default to SAM/BAM/CRAM parsing
190        _ => Ok(parsing::sam::parse_file(path)?),
191    }
192}
193
194fn print_text_result(result: &ScoreResult, weights: &ScoringWeights, suffix: &str) {
195    let norm = weights.normalized();
196
197    println!(
198        "\nScoring{}: {} vs {}",
199        suffix,
200        result.query_path.display(),
201        result.reference_path.display()
202    );
203
204    // Score breakdown
205    println!(
206        "\n   Score: {:.1}% = {:.0}%×match + {:.0}%×coverage + {:.0}%×order",
207        result.score.composite * 100.0,
208        result.score.match_quality * 100.0,
209        result.score.coverage_score * 100.0,
210        result.score.order_score * 100.0,
211    );
212    println!(
213        "          (weights: {:.0}% match, {:.0}% coverage, {:.0}% order)",
214        norm.contig_match * 100.0,
215        norm.coverage * 100.0,
216        norm.order * 100.0,
217    );
218
219    // Query contigs
220    let total_query = result.query_header.contigs.len();
221    let exact = result.score.exact_matches;
222    let name_len = result.score.name_length_matches;
223    let conflicts = result.score.md5_conflicts;
224    let unmatched = result.score.unmatched;
225
226    println!(
227        "\n   Query contigs: {total_query} total → {exact} exact, {name_len} name+length, {conflicts} conflicts, {unmatched} unmatched"
228    );
229
230    // Reference coverage
231    let total_ref = result.reference_header.contigs.len();
232    let matched_ref = exact + name_len;
233    let uncovered_ref = total_ref.saturating_sub(matched_ref);
234    println!(
235        "   Reference contigs: {total_ref} total, {matched_ref} matched, {uncovered_ref} not in query"
236    );
237
238    // Order
239    if !result.score.order_preserved {
240        println!("   Order: DIFFERENT from reference");
241    }
242
243    // Confidence
244    println!("   Confidence: {:?}", result.score.confidence);
245}
246
247fn print_json_results(
248    forward: &ScoreResult,
249    reverse: Option<&ScoreResult>,
250    weights: &ScoringWeights,
251) -> anyhow::Result<()> {
252    let norm = weights.normalized();
253
254    let make_result_json = |result: &ScoreResult| {
255        let ref_total = result.reference_header.contigs.len();
256        let ref_matched = result.score.exact_matches + result.score.name_length_matches;
257        let ref_uncovered = ref_total.saturating_sub(ref_matched);
258
259        serde_json::json!({
260            "query": {
261                "file": result.query_path.display().to_string(),
262                "contigs": result.query_header.contigs.len(),
263            },
264            "reference": {
265                "file": result.reference_path.display().to_string(),
266                "contigs": result.reference_header.contigs.len(),
267            },
268            "score": {
269                "composite": result.score.composite,
270                "confidence": format!("{:?}", result.score.confidence),
271                "match_quality": result.score.match_quality,
272                "coverage_score": result.score.coverage_score,
273                "order_score": result.score.order_score,
274                "weights": {
275                    "match": norm.contig_match,
276                    "coverage": norm.coverage,
277                    "order": norm.order,
278                },
279            },
280            "query_contigs": {
281                "exact_matches": result.score.exact_matches,
282                "name_length_matches": result.score.name_length_matches,
283                "md5_conflicts": result.score.md5_conflicts,
284                "unmatched": result.score.unmatched,
285            },
286            "reference_coverage": {
287                "total": ref_total,
288                "matched": ref_matched,
289                "not_in_query": ref_uncovered,
290            },
291            "order_preserved": result.score.order_preserved,
292        })
293    };
294
295    let output = if let Some(rev) = reverse {
296        serde_json::json!({
297            "forward": make_result_json(forward),
298            "reverse": make_result_json(rev),
299        })
300    } else {
301        make_result_json(forward)
302    };
303
304    println!("{}", serde_json::to_string_pretty(&output)?);
305    Ok(())
306}
307
308fn print_tsv_results(
309    forward: &ScoreResult,
310    reverse: Option<&ScoreResult>,
311    weights: &ScoringWeights,
312) {
313    let norm = weights.normalized();
314
315    // Header
316    println!(
317        "direction\tquery\treference\tscore\tmatch_score\tcoverage_score\torder_score\tweight_match\tweight_coverage\tweight_order\tconfidence\texact\tname_length\tconflicts\tunmatched\tref_total\tref_matched\tref_uncovered"
318    );
319
320    let print_row = |direction: &str, result: &ScoreResult| {
321        let ref_total = result.reference_header.contigs.len();
322        let ref_matched = result.score.exact_matches + result.score.name_length_matches;
323        let ref_uncovered = ref_total.saturating_sub(ref_matched);
324
325        println!(
326            "{}\t{}\t{}\t{:.4}\t{:.4}\t{:.4}\t{:.4}\t{:.2}\t{:.2}\t{:.2}\t{:?}\t{}\t{}\t{}\t{}\t{}\t{}\t{}",
327            direction,
328            result.query_path.display(),
329            result.reference_path.display(),
330            result.score.composite,
331            result.score.match_quality,
332            result.score.coverage_score,
333            result.score.order_score,
334            norm.contig_match,
335            norm.coverage,
336            norm.order,
337            result.score.confidence,
338            result.score.exact_matches,
339            result.score.name_length_matches,
340            result.score.md5_conflicts,
341            result.score.unmatched,
342            ref_total,
343            ref_matched,
344            ref_uncovered,
345        );
346    };
347
348    print_row("forward", forward);
349    if let Some(rev) = reverse {
350        print_row("reverse", rev);
351    }
352}
353
354#[cfg(test)]
355mod tests {
356    use super::*;
357    use crate::core::contig::Contig;
358    use std::io::Write;
359    use tempfile::NamedTempFile;
360
361    fn create_temp_dict_file(contigs: &[(&str, u64, Option<&str>)]) -> NamedTempFile {
362        let mut file = NamedTempFile::with_suffix(".dict").unwrap();
363        writeln!(file, "@HD\tVN:1.0\tSO:unsorted").unwrap();
364        for (name, len, md5) in contigs {
365            let md5_field = md5.map(|m| format!("\tM5:{m}")).unwrap_or_default();
366            writeln!(file, "@SQ\tSN:{name}\tLN:{len}{md5_field}").unwrap();
367        }
368        file.flush().unwrap();
369        file
370    }
371
372    #[test]
373    fn test_parse_dict_input() {
374        // MD5 must be exactly 32 hex characters
375        let valid_md5 = "6aef897c3d6ff0c78aff06ac189178dd";
376        let file = create_temp_dict_file(&[("chr1", 1000, Some(valid_md5)), ("chr2", 2000, None)]);
377
378        let header = parse_input(file.path()).unwrap();
379        assert_eq!(header.contigs.len(), 2);
380        assert_eq!(header.contigs[0].name, "chr1");
381        assert_eq!(header.contigs[0].length, 1000);
382        assert_eq!(header.contigs[0].md5.as_deref(), Some(valid_md5));
383        assert_eq!(header.contigs[1].name, "chr2");
384        assert_eq!(header.contigs[1].length, 2000);
385        assert!(header.contigs[1].md5.is_none());
386    }
387
388    #[test]
389    fn test_compute_score_perfect_match() {
390        let query_header =
391            QueryHeader::new(vec![Contig::new("chr1", 1000), Contig::new("chr2", 2000)]);
392        let reference_header =
393            QueryHeader::new(vec![Contig::new("chr1", 1000), Contig::new("chr2", 2000)]);
394
395        let weights = ScoringWeights::default();
396        let result = compute_score(
397            PathBuf::from("query.dict"),
398            PathBuf::from("reference.dict"),
399            &query_header,
400            &reference_header,
401            &weights,
402        );
403
404        // Perfect match: all contigs match by name+length
405        assert_eq!(result.score.name_length_matches, 2);
406        assert_eq!(result.score.unmatched, 0);
407        assert!(
408            result.score.composite > 0.9,
409            "Perfect match should score > 90%"
410        );
411    }
412
413    #[test]
414    fn test_compute_score_partial_match() {
415        let query_header = QueryHeader::new(vec![
416            Contig::new("chr1", 1000),
417            Contig::new("chr2", 2000),
418            Contig::new("chr3", 3000),
419        ]);
420        let reference_header =
421            QueryHeader::new(vec![Contig::new("chr1", 1000), Contig::new("chr2", 2000)]);
422
423        let weights = ScoringWeights::default();
424        let result = compute_score(
425            PathBuf::from("query.dict"),
426            PathBuf::from("reference.dict"),
427            &query_header,
428            &reference_header,
429            &weights,
430        );
431
432        // 2 of 3 query contigs match
433        assert_eq!(result.score.name_length_matches, 2);
434        assert_eq!(result.score.unmatched, 1);
435        assert!(
436            result.score.match_quality < 1.0,
437            "Partial match should have match_quality < 1.0"
438        );
439    }
440
441    #[test]
442    fn test_compute_score_asymmetric() {
443        // Query has fewer contigs than reference
444        let query_header = QueryHeader::new(vec![Contig::new("chr1", 1000)]);
445        let reference_header = QueryHeader::new(vec![
446            Contig::new("chr1", 1000),
447            Contig::new("chr2", 2000),
448            Contig::new("chr3", 3000),
449        ]);
450
451        let weights = ScoringWeights::default();
452
453        // Forward: query → reference
454        let forward = compute_score(
455            PathBuf::from("query.dict"),
456            PathBuf::from("reference.dict"),
457            &query_header,
458            &reference_header,
459            &weights,
460        );
461
462        // Reverse: reference → query
463        let reverse = compute_score(
464            PathBuf::from("reference.dict"),
465            PathBuf::from("query.dict"),
466            &reference_header,
467            &query_header,
468            &weights,
469        );
470
471        // Forward: all query contigs match (100% contig match), but only 1/3 ref covered
472        assert_eq!(forward.score.name_length_matches, 1);
473        assert!(
474            (forward.score.match_quality - 1.0).abs() < 0.001,
475            "All query contigs match"
476        );
477        assert!(
478            forward.score.coverage_score < 0.5,
479            "Coverage should be 1/3 = 0.33"
480        );
481
482        // Reverse: only 1/3 query contigs match, but reference is fully covered
483        assert_eq!(reverse.score.unmatched, 2);
484        assert!(
485            reverse.score.match_quality < 0.5,
486            "Only 1/3 query contigs match"
487        );
488        assert!(
489            (reverse.score.coverage_score - 1.0).abs() < 0.001,
490            "Reference fully covered"
491        );
492    }
493
494    #[test]
495    fn test_custom_weights() {
496        let query_header =
497            QueryHeader::new(vec![Contig::new("chr1", 1000), Contig::new("chr2", 2000)]);
498        let reference_header = QueryHeader::new(vec![
499            Contig::new("chr1", 1000),
500            Contig::new("chr2", 2000),
501            Contig::new("chr3", 3000),
502        ]);
503
504        // Emphasize coverage over match
505        let high_coverage_weights = ScoringWeights {
506            contig_match: 0.2,
507            coverage: 0.7,
508            order: 0.1,
509            conflict_penalty: 0.1,
510        };
511
512        // Emphasize match over coverage
513        let high_match_weights = ScoringWeights {
514            contig_match: 0.8,
515            coverage: 0.1,
516            order: 0.1,
517            conflict_penalty: 0.1,
518        };
519
520        let result_high_cov = compute_score(
521            PathBuf::from("q.dict"),
522            PathBuf::from("r.dict"),
523            &query_header,
524            &reference_header,
525            &high_coverage_weights,
526        );
527
528        let result_high_match = compute_score(
529            PathBuf::from("q.dict"),
530            PathBuf::from("r.dict"),
531            &query_header,
532            &reference_header,
533            &high_match_weights,
534        );
535
536        // With emphasis on coverage (2/3), score should be lower
537        // With emphasis on match (100%), score should be higher
538        assert!(
539            result_high_match.score.composite > result_high_cov.score.composite,
540            "High match weight should yield higher score when matches are 100% but coverage is 66%"
541        );
542    }
543
544    #[test]
545    fn test_score_with_md5_match() {
546        let query_header = QueryHeader::new(vec![
547            Contig::new("chr1", 1000).with_md5("abc123"),
548            Contig::new("chr2", 2000).with_md5("def456"),
549        ]);
550        let reference_header = QueryHeader::new(vec![
551            Contig::new("chr1", 1000).with_md5("abc123"),
552            Contig::new("chr2", 2000).with_md5("def456"),
553        ]);
554
555        let weights = ScoringWeights::default();
556        let result = compute_score(
557            PathBuf::from("query.dict"),
558            PathBuf::from("reference.dict"),
559            &query_header,
560            &reference_header,
561            &weights,
562        );
563
564        // With matching MD5s, these should be exact matches
565        assert_eq!(result.score.exact_matches, 2);
566        assert_eq!(result.score.name_length_matches, 0);
567        assert!(
568            result.score.composite > 0.95,
569            "Exact MD5 match should score very high"
570        );
571    }
572
573    #[test]
574    fn test_score_with_md5_conflict() {
575        let query_header = QueryHeader::new(vec![
576            Contig::new("chr1", 1000).with_md5("abc123"),
577            Contig::new("chr2", 2000).with_md5("def456"),
578        ]);
579        let reference_header = QueryHeader::new(vec![
580            Contig::new("chr1", 1000).with_md5("DIFFERENT1"),
581            Contig::new("chr2", 2000).with_md5("DIFFERENT2"),
582        ]);
583
584        let weights = ScoringWeights::default();
585        let result = compute_score(
586            PathBuf::from("query.dict"),
587            PathBuf::from("reference.dict"),
588            &query_header,
589            &reference_header,
590            &weights,
591        );
592
593        // MD5 conflicts should be penalized
594        assert_eq!(result.score.md5_conflicts, 2);
595        assert_eq!(result.score.exact_matches, 0);
596        assert!(
597            result.score.composite < 0.3,
598            "MD5 conflicts should result in low score, got {:.1}%",
599            result.score.composite * 100.0
600        );
601    }
602}