use parlov_core::{Error, Vector};
use parlov_elicit::{ProbeSpec, RiskLevel};
#[derive(Debug, Clone, PartialEq, Eq)]
pub(crate) struct VectorFilter {
pub vector: Vector,
pub risk: Option<RiskLevel>,
}
pub(crate) fn parse_vector_flag(s: &str) -> Result<VectorFilter, Error> {
let (vector_str, risk_str) = match s.split_once(':') {
Some((v, r)) => (v, Some(r)),
None => (s, None),
};
let vector = parse_vector_name(vector_str)?;
let risk = risk_str.map(parse_risk_name).transpose()?;
Ok(VectorFilter { vector, risk })
}
fn parse_vector_name(s: &str) -> Result<Vector, Error> {
match s {
"cache-probing" => Ok(Vector::CacheProbing),
"error-message-granularity" => Ok(Vector::ErrorMessageGranularity),
"redirect-diff" => Ok(Vector::RedirectDiff),
"status-code-diff" => Ok(Vector::StatusCodeDiff),
other => Err(Error::Http(format!(
"unknown vector '{other}'; expected cache-probing | error-message-granularity | redirect-diff | status-code-diff"
))),
}
}
fn parse_risk_name(s: &str) -> Result<RiskLevel, Error> {
match s {
"safe" => Ok(RiskLevel::Safe),
"method-destructive" => Ok(RiskLevel::MethodDestructive),
"operation-destructive" => Ok(RiskLevel::OperationDestructive),
other => Err(Error::Http(format!(
"invalid risk level '{other}'; expected safe | method-destructive | operation-destructive"
))),
}
}
fn spec_risk(spec: &ProbeSpec) -> RiskLevel {
match spec {
ProbeSpec::Pair(p) | ProbeSpec::HeaderDiff(p) => p.metadata.risk,
ProbeSpec::Burst(b) => b.metadata.risk,
}
}
pub(crate) fn apply_vector_filters(plan: Vec<ProbeSpec>, filters: &[VectorFilter]) -> Vec<ProbeSpec> {
plan.into_iter()
.filter(|spec| {
let technique_vector = spec.technique().vector;
let strategy_risk = spec_risk(spec);
filters.iter().any(|f| {
let ceiling = f.risk.unwrap_or(RiskLevel::Safe);
technique_vector == f.vector && strategy_risk <= ceiling
})
})
.collect()
}
pub(crate) fn max_risk_from_filters(filters: &[VectorFilter]) -> RiskLevel {
filters
.iter()
.map(|f| f.risk.unwrap_or(RiskLevel::Safe))
.max()
.unwrap_or(RiskLevel::Safe)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_vector_name_only() {
let f = parse_vector_flag("cache-probing").unwrap();
assert_eq!(f.vector, Vector::CacheProbing);
assert_eq!(f.risk, None);
}
#[test]
fn parse_vector_with_risk() {
let f = parse_vector_flag("status-code-diff:method-destructive").unwrap();
assert_eq!(f.vector, Vector::StatusCodeDiff);
assert_eq!(f.risk, Some(RiskLevel::MethodDestructive));
}
#[test]
fn parse_vector_with_safe_risk() {
let f = parse_vector_flag("cache-probing:safe").unwrap();
assert_eq!(f.vector, Vector::CacheProbing);
assert_eq!(f.risk, Some(RiskLevel::Safe));
}
#[test]
fn parse_vector_unknown_name_returns_err() {
assert!(parse_vector_flag("unknown-vector").is_err());
}
#[test]
fn parse_vector_unknown_risk_returns_err() {
assert!(parse_vector_flag("cache-probing:unknown").is_err());
}
#[test]
fn max_risk_defaults_to_safe_when_no_risk_specified() {
let filters = vec![
VectorFilter { vector: Vector::CacheProbing, risk: None },
];
assert_eq!(max_risk_from_filters(&filters), RiskLevel::Safe);
}
#[test]
fn max_risk_picks_highest_across_filters() {
let filters = vec![
VectorFilter { vector: Vector::CacheProbing, risk: Some(RiskLevel::Safe) },
VectorFilter { vector: Vector::StatusCodeDiff, risk: Some(RiskLevel::MethodDestructive) },
];
assert_eq!(max_risk_from_filters(&filters), RiskLevel::MethodDestructive);
}
#[test]
fn max_risk_empty_filters_returns_safe() {
assert_eq!(max_risk_from_filters(&[]), RiskLevel::Safe);
}
#[test]
fn parse_vector_emg_name_only() {
let f = parse_vector_flag("error-message-granularity").unwrap();
assert_eq!(f.vector, Vector::ErrorMessageGranularity);
assert_eq!(f.risk, None);
}
#[test]
fn parse_vector_emg_with_risk() {
let f = parse_vector_flag("error-message-granularity:method-destructive").unwrap();
assert_eq!(f.vector, Vector::ErrorMessageGranularity);
assert_eq!(f.risk, Some(RiskLevel::MethodDestructive));
}
#[test]
fn parse_vector_emg_with_safe_risk() {
let f = parse_vector_flag("error-message-granularity:safe").unwrap();
assert_eq!(f.vector, Vector::ErrorMessageGranularity);
assert_eq!(f.risk, Some(RiskLevel::Safe));
}
#[test]
fn parse_vector_redirect_diff_name_only() {
let f = parse_vector_flag("redirect-diff").unwrap();
assert_eq!(f.vector, Vector::RedirectDiff);
assert_eq!(f.risk, None);
}
#[test]
fn parse_vector_redirect_diff_with_risk() {
let f = parse_vector_flag("redirect-diff:method-destructive").unwrap();
assert_eq!(f.vector, Vector::RedirectDiff);
assert_eq!(f.risk, Some(RiskLevel::MethodDestructive));
}
#[test]
fn parse_vector_redirect_diff_with_safe_risk() {
let f = parse_vector_flag("redirect-diff:safe").unwrap();
assert_eq!(f.vector, Vector::RedirectDiff);
assert_eq!(f.risk, Some(RiskLevel::Safe));
}
#[test]
fn filter_emg_specs_only() {
use http::{HeaderMap, Method};
use parlov_core::{NormativeStrength, OracleClass, ProbeDefinition, Technique};
use parlov_elicit::{ProbePair, StrategyMetadata};
let make_def = || ProbeDefinition {
url: "https://example.com/1".to_owned(),
method: Method::GET,
headers: HeaderMap::new(),
body: None,
};
let make_pair = |vector: Vector| ProbeSpec::Pair(ProbePair {
baseline: make_def(),
probe: make_def(),
metadata: StrategyMetadata {
strategy_id: "test",
strategy_name: "Test",
risk: RiskLevel::Safe,
},
technique: Technique {
id: "test",
name: "Test",
oracle_class: OracleClass::Existence,
vector,
strength: NormativeStrength::Should,
},
});
let plan = vec![
make_pair(Vector::StatusCodeDiff),
make_pair(Vector::ErrorMessageGranularity),
make_pair(Vector::CacheProbing),
make_pair(Vector::ErrorMessageGranularity),
];
let filters = vec![VectorFilter {
vector: Vector::ErrorMessageGranularity,
risk: None,
}];
let filtered = apply_vector_filters(plan, &filters);
assert_eq!(filtered.len(), 2);
for spec in &filtered {
assert_eq!(spec.technique().vector, Vector::ErrorMessageGranularity);
}
}
}