1use std::fmt;
18
19use crate::parse::{CmpOp, Expr, FieldRef, LogOp, Value};
20
21#[derive(Debug, Clone, PartialEq)]
29pub enum EvalError {
30 MalformedLine(String),
31 FieldNotFound(String),
32 TypeMismatch {
33 field: String,
34 expected: &'static str,
35 got: String,
36 },
37 Internal(String),
38}
39
40impl fmt::Display for EvalError {
41 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
42 match self {
43 EvalError::MalformedLine(s) => write!(f, "malformed VCF line: {s}"),
44 EvalError::FieldNotFound(s) => write!(f, "field not found: {s}"),
45 EvalError::TypeMismatch {
46 field,
47 expected,
48 got,
49 } => {
50 write!(
51 f,
52 "type mismatch for {field}: expected {expected}, got '{got}'"
53 )
54 }
55 EvalError::Internal(s) => write!(f, "internal evaluator error: {s}"),
56 }
57 }
58}
59
60impl std::error::Error for EvalError {}
61
62#[inline]
65fn nth_colon_field(s: &str, n: usize) -> Option<&str> {
66 if n == 0 {
67 let end = s.find(':').unwrap_or(s.len());
68 Some(&s[..end])
69 } else {
70 let mut skipped = 0usize;
71 let mut start = 0usize;
72 for (i, &b) in s.as_bytes().iter().enumerate() {
73 if b == b':' {
74 skipped += 1;
75 if skipped == n {
76 start = i + 1;
77 break;
78 }
79 }
80 }
81 if skipped < n {
82 return None;
83 }
84 let end = s[start..].find(':').map_or(s.len(), |p| start + p);
85 Some(&s[start..end])
86 }
87}
88
89#[derive(Debug, Clone)]
91pub struct SampleResult {
92 pub pass: Vec<bool>,
94}
95
96struct VcfLine<'a> {
98 qual: &'a str,
99 filter: &'a str,
100 info: &'a str,
101 fmt_keys: Vec<&'a str>,
103 sample_starts: Vec<usize>,
106 sample_ends: Vec<usize>,
107 line: &'a str,
108}
109
110impl<'a> VcfLine<'a> {
111 fn parse(line: &'a str) -> Result<Self, EvalError> {
112 let bytes = line.as_bytes();
114 let mut tab_pos = [0usize; 9];
115 let mut ntabs = 0usize;
116 for (i, &b) in bytes.iter().enumerate() {
117 if b == b'\t' {
118 if ntabs < 9 {
119 tab_pos[ntabs] = i;
120 }
121 ntabs += 1;
122 }
123 }
124 if ntabs < 7 {
125 return Err(EvalError::MalformedLine(format!(
126 "expected ≥8 tab-separated columns, got {}",
127 ntabs + 1
128 )));
129 }
130
131 let qual = &line[tab_pos[4] + 1..tab_pos[5]];
132 let filter = &line[tab_pos[5] + 1..tab_pos[6]];
133 let info = &line[tab_pos[6] + 1..tab_pos[7]];
134
135 let fmt_keys: Vec<&'a str> = if ntabs >= 8 {
136 let fmt_start = tab_pos[7] + 1;
137 let fmt_end = tab_pos[8];
138 line[fmt_start..fmt_end].split(':').collect()
139 } else {
140 vec![]
141 };
142
143 let (sample_starts, sample_ends) = if ntabs >= 8 && tab_pos[8] < line.len() {
146 let first_sample_start = tab_pos[8] + 1;
147 let mut starts = Vec::with_capacity(ntabs.saturating_sub(8));
148 let mut ends = Vec::with_capacity(ntabs.saturating_sub(8));
149 let mut cur = first_sample_start;
150 for (i, &b) in bytes[first_sample_start..].iter().enumerate() {
151 if b == b'\t' {
152 starts.push(cur);
153 ends.push(first_sample_start + i);
154 cur = first_sample_start + i + 1;
155 }
156 }
157 starts.push(cur);
159 ends.push(line.len());
160 (starts, ends)
161 } else {
162 (vec![], vec![])
163 };
164
165 Ok(VcfLine {
166 qual,
167 filter,
168 info,
169 fmt_keys,
170 sample_starts,
171 sample_ends,
172 line,
173 })
174 }
175
176 fn n_samples(&self) -> usize {
178 self.sample_starts.len()
179 }
180
181 fn sample_col(&self, sample_idx: usize) -> Option<&'a str> {
183 let start = *self.sample_starts.get(sample_idx)?;
184 let end = *self.sample_ends.get(sample_idx)?;
185 Some(&self.line[start..end])
186 }
187
188 fn fmt_value(&self, tag: &str, sample_idx: usize) -> Option<&'a str> {
191 let field_pos = self.fmt_keys.iter().position(|&k| k == tag)?;
192 let sample_col = self.sample_col(sample_idx)?;
193 nth_colon_field(sample_col, field_pos).filter(|&v| v != ".")
194 }
195
196 fn fmt_num(&self, tag: &str, sample_idx: usize) -> Result<Option<f64>, EvalError> {
198 match self.fmt_value(tag, sample_idx) {
199 None => Ok(None),
200 Some(v) => v
201 .parse::<f64>()
202 .map(Some)
203 .map_err(|_| EvalError::TypeMismatch {
204 field: format!("FMT/{tag}"),
205 expected: "numeric",
206 got: v.to_owned(),
207 }),
208 }
209 }
210
211 fn gt_str(&self, sample_idx: usize) -> Option<&'a str> {
213 let gt_pos = self.fmt_keys.iter().position(|&k| k == "GT")?;
214 let sample_col = self.sample_col(sample_idx)?;
215 nth_colon_field(sample_col, gt_pos)
216 }
217
218 fn info_num(&self, tag: &str) -> Result<Option<f64>, EvalError> {
221 for entry in self.info.split(';') {
222 if entry == "." {
223 continue;
224 }
225 if let Some((k, v)) = entry.split_once('=') {
226 if k.eq_ignore_ascii_case(tag) {
227 let first = v.split(',').next().unwrap_or(v);
228 return first
229 .parse::<f64>()
230 .map(Some)
231 .map_err(|_| EvalError::TypeMismatch {
232 field: format!("INFO/{tag}"),
233 expected: "numeric",
234 got: first.to_owned(),
235 });
236 }
237 } else if entry.eq_ignore_ascii_case(tag) {
238 return Ok(Some(1.0));
240 }
241 }
242 Ok(None)
243 }
244
245 fn qual_num(&self) -> Result<Option<f64>, EvalError> {
247 if self.qual == "." {
248 return Ok(None);
249 }
250 self.qual
251 .parse::<f64>()
252 .map(Some)
253 .map_err(|_| EvalError::TypeMismatch {
254 field: "QUAL".into(),
255 expected: "numeric",
256 got: self.qual.to_owned(),
257 })
258 }
259}
260
261fn gt_classify(gt: &str) -> GtClass {
265 let alleles: Vec<&str> = gt.split(['/', '|']).collect();
266 let n = alleles.len();
267 let n_miss = alleles.iter().filter(|&&a| a == ".").count();
268 if n_miss == n {
269 return GtClass::Missing;
270 }
271 if n_miss > 0 {
272 return GtClass::PartialMiss;
273 }
274 let all_ref = alleles.iter().all(|&a| a == "0");
275 let any_ref = alleles.contains(&"0");
276 let any_alt = alleles.iter().any(|&a| a != "0");
277 let all_same = alleles.windows(2).all(|w| w[0] == w[1]);
278
279 if all_ref {
280 GtClass::HomRef
281 } else if !any_ref && all_same {
282 GtClass::HomAlt
283 } else if n == 1 {
284 GtClass::Haploid
285 } else if any_ref && any_alt {
286 GtClass::Het
287 } else {
288 GtClass::Het
290 }
291}
292
293#[derive(Debug, Clone, Copy, PartialEq)]
294enum GtClass {
295 Missing,
296 PartialMiss,
297 HomRef,
298 HomAlt,
299 Het,
300 Haploid,
301}
302
303fn eval_gt_str(gt: &str, op: &CmpOp, pattern: &str) -> bool {
314 let class = gt_classify(gt);
315 let matches_pattern = match pattern.to_ascii_lowercase().as_str() {
316 "." | "miss" | "missing" => {
317 matches!(class, GtClass::Missing | GtClass::PartialMiss)
318 }
319 "hom" => matches!(class, GtClass::HomRef | GtClass::HomAlt),
320 "het" => class == GtClass::Het,
321 "ref" => class == GtClass::HomRef,
322 "alt" => class == GtClass::HomAlt,
323 "hap" => class == GtClass::Haploid,
324 other => {
325 let norm_gt: String = gt.chars().map(|c| if c == '|' { '/' } else { c }).collect();
327 let norm_pat: String = other
328 .chars()
329 .map(|c| if c == '|' { '/' } else { c })
330 .collect();
331 norm_gt == norm_pat
332 }
333 };
334 match op {
335 CmpOp::Eq => matches_pattern,
336 CmpOp::Ne => !matches_pattern,
337 _ => false, }
339}
340
341fn cmp_num(lhs: f64, op: &CmpOp, rhs: f64) -> bool {
344 match op {
345 CmpOp::Lt => lhs < rhs,
346 CmpOp::Le => lhs <= rhs,
347 CmpOp::Gt => lhs > rhs,
348 CmpOp::Ge => lhs >= rhs,
349 CmpOp::Eq => (lhs - rhs).abs() < f64::EPSILON,
350 CmpOp::Ne => (lhs - rhs).abs() >= f64::EPSILON,
351 }
352}
353
354fn cmp_str(lhs: &str, op: &CmpOp, rhs: &str) -> bool {
355 match op {
356 CmpOp::Eq => lhs == rhs,
357 CmpOp::Ne => lhs != rhs,
358 _ => false,
359 }
360}
361
362fn eval_cmp_sample(
370 vcf: &VcfLine<'_>,
371 field: &FieldRef,
372 op: &CmpOp,
373 val: &Value,
374 sample_idx: usize,
375) -> Result<bool, EvalError> {
376 match field {
377 FieldRef::Qual => {
378 let q = vcf.qual_num()?;
379 match val {
380 Value::Num(threshold) => Ok(q.is_some_and(|v| cmp_num(v, op, *threshold))),
381 Value::Str(s) => Err(EvalError::TypeMismatch {
382 field: "QUAL".into(),
383 expected: "numeric",
384 got: s.clone(),
385 }),
386 }
387 }
388
389 FieldRef::Filter => {
390 let filter_val = vcf.filter;
391 match val {
392 Value::Str(s) => Ok(cmp_str(filter_val, op, s)),
393 Value::Num(n) => Err(EvalError::TypeMismatch {
394 field: "FILTER".into(),
395 expected: "string",
396 got: n.to_string(),
397 }),
398 }
399 }
400
401 FieldRef::Gt => {
402 let Some(gt) = vcf.gt_str(sample_idx) else {
403 return Ok(false);
404 };
405 match val {
406 Value::Str(pattern) => Ok(eval_gt_str(gt, op, pattern)),
407 Value::Num(_) => Ok(false),
408 }
409 }
410
411 FieldRef::Fmt(tag) => {
412 match val {
413 Value::Num(threshold) => {
414 let v = vcf.fmt_num(tag, sample_idx)?;
415 Ok(v.is_some_and(|n| cmp_num(n, op, *threshold)))
416 }
417 Value::Str(s) => {
418 let raw = vcf.fmt_value(tag, sample_idx);
421 Ok(raw.is_some_and(|v| cmp_str(v, op, s)))
422 }
423 }
424 }
425
426 FieldRef::Info(tag) => {
427 match val {
429 Value::Num(threshold) => {
430 if let Ok(Some(v)) = vcf.fmt_num(tag, sample_idx) {
432 return Ok(cmp_num(v, op, *threshold));
433 }
434 let v = vcf.info_num(tag)?;
436 Ok(v.is_some_and(|n| cmp_num(n, op, *threshold)))
437 }
438 Value::Str(s) => {
439 let raw = vcf.fmt_value(tag, sample_idx);
440 if let Some(v) = raw {
441 return Ok(cmp_str(v, op, s));
442 }
443 Ok(false)
445 }
446 }
447 }
448 }
449}
450
451fn eval_one(expr: &Expr, vcf: &VcfLine<'_>, sample_idx: usize) -> Result<bool, EvalError> {
453 match expr {
454 Expr::Cmp { field, op, val } => eval_cmp_sample(vcf, field, op, val, sample_idx),
455 Expr::Paren(inner) => eval_one(inner.as_ref(), vcf, sample_idx),
456 Expr::Logic { op, lhs, rhs } => {
457 let l = eval_one(lhs, vcf, sample_idx)?;
458 match op {
459 LogOp::And => {
461 if l {
462 eval_one(rhs, vcf, sample_idx)
463 } else {
464 Ok(false)
465 }
466 }
467 LogOp::AndVec | LogOp::Or | LogOp::OrVec => {
470 if l {
471 Ok(true)
472 } else {
473 eval_one(rhs, vcf, sample_idx)
474 }
475 }
476 }
477 }
478 }
479}
480
481pub fn eval_expr(expr: &Expr, line: &str, n_samples: usize) -> Result<SampleResult, EvalError> {
489 let vcf = VcfLine::parse(line)?;
490 let actual_n = vcf.n_samples();
491 let count = if actual_n > 0 {
493 actual_n
494 } else {
495 n_samples.max(1)
496 };
497 let mut pass = Vec::with_capacity(count);
498 for i in 0..count {
499 pass.push(eval_one(expr, &vcf, i)?);
500 }
501 Ok(SampleResult { pass })
502}
503
504pub struct EvalContext {
508 pub expr: Expr,
509 pub negate: bool,
510}
511
512impl EvalContext {
513 #[must_use]
515 pub fn new(expr: Expr, negate: bool) -> Self {
516 Self { expr, negate }
517 }
518
519 pub fn eval_line(&self, line: &str, n_samples: usize) -> Result<SampleResult, EvalError> {
521 let mut result = eval_expr(&self.expr, line, n_samples)?;
522 if self.negate {
523 for p in &mut result.pass {
524 *p = !*p;
525 }
526 }
527 Ok(result)
528 }
529}
530
531#[cfg(test)]
534mod tests {
535 use super::*;
536 use crate::parse::parse_expr;
537
538 fn make_line(format: &str, samples: &[&str]) -> String {
539 let sample_part = samples.join("\t");
540 format!("chr1\t100\t.\tA\tT\t50\tPASS\t.\t{format}\t{sample_part}")
541 }
542
543 #[test]
544 fn fmt_dp_lt_selects_low_dp() {
545 let line = make_line("GT:DP", &["0/1:3", "0/0:10", "./.:0"]);
547 let expr = parse_expr("FMT/DP<5").unwrap();
548 let res = eval_expr(&expr, &line, 3).unwrap();
549 assert_eq!(res.pass, vec![true, false, true]);
550 }
551
552 #[test]
553 fn fmt_dp_missing_returns_false() {
554 let line = make_line("GT:DP", &["0/1:.", "0/0:10"]);
556 let expr = parse_expr("FMT/DP<5").unwrap();
557 let res = eval_expr(&expr, &line, 2).unwrap();
558 assert_eq!(res.pass, vec![false, false]);
559 }
560
561 #[test]
562 fn qual_ge_site_level() {
563 let line = make_line("GT", &["0/1", "0/0"]);
564 let expr = parse_expr("QUAL>=30").unwrap();
565 let res = eval_expr(&expr, &line, 2).unwrap();
566 assert_eq!(res.pass, vec![true, true]);
568 }
569
570 #[test]
571 fn gt_eq_missing() {
572 let line = make_line("GT:DP", &["0/1:10", "./.:5", "0/0:20"]);
573 let expr = parse_expr(r#"GT=".""#).unwrap();
574 let res = eval_expr(&expr, &line, 3).unwrap();
575 assert_eq!(res.pass, vec![false, true, false]);
576 }
577
578 #[test]
579 fn gt_eq_hom() {
580 let line = make_line("GT", &["0/0", "0/1", "1/1"]);
581 let expr = parse_expr(r#"GT="hom""#).unwrap();
582 let res = eval_expr(&expr, &line, 3).unwrap();
583 assert_eq!(res.pass, vec![true, false, true]);
584 }
585
586 #[test]
587 fn gt_eq_het() {
588 let line = make_line("GT", &["0/0", "0/1", "1/1"]);
589 let expr = parse_expr(r#"GT="het""#).unwrap();
590 let res = eval_expr(&expr, &line, 3).unwrap();
591 assert_eq!(res.pass, vec![false, true, false]);
592 }
593
594 #[test]
595 fn andvec_combination_is_per_sample_or() {
596 let line = make_line("GT:DP:GQ", &["0/1:3:25", "0/1:10:15", "0/0:8:30"]);
598 let expr = parse_expr("FMT/DP<5 && FMT/GQ>=20").unwrap();
599 let res = eval_expr(&expr, &line, 3).unwrap();
600 assert_eq!(res.pass, vec![true, false, true]);
604 }
605
606 #[test]
607 fn and_single_is_per_sample_and() {
608 let line = make_line("GT:DP:GQ", &["0/1:3:25", "0/1:10:25", "0/0:8:30"]);
610 let expr = parse_expr("FMT/DP<5 & FMT/GQ>=20").unwrap();
611 let res = eval_expr(&expr, &line, 3).unwrap();
612 assert_eq!(res.pass, vec![true, false, false]);
616 }
617
618 #[test]
619 fn negate_mode() {
620 let line = make_line("GT:DP", &["0/1:3", "0/0:10"]);
621 let expr = parse_expr("FMT/DP<5").unwrap();
622 let ctx = EvalContext::new(expr, true); let res = ctx.eval_line(&line, 2).unwrap();
624 assert_eq!(res.pass, vec![false, true]);
626 }
627
628 #[test]
629 fn filter_string_eq() {
630 let line = "chr1\t100\t.\tA\tT\t50\tPASS\t.\tGT\t0/1";
631 let expr = parse_expr(r#"FILTER="PASS""#).unwrap();
632 let res = eval_expr(&expr, line, 1).unwrap();
633 assert_eq!(res.pass, vec![true]);
634 }
635
636 #[test]
637 fn missing_dp_returns_false() {
638 let line = make_line("GT:DP", &["0/1:."]);
640 let expr = parse_expr("FMT/DP<5").unwrap();
641 let res = eval_expr(&expr, &line, 1).unwrap();
642 assert_eq!(res.pass, vec![false]);
643 }
644
645 #[test]
646 fn info_field_numeric() {
647 let line = "chr1\t100\t.\tA\tT\t50\tPASS\tDP=3\tGT\t0/1";
648 let expr = parse_expr("INFO/DP<5").unwrap();
649 let res = eval_expr(&expr, line, 1).unwrap();
650 assert_eq!(res.pass, vec![true]);
651 }
652}