use std::collections::HashSet;
use std::path::{Path, PathBuf};
use serde_json::Value;
use crate::commands::dry_sampling_classifier as clf;
use crate::error::{CliError, Result};
pub(crate) fn run(observation_file: &Path, json: bool) -> Result<()> {
if !observation_file.exists() {
return Err(CliError::FileNotFound(PathBuf::from(observation_file)));
}
let body = std::fs::read_to_string(observation_file)?;
let obs: Value = serde_json::from_str(&body).map_err(|e| {
CliError::InvalidFormat(format!(
"apr dry-sampling-lint: failed to parse JSON from {}: {e}",
observation_file.display()
))
})?;
let params = classify_params(&obs);
let identity = classify_identity(&obs);
let match_len = classify_match_len(&obs);
let penalty = classify_penalty(&obs);
let monotone = classify_monotone(&obs);
let fail_reasons: Vec<String> = [
params.as_ref().and_then(params_fail_reason),
identity.as_ref().and_then(identity_fail_reason),
match_len.as_ref().and_then(match_len_fail_reason),
penalty.as_ref().and_then(penalty_fail_reason),
monotone.as_ref().and_then(monotone_fail_reason),
]
.into_iter()
.flatten()
.collect();
print_report(
observation_file,
params.as_ref(),
identity.as_ref(),
match_len.as_ref(),
penalty.as_ref(),
monotone.as_ref(),
json,
);
if fail_reasons.is_empty() {
Ok(())
} else {
Err(CliError::ValidationFailed(fail_reasons.join("; ")))
}
}
fn classify_params(obs: &Value) -> Option<clf::DryParamOutcome> {
let sec = obs.get("params")?.as_object()?;
let multiplier = sec.get("multiplier")?.as_f64()?;
let base = sec.get("base")?.as_f64()?;
let allowed_length = u32::try_from(sec.get("allowed_length")?.as_u64()?).ok()?;
Some(clf::classify_dry_params(multiplier, base, allowed_length))
}
fn classify_identity(obs: &Value) -> Option<clf::IdentityOutcome> {
let sec = obs.get("identity")?.as_object()?;
let before: Vec<f64> = sec
.get("logits_before")?
.as_array()?
.iter()
.map(|v| v.as_f64().unwrap_or(f64::NAN))
.collect();
let after: Vec<f64> = sec
.get("logits_after")?
.as_array()?
.iter()
.map(|v| v.as_f64().unwrap_or(f64::NAN))
.collect();
let multiplier = sec.get("multiplier")?.as_f64()?;
Some(clf::classify_dry_identity_zero_multiplier(
&before, &after, multiplier,
))
}
#[derive(Debug)]
pub(crate) enum MatchLenOutcome {
Ok { match_len: u32 },
Mismatch { expected: u32, actual: u32 },
}
fn classify_match_len(obs: &Value) -> Option<MatchLenOutcome> {
let sec = obs.get("match_len")?.as_object()?;
let ctx: Vec<u32> = sec
.get("ctx")?
.as_array()?
.iter()
.filter_map(|v| u32::try_from(v.as_u64()?).ok())
.collect();
let candidate = u32::try_from(sec.get("candidate")?.as_u64()?).ok()?;
let seq_breakers: HashSet<u32> = sec
.get("seq_breakers")
.and_then(|v| v.as_array())
.map(|arr| {
arr.iter()
.filter_map(|v| u32::try_from(v.as_u64()?).ok())
.collect()
})
.unwrap_or_default();
let expected = u32::try_from(sec.get("expected_match_len")?.as_u64()?).ok()?;
let actual = clf::classify_dry_match_len(&ctx, candidate, &seq_breakers);
if actual == expected {
Some(MatchLenOutcome::Ok { match_len: actual })
} else {
Some(MatchLenOutcome::Mismatch { expected, actual })
}
}
fn classify_penalty(obs: &Value) -> Option<clf::PenaltyOutcome> {
let sec = obs.get("penalty")?.as_object()?;
let match_len = u32::try_from(sec.get("match_len")?.as_u64()?).ok()?;
let allowed_length = u32::try_from(sec.get("allowed_length")?.as_u64()?).ok()?;
let multiplier = sec.get("multiplier")?.as_f64()?;
let base = sec.get("base")?.as_f64()?;
Some(clf::classify_dry_penalty(
match_len,
allowed_length,
multiplier,
base,
))
}
fn classify_monotone(obs: &Value) -> Option<clf::MonotonicityOutcome> {
let sec = obs.get("monotone")?.as_object()?;
let a = u32::try_from(sec.get("match_len_a")?.as_u64()?).ok()?;
let b = u32::try_from(sec.get("match_len_b")?.as_u64()?).ok()?;
let allowed_length = u32::try_from(sec.get("allowed_length")?.as_u64()?).ok()?;
let multiplier = sec.get("multiplier")?.as_f64()?;
let base = sec.get("base")?.as_f64()?;
Some(clf::classify_dry_penalty_monotone_in_match_len(
a,
b,
allowed_length,
multiplier,
base,
))
}
fn params_fail_reason(o: &clf::DryParamOutcome) -> Option<String> {
match o {
clf::DryParamOutcome::Valid => None,
clf::DryParamOutcome::NotFinite { field } => Some(format!(
"FALSIFY-CRUX-C-23-001 params: {field} is not finite"
)),
clf::DryParamOutcome::MultiplierNegative { multiplier } => Some(format!(
"FALSIFY-CRUX-C-23-001 params: multiplier={multiplier} < 0.0"
)),
clf::DryParamOutcome::BaseBelowOne { base } => {
Some(format!("FALSIFY-CRUX-C-23-001 params: base={base} < 1.0"))
}
clf::DryParamOutcome::AllowedLengthZero => {
Some("FALSIFY-CRUX-C-23-001 params: allowed_length == 0".to_string())
}
}
}
fn identity_fail_reason(o: &clf::IdentityOutcome) -> Option<String> {
match o {
clf::IdentityOutcome::Ok => None,
clf::IdentityOutcome::InvalidInput { reason } => Some(format!(
"FALSIFY-CRUX-C-23-001 identity: invalid input: {reason}"
)),
clf::IdentityOutcome::LogitsChanged {
first_diff_index,
before,
after,
} => Some(format!(
"FALSIFY-CRUX-C-23-001 identity: logit changed at idx {first_diff_index}: before={before} after={after}"
)),
}
}
fn match_len_fail_reason(o: &MatchLenOutcome) -> Option<String> {
match o {
MatchLenOutcome::Ok { .. } => None,
MatchLenOutcome::Mismatch { expected, actual } => Some(format!(
"FALSIFY-CRUX-C-23-002 match_len: expected={expected} actual={actual}"
)),
}
}
fn penalty_fail_reason(o: &clf::PenaltyOutcome) -> Option<String> {
match o {
clf::PenaltyOutcome::Ok { .. } => None,
clf::PenaltyOutcome::InvalidInput { reason } => Some(format!(
"FALSIFY-CRUX-C-23-002 penalty: invalid input: {reason}"
)),
clf::PenaltyOutcome::Negative { penalty } => Some(format!(
"FALSIFY-CRUX-C-23-002 penalty: penalty={penalty} < 0.0"
)),
}
}
fn monotone_fail_reason(o: &clf::MonotonicityOutcome) -> Option<String> {
match o {
clf::MonotonicityOutcome::Ok => None,
clf::MonotonicityOutcome::InvalidInput { reason } => Some(format!(
"FALSIFY-CRUX-C-23-002 monotone: invalid input: {reason}"
)),
clf::MonotonicityOutcome::Violation {
match_len_a,
match_len_b,
penalty_a,
penalty_b,
} => Some(format!(
"FALSIFY-CRUX-C-23-002 monotone: violation a={match_len_a}(p={penalty_a}) > b={match_len_b}(p={penalty_b})"
)),
}
}
#[allow(clippy::too_many_arguments)]
fn print_report(
path: &Path,
params: Option<&clf::DryParamOutcome>,
identity: Option<&clf::IdentityOutcome>,
match_len: Option<&MatchLenOutcome>,
penalty: Option<&clf::PenaltyOutcome>,
monotone: Option<&clf::MonotonicityOutcome>,
json: bool,
) {
if json {
let v = serde_json::json!({
"observation_path": path.display().to_string(),
"params": params.map(|o| format!("{o:?}")),
"identity": identity.map(|o| format!("{o:?}")),
"match_len": match_len.map(|o| format!("{o:?}")),
"penalty": penalty.map(|o| format!("{o:?}")),
"monotone": monotone.map(|o| format!("{o:?}")),
});
println!(
"{}",
serde_json::to_string_pretty(&v).unwrap_or_else(|_| v.to_string())
);
} else {
println!("dry-sampling-lint report for {}", path.display());
print_line(" params: ", params.map(|o| format!("{o:?}")));
print_line(" identity: ", identity.map(|o| format!("{o:?}")));
print_line(" match_len: ", match_len.map(|o| format!("{o:?}")));
print_line(" penalty: ", penalty.map(|o| format!("{o:?}")));
print_line(" monotone: ", monotone.map(|o| format!("{o:?}")));
}
}
fn print_line(prefix: &str, v: Option<String>) {
match v {
Some(s) => println!("{prefix}{s}"),
None => println!("{prefix}(missing fields — classifier skipped)"),
}
}