1use std::path::{Path, PathBuf};
7
8use clap::Args;
9
10use crate::cli::OutputFormat;
11use crate::core::header::QueryHeader;
12use crate::core::reference::KnownReference;
13use crate::core::types::{Assembly, ReferenceSource};
14use crate::matching::engine::ScoringWeights;
15use crate::matching::scoring::MatchScore;
16use crate::parsing;
17
18#[derive(Args)]
20pub struct ScoreArgs {
21 #[arg(required = true)]
24 pub query: PathBuf,
25
26 #[arg(required = true)]
29 pub reference: PathBuf,
30
31 #[arg(long)]
35 pub symmetric: bool,
36
37 #[arg(long, default_value = "70", value_parser = clap::value_parser!(u32).range(0..=100))]
41 pub weight_match: u32,
42
43 #[arg(long, default_value = "20", value_parser = clap::value_parser!(u32).range(0..=100))]
46 pub weight_coverage: u32,
47
48 #[arg(long, default_value = "10", value_parser = clap::value_parser!(u32).range(0..=100))]
51 pub weight_order: u32,
52}
53
54struct ScoreResult {
56 query_path: PathBuf,
57 reference_path: PathBuf,
58 query_header: QueryHeader,
59 reference_header: QueryHeader,
60 score: MatchScore,
61}
62
63#[allow(clippy::needless_pass_by_value)]
69pub fn run(args: ScoreArgs, format: OutputFormat, verbose: bool) -> anyhow::Result<()> {
70 let weights = ScoringWeights {
72 contig_match: f64::from(args.weight_match) / 100.0,
73 coverage: f64::from(args.weight_coverage) / 100.0,
74 order: f64::from(args.weight_order) / 100.0,
75 conflict_penalty: 0.1,
76 };
77
78 if verbose {
79 eprintln!(
80 "Scoring weights: {:.0}% match, {:.0}% coverage, {:.0}% order",
81 weights.contig_match * 100.0,
82 weights.coverage * 100.0,
83 weights.order * 100.0,
84 );
85 }
86
87 let query_header = parse_input(&args.query)?;
89 if verbose {
90 eprintln!(
91 "Query: {} contigs ({:.0}% have MD5)",
92 query_header.contigs.len(),
93 query_header.md5_coverage() * 100.0,
94 );
95 }
96
97 let reference_header = parse_input(&args.reference)?;
99 if verbose {
100 eprintln!(
101 "Reference: {} contigs ({:.0}% have MD5)",
102 reference_header.contigs.len(),
103 reference_header.md5_coverage() * 100.0,
104 );
105 }
106
107 let forward_result = compute_score(
109 args.query.clone(),
110 args.reference.clone(),
111 &query_header,
112 &reference_header,
113 &weights,
114 );
115
116 let reverse_result = if args.symmetric {
118 Some(compute_score(
119 args.reference.clone(),
120 args.query.clone(),
121 &reference_header,
122 &query_header,
123 &weights,
124 ))
125 } else {
126 None
127 };
128
129 match format {
131 OutputFormat::Text => {
132 print_text_result(&forward_result, &weights, "");
133 if let Some(ref reverse) = reverse_result {
134 println!("\n{}", "─".repeat(60));
135 print_text_result(reverse, &weights, " (reverse)");
136 }
137 }
138 OutputFormat::Json => {
139 print_json_results(&forward_result, reverse_result.as_ref(), &weights)?;
140 }
141 OutputFormat::Tsv => {
142 print_tsv_results(&forward_result, reverse_result.as_ref(), &weights);
143 }
144 }
145
146 Ok(())
147}
148
149fn compute_score(
150 query_path: PathBuf,
151 reference_path: PathBuf,
152 query_header: &QueryHeader,
153 reference_header: &QueryHeader,
154 weights: &ScoringWeights,
155) -> ScoreResult {
156 let reference = KnownReference::new(
158 "reference",
159 reference_path.display().to_string().as_str(),
160 Assembly::Other("unknown".to_string()),
161 ReferenceSource::Custom("file".to_string()),
162 )
163 .with_contigs(reference_header.contigs.clone());
164
165 let score = MatchScore::calculate_with_weights(query_header, &reference, weights);
166
167 ScoreResult {
168 query_path,
169 reference_path,
170 query_header: query_header.clone(),
171 reference_header: reference_header.clone(),
172 score,
173 }
174}
175
176fn parse_input(path: &Path) -> anyhow::Result<QueryHeader> {
177 let ext = path
178 .extension()
179 .and_then(|e| e.to_str())
180 .map(str::to_lowercase);
181
182 match ext.as_deref() {
183 Some("dict") => Ok(parsing::dict::parse_dict_file(path)?),
184 Some("fai") => Ok(parsing::fai::parse_fai_file(path)?),
185 Some("fa" | "fasta" | "fna") => Ok(parsing::fasta::parse_fasta_file(path)?),
186 Some("vcf" | "vcf.gz") => Ok(parsing::vcf::parse_vcf_file(path)?),
187 Some("tsv") => Ok(parsing::tsv::parse_tsv_file(path, '\t')?),
188 Some("csv") => Ok(parsing::tsv::parse_tsv_file(path, ',')?),
189 _ => Ok(parsing::sam::parse_file(path)?),
191 }
192}
193
194fn print_text_result(result: &ScoreResult, weights: &ScoringWeights, suffix: &str) {
195 let norm = weights.normalized();
196
197 println!(
198 "\nScoring{}: {} vs {}",
199 suffix,
200 result.query_path.display(),
201 result.reference_path.display()
202 );
203
204 println!(
206 "\n Score: {:.1}% = {:.0}%×match + {:.0}%×coverage + {:.0}%×order",
207 result.score.composite * 100.0,
208 result.score.match_quality * 100.0,
209 result.score.coverage_score * 100.0,
210 result.score.order_score * 100.0,
211 );
212 println!(
213 " (weights: {:.0}% match, {:.0}% coverage, {:.0}% order)",
214 norm.contig_match * 100.0,
215 norm.coverage * 100.0,
216 norm.order * 100.0,
217 );
218
219 let total_query = result.query_header.contigs.len();
221 let exact = result.score.exact_matches;
222 let name_len = result.score.name_length_matches;
223 let conflicts = result.score.md5_conflicts;
224 let unmatched = result.score.unmatched;
225
226 println!(
227 "\n Query contigs: {total_query} total → {exact} exact, {name_len} name+length, {conflicts} conflicts, {unmatched} unmatched"
228 );
229
230 let total_ref = result.reference_header.contigs.len();
232 let matched_ref = exact + name_len;
233 let uncovered_ref = total_ref.saturating_sub(matched_ref);
234 println!(
235 " Reference contigs: {total_ref} total, {matched_ref} matched, {uncovered_ref} not in query"
236 );
237
238 if !result.score.order_preserved {
240 println!(" Order: DIFFERENT from reference");
241 }
242
243 println!(" Confidence: {:?}", result.score.confidence);
245}
246
247fn print_json_results(
248 forward: &ScoreResult,
249 reverse: Option<&ScoreResult>,
250 weights: &ScoringWeights,
251) -> anyhow::Result<()> {
252 let norm = weights.normalized();
253
254 let make_result_json = |result: &ScoreResult| {
255 let ref_total = result.reference_header.contigs.len();
256 let ref_matched = result.score.exact_matches + result.score.name_length_matches;
257 let ref_uncovered = ref_total.saturating_sub(ref_matched);
258
259 serde_json::json!({
260 "query": {
261 "file": result.query_path.display().to_string(),
262 "contigs": result.query_header.contigs.len(),
263 },
264 "reference": {
265 "file": result.reference_path.display().to_string(),
266 "contigs": result.reference_header.contigs.len(),
267 },
268 "score": {
269 "composite": result.score.composite,
270 "confidence": format!("{:?}", result.score.confidence),
271 "match_quality": result.score.match_quality,
272 "coverage_score": result.score.coverage_score,
273 "order_score": result.score.order_score,
274 "weights": {
275 "match": norm.contig_match,
276 "coverage": norm.coverage,
277 "order": norm.order,
278 },
279 },
280 "query_contigs": {
281 "exact_matches": result.score.exact_matches,
282 "name_length_matches": result.score.name_length_matches,
283 "md5_conflicts": result.score.md5_conflicts,
284 "unmatched": result.score.unmatched,
285 },
286 "reference_coverage": {
287 "total": ref_total,
288 "matched": ref_matched,
289 "not_in_query": ref_uncovered,
290 },
291 "order_preserved": result.score.order_preserved,
292 })
293 };
294
295 let output = if let Some(rev) = reverse {
296 serde_json::json!({
297 "forward": make_result_json(forward),
298 "reverse": make_result_json(rev),
299 })
300 } else {
301 make_result_json(forward)
302 };
303
304 println!("{}", serde_json::to_string_pretty(&output)?);
305 Ok(())
306}
307
308fn print_tsv_results(
309 forward: &ScoreResult,
310 reverse: Option<&ScoreResult>,
311 weights: &ScoringWeights,
312) {
313 let norm = weights.normalized();
314
315 println!(
317 "direction\tquery\treference\tscore\tmatch_score\tcoverage_score\torder_score\tweight_match\tweight_coverage\tweight_order\tconfidence\texact\tname_length\tconflicts\tunmatched\tref_total\tref_matched\tref_uncovered"
318 );
319
320 let print_row = |direction: &str, result: &ScoreResult| {
321 let ref_total = result.reference_header.contigs.len();
322 let ref_matched = result.score.exact_matches + result.score.name_length_matches;
323 let ref_uncovered = ref_total.saturating_sub(ref_matched);
324
325 println!(
326 "{}\t{}\t{}\t{:.4}\t{:.4}\t{:.4}\t{:.4}\t{:.2}\t{:.2}\t{:.2}\t{:?}\t{}\t{}\t{}\t{}\t{}\t{}\t{}",
327 direction,
328 result.query_path.display(),
329 result.reference_path.display(),
330 result.score.composite,
331 result.score.match_quality,
332 result.score.coverage_score,
333 result.score.order_score,
334 norm.contig_match,
335 norm.coverage,
336 norm.order,
337 result.score.confidence,
338 result.score.exact_matches,
339 result.score.name_length_matches,
340 result.score.md5_conflicts,
341 result.score.unmatched,
342 ref_total,
343 ref_matched,
344 ref_uncovered,
345 );
346 };
347
348 print_row("forward", forward);
349 if let Some(rev) = reverse {
350 print_row("reverse", rev);
351 }
352}
353
354#[cfg(test)]
355mod tests {
356 use super::*;
357 use crate::core::contig::Contig;
358 use std::io::Write;
359 use tempfile::NamedTempFile;
360
361 fn create_temp_dict_file(contigs: &[(&str, u64, Option<&str>)]) -> NamedTempFile {
362 let mut file = NamedTempFile::with_suffix(".dict").unwrap();
363 writeln!(file, "@HD\tVN:1.0\tSO:unsorted").unwrap();
364 for (name, len, md5) in contigs {
365 let md5_field = md5.map(|m| format!("\tM5:{m}")).unwrap_or_default();
366 writeln!(file, "@SQ\tSN:{name}\tLN:{len}{md5_field}").unwrap();
367 }
368 file.flush().unwrap();
369 file
370 }
371
372 #[test]
373 fn test_parse_dict_input() {
374 let valid_md5 = "6aef897c3d6ff0c78aff06ac189178dd";
376 let file = create_temp_dict_file(&[("chr1", 1000, Some(valid_md5)), ("chr2", 2000, None)]);
377
378 let header = parse_input(file.path()).unwrap();
379 assert_eq!(header.contigs.len(), 2);
380 assert_eq!(header.contigs[0].name, "chr1");
381 assert_eq!(header.contigs[0].length, 1000);
382 assert_eq!(header.contigs[0].md5.as_deref(), Some(valid_md5));
383 assert_eq!(header.contigs[1].name, "chr2");
384 assert_eq!(header.contigs[1].length, 2000);
385 assert!(header.contigs[1].md5.is_none());
386 }
387
388 #[test]
389 fn test_compute_score_perfect_match() {
390 let query_header =
391 QueryHeader::new(vec![Contig::new("chr1", 1000), Contig::new("chr2", 2000)]);
392 let reference_header =
393 QueryHeader::new(vec![Contig::new("chr1", 1000), Contig::new("chr2", 2000)]);
394
395 let weights = ScoringWeights::default();
396 let result = compute_score(
397 PathBuf::from("query.dict"),
398 PathBuf::from("reference.dict"),
399 &query_header,
400 &reference_header,
401 &weights,
402 );
403
404 assert_eq!(result.score.name_length_matches, 2);
406 assert_eq!(result.score.unmatched, 0);
407 assert!(
408 result.score.composite > 0.9,
409 "Perfect match should score > 90%"
410 );
411 }
412
413 #[test]
414 fn test_compute_score_partial_match() {
415 let query_header = QueryHeader::new(vec![
416 Contig::new("chr1", 1000),
417 Contig::new("chr2", 2000),
418 Contig::new("chr3", 3000),
419 ]);
420 let reference_header =
421 QueryHeader::new(vec![Contig::new("chr1", 1000), Contig::new("chr2", 2000)]);
422
423 let weights = ScoringWeights::default();
424 let result = compute_score(
425 PathBuf::from("query.dict"),
426 PathBuf::from("reference.dict"),
427 &query_header,
428 &reference_header,
429 &weights,
430 );
431
432 assert_eq!(result.score.name_length_matches, 2);
434 assert_eq!(result.score.unmatched, 1);
435 assert!(
436 result.score.match_quality < 1.0,
437 "Partial match should have match_quality < 1.0"
438 );
439 }
440
441 #[test]
442 fn test_compute_score_asymmetric() {
443 let query_header = QueryHeader::new(vec![Contig::new("chr1", 1000)]);
445 let reference_header = QueryHeader::new(vec![
446 Contig::new("chr1", 1000),
447 Contig::new("chr2", 2000),
448 Contig::new("chr3", 3000),
449 ]);
450
451 let weights = ScoringWeights::default();
452
453 let forward = compute_score(
455 PathBuf::from("query.dict"),
456 PathBuf::from("reference.dict"),
457 &query_header,
458 &reference_header,
459 &weights,
460 );
461
462 let reverse = compute_score(
464 PathBuf::from("reference.dict"),
465 PathBuf::from("query.dict"),
466 &reference_header,
467 &query_header,
468 &weights,
469 );
470
471 assert_eq!(forward.score.name_length_matches, 1);
473 assert!(
474 (forward.score.match_quality - 1.0).abs() < 0.001,
475 "All query contigs match"
476 );
477 assert!(
478 forward.score.coverage_score < 0.5,
479 "Coverage should be 1/3 = 0.33"
480 );
481
482 assert_eq!(reverse.score.unmatched, 2);
484 assert!(
485 reverse.score.match_quality < 0.5,
486 "Only 1/3 query contigs match"
487 );
488 assert!(
489 (reverse.score.coverage_score - 1.0).abs() < 0.001,
490 "Reference fully covered"
491 );
492 }
493
494 #[test]
495 fn test_custom_weights() {
496 let query_header =
497 QueryHeader::new(vec![Contig::new("chr1", 1000), Contig::new("chr2", 2000)]);
498 let reference_header = QueryHeader::new(vec![
499 Contig::new("chr1", 1000),
500 Contig::new("chr2", 2000),
501 Contig::new("chr3", 3000),
502 ]);
503
504 let high_coverage_weights = ScoringWeights {
506 contig_match: 0.2,
507 coverage: 0.7,
508 order: 0.1,
509 conflict_penalty: 0.1,
510 };
511
512 let high_match_weights = ScoringWeights {
514 contig_match: 0.8,
515 coverage: 0.1,
516 order: 0.1,
517 conflict_penalty: 0.1,
518 };
519
520 let result_high_cov = compute_score(
521 PathBuf::from("q.dict"),
522 PathBuf::from("r.dict"),
523 &query_header,
524 &reference_header,
525 &high_coverage_weights,
526 );
527
528 let result_high_match = compute_score(
529 PathBuf::from("q.dict"),
530 PathBuf::from("r.dict"),
531 &query_header,
532 &reference_header,
533 &high_match_weights,
534 );
535
536 assert!(
539 result_high_match.score.composite > result_high_cov.score.composite,
540 "High match weight should yield higher score when matches are 100% but coverage is 66%"
541 );
542 }
543
544 #[test]
545 fn test_score_with_md5_match() {
546 let query_header = QueryHeader::new(vec![
547 Contig::new("chr1", 1000).with_md5("abc123"),
548 Contig::new("chr2", 2000).with_md5("def456"),
549 ]);
550 let reference_header = QueryHeader::new(vec![
551 Contig::new("chr1", 1000).with_md5("abc123"),
552 Contig::new("chr2", 2000).with_md5("def456"),
553 ]);
554
555 let weights = ScoringWeights::default();
556 let result = compute_score(
557 PathBuf::from("query.dict"),
558 PathBuf::from("reference.dict"),
559 &query_header,
560 &reference_header,
561 &weights,
562 );
563
564 assert_eq!(result.score.exact_matches, 2);
566 assert_eq!(result.score.name_length_matches, 0);
567 assert!(
568 result.score.composite > 0.95,
569 "Exact MD5 match should score very high"
570 );
571 }
572
573 #[test]
574 fn test_score_with_md5_conflict() {
575 let query_header = QueryHeader::new(vec![
576 Contig::new("chr1", 1000).with_md5("abc123"),
577 Contig::new("chr2", 2000).with_md5("def456"),
578 ]);
579 let reference_header = QueryHeader::new(vec![
580 Contig::new("chr1", 1000).with_md5("DIFFERENT1"),
581 Contig::new("chr2", 2000).with_md5("DIFFERENT2"),
582 ]);
583
584 let weights = ScoringWeights::default();
585 let result = compute_score(
586 PathBuf::from("query.dict"),
587 PathBuf::from("reference.dict"),
588 &query_header,
589 &reference_header,
590 &weights,
591 );
592
593 assert_eq!(result.score.md5_conflicts, 2);
595 assert_eq!(result.score.exact_matches, 0);
596 assert!(
597 result.score.composite < 0.3,
598 "MD5 conflicts should result in low score, got {:.1}%",
599 result.score.composite * 100.0
600 );
601 }
602}