use crate::RegisterPlugin;
use crate::Result;
use crate::dns::ResponseCode;
use crate::plugin::{Context, Plugin};
use async_trait::async_trait;
use lru::LruCache;
use std::collections::HashSet;
use std::num::NonZeroUsize;
use std::sync::Arc;
use tokio::sync::RwLock;
use tracing::{debug, warn};
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum ValidationResult {
Valid,
InvalidChars,
InvalidLength,
InvalidFormat,
Blacklisted,
}
#[derive(Debug, RegisterPlugin)]
pub struct DomainValidatorPlugin {
strict_mode: bool,
cache: Arc<RwLock<LruCache<String, ValidationResult>>>,
blacklist: HashSet<String>,
}
impl DomainValidatorPlugin {
pub fn new(strict_mode: bool, cache_size: usize, blacklist: Vec<String>) -> Self {
let cache = if cache_size > 0 {
LruCache::new(NonZeroUsize::new(cache_size).unwrap())
} else {
LruCache::new(NonZeroUsize::new(1).unwrap()) };
#[cfg(feature = "metrics")]
{
crate::metrics::DNS_DOMAIN_VALIDATION_CACHE_SIZE.set(cache.len() as i64);
}
Self {
strict_mode,
cache: Arc::new(RwLock::new(cache)),
blacklist: blacklist.into_iter().collect(),
}
}
fn is_blacklisted(&self, domain: &str) -> bool {
self.blacklist.iter().any(|pattern| {
if let Some(suffix) = pattern.strip_prefix("*.") {
self.matches_suffix(domain, suffix)
} else {
self.matches_suffix(domain, pattern)
}
})
}
fn matches_suffix(&self, domain: &str, suffix: &str) -> bool {
domain == suffix
|| (domain.len() > suffix.len()
&& domain.ends_with(suffix)
&& domain.as_bytes()[domain.len() - suffix.len() - 1] == b'.')
}
pub fn validate_domain(&self, domain: &str) -> ValidationResult {
if self.is_blacklisted(domain) {
return ValidationResult::Blacklisted;
}
if domain.is_empty() || domain.len() > 253 {
return ValidationResult::InvalidLength;
}
if domain == "." {
return ValidationResult::Valid;
}
let labels: Vec<&str> = domain.split('.').collect();
for label in labels {
if label.is_empty() || label.len() > 63 {
return ValidationResult::InvalidLength;
}
let bytes = label.as_bytes();
if bytes.is_empty() {
return ValidationResult::InvalidLength;
}
if !bytes[0].is_ascii_alphanumeric() {
return ValidationResult::InvalidChars;
}
let last = bytes[bytes.len() - 1];
if !last.is_ascii_alphanumeric() {
return ValidationResult::InvalidChars;
}
if bytes.len() > 2 {
for &b in &bytes[1..bytes.len() - 1] {
if !b.is_ascii_alphanumeric() && b != b'-' {
return ValidationResult::InvalidChars;
}
}
}
if self.strict_mode && label.contains("--") {
return ValidationResult::InvalidFormat;
}
}
ValidationResult::Valid
}
}
#[async_trait]
impl Plugin for DomainValidatorPlugin {
async fn execute(&self, ctx: &mut Context) -> Result<()> {
#[cfg(feature = "metrics")]
let start = std::time::Instant::now();
let qname = ctx
.request()
.questions()
.first()
.map(|q| q.qname().to_string())
.unwrap_or_default();
{
let cache = self.cache.read().await;
if let Some(result) = cache.peek(&qname) {
#[cfg(feature = "metrics")]
{
crate::metrics::DNS_DOMAIN_VALIDATION_CACHE_HITS_TOTAL.inc();
let duration = start.elapsed().as_secs_f64();
crate::metrics::DNS_DOMAIN_VALIDATION_DURATION_SECONDS.observe(duration);
}
return handle_result(*result, &qname, ctx);
}
}
let result = self.validate_domain(&qname);
#[cfg(feature = "metrics")]
{
let result_label = match &result {
ValidationResult::Valid => "valid",
ValidationResult::InvalidChars => "invalid_chars",
ValidationResult::InvalidLength => "invalid_length",
ValidationResult::InvalidFormat => "invalid_format",
ValidationResult::Blacklisted => "blacklisted",
};
crate::metrics::DNS_DOMAIN_VALIDATION_TOTAL
.with_label_values(&[result_label])
.inc();
}
{
let mut cache = self.cache.write().await;
#[cfg(feature = "metrics")]
{
let evicted = cache.put(qname.clone(), result);
if evicted.is_some() {
crate::metrics::DNS_DOMAIN_VALIDATION_CACHE_EVICTIONS_TOTAL.inc();
}
crate::metrics::DNS_DOMAIN_VALIDATION_CACHE_SIZE.set(cache.len() as i64);
}
#[cfg(not(feature = "metrics"))]
{
cache.put(qname.clone(), result);
}
}
#[cfg(feature = "metrics")]
{
let duration = start.elapsed().as_secs_f64();
crate::metrics::DNS_DOMAIN_VALIDATION_DURATION_SECONDS.observe(duration);
}
handle_result(result, &qname, ctx)
}
fn name(&self) -> &str {
"domain_validator"
}
fn priority(&self) -> i32 {
2100 }
fn init(config: &crate::config::PluginConfig) -> Result<Arc<dyn Plugin>> {
let args = config.effective_args();
let strict_mode = args
.get("strict_mode")
.and_then(|v| v.as_bool())
.unwrap_or(true);
let cache_size = args
.get("cache_size")
.and_then(|v| v.as_u64())
.unwrap_or(1000) as usize;
let blacklist = args
.get("blacklist")
.and_then(|v| v.as_sequence())
.map(|seq| {
seq.iter()
.filter_map(|v| v.as_str().map(|s| s.to_string()))
.collect()
})
.unwrap_or_default();
Ok(Arc::new(Self::new(strict_mode, cache_size, blacklist)))
}
}
fn handle_result(result: ValidationResult, qname: &str, ctx: &mut Context) -> Result<()> {
match result {
ValidationResult::Valid => Ok(()),
ValidationResult::Blacklisted => {
warn!("Rejected blacklisted domain: {}", qname);
set_refused_response(ctx);
Ok(())
}
ValidationResult::InvalidChars => {
debug!("Rejected domain with invalid characters: {}", qname);
set_refused_response(ctx);
Ok(())
}
ValidationResult::InvalidLength => {
debug!("Rejected domain with invalid length: {}", qname);
set_refused_response(ctx);
Ok(())
}
ValidationResult::InvalidFormat => {
debug!("Rejected domain with invalid format: {}", qname);
set_refused_response(ctx);
Ok(())
}
}
}
fn set_refused_response(ctx: &mut Context) {
let mut response = crate::dns::Message::new();
response.set_id(ctx.request().id());
response.set_response(true);
response.set_response_code(ResponseCode::Refused);
ctx.set_response(Some(response));
}
impl Default for DomainValidatorPlugin {
fn default() -> Self {
Self::new(true, 1000, vec![])
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_valid_domains() {
let plugin = DomainValidatorPlugin::default();
assert_eq!(
plugin.validate_domain("example.com"),
ValidationResult::Valid
);
assert_eq!(
plugin.validate_domain("sub.example.co.uk"),
ValidationResult::Valid
);
assert_eq!(plugin.validate_domain("localhost"), ValidationResult::Valid);
assert_eq!(plugin.validate_domain("."), ValidationResult::Valid);
}
#[tokio::test]
async fn test_invalid_chars() {
let plugin = DomainValidatorPlugin::default();
assert_eq!(
plugin.validate_domain("test space.com"),
ValidationResult::InvalidChars
);
assert_eq!(
plugin.validate_domain("test@domain.com"),
ValidationResult::InvalidChars
);
assert_eq!(
plugin.validate_domain("-test.com"),
ValidationResult::InvalidChars
);
assert_eq!(
plugin.validate_domain("test-.com"),
ValidationResult::InvalidChars
);
}
#[tokio::test]
async fn test_single_char_labels() {
let plugin = DomainValidatorPlugin::default();
assert_eq!(plugin.validate_domain("a.com"), ValidationResult::Valid);
assert_eq!(plugin.validate_domain("a.b.com"), ValidationResult::Valid);
assert_eq!(plugin.validate_domain("x.y.z"), ValidationResult::Valid);
}
#[tokio::test]
async fn test_invalid_length() {
let plugin = DomainValidatorPlugin::default();
let long_label = "a".repeat(64) + ".com";
assert_eq!(
plugin.validate_domain(&long_label),
ValidationResult::InvalidLength
);
let long_domain = "a.".repeat(126) + "com";
assert_eq!(
plugin.validate_domain(&long_domain),
ValidationResult::InvalidLength
);
}
#[tokio::test]
async fn test_strict_mode() {
let strict_plugin = DomainValidatorPlugin::new(true, 1000, vec![]);
assert_eq!(
strict_plugin.validate_domain("te--st.com"),
ValidationResult::InvalidFormat
);
let lenient_plugin = DomainValidatorPlugin::new(false, 1000, vec![]);
assert_eq!(
lenient_plugin.validate_domain("te--st.com"),
ValidationResult::Valid
);
}
#[tokio::test]
async fn test_blacklist() {
let plugin = DomainValidatorPlugin::new(true, 1000, vec!["malicious.com".to_string()]);
assert_eq!(
plugin.validate_domain("malicious.com"),
ValidationResult::Blacklisted
);
assert_eq!(
plugin.validate_domain("sub.malicious.com"),
ValidationResult::Blacklisted
);
}
#[tokio::test]
async fn test_wildcard_blacklist() {
let plugin = DomainValidatorPlugin::new(
true,
1000,
vec!["*.blocked.org".to_string(), "*.test.invalid".to_string()],
);
assert_eq!(
plugin.validate_domain("blocked.org"),
ValidationResult::Blacklisted
);
assert_eq!(
plugin.validate_domain("sub.blocked.org"),
ValidationResult::Blacklisted
);
assert_eq!(
plugin.validate_domain("deep.sub.blocked.org"),
ValidationResult::Blacklisted
);
assert_eq!(
plugin.validate_domain("test.invalid"),
ValidationResult::Blacklisted
);
assert_eq!(
plugin.validate_domain("any.test.invalid"),
ValidationResult::Blacklisted
);
assert_eq!(
plugin.validate_domain("example.com"),
ValidationResult::Valid
);
assert_eq!(
plugin.validate_domain("blocked.com"),
ValidationResult::Valid
);
}
#[tokio::test]
async fn test_mixed_blacklist() {
let plugin = DomainValidatorPlugin::new(
true,
1000,
vec![
"exact.example.com".to_string(),
"*.wildcard.com".to_string(),
"suffix.org".to_string(),
],
);
assert_eq!(
plugin.validate_domain("exact.example.com"),
ValidationResult::Blacklisted
);
assert_eq!(
plugin.validate_domain("wildcard.com"),
ValidationResult::Blacklisted
);
assert_eq!(
plugin.validate_domain("sub.wildcard.com"),
ValidationResult::Blacklisted
);
assert_eq!(
plugin.validate_domain("suffix.org"),
ValidationResult::Blacklisted
);
assert_eq!(
plugin.validate_domain("sub.suffix.org"),
ValidationResult::Blacklisted
);
assert_eq!(
plugin.validate_domain("example.com"),
ValidationResult::Valid
);
}
#[tokio::test]
async fn test_cache() {
use crate::dns::{Message, Question, RecordClass, RecordType};
let plugin = DomainValidatorPlugin::new(true, 10, vec![]);
let mut request = Message::new();
request.add_question(Question::new(
"example.com".parse().unwrap(),
RecordType::A,
RecordClass::IN,
));
let mut ctx = Context::new(request);
let result = plugin.execute(&mut ctx).await;
assert!(result.is_ok());
assert!(ctx.response().is_none());
{
let cache = plugin.cache.write().await;
assert!(cache.contains("example.com"));
}
}
#[tokio::test]
async fn test_consecutive_dots() {
let plugin = DomainValidatorPlugin::default();
assert_eq!(
plugin.validate_domain("example..com"),
ValidationResult::InvalidLength
);
assert_eq!(
plugin.validate_domain("sub..domain.example.com"),
ValidationResult::InvalidLength
);
assert_eq!(
plugin.validate_domain("..."),
ValidationResult::InvalidLength
);
}
#[tokio::test]
async fn test_domains_starting_with_dot() {
let plugin = DomainValidatorPlugin::default();
assert_eq!(
plugin.validate_domain(".example.com"),
ValidationResult::InvalidLength
);
assert_eq!(
plugin.validate_domain(".com"),
ValidationResult::InvalidLength
);
}
#[tokio::test]
async fn test_domains_ending_with_dot() {
let plugin = DomainValidatorPlugin::default();
assert_eq!(
plugin.validate_domain("example.com."),
ValidationResult::InvalidLength
);
assert_eq!(
plugin.validate_domain("localhost."),
ValidationResult::InvalidLength
);
}
#[tokio::test]
async fn test_empty_string() {
let plugin = DomainValidatorPlugin::default();
assert_eq!(plugin.validate_domain(""), ValidationResult::InvalidLength);
}
}