use std::collections::HashMap;
use veilus_fingerprint_core::FingerprintError;
use veilus_fingerprint_data::network::BayesianNetwork;
use rand::Rng;
use super::sampler::sample_ancestral;
pub type Constraints = HashMap<String, Vec<String>>;
pub fn sample_constrained(
network: &BayesianNetwork,
constraints: &Constraints,
rng: &mut impl Rng,
) -> Result<HashMap<String, String>, FingerprintError> {
let node_count = network.nodes.len();
let max_backtracks = node_count * 10;
for attempt in 0..=max_backtracks {
let assignment = sample_ancestral(network, rng)?;
let satisfied = constraints.iter().all(|(node_name, allowed_values)| {
if allowed_values.is_empty() {
return true; }
assignment
.get(node_name)
.map(|v| allowed_values.iter().any(|a| a == v))
.unwrap_or(false)
});
if satisfied {
tracing::debug!(
attempt,
"Constrained sampling succeeded after {} attempt(s)",
attempt + 1
);
return Ok(assignment);
}
tracing::debug!(attempt, "Constrained sampling attempt failed, retrying");
}
let constraint_desc = constraints
.iter()
.filter(|(_, v)| !v.is_empty())
.map(|(k, v)| format!("{k}=[{}]", v.join(",")))
.collect::<Vec<_>>()
.join(", ");
Err(FingerprintError::ConstraintsTooRestrictive(format!(
"Could not satisfy constraints after {max_backtracks} attempts: {constraint_desc}"
)))
}
#[cfg(test)]
mod tests {
use super::*;
use veilus_fingerprint_data::loader::{get_fingerprint_network, get_header_network};
use rand::SeedableRng;
use rand_chacha::ChaCha8Rng;
#[test]
fn header_browser_prefix_constraint_always_satisfied() {
let network = get_header_network().expect("must load");
let mut rng = ChaCha8Rng::seed_from_u64(12345);
let browser_node = network
.nodes
.iter()
.find(|n| n.name == "*BROWSER")
.expect("*BROWSER node must exist");
let first_value = browser_node.possible_values[0].clone();
let mut constraints = Constraints::new();
constraints.insert("*BROWSER".to_string(), vec![first_value.clone()]);
for _ in 0..5 {
let assignment = sample_constrained(network, &constraints, &mut rng)
.expect("constraint must be satisfiable");
assert_eq!(
assignment.get("*BROWSER").map(|s| s.as_str()),
Some(first_value.as_str()),
"*BROWSER constraint must be respected"
);
}
}
#[test]
fn fingerprint_ua_constraint_satisfied() {
let network = get_fingerprint_network().expect("must load");
let mut rng = ChaCha8Rng::seed_from_u64(999);
let ua_node = network
.nodes
.iter()
.find(|n| n.name == "userAgent")
.expect("userAgent node must exist");
let target_ua = ua_node.possible_values[0].clone();
let mut constraints = Constraints::new();
constraints.insert("userAgent".to_string(), vec![target_ua.clone()]);
let assignment = sample_constrained(network, &constraints, &mut rng)
.expect("userAgent constraint must be satisfiable");
assert_eq!(
assignment.get("userAgent").map(String::as_str),
Some(target_ua.as_str())
);
}
#[test]
fn impossible_constraint_returns_error() {
let network = get_header_network().expect("must load");
let mut rng = ChaCha8Rng::seed_from_u64(99999);
let mut constraints = Constraints::new();
constraints.insert(
"*BROWSER".to_string(),
vec!["netscape99/1.0".to_string()],
);
let result = sample_constrained(network, &constraints, &mut rng);
assert!(
matches!(
result,
Err(FingerprintError::ConstraintsTooRestrictive(_))
),
"Impossible constraint must return ConstraintsTooRestrictive, got: {:?}",
result.err()
);
}
#[test]
fn empty_constraints_always_succeeds() {
let network = get_header_network().expect("must load");
let mut rng = ChaCha8Rng::seed_from_u64(42);
let constraints = Constraints::new();
let result = sample_constrained(network, &constraints, &mut rng);
assert!(result.is_ok(), "Empty constraints must always succeed");
}
}