use crate::plugin::{Context, Plugin};
use crate::{RegisterPlugin, Result};
use async_trait::async_trait;
use std::fmt;
use tracing::debug;
#[derive(RegisterPlugin)]
pub struct RedirectPlugin {
from_domain: String,
to_domain: String,
}
impl RedirectPlugin {
pub fn new(from_domain: impl Into<String>, to_domain: impl Into<String>) -> Self {
Self {
from_domain: from_domain.into(),
to_domain: to_domain.into(),
}
}
fn matches(&self, qname: &str) -> bool {
let from_lower = self.from_domain.to_lowercase();
let qname_lower = qname.to_lowercase();
if let Some(suffix) = from_lower.strip_prefix("*.") {
qname_lower.ends_with(suffix) || qname_lower == suffix
} else {
qname_lower == from_lower
}
}
fn redirect(&self, qname: &str) -> String {
let from_lower = self.from_domain.to_lowercase();
let qname_lower = qname.to_lowercase();
let to_lower = self.to_domain.to_lowercase();
if let (Some(from_suffix), Some(to_suffix)) =
(from_lower.strip_prefix("*."), to_lower.strip_prefix("*."))
{
if let Some(mut prefix) = qname_lower.strip_suffix(from_suffix) {
if prefix.ends_with('.') && to_suffix.starts_with('.') {
prefix = &prefix[..prefix.len() - 1];
}
return format!("{}{}", prefix, to_suffix);
}
}
self.to_domain.clone()
}
}
impl fmt::Debug for RedirectPlugin {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("RedirectPlugin")
.field("from_domain", &self.from_domain)
.field("to_domain", &self.to_domain)
.finish()
}
}
#[async_trait]
impl Plugin for RedirectPlugin {
fn name(&self) -> &str {
"redirect"
}
fn init(config: &crate::config::types::PluginConfig) -> Result<std::sync::Arc<dyn Plugin>> {
use serde_yaml::Value;
use std::sync::Arc;
let args = config.effective_args();
if let Some(Value::Sequence(seq)) = args.get("rules") {
if seq.is_empty() {
return Err(crate::Error::Config(
"redirect requires at least one rule".to_string(),
));
}
let first = &seq[0];
if let Value::String(s) = first {
let parts: Vec<&str> = s.split_whitespace().collect();
if parts.len() == 2 {
Ok(Arc::new(RedirectPlugin::new(
parts[0].to_string(),
parts[1].to_string(),
)))
} else {
Err(crate::Error::Config(
"redirect rule must be 'from to'".to_string(),
))
}
} else if let Value::Mapping(map) = first {
let from = map
.get(Value::String("from".to_string()))
.and_then(|v| v.as_str())
.ok_or_else(|| {
crate::Error::Config("redirect rule mapping missing 'from'".to_string())
})?;
let to = map
.get(Value::String("to".to_string()))
.and_then(|v| v.as_str())
.ok_or_else(|| {
crate::Error::Config("redirect rule mapping missing 'to'".to_string())
})?;
Ok(Arc::new(RedirectPlugin::new(
from.to_string(),
to.to_string(),
)))
} else {
Err(crate::Error::Config(
"unsupported redirect rule format".to_string(),
))
}
} else {
Err(crate::Error::Config(
"redirect plugin requires 'rules' array".to_string(),
))
}
}
async fn execute(&self, ctx: &mut Context) -> Result<()> {
let request = ctx.request_mut();
if let Some(question) = request.questions_mut().first_mut() {
let qname = question.qname().to_string();
if self.matches(&qname) {
let new_qname = self.redirect(&qname);
debug!("Redirecting query from {} to {}", qname, new_qname);
question.set_qname(new_qname);
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dns::types::{RecordClass, RecordType};
use crate::dns::{Message, Question};
#[tokio::test]
async fn test_redirect_exact() {
let plugin = RedirectPlugin::new("example.com", "example.net");
let mut request = Message::new();
request.add_question(Question::new(
"example.com".to_string(),
RecordType::A,
RecordClass::IN,
));
let mut ctx = Context::new(request);
plugin.execute(&mut ctx).await.unwrap();
let request = ctx.request();
assert_eq!(request.questions().first().unwrap().qname(), "example.net");
}
#[tokio::test]
async fn test_redirect_wildcard() {
let plugin = RedirectPlugin::new("*.old.com", "*.new.com");
let mut request = Message::new();
request.add_question(Question::new(
"www.old.com".to_string(),
RecordType::A,
RecordClass::IN,
));
let mut ctx = Context::new(request);
plugin.execute(&mut ctx).await.unwrap();
let request = ctx.request();
assert_eq!(request.questions().first().unwrap().qname(), "www.new.com");
}
#[tokio::test]
async fn test_redirect_no_match() {
let plugin = RedirectPlugin::new("example.com", "example.net");
let mut request = Message::new();
request.add_question(Question::new(
"different.com".to_string(),
RecordType::A,
RecordClass::IN,
));
let mut ctx = Context::new(request);
plugin.execute(&mut ctx).await.unwrap();
let request = ctx.request();
assert_eq!(
request.questions().first().unwrap().qname(),
"different.com"
);
}
#[tokio::test]
async fn test_redirect_case_insensitive() {
let plugin = RedirectPlugin::new("Example.COM", "example.net");
let mut request = Message::new();
request.add_question(Question::new(
"EXAMPLE.com".to_string(),
RecordType::A,
RecordClass::IN,
));
let mut ctx = Context::new(request);
plugin.execute(&mut ctx).await.unwrap();
let request = ctx.request();
assert_eq!(request.questions().first().unwrap().qname(), "example.net");
}
}