use crate::Error;
use crate::Result;
use crate::config::types::PluginConfig;
use crate::dns_type_match;
use crate::plugin::traits::Matcher;
use crate::plugin::{Context, Plugin};
use crate::plugins::executable::SequenceStep;
use crate::plugins::*;
use serde_yaml::Value;
use std::collections::HashMap;
use std::sync::Arc;
use tracing::trace;
use tracing::{debug, error, info, warn};
pub struct PluginBuilder {
plugins: HashMap<String, Arc<dyn Plugin>>,
server_plugin_tags: Vec<String>,
}
impl PluginBuilder {
pub fn new() -> Self {
info!("Initializing plugin builder system...");
crate::plugin::factory::init();
Self {
plugins: HashMap::new(),
server_plugin_tags: Vec::new(),
}
}
pub fn build(&mut self, config: &PluginConfig) -> Result<Arc<dyn Plugin>> {
let plugin_type = config.plugin_type.trim().to_lowercase();
if let Some(builder) = crate::plugin::factory::get_plugin_factory(&plugin_type) {
trace!(
name = %config.effective_name(),
plugin_type = %plugin_type,
"Creating plugin using registered builder"
);
let plugin = builder.create(config)?;
let effective_name = config.effective_name().to_string();
self.plugins.insert(effective_name, Arc::clone(&plugin));
return Ok(plugin);
}
trace!(
name = %config.effective_name(),
plugin_type = %plugin_type,
"Plugin type not found in builder registry, using legacy match",
);
let plugin: Arc<dyn Plugin> = match plugin_type.as_str() {
"sequence" => {
if let Value::Mapping(map) = &config.args {
if let Some(plugins_value) = map.get(Value::String("plugins".to_string())) {
if let Value::Sequence(plugin_names) = plugins_value {
let plugins = Vec::new();
for name_value in plugin_names {
if let Value::String(_name) = name_value {
warn!(
"Sequence plugin with 'plugins' key not fully implemented yet"
);
}
}
Arc::new(SequencePlugin::new(plugins))
} else {
return Err(Error::Config(
"sequence 'plugins' must be an array".to_string(),
));
}
} else {
warn!(
"Sequence plugin mapping format not implemented yet, using empty sequence"
);
Arc::new(SequencePlugin::new(Vec::new()))
}
} else if let Value::Sequence(sequence) = &config.args {
match parse_sequence_steps(self, sequence) {
Ok(steps) => Arc::new(SequencePlugin::with_steps(steps)),
Err(e) => {
warn!(
"Failed to parse complex sequence: {}, using empty sequence",
e
);
Arc::new(SequencePlugin::new(Vec::new()))
}
}
} else {
warn!("Sequence plugin args format not recognized, using empty sequence");
Arc::new(SequencePlugin::new(Vec::new()))
}
}
"tcp_server" | "udp_server" | "doh_server" | "dot_server" | "doq_server" => {
let tag = config.effective_name().to_string();
self.server_plugin_tags.push(tag.clone());
Arc::new(crate::plugins::AcceptPlugin::new())
}
_ => {
return Err(Error::Config(format!(
"Unknown plugin type: {}",
plugin_type
)));
}
};
let effective_name = config.effective_name().to_string();
self.plugins.insert(effective_name, Arc::clone(&plugin));
Ok(plugin)
}
pub fn resolve_references(&mut self, configs: &[PluginConfig]) -> Result<()> {
for config in configs {
if config.plugin_type == "sequence"
&& let Value::Sequence(sequence) = &config.args
{
match parse_sequence_steps(self, sequence) {
Ok(steps) => {
let sequence_plugin = Arc::new(SequencePlugin::with_steps_and_tag(
steps,
config.tag.clone(),
));
let name = config.effective_name().to_string();
let dname = sequence_plugin.display_name().to_string();
self.plugins.insert(name.clone(), sequence_plugin);
trace!(
"Updated sequence plugin '{}' with resolved references (display={})",
name, dname
);
}
Err(e) => {
warn!(
"Failed to update sequence '{}': {}",
config.effective_name(),
e
);
}
}
}
}
for config in configs {
if config.plugin_type == "fallback" {
let name = config.effective_name().to_string();
debug!("Resolving fallback plugin: {}", name);
if let Some(plugin) = self.plugins.get(&name).cloned() {
if let Some(fp) = plugin.as_ref().as_any().downcast_ref::<FallbackPlugin>() {
fp.resolve_children(&self.plugins);
} else {
warn!(plugin = %name, "Plugin registered under name is not a FallbackPlugin");
}
} else {
warn!(plugin = %name, "Fallback plugin not found in registry");
}
}
}
Ok(())
}
pub fn get_plugin(&self, name: &str) -> Option<Arc<dyn Plugin>> {
self.plugins.get(name).cloned()
}
pub fn get_all_plugins(&self) -> Vec<Arc<dyn Plugin>> {
self.plugins.values().cloned().collect()
}
pub fn into_registry(self) -> crate::plugin::Registry {
let mut registry = crate::plugin::Registry::new();
for (name, plugin) in self.plugins {
registry.register_replace_with_name(&name, plugin);
}
registry
}
pub fn get_registry(&self) -> crate::plugin::Registry {
let mut registry = crate::plugin::Registry::new();
for (name, plugin) in &self.plugins {
registry.register_replace_with_name(name, Arc::clone(plugin));
}
registry
}
pub fn get_server_plugin_tags(&self) -> &[String] {
&self.server_plugin_tags
}
pub async fn shutdown_all(&self) -> Result<()> {
for (name, plugin) in &self.plugins {
if let Some(sh) = plugin.as_shutdown() {
info!("Shutting down plugin: {}", name);
if let Err(e) = sh.shutdown().await {
error!("Error shutting down plugin {}: {}", name, e);
return Err(e);
}
}
}
Ok(())
}
pub fn start_background_tasks(&self) -> Vec<tokio::task::JoinHandle<()>> {
let mut background_tasks = Vec::new();
for (plugin_name, plugin) in &self.plugins {
if plugin
.as_any()
.downcast_ref::<CachePlugin>()
.is_some_and(|cache_plugin| cache_plugin.is_cleanup_enabled())
{
let cache_plugin = plugin.as_any().downcast_ref::<CachePlugin>().unwrap();
info!(
"Starting background cleanup task for cache plugin '{}'",
plugin_name
);
let cache_arc = Arc::new((*cache_plugin).clone());
let cleanup_handle = cache_arc.spawn_cleanup_task();
background_tasks.push(cleanup_handle);
}
}
background_tasks
}
}
impl Default for PluginBuilder {
fn default() -> Self {
Self::new()
}
}
fn parse_sequence_steps(builder: &PluginBuilder, sequence: &[Value]) -> Result<Vec<SequenceStep>> {
use crate::plugins::executable::SequenceStep;
trace!("Parsing {} sequence steps", sequence.len());
let mut steps = Vec::new();
for step_value in sequence {
match step_value {
Value::Mapping(map) => {
if let Some(matches_value) = map.get(Value::String("matches".to_string())) {
if let Value::String(condition_str) = matches_value {
let condition = parse_condition(builder, condition_str)?;
if let Some(exec_value) = map.get(Value::String("exec".to_string())) {
let action = parse_exec_action(builder, exec_value)?;
steps.push(SequenceStep::If {
condition,
action,
desc: condition_str.to_string(),
});
} else {
return Err(Error::Config("matches step must have exec".to_string()));
}
} else {
return Err(Error::Config("matches value must be string".to_string()));
}
} else if let Some(exec_value) = map.get(Value::String("exec".to_string())) {
let plugin = parse_exec_action(builder, exec_value)?;
steps.push(SequenceStep::Exec(plugin));
} else {
return Err(Error::Config(
"sequence step must have exec or matches".to_string(),
));
}
}
_ => return Err(Error::Config("sequence step must be a mapping".to_string())),
}
}
Ok(steps)
}
fn parse_exec_action(builder: &PluginBuilder, exec_value: &Value) -> Result<Arc<dyn Plugin>> {
match exec_value {
Value::String(exec_str) => {
let (prefix, exec_args) = if let Some(space_pos) = exec_str.find(' ') {
let (p, rest) = exec_str.split_at(space_pos);
(p.trim(), rest.trim())
} else {
(exec_str.as_str(), "")
};
if let Some(factory) = crate::plugin::factory::get_exec_plugin_factory(prefix) {
return factory.create(prefix, exec_args);
}
if let Some(plugin_name) = exec_str.strip_prefix('$') {
if let Some(plugin) = builder.get_plugin(plugin_name) {
Ok(plugin)
} else {
Err(Error::Config(format!(
"Referenced plugin '{}' not found",
plugin_name
)))
}
} else {
Err(Error::Config(format!("Unknown exec action: {}", exec_str)))
}
}
_ => Err(Error::Config("exec value must be string".to_string())),
}
}
#[allow(clippy::type_complexity)]
fn parse_condition(
builder: &PluginBuilder,
condition_str: &str,
) -> Result<Arc<dyn Fn(&Context) -> bool + Send + Sync>> {
use crate::plugin::condition::builder::get_condition_builder_registry;
let registry = get_condition_builder_registry();
if let Some(condition_builder) = registry.get_builder(condition_str) {
condition_builder.build(condition_str, builder)
} else {
Err(Error::Config(format!(
"Unknown condition: {}",
condition_str
)))
}
}
#[allow(clippy::type_complexity, dead_code)]
#[deprecated(
since = "0.2.43",
note = "Legacy condition parsing is deprecated. Please use the new condition builder framework."
)]
fn legacy_parse_condition(
builder: &PluginBuilder,
condition_str: &str,
) -> Result<Arc<dyn Fn(&Context) -> bool + Send + Sync>> {
if condition_str == "has_resp" {
Ok(Arc::new(|ctx: &crate::plugin::Context| ctx.has_response()))
} else if let Some(ip_set_ref) = condition_str.strip_prefix("resp_ip ") {
let ip_set_name = if let Some(name) = ip_set_ref.strip_prefix('$') {
name
} else {
ip_set_ref
};
if let Some(plugin) = builder.get_plugin(ip_set_name) {
if plugin.name() == "ip_set" {
let plugin_clone = Arc::clone(&plugin);
Ok(Arc::new(move |ctx: &crate::plugin::Context| {
if let Some(matcher) = plugin_clone
.as_ref()
.as_any()
.downcast_ref::<crate::plugins::dataset::IpSetPlugin>()
{
matcher.matches_context(ctx)
} else {
false
}
}))
} else {
warn!("Plugin '{}' is not an IP set plugin", ip_set_name);
Ok(Arc::new(|_ctx: &crate::plugin::Context| false))
}
} else {
warn!("IP set plugin '{}' not found", ip_set_name);
Ok(Arc::new(|_ctx: &crate::plugin::Context| false))
}
} else if let Some(ip_set_ref) = condition_str.strip_prefix("!resp_ip ") {
let ip_set_name = if let Some(name) = ip_set_ref.strip_prefix('$') {
name
} else {
ip_set_ref
};
if let Some(plugin) = builder.get_plugin(ip_set_name) {
if plugin.name() == "ip_set" {
let plugin_clone = Arc::clone(&plugin);
Ok(Arc::new(move |ctx: &crate::plugin::Context| {
if let Some(matcher) = plugin_clone
.as_ref()
.as_any()
.downcast_ref::<crate::plugins::dataset::IpSetPlugin>()
{
!matcher.matches_context(ctx)
} else {
true
}
}))
} else {
warn!("Plugin '{}' is not an IP set plugin", ip_set_name);
Ok(Arc::new(|_ctx: &crate::plugin::Context| true))
}
} else {
warn!("IP set plugin '{}' not found", ip_set_name);
Ok(Arc::new(|_ctx: &crate::plugin::Context| true))
}
} else if let Some(domain_set_ref) = condition_str.strip_prefix("qname ") {
let domain_set_name = if let Some(name) = domain_set_ref.strip_prefix('$') {
name
} else {
domain_set_ref
};
if let Some(plugin) = builder.get_plugin(domain_set_name) {
if plugin.name() == "domain_set" {
let plugin_clone = Arc::clone(&plugin);
Ok(Arc::new(move |ctx: &crate::plugin::Context| {
if let Some(matcher) = plugin_clone
.as_ref()
.as_any()
.downcast_ref::<crate::plugins::dataset::DomainSetPlugin>(
) {
matcher.matches_context(ctx)
} else {
false
}
}))
} else {
warn!("Plugin '{}' is not a domain set plugin", domain_set_name);
Ok(Arc::new(|_ctx: &crate::plugin::Context| false))
}
} else {
warn!("Domain set plugin '{}' not found", domain_set_name);
Ok(Arc::new(|_ctx: &crate::plugin::Context| false))
}
} else if let Some(domain) = condition_str.strip_prefix("!qname ") {
let domain_lower = domain.to_lowercase();
Ok(Arc::new(move |ctx: &crate::plugin::Context| {
if let Some(question) = ctx.request().questions().first() {
let qname = question.qname().to_string().to_lowercase();
!qname.eq(&domain_lower)
} else {
true
}
}))
} else if condition_str.starts_with("qtype ") {
let type_str = condition_str.strip_prefix("qtype ").unwrap_or_default();
let mut qtypes = Vec::new();
for type_part in type_str.split_whitespace() {
match type_part.parse::<u16>() {
Ok(qtype_num) => {
qtypes.push(qtype_num);
}
Err(_) => {
return Err(Error::Config(format!(
"Invalid query type number '{}': {}",
type_part, condition_str
)));
}
}
}
if qtypes.is_empty() {
return Err(Error::Config(format!(
"No query types specified: {}",
condition_str
)));
}
Ok(Arc::new(move |ctx: &crate::plugin::Context| {
if let Some(question) = ctx.request().questions().first() {
let qtype = question.qtype().to_u16();
qtypes.contains(&qtype)
} else {
false
}
}))
} else if condition_str.starts_with("qclass ") {
let class_str = condition_str.strip_prefix("qclass ").unwrap_or_default();
let mut qclasses = Vec::new();
for class_part in class_str.split_whitespace() {
let class_val =
dns_type_match!(class_part, u16, "IN" => 1u16, "CH" => 3u16, "HS" => 4u16)
.map_err(|_| {
Error::Config(format!(
"Invalid query class '{}': {}",
class_part, condition_str
))
})?;
qclasses.push(class_val);
}
if qclasses.is_empty() {
return Err(Error::Config(format!(
"No query classes specified: {}",
condition_str
)));
}
Ok(Arc::new(move |ctx: &crate::plugin::Context| {
if let Some(question) = ctx.request().questions().first() {
let qclass = question.qclass().to_u16();
qclasses.contains(&qclass)
} else {
false
}
}))
} else if condition_str.starts_with("rcode ") {
let rcode_str = condition_str.strip_prefix("rcode ").unwrap_or_default();
let mut rcodes = Vec::new();
for rcode_part in rcode_str.split_whitespace() {
let rcode_val = dns_type_match!(rcode_part, u8,
"NOERROR" => 0u8,
"FORMERR" | "FORMDERR" => 1u8,
"SERVFAIL" => 2u8,
"NXDOMAIN" | "NXDOM" => 3u8,
"NOTIMP" | "NOTIMPL" => 4u8,
"REFUSED" | "REFUSE" => 5u8,
"YXDOMAIN" | "YXDOM" => 6u8,
"YXRRSET" => 7u8,
"NXRRSET" => 8u8,
"NOTAUTH" | "NOTAUTHZ" => 9u8,
"NOTZONE" => 10u8
)
.map_err(|_| {
Error::Config(format!(
"Invalid response code '{}': {}",
rcode_part, condition_str
))
})?;
rcodes.push(rcode_val);
}
if rcodes.is_empty() {
return Err(Error::Config(format!(
"No response codes specified: {}",
condition_str
)));
}
Ok(Arc::new(move |ctx: &crate::plugin::Context| {
if let Some(response) = ctx.response() {
let rcode = response.response_code().to_u8();
rcodes.contains(&rcode)
} else {
false
}
}))
} else if condition_str == "has_cname" {
Ok(Arc::new(|ctx: &crate::plugin::Context| {
if let Some(response) = ctx.response() {
response
.answers()
.iter()
.any(|rr| rr.rtype() == crate::dns::types::RecordType::CNAME)
} else {
false
}
}))
} else {
Err(Error::Config(format!(
"Unknown condition: {}",
condition_str
)))
}
}
#[allow(clippy::items_after_test_module)]
#[cfg(test)]
mod tests {
use super::*;
use crate::dns::types::{RecordClass, RecordType, ResponseCode};
use crate::dns::{Message, Question, RData, ResourceRecord};
use crate::plugin::Context;
use serde_yaml::Mapping;
#[test]
fn test_plugin_builder_get_plugin() {
let mut builder = PluginBuilder::new();
assert!(builder.get_plugin("nonexistent").is_none());
let plugin: Arc<dyn Plugin> = Arc::new(crate::plugins::AcceptPlugin::new());
builder
.plugins
.insert("test_plugin".to_string(), plugin.clone());
let retrieved = builder.get_plugin("test_plugin");
assert!(retrieved.is_some());
assert_eq!(retrieved.unwrap().name(), "accept");
}
#[test]
fn test_plugin_builder_get_all_plugins() {
let mut builder = PluginBuilder::new();
assert_eq!(builder.get_all_plugins().len(), 0);
let plugin1: Arc<dyn Plugin> = Arc::new(crate::plugins::AcceptPlugin::new());
let plugin2: Arc<dyn Plugin> =
Arc::new(crate::plugins::flow::return_plugin::ReturnPlugin::new());
builder.plugins.insert("plugin1".to_string(), plugin1);
builder.plugins.insert("plugin2".to_string(), plugin2);
let all_plugins = builder.get_all_plugins();
assert_eq!(all_plugins.len(), 2);
let names: std::collections::HashSet<_> = all_plugins.iter().map(|p| p.name()).collect();
assert!(names.contains("accept"));
assert!(names.contains("return"));
}
#[test]
fn test_plugin_builder_into_registry() {
let mut builder = PluginBuilder::new();
let plugin: Arc<dyn Plugin> = Arc::new(crate::plugins::AcceptPlugin::new());
builder
.plugins
.insert("test_plugin".to_string(), plugin.clone());
let registry = builder.into_registry();
let retrieved = registry.get("test_plugin");
assert!(retrieved.is_some());
assert_eq!(retrieved.unwrap().name(), "accept");
}
#[test]
fn test_plugin_builder_get_registry() {
let mut builder = PluginBuilder::new();
let plugin: Arc<dyn Plugin> = Arc::new(crate::plugins::AcceptPlugin::new());
builder
.plugins
.insert("test_plugin".to_string(), plugin.clone());
let registry = builder.get_registry();
let retrieved = registry.get("test_plugin");
assert!(retrieved.is_some());
assert_eq!(retrieved.unwrap().name(), "accept");
assert!(builder.get_plugin("test_plugin").is_some());
}
#[test]
fn test_plugin_builder_get_server_plugin_tags() {
let mut builder = PluginBuilder::new();
assert_eq!(builder.get_server_plugin_tags().len(), 0);
builder.server_plugin_tags.push("doh_server".to_string());
builder.server_plugin_tags.push("dot_server".to_string());
let tags = builder.get_server_plugin_tags();
assert_eq!(tags.len(), 2);
assert!(tags.contains(&"doh_server".to_string()));
assert!(tags.contains(&"dot_server".to_string()));
}
#[tokio::test]
async fn test_plugin_builder_shutdown_all() {
let mut builder = PluginBuilder::new();
let plugin: Arc<dyn Plugin> = Arc::new(crate::plugins::AcceptPlugin::new());
builder
.plugins
.insert("test_plugin".to_string(), plugin.clone());
let result = builder.shutdown_all().await;
assert!(result.is_ok());
}
#[test]
fn test_build_cache_plugin() {
let mut builder = PluginBuilder::new();
let mut config_map = HashMap::new();
config_map.insert("size".to_string(), Value::Number(2048.into()));
let config = PluginConfig {
tag: Some("my_cache".to_string()),
plugin_type: "cache".to_string(),
args: Value::Mapping(Mapping::new()),
priority: 100,
config: config_map,
};
let plugin = builder.build(&config).unwrap();
assert_eq!(plugin.name(), "cache");
}
#[test]
fn test_build_forward_plugin() {
let mut builder = PluginBuilder::new();
let mut config_map = HashMap::new();
let upstreams = vec![
Value::String("udp://8.8.8.8:53".to_string()),
Value::String("tcp://1.1.1.1:53".to_string()),
];
config_map.insert("upstreams".to_string(), Value::Sequence(upstreams));
let config = PluginConfig {
tag: None,
plugin_type: "forward".to_string(),
args: Value::Mapping(Mapping::new()),
priority: 100,
config: config_map,
};
let plugin = builder.build(&config).unwrap();
assert_eq!(plugin.name(), "forward");
}
#[test]
fn test_build_forward_plugin_with_default_port() {
let mut builder = PluginBuilder::new();
let mut config_map = HashMap::new();
let upstreams = vec![Value::String("udp://119.29.29.29".to_string())];
config_map.insert("upstreams".to_string(), Value::Sequence(upstreams));
let config = PluginConfig {
tag: None,
plugin_type: "forward".to_string(),
args: Value::Mapping(Mapping::new()),
priority: 100,
config: config_map,
};
let plugin = builder.build(&config).unwrap();
assert_eq!(plugin.name(), "forward");
if let Some(fp) = plugin.as_ref().as_any().downcast_ref::<ForwardPlugin>() {
let addrs = fp.upstream_addrs();
assert_eq!(addrs.len(), 1);
assert_eq!(addrs[0], "119.29.29.29:53");
} else {
panic!("Failed to downcast plugin to ForwardPlugin");
}
}
#[test]
fn test_parse_condition_qtype_single_and_multiple() {
let builder = PluginBuilder::new();
let cond = parse_condition(&builder, "qtype 1").unwrap();
let mut req = Message::new();
req.add_question(Question::new(
"example.com".to_string(),
RecordType::A,
RecordClass::IN,
));
let ctx = Context::new(req);
assert!(cond(&ctx));
let cond2 = parse_condition(&builder, "qtype 1 28").unwrap();
let mut req2 = Message::new();
req2.add_question(Question::new(
"example.com".to_string(),
RecordType::AAAA,
RecordClass::IN,
));
let ctx2 = Context::new(req2);
assert!(cond2(&ctx2));
}
#[test]
fn test_parse_sequence_steps_simple_exec() {
let builder = PluginBuilder::new();
let sequence = vec![Value::Mapping({
let mut map = serde_yaml::Mapping::new();
map.insert(
Value::String("exec".to_string()),
Value::String("accept".to_string()),
);
map
})];
let steps = parse_sequence_steps(&builder, &sequence).unwrap();
assert_eq!(steps.len(), 1);
match &steps[0] {
crate::plugins::executable::SequenceStep::Exec(plugin) => {
assert_eq!(plugin.name(), "accept");
}
_ => panic!("Expected Exec step"),
}
}
#[test]
fn test_parse_sequence_steps_conditional() {
let builder = PluginBuilder::new();
let sequence = vec![Value::Mapping({
let mut map = serde_yaml::Mapping::new();
map.insert(
Value::String("matches".to_string()),
Value::String("has_resp".to_string()),
);
map.insert(
Value::String("exec".to_string()),
Value::String("accept".to_string()),
);
map
})];
let steps = parse_sequence_steps(&builder, &sequence).unwrap();
assert_eq!(steps.len(), 1);
match &steps[0] {
crate::plugins::executable::SequenceStep::If {
condition: _,
action,
desc,
} => {
assert_eq!(action.name(), "accept");
assert_eq!(desc, "has_resp");
}
_ => panic!("Expected If step"),
}
}
#[test]
fn test_parse_sequence_steps_invalid() {
let builder = PluginBuilder::new();
let sequence = vec![Value::Mapping(serde_yaml::Mapping::new())];
let result = parse_sequence_steps(&builder, &sequence);
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("sequence step must have exec or matches")
);
}
#[test]
fn test_parse_exec_action_plugin_reference() {
let mut builder = PluginBuilder::new();
let plugin: Arc<dyn Plugin> = Arc::new(crate::plugins::AcceptPlugin::new());
builder
.plugins
.insert("test_plugin".to_string(), plugin.clone());
let exec_value = Value::String("$test_plugin".to_string());
let result = parse_exec_action(&builder, &exec_value).unwrap();
assert_eq!(result.name(), "accept");
}
#[test]
fn test_parse_exec_action_exec_plugin() {
let builder = PluginBuilder::new();
let exec_value = Value::String("accept".to_string());
let result = parse_exec_action(&builder, &exec_value).unwrap();
assert_eq!(result.name(), "accept");
}
#[test]
fn test_parse_exec_action_unknown() {
let builder = PluginBuilder::new();
let exec_value = Value::String("unknown_action".to_string());
let result = parse_exec_action(&builder, &exec_value);
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("Unknown exec action")
);
}
#[test]
fn test_parse_exec_action_invalid_value_type() {
let builder = PluginBuilder::new();
let exec_value = Value::Number(42.into());
let result = parse_exec_action(&builder, &exec_value);
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("exec value must be string")
);
}
#[test]
fn test_derived_plugin_type_names() {
crate::plugin::factory::init();
let types = crate::plugin::factory::get_all_plugin_types();
assert!(types.contains(&"query_acl".to_string()));
assert!(types.contains(&"cache".to_string()));
#[cfg(feature = "cron")]
assert!(types.contains(&"cron".to_string()));
assert!(types.contains(&"domain_validator".to_string()));
assert!(types.contains(&"forward".to_string()));
assert!(types.contains(&"geo_ip".to_string()));
assert!(types.contains(&"geo_site".to_string()));
assert!(types.contains(&"arbitrary".to_string()));
assert!(types.contains(&"hosts".to_string()));
assert!(types.contains(&"domain_set".to_string()));
assert!(types.contains(&"ip_set".to_string()));
assert!(types.contains(&"dual_selector".to_string()));
assert!(types.contains(&"edns0_opt".to_string()));
assert!(types.contains(&"rate_limit".to_string()));
assert!(types.contains(&"redirect".to_string()));
assert!(types.contains(&"reverse_lookup".to_string()));
assert!(types.contains(&"ros_addrlist".to_string()));
assert!(types.contains(&"blackhole".to_string()));
let types = crate::plugin::factory::get_all_exec_plugin_types();
assert!(types.contains(&"blackhole".to_string()));
assert!(types.contains(&"debug_print".to_string()));
assert!(types.contains(&"drop_resp".to_string()));
assert!(types.contains(&"ecs".to_string()));
assert!(types.contains(&"fallback".to_string()));
assert!(types.contains(&"ipset".to_string()));
assert!(types.contains(&"mark".to_string()));
assert!(types.contains(&"nftset".to_string()));
assert!(types.contains(&"query_summary".to_string()));
assert!(types.contains(&"sleep".to_string()));
assert!(types.contains(&"ttl".to_string()));
assert!(types.contains(&"prefer_ipv4".to_string()));
assert!(types.contains(&"prefer_ipv6".to_string()));
#[cfg(feature = "metrics")]
{
assert!(types.contains(&"prom_metrics_collector".to_string()));
assert!(types.contains(&"metrics_collector".to_string()));
}
}
#[test]
fn test_derived_plugin_name_mapping() {
fn derive<T: 'static>() -> String {
let t = std::any::type_name::<T>();
let last = t.rsplit("::").next().unwrap_or(t);
let base = last.strip_suffix("Plugin").unwrap_or(last);
let mut s = String::new();
for (i, ch) in base.chars().enumerate() {
if ch.is_uppercase() {
if i != 0 {
s.push('_');
}
for lc in ch.to_lowercase() {
s.push(lc);
}
} else {
s.push(ch);
}
}
s
}
assert_eq!(
derive::<crate::plugins::executable::DropRespPlugin>(),
"drop_resp"
);
assert_eq!(derive::<crate::plugins::ForwardPlugin>(), "forward");
assert_eq!(derive::<crate::plugins::AcceptPlugin>(), "accept");
assert_eq!(
derive::<crate::plugins::flow::return_plugin::ReturnPlugin>(),
"return"
);
assert_eq!(derive::<crate::plugins::flow::jump::JumpPlugin>(), "jump");
assert_eq!(
derive::<crate::plugins::flow::reject::RejectPlugin>(),
"reject"
);
assert_eq!(
derive::<crate::plugins::flow::prefer_ipv4::PreferIpv4Plugin>(),
"prefer_ipv4"
);
assert_eq!(
derive::<crate::plugins::flow::prefer_ipv6::PreferIpv6Plugin>(),
"prefer_ipv6"
);
assert_eq!(derive::<crate::plugins::CachePlugin>(), "cache");
assert_eq!(
derive::<crate::plugins::dataset::DomainSetPlugin>(),
"domain_set"
);
}
#[test]
fn test_no_derived_name_collisions() {
fn derive<T: 'static>() -> String {
let t = std::any::type_name::<T>();
let last = t.rsplit("::").next().unwrap_or(t);
let base = last.strip_suffix("Plugin").unwrap_or(last);
let mut s = String::new();
for (i, ch) in base.chars().enumerate() {
if ch.is_uppercase() {
if i != 0 {
s.push('_');
}
for lc in ch.to_lowercase() {
s.push(lc);
}
} else {
s.push(ch);
}
}
s
}
let derived = vec![
derive::<crate::plugins::executable::DropRespPlugin>(),
derive::<crate::plugins::ForwardPlugin>(),
derive::<crate::plugins::AcceptPlugin>(),
derive::<crate::plugins::flow::return_plugin::ReturnPlugin>(),
derive::<crate::plugins::flow::jump::JumpPlugin>(),
derive::<crate::plugins::flow::reject::RejectPlugin>(),
derive::<crate::plugins::flow::prefer_ipv4::PreferIpv4Plugin>(),
derive::<crate::plugins::flow::prefer_ipv6::PreferIpv6Plugin>(),
derive::<crate::plugins::CachePlugin>(),
derive::<crate::plugins::dataset::DomainSetPlugin>(),
derive::<crate::plugins::geoip::GeoIpPlugin>(),
derive::<crate::plugins::geosite::GeoSitePlugin>(),
derive::<crate::plugins::HostsPlugin>(),
];
let set: std::collections::HashSet<_> = derived.iter().cloned().collect();
assert_eq!(
set.len(),
derived.len(),
"Derived plugin names collided: {:?}",
{
let mut counts = std::collections::HashMap::new();
for name in &derived {
*counts.entry(name.clone()).or_insert(0usize) += 1;
}
counts
.into_iter()
.filter_map(|(k, v)| if v > 1 { Some(k) } else { None })
.collect::<Vec<_>>()
}
);
}
#[tokio::test]
async fn test_build_redirect_from_string_rule_executes() {
let mut builder = PluginBuilder::new();
let mut args_map = Mapping::new();
args_map.insert(
Value::String("rules".to_string()),
Value::Sequence(vec![Value::String("example.com example.net".to_string())]),
);
let config = PluginConfig {
tag: Some("redirect_str".to_string()),
plugin_type: "redirect".to_string(),
args: Value::Mapping(args_map),
priority: 100,
config: HashMap::new(),
};
let plugin = builder.build(&config).expect("build redirect plugin");
assert_eq!(plugin.name(), "redirect");
let mut request = Message::new();
request.add_question(Question::new(
"example.com".to_string(),
RecordType::A,
RecordClass::IN,
));
let mut ctx = Context::new(request);
plugin.execute(&mut ctx).await.expect("execute");
let got = ctx
.request()
.questions()
.first()
.unwrap()
.qname()
.to_string();
assert_eq!(got, "example.net");
}
#[tokio::test]
async fn test_build_redirect_from_mapping_rule_executes() {
let mut builder = PluginBuilder::new();
let mut args_map = Mapping::new();
let mut rule_map = Mapping::new();
rule_map.insert(
Value::String("from".to_string()),
Value::String("foo.example".to_string()),
);
rule_map.insert(
Value::String("to".to_string()),
Value::String("bar.example".to_string()),
);
args_map.insert(
Value::String("rules".to_string()),
Value::Sequence(vec![Value::Mapping(rule_map)]),
);
let config = PluginConfig {
tag: Some("redirect_map".to_string()),
plugin_type: "redirect".to_string(),
args: Value::Mapping(args_map),
priority: 100,
config: HashMap::new(),
};
let plugin = builder.build(&config).expect("build redirect plugin");
assert_eq!(plugin.name(), "redirect");
let mut request = Message::new();
request.add_question(Question::new(
"foo.example".to_string(),
RecordType::A,
RecordClass::IN,
));
let mut ctx = Context::new(request);
plugin.execute(&mut ctx).await.expect("execute");
let got = ctx
.request()
.questions()
.first()
.unwrap()
.qname()
.to_string();
assert_eq!(got, "bar.example");
}
#[test]
fn test_build_plugin_type_case_insensitive() {
let mut builder = PluginBuilder::new();
let mut args_map = Mapping::new();
args_map.insert(
Value::String("rules".to_string()),
Value::Sequence(vec![Value::String("a b".to_string())]),
);
let config = PluginConfig {
tag: Some("redirect_upper".to_string()),
plugin_type: "Redirect".to_string(),
args: Value::Mapping(args_map),
priority: 100,
config: HashMap::new(),
};
let plugin = builder.build(&config).expect("build redirect plugin");
assert_eq!(plugin.name(), "redirect");
}
#[test]
fn test_fallback_resolves_children() {
let mut builder = PluginBuilder::new();
let primary_plugin = Arc::new(crate::plugins::flow::AcceptPlugin::new());
builder
.plugins
.insert("primary".to_string(), primary_plugin);
let secondary_plugin = Arc::new(crate::plugins::flow::AcceptPlugin::new());
builder
.plugins
.insert("secondary".to_string(), secondary_plugin);
let primary_cfg = PluginConfig {
tag: None,
plugin_type: "accept".to_string(),
args: Value::Mapping(Mapping::new()),
priority: 100,
config: HashMap::new(),
};
let secondary_cfg = PluginConfig {
tag: None,
plugin_type: "accept".to_string(),
args: Value::Mapping(Mapping::new()),
priority: 100,
config: HashMap::new(),
};
let mut args_map = Mapping::new();
args_map.insert(
Value::String("primary".to_string()),
Value::String("primary".to_string()),
);
args_map.insert(
Value::String("secondary".to_string()),
Value::String("secondary".to_string()),
);
let fb_cfg = PluginConfig {
tag: None,
plugin_type: "fallback".to_string(),
args: Value::Mapping(args_map),
priority: 100,
config: HashMap::new(),
};
builder.build(&fb_cfg).unwrap();
builder
.resolve_references(&[primary_cfg, secondary_cfg, fb_cfg])
.unwrap();
let plugin = builder
.get_plugin("fallback")
.expect("fallback plugin present");
if let Some(fp) = plugin
.as_ref()
.as_any()
.downcast_ref::<crate::plugins::executable::FallbackPlugin>()
{
assert_eq!(fp.resolved_child_count(), 2);
assert_eq!(fp.pending_child_count(), 0);
} else {
panic!("fallback plugin is wrong type");
}
}
#[test]
fn test_parse_condition_has_resp() {
let builder = PluginBuilder::new();
let condition = parse_condition(&builder, "has_resp").unwrap();
let ctx = Context::new(Message::new());
assert!(!condition(&ctx));
let mut ctx_with_resp = Context::new(Message::new());
ctx_with_resp.set_response(Some(Message::new()));
assert!(condition(&ctx_with_resp));
}
#[test]
fn test_parse_condition_resp_ip() {
let mut builder = PluginBuilder::new();
let ip_set_plugin = Arc::new(crate::plugins::dataset::IpSetPlugin::new("test_ip_set"));
builder
.plugins
.insert("test_ip_set".to_string(), ip_set_plugin);
let condition = parse_condition(&builder, "resp_ip $test_ip_set").unwrap();
let mut ctx = Context::new(Message::new());
ctx.set_response(Some(Message::new()));
assert!(!condition(&ctx));
}
#[test]
fn test_parse_condition_negated_resp_ip() {
let mut builder = PluginBuilder::new();
let ip_set_plugin = Arc::new(crate::plugins::dataset::IpSetPlugin::new("test_ip_set"));
builder
.plugins
.insert("test_ip_set".to_string(), ip_set_plugin);
let condition = parse_condition(&builder, "!resp_ip $test_ip_set").unwrap();
let mut ctx = Context::new(Message::new());
ctx.set_response(Some(Message::new()));
assert!(condition(&ctx));
}
#[test]
fn test_parse_condition_qname() {
let mut builder = PluginBuilder::new();
let domain_set_plugin = Arc::new(crate::plugins::dataset::DomainSetPlugin::new(
"test_domain_set",
));
builder
.plugins
.insert("test_domain_set".to_string(), domain_set_plugin);
let condition = parse_condition(&builder, "qname $test_domain_set").unwrap();
let mut req = Message::new();
req.add_question(Question::new(
"example.com".to_string(),
RecordType::A,
RecordClass::IN,
));
let ctx = Context::new(req);
assert!(!condition(&ctx));
}
#[test]
fn test_parse_condition_negated_qname() {
let builder = PluginBuilder::new();
let condition = parse_condition(&builder, "!qname example.com").unwrap();
let mut req = Message::new();
req.add_question(Question::new(
"example.com".to_string(),
RecordType::A,
RecordClass::IN,
));
let ctx = Context::new(req);
assert!(!condition(&ctx));
let mut req2 = Message::new();
req2.add_question(Question::new(
"other.com".to_string(),
RecordType::A,
RecordClass::IN,
));
let ctx2 = Context::new(req2);
assert!(condition(&ctx2));
}
#[test]
fn test_parse_condition_qtype_invalid() {
let builder = PluginBuilder::new();
assert!(parse_condition(&builder, "qtype").is_err());
assert!(parse_condition(&builder, "qtype abc").is_err());
assert!(parse_condition(&builder, "qtype 1 abc").is_err());
}
#[test]
fn test_parse_condition_qclass() {
let builder = PluginBuilder::new();
let condition = parse_condition(&builder, "qclass IN").unwrap();
let mut request = Message::new();
request.add_question(Question::new(
"example.com".to_string(),
RecordType::A,
RecordClass::IN,
));
let ctx = Context::new(request);
assert!(condition(&ctx));
let mut request = Message::new();
request.add_question(Question::new(
"example.com".to_string(),
RecordType::A,
RecordClass::CH,
));
let ctx = Context::new(request);
let condition = parse_condition(&builder, "qclass IN").unwrap();
assert!(!condition(&ctx));
let condition = parse_condition(&builder, "qclass IN CH").unwrap();
let mut request = Message::new();
request.add_question(Question::new(
"example.com".to_string(),
RecordType::A,
RecordClass::CH,
));
let ctx = Context::new(request);
assert!(condition(&ctx));
let condition = parse_condition(&builder, "qclass 1").unwrap();
let mut request = Message::new();
request.add_question(Question::new(
"example.com".to_string(),
RecordType::A,
RecordClass::IN,
));
let ctx = Context::new(request);
assert!(condition(&ctx));
}
#[test]
fn test_parse_condition_rcode() {
let builder = PluginBuilder::new();
let condition = parse_condition(&builder, "rcode NOERROR").unwrap();
let mut response = Message::new();
response.set_response_code(ResponseCode::NoError);
let mut ctx = Context::new(Message::new());
ctx.set_response(Some(response));
assert!(condition(&ctx));
let condition = parse_condition(&builder, "rcode NXDOMAIN").unwrap();
let mut response = Message::new();
response.set_response_code(ResponseCode::NXDomain);
let mut ctx = Context::new(Message::new());
ctx.set_response(Some(response));
assert!(condition(&ctx));
let condition = parse_condition(&builder, "rcode NOERROR NXDOMAIN").unwrap();
let mut response = Message::new();
response.set_response_code(ResponseCode::NXDomain);
let mut ctx = Context::new(Message::new());
ctx.set_response(Some(response));
assert!(condition(&ctx));
let condition = parse_condition(&builder, "rcode NOERROR").unwrap();
let mut response = Message::new();
response.set_response_code(ResponseCode::ServFail);
let mut ctx = Context::new(Message::new());
ctx.set_response(Some(response));
assert!(!condition(&ctx));
let condition = parse_condition(&builder, "rcode 3").unwrap();
let mut response = Message::new();
response.set_response_code(ResponseCode::NXDomain);
let mut ctx = Context::new(Message::new());
ctx.set_response(Some(response));
assert!(condition(&ctx));
}
#[test]
fn test_parse_condition_has_cname() {
let builder = PluginBuilder::new();
let condition = parse_condition(&builder, "has_cname").unwrap();
let mut response = Message::new();
response.add_answer(ResourceRecord::new(
"example.com".to_string(),
RecordType::CNAME,
RecordClass::IN,
300,
RData::CNAME("target.example.com".to_string()),
));
let mut ctx = Context::new(Message::new());
ctx.set_response(Some(response));
assert!(condition(&ctx));
let condition = parse_condition(&builder, "has_cname").unwrap();
let mut response = Message::new();
response.add_answer(ResourceRecord::new(
"example.com".to_string(),
RecordType::A,
RecordClass::IN,
300,
RData::A("192.168.1.1".parse().unwrap()),
));
let mut ctx = Context::new(Message::new());
ctx.set_response(Some(response));
assert!(!condition(&ctx));
let condition = parse_condition(&builder, "has_cname").unwrap();
let ctx = Context::new(Message::new());
assert!(!condition(&ctx));
}
#[test]
fn test_parse_condition_qclass_invalid() {
let builder = PluginBuilder::new();
let result = parse_condition(&builder, "qclass INVALID");
assert!(result.is_err());
let result = parse_condition(&builder, "qclass");
assert!(result.is_err());
}
#[test]
fn test_parse_condition_rcode_invalid() {
let builder = PluginBuilder::new();
let result = parse_condition(&builder, "rcode INVALID");
assert!(result.is_err());
let result = parse_condition(&builder, "rcode");
assert!(result.is_err());
}
#[test]
fn test_parse_condition_unknown() {
let builder = PluginBuilder::new();
let result = parse_condition(&builder, "unknown_condition");
assert!(result.is_err());
}
}