use crate::dns::types::RecordType;
use crate::plugin::{Context, Plugin};
use crate::{RegisterPlugin, Result};
use async_trait::async_trait;
use std::fmt;
use std::sync::Arc;
use tracing::debug;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum IpPreference {
IPv4,
IPv6,
IPv4PreferIPv6Fallback,
IPv6PreferIPv4Fallback,
Both,
}
#[derive(RegisterPlugin)]
pub struct DualSelectorPlugin {
preference: IpPreference,
}
impl DualSelectorPlugin {
pub fn new(preference: IpPreference) -> Self {
Self { preference }
}
pub fn prefer_ipv4() -> Self {
Self {
preference: IpPreference::IPv4,
}
}
pub fn prefer_ipv6() -> Self {
Self {
preference: IpPreference::IPv6,
}
}
}
impl fmt::Debug for DualSelectorPlugin {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("DualSelectorPlugin")
.field("preference", &self.preference)
.finish()
}
}
#[async_trait]
impl Plugin for DualSelectorPlugin {
fn name(&self) -> &str {
"dual_selector"
}
fn init(config: &crate::config::types::PluginConfig) -> Result<std::sync::Arc<dyn Plugin>> {
let args = config.effective_args();
let pref_str = args
.get("preference")
.and_then(|v| v.as_str())
.unwrap_or("both")
.to_lowercase()
.replace('-', "_");
let preference = match pref_str.as_str() {
"ipv4" => IpPreference::IPv4,
"ipv6" => IpPreference::IPv6,
"ipv4_prefer_ipv6_fallback" | "ipv4_prefer_ipv6" => {
IpPreference::IPv4PreferIPv6Fallback
}
"ipv6_prefer_ipv4_fallback" | "ipv6_prefer_ipv4" => {
IpPreference::IPv6PreferIPv4Fallback
}
"both" => IpPreference::Both,
other => {
return Err(crate::Error::Config(format!(
"dual_selector: unknown preference '{}'",
other
)));
}
};
Ok(Arc::new(DualSelectorPlugin::new(preference)))
}
async fn execute(&self, ctx: &mut Context) -> Result<()> {
if let Some(response) = ctx.response_mut() {
let answers = response.answers_mut();
let has_ipv4 = answers.iter().any(|r| r.rtype() == RecordType::A);
let has_ipv6 = answers.iter().any(|r| r.rtype() == RecordType::AAAA);
match self.preference {
IpPreference::IPv4 => {
answers.retain(|r| r.rtype() != RecordType::AAAA);
debug!("Dual selector: Keeping only IPv4 records");
}
IpPreference::IPv6 => {
answers.retain(|r| r.rtype() != RecordType::A);
debug!("Dual selector: Keeping only IPv6 records");
}
IpPreference::IPv4PreferIPv6Fallback => {
if has_ipv4 {
answers.retain(|r| r.rtype() != RecordType::AAAA);
debug!("Dual selector: Preferring IPv4, removing IPv6");
} else {
debug!("Dual selector: No IPv4, keeping IPv6");
}
}
IpPreference::IPv6PreferIPv4Fallback => {
if has_ipv6 {
answers.retain(|r| r.rtype() != RecordType::A);
debug!("Dual selector: Preferring IPv6, removing IPv4");
} else {
debug!("Dual selector: No IPv6, keeping IPv4");
}
}
IpPreference::Both => {
debug!("Dual selector: Keeping both IPv4 and IPv6");
}
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dns::types::RecordClass;
use crate::dns::{Message, RData, ResourceRecord};
#[tokio::test]
async fn test_dual_selector_ipv4_only() {
let plugin = DualSelectorPlugin::prefer_ipv4();
let request = Message::new();
let mut ctx = Context::new(request);
let mut response = Message::new();
response.add_answer(ResourceRecord::new(
"example.com".to_string(),
RecordType::A,
RecordClass::IN,
300,
RData::A("192.0.2.1".parse().unwrap()),
));
response.add_answer(ResourceRecord::new(
"example.com".to_string(),
RecordType::AAAA,
RecordClass::IN,
300,
RData::AAAA("2001:db8::1".parse().unwrap()),
));
ctx.set_response(Some(response));
plugin.execute(&mut ctx).await.unwrap();
let response = ctx.response().unwrap();
assert_eq!(response.answers().len(), 1);
assert_eq!(response.answers()[0].rtype(), RecordType::A);
}
#[tokio::test]
async fn test_dual_selector_ipv6_only() {
let plugin = DualSelectorPlugin::prefer_ipv6();
let request = Message::new();
let mut ctx = Context::new(request);
let mut response = Message::new();
response.add_answer(ResourceRecord::new(
"example.com".to_string(),
RecordType::A,
RecordClass::IN,
300,
RData::A("192.0.2.1".parse().unwrap()),
));
response.add_answer(ResourceRecord::new(
"example.com".to_string(),
RecordType::AAAA,
RecordClass::IN,
300,
RData::AAAA("2001:db8::1".parse().unwrap()),
));
ctx.set_response(Some(response));
plugin.execute(&mut ctx).await.unwrap();
let response = ctx.response().unwrap();
assert_eq!(response.answers().len(), 1);
assert_eq!(response.answers()[0].rtype(), RecordType::AAAA);
}
#[tokio::test]
async fn test_dual_selector_ipv4_prefer_ipv6_fallback() {
let plugin = DualSelectorPlugin::new(IpPreference::IPv4PreferIPv6Fallback);
let request = Message::new();
let mut ctx = Context::new(request);
let mut response = Message::new();
response.add_answer(ResourceRecord::new(
"example.com".to_string(),
RecordType::AAAA,
RecordClass::IN,
300,
RData::AAAA("2001:db8::1".parse().unwrap()),
));
ctx.set_response(Some(response));
plugin.execute(&mut ctx).await.unwrap();
let response = ctx.response().unwrap();
assert_eq!(response.answers().len(), 1);
assert_eq!(response.answers()[0].rtype(), RecordType::AAAA);
}
#[tokio::test]
async fn test_dual_selector_both() {
let plugin = DualSelectorPlugin::new(IpPreference::Both);
let request = Message::new();
let mut ctx = Context::new(request);
let mut response = Message::new();
response.add_answer(ResourceRecord::new(
"example.com".to_string(),
RecordType::A,
RecordClass::IN,
300,
RData::A("192.0.2.1".parse().unwrap()),
));
response.add_answer(ResourceRecord::new(
"example.com".to_string(),
RecordType::AAAA,
RecordClass::IN,
300,
RData::AAAA("2001:db8::1".parse().unwrap()),
));
ctx.set_response(Some(response));
plugin.execute(&mut ctx).await.unwrap();
let response = ctx.response().unwrap();
assert_eq!(response.answers().len(), 2);
}
}