use crate::Result;
use crate::plugin::{Context, builder::PluginBuilder};
use std::collections::HashMap;
use std::sync::Arc;
use tracing::{debug, info, warn};
pub type Condition = Arc<dyn Fn(&Context) -> bool + Send + Sync>;
pub trait ConditionBuilder: Send + Sync {
fn name(&self) -> &str;
fn can_handle(&self, condition_str: &str) -> bool {
let condition_name = condition_str.split_whitespace().next().unwrap_or("");
condition_name == self.name()
|| condition_name.starts_with(&format!("{}.", self.name()))
|| (self.name().starts_with('!') && condition_str.starts_with(self.name()))
}
fn build(&self, condition_str: &str, builder: &PluginBuilder) -> Result<Condition>;
}
pub struct ConditionBuilderRegistry {
builders: HashMap<String, Arc<dyn ConditionBuilder>>,
}
impl ConditionBuilderRegistry {
pub fn new() -> Self {
Self {
builders: HashMap::new(),
}
}
}
impl Default for ConditionBuilderRegistry {
fn default() -> Self {
Self::new()
}
}
impl ConditionBuilderRegistry {
pub fn register(&mut self, builder: Arc<dyn ConditionBuilder>) {
let name = builder.name().to_string();
if self.builders.insert(name.clone(), builder).is_some() {
warn!("Overwriting existing condition builder: {}", name);
} else {
debug!("Registered condition builder: {}", name);
}
}
pub fn get_builder(&self, condition_str: &str) -> Option<Arc<dyn ConditionBuilder>> {
let condition_name = condition_str.split_whitespace().next().unwrap_or("");
if let Some(builder) = self.builders.get(condition_name) {
return Some(Arc::clone(builder));
}
for builder in self.builders.values() {
if builder.can_handle(condition_str) {
return Some(Arc::clone(builder));
}
}
None
}
pub fn builder_names(&self) -> Vec<&str> {
self.builders.keys().map(|s| s.as_str()).collect()
}
}
use std::sync::OnceLock;
static CONDITION_BUILDER_REGISTRY: OnceLock<ConditionBuilderRegistry> = OnceLock::new();
pub fn init_condition_builders() {
let _ = CONDITION_BUILDER_REGISTRY.get_or_init(|| {
let mut registry = ConditionBuilderRegistry::new();
use crate::plugin::condition::*;
registry.register(Arc::new(HasRespBuilder));
registry.register(Arc::new(RespIpBuilder));
registry.register(Arc::new(RespIpNegBuilder));
registry.register(Arc::new(QnameBuilder));
registry.register(Arc::new(QnameNegBuilder));
registry.register(Arc::new(QtypeBuilder));
registry.register(Arc::new(QclassBuilder));
registry.register(Arc::new(RcodeBuilder));
registry.register(Arc::new(HasCnameBuilder));
info!("Condition builder registry initialized");
registry
});
}
pub fn get_condition_builder_registry() -> &'static ConditionBuilderRegistry {
init_condition_builders();
CONDITION_BUILDER_REGISTRY
.get()
.expect("Registry should be initialized")
}
#[cfg(test)]
mod tests {
use super::*;
struct TestBuilder;
impl ConditionBuilder for TestBuilder {
fn name(&self) -> &str {
"test_condition"
}
fn build(&self, condition_str: &str, _builder: &PluginBuilder) -> Result<Condition> {
if condition_str == "test_condition" {
Ok(Arc::new(|_: &Context| true))
} else {
Err(crate::Error::Config("Invalid test condition".to_string()))
}
}
}
#[test]
fn test_registry_registration() {
let mut registry = ConditionBuilderRegistry::new();
let builder: Arc<dyn ConditionBuilder> = Arc::new(TestBuilder);
registry.register(Arc::clone(&builder));
assert!(registry.get_builder("test_condition").is_some());
}
#[test]
fn test_registry_lookup() {
let mut registry = ConditionBuilderRegistry::new();
let builder: Arc<dyn ConditionBuilder> = Arc::new(TestBuilder);
registry.register(builder);
let found = registry.get_builder("test_condition");
assert!(found.is_some());
let not_found = registry.get_builder("unknown");
assert!(not_found.is_none());
}
#[test]
fn test_builder_names() {
let mut registry = ConditionBuilderRegistry::new();
let builder: Arc<dyn ConditionBuilder> = Arc::new(TestBuilder);
registry.register(builder);
let names = registry.builder_names();
assert!(names.contains(&"test_condition"));
}
#[test]
fn test_default_builders_registered() {
init_condition_builders();
let registry = get_condition_builder_registry();
let names = registry.builder_names();
let expected = [
"has_resp",
"resp_ip",
"!resp_ip",
"qname",
"!qname",
"qtype",
"qclass",
"rcode",
"has_cname",
];
for &n in expected.iter() {
assert!(
names.contains(&n),
"Expected condition builder '{}' to be registered",
n
);
}
}
}