use std::collections::{HashMap, HashSet};
use std::net::IpAddr;
use std::str::FromStr;
use ipnet::IpNet;
use cellos_core::DnsRebindingPolicy;
#[derive(Debug, Default, Clone)]
pub struct RebindingState {
histories: HashMap<String, Vec<String>>,
}
pub struct RebindingDecision<'a> {
pub novel_ips: Vec<&'a str>,
pub threshold_exceeded: bool,
pub allowlist_violations: Vec<&'a str>,
pub effective_targets: Vec<String>,
}
impl RebindingState {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn hostname_count(&self) -> usize {
self.histories.len()
}
#[must_use]
pub fn history(&self, hostname: &str) -> &[String] {
self.histories
.get(hostname)
.map(|v| v.as_slice())
.unwrap_or(&[])
}
pub fn evaluate<'a>(
&self,
hostname: &str,
new_targets: &'a [String],
policy: &DnsRebindingPolicy,
) -> RebindingDecision<'a> {
let prior: HashSet<&String> = self
.histories
.get(hostname)
.map(|v| v.iter().collect())
.unwrap_or_default();
let mut novel_ips: Vec<&str> = Vec::new();
let mut novel_seen: HashSet<&str> = HashSet::new();
for ip in new_targets {
let s: &str = ip.as_str();
if !prior.iter().any(|p| p.as_str() == s) && novel_seen.insert(s) {
novel_ips.push(s);
}
}
let prior_len = prior.len() as u64;
let novel_len = novel_ips.len() as u64;
let cap = u64::from(policy.max_novel_ips_per_hostname);
let threshold_exceeded = novel_len > 0 && (prior_len.saturating_add(novel_len)) > cap;
let mut allowlist_violations: Vec<&str> = Vec::new();
if !policy.response_ip_allowlist.is_empty() {
let entries = parse_allowlist_for_hostname(&policy.response_ip_allowlist, hostname);
for ip_str in new_targets {
let s: &str = ip_str.as_str();
if !ip_in_allowlist(s, &entries) {
allowlist_violations.push(s);
}
}
}
let effective_targets: Vec<String> = if policy.reject_on_rebind {
let mut dropped_novel: HashSet<&str> = HashSet::new();
if threshold_exceeded {
let keep_novel = cap.saturating_sub(prior_len) as usize;
for &novel in novel_ips.iter().skip(keep_novel) {
dropped_novel.insert(novel);
}
}
let dropped_allowlist: HashSet<&str> = allowlist_violations.iter().copied().collect();
new_targets
.iter()
.filter(|t| {
!dropped_novel.contains(t.as_str()) && !dropped_allowlist.contains(t.as_str())
})
.cloned()
.collect()
} else {
new_targets.to_vec()
};
RebindingDecision {
novel_ips,
threshold_exceeded,
allowlist_violations,
effective_targets,
}
}
pub fn commit(&mut self, hostname: &str, effective_targets: &[String]) {
let history = self.histories.entry(hostname.to_string()).or_default();
for t in effective_targets {
if !history.iter().any(|h| h == t) {
history.push(t.clone());
}
}
}
}
#[derive(Debug)]
enum AllowlistMatcher {
Ip(IpAddr),
Net(IpNet),
}
fn parse_allowlist_for_hostname(entries: &[String], hostname: &str) -> Vec<AllowlistMatcher> {
let mut out: Vec<AllowlistMatcher> = Vec::new();
for raw in entries {
let Some((prefix, suffix)) = raw.split_once(':') else {
continue; };
if prefix != hostname {
continue; }
let suffix = suffix.trim();
if suffix.is_empty() {
continue; }
if suffix.contains('/') {
if let Ok(net) = IpNet::from_str(suffix) {
out.push(AllowlistMatcher::Net(net));
}
} else if let Ok(ip) = IpAddr::from_str(suffix) {
out.push(AllowlistMatcher::Ip(ip));
}
}
out
}
fn ip_in_allowlist(ip_str: &str, entries: &[AllowlistMatcher]) -> bool {
let Ok(ip) = IpAddr::from_str(ip_str) else {
return false;
};
entries.iter().any(|e| match e {
AllowlistMatcher::Ip(matcher) => *matcher == ip,
AllowlistMatcher::Net(net) => net.contains(&ip),
})
}
#[cfg(test)]
mod tests {
use super::*;
fn policy_default() -> DnsRebindingPolicy {
DnsRebindingPolicy::default()
}
fn policy_with(max: u32, reject: bool, allowlist: Vec<&str>) -> DnsRebindingPolicy {
DnsRebindingPolicy {
response_ip_allowlist: allowlist.into_iter().map(String::from).collect(),
max_novel_ips_per_hostname: max,
reject_on_rebind: reject,
}
}
fn s(items: &[&str]) -> Vec<String> {
items.iter().map(|s| (*s).to_string()).collect()
}
#[test]
fn evaluate_returns_all_novel_when_first_observation() {
let state = RebindingState::new();
let new_targets = s(&["1.1.1.1", "1.0.0.1"]);
let policy = policy_default();
let decision = state.evaluate("api.example.com", &new_targets, &policy);
assert_eq!(decision.novel_ips, vec!["1.1.1.1", "1.0.0.1"]);
}
#[test]
fn evaluate_returns_no_novel_when_repeat_observation() {
let mut state = RebindingState::new();
let first = s(&["1.1.1.1", "1.0.0.1"]);
state.commit("api.example.com", &first);
let new_targets = s(&["1.1.1.1", "1.0.0.1"]);
let policy = policy_default();
let decision = state.evaluate("api.example.com", &new_targets, &policy);
assert!(decision.novel_ips.is_empty());
assert!(!decision.threshold_exceeded);
}
#[test]
fn evaluate_threshold_exceeded_above_max_novel_ips() {
let mut state = RebindingState::new();
state.commit("h", &s(&["1.0.0.1", "1.0.0.2", "1.0.0.3"]));
let policy = policy_with(4, false, vec![]);
let new_targets = s(&["1.0.0.4", "1.0.0.5"]);
let decision = state.evaluate("h", &new_targets, &policy);
assert!(decision.threshold_exceeded);
assert_eq!(decision.novel_ips, vec!["1.0.0.4", "1.0.0.5"]);
}
#[test]
fn evaluate_threshold_not_exceeded_at_exact_max() {
let mut state = RebindingState::new();
state.commit("h", &s(&["1.0.0.1", "1.0.0.2", "1.0.0.3"]));
let policy = policy_with(4, false, vec![]);
let new_targets = s(&["1.0.0.4"]);
let decision = state.evaluate("h", &new_targets, &policy);
assert!(!decision.threshold_exceeded);
assert_eq!(decision.novel_ips, vec!["1.0.0.4"]);
}
#[test]
fn evaluate_allowlist_violations_when_set() {
let state = RebindingState::new();
let policy = policy_with(
10,
false,
vec!["api.example.com:1.1.1.1", "api.example.com:1.0.0.1"],
);
let new_targets = s(&["1.1.1.1", "198.51.100.7"]);
let decision = state.evaluate("api.example.com", &new_targets, &policy);
assert_eq!(decision.allowlist_violations, vec!["198.51.100.7"]);
}
#[test]
fn evaluate_no_allowlist_violations_when_unset() {
let state = RebindingState::new();
let policy = policy_with(10, false, vec![]);
let new_targets = s(&["198.51.100.7"]);
let decision = state.evaluate("api.example.com", &new_targets, &policy);
assert!(decision.allowlist_violations.is_empty());
}
#[test]
fn evaluate_reject_on_rebind_filters_novel_above_threshold() {
let mut state = RebindingState::new();
state.commit("h", &s(&["1.0.0.1", "1.0.0.2", "1.0.0.3", "1.0.0.4"])); let policy = policy_with(4, true, vec![]);
let new_targets = s(&["1.0.0.4", "1.0.0.5", "1.0.0.6"]);
let decision = state.evaluate("h", &new_targets, &policy);
assert!(decision.threshold_exceeded);
assert_eq!(decision.effective_targets, vec!["1.0.0.4".to_string()]);
}
#[test]
fn evaluate_reject_on_rebind_filters_allowlist_violations() {
let state = RebindingState::new();
let policy = policy_with(10, true, vec!["api.example.com:1.1.1.1"]);
let new_targets = s(&["1.1.1.1", "198.51.100.7"]);
let decision = state.evaluate("api.example.com", &new_targets, &policy);
assert_eq!(decision.allowlist_violations, vec!["198.51.100.7"]);
assert_eq!(decision.effective_targets, vec!["1.1.1.1".to_string()]);
}
#[test]
fn evaluate_audit_only_keeps_violations_in_effective_targets() {
let state = RebindingState::new();
let policy = policy_with(10, false, vec!["api.example.com:1.1.1.1"]);
let new_targets = s(&["1.1.1.1", "198.51.100.7"]);
let decision = state.evaluate("api.example.com", &new_targets, &policy);
assert_eq!(decision.allowlist_violations, vec!["198.51.100.7"]);
assert_eq!(decision.effective_targets, new_targets);
}
#[test]
fn commit_persists_observation() {
let mut state = RebindingState::new();
assert_eq!(state.hostname_count(), 0);
state.commit("h", &s(&["1.1.1.1"]));
assert_eq!(state.hostname_count(), 1);
assert_eq!(state.history("h"), &["1.1.1.1".to_string()]);
state.commit("h", &s(&["1.1.1.1", "1.0.0.1"]));
assert_eq!(
state.history("h"),
&["1.1.1.1".to_string(), "1.0.0.1".to_string()]
);
}
#[test]
fn parse_allowlist_skips_malformed_entries() {
let entries: Vec<String> = vec![
"no-colon-here".into(), "h:".into(), "h:not-an-ip".into(), "h:999.999.999.999".into(), "h:1.1.1.1/notanumber".into(), "h:1.1.1.1".into(), ];
let parsed = parse_allowlist_for_hostname(&entries, "h");
assert_eq!(parsed.len(), 1, "only the well-formed entry survives");
}
#[test]
fn parse_allowlist_supports_cidr() {
let entries: Vec<String> = vec!["h:203.0.113.0/24".into()];
let parsed = parse_allowlist_for_hostname(&entries, "h");
assert_eq!(parsed.len(), 1);
assert!(ip_in_allowlist("203.0.113.42", &parsed));
assert!(!ip_in_allowlist("203.0.114.42", &parsed));
}
#[test]
fn allowlist_entries_for_other_hostname_are_ignored() {
let state = RebindingState::new();
let policy = policy_with(10, false, vec!["other.example.com:1.1.1.1"]);
let new_targets = s(&["1.1.1.1"]);
let decision = state.evaluate("h", &new_targets, &policy);
assert_eq!(decision.allowlist_violations, vec!["1.1.1.1"]);
}
}