use crate::plugin::{Context, ExecPlugin, Plugin};
use crate::{RegisterExecPlugin, Result};
use async_trait::async_trait;
use serde::Deserialize;
use std::fmt;
use std::net::IpAddr;
use std::sync::Arc;
#[derive(Deserialize, Clone)]
pub struct EcsArgs {
pub forward: Option<bool>,
pub send: Option<bool>,
pub preset: Option<String>,
pub mask4: Option<u8>,
pub mask6: Option<u8>,
}
#[derive(Clone, RegisterExecPlugin)]
pub struct EcsPlugin {
forward: bool,
send: bool,
preset: Option<IpAddr>,
mask4: u8,
mask6: u8,
}
impl EcsPlugin {
pub fn new(args: EcsArgs) -> Result<Self> {
let forward = args.forward.unwrap_or(false);
let send = args.send.unwrap_or(false);
let mask4 = args.mask4.unwrap_or(24);
let mask6 = args.mask6.unwrap_or(48);
if mask4 > 32 || mask6 > 128 {
return Err(crate::Error::Other("invalid mask".into()));
}
let preset = if let Some(p) = args.preset {
match p.parse::<IpAddr>() {
Ok(ip) => Some(ip),
Err(e) => return Err(crate::Error::Other(format!("invalid preset addr: {}", e))),
}
} else {
None
};
Ok(Self {
forward,
send,
preset,
mask4,
mask6,
})
}
#[allow(clippy::manual_div_ceil)]
fn make_ecs_option(ip: IpAddr, mask4: u8, mask6: u8) -> Option<(u16, Vec<u8>)> {
let code = 8u16;
match ip {
IpAddr::V4(v4) => {
let family = 1u16.to_be_bytes();
let src_mask = mask4.min(32);
let scope = 0u8;
let octets = v4.octets();
let nbytes = ((src_mask as usize) + 7) / 8;
let mut data = Vec::with_capacity(4 + nbytes);
data.extend_from_slice(&family);
data.push(src_mask);
data.push(scope);
data.extend_from_slice(&octets[..nbytes]);
Some((code, data))
}
IpAddr::V6(v6) => {
let family = 2u16.to_be_bytes();
let src_mask = mask6.min(128);
let scope = 0u8;
let octets = v6.octets();
let nbytes = ((src_mask as usize) + 7) / 8;
let mut data = Vec::with_capacity(4 + nbytes);
data.extend_from_slice(&family);
data.push(src_mask);
data.push(scope);
data.extend_from_slice(&octets[..nbytes]);
Some((code, data))
}
}
}
}
impl fmt::Debug for EcsPlugin {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("EcsPlugin")
.field("forward", &self.forward)
.field("send", &self.send)
.field("preset", &self.preset)
.finish()
}
}
#[async_trait]
impl Plugin for EcsPlugin {
fn name(&self) -> &str {
"ecs"
}
async fn execute(&self, ctx: &mut Context) -> Result<()> {
if self.forward
&& let Some(options) = ctx.get_metadata::<Vec<(u16, Vec<u8>)>>("client_edns0_options")
{
ctx.set_metadata("edns0_options", options.clone());
ctx.set_metadata("edns0_preserve_existing", true);
return Ok(());
}
if let Some(ip) = &self.preset {
if let Some((code, data)) = EcsPlugin::make_ecs_option(*ip, self.mask4, self.mask6) {
let opt = vec![(code, data)];
ctx.set_metadata("edns0_options", opt);
ctx.set_metadata("edns0_preserve_existing", true);
}
return Ok(());
}
if self.send
&& let Some(addr) = ctx.get_metadata::<String>("client_addr")
&& let Ok(ip) = addr.parse::<IpAddr>()
&& let Some((code, data)) = EcsPlugin::make_ecs_option(ip, self.mask4, self.mask6)
{
let opt = vec![(code, data)];
ctx.set_metadata("edns0_options", opt);
ctx.set_metadata("edns0_preserve_existing", true);
}
Ok(())
}
}
impl ExecPlugin for EcsPlugin {
fn quick_setup(prefix: &str, exec_str: &str) -> Result<Arc<dyn Plugin>> {
if prefix != "ecs" {
return Err(crate::Error::Config(format!(
"ExecPlugin quick_setup: unsupported prefix '{}', expected 'ecs'",
prefix
)));
}
let mut forward = None;
let mut send = None;
let mut preset = None;
let mut mask4 = None;
let mut mask6 = None;
for part in exec_str.split(',') {
let part = part.trim();
if part.is_empty() {
continue;
}
let kv: Vec<&str> = part.splitn(2, '=').collect();
if kv.len() != 2 {
return Err(crate::Error::Config(format!(
"Invalid key=value pair: '{}'",
part
)));
}
let key = kv[0].trim();
let value = kv[1].trim();
match key {
"forward" => {
forward = Some(value.parse::<bool>().map_err(|_| {
crate::Error::Config(format!("Invalid boolean for forward: '{}'", value))
})?);
}
"send" => {
send = Some(value.parse::<bool>().map_err(|_| {
crate::Error::Config(format!("Invalid boolean for send: '{}'", value))
})?);
}
"preset" => {
preset = Some(value.to_string());
}
"mask4" => {
mask4 = Some(value.parse::<u8>().map_err(|_| {
crate::Error::Config(format!("Invalid u8 for mask4: '{}'", value))
})?);
}
"mask6" => {
mask6 = Some(value.parse::<u8>().map_err(|_| {
crate::Error::Config(format!("Invalid u8 for mask6: '{}'", value))
})?);
}
_ => {
return Err(crate::Error::Config(format!("Unknown option: '{}'", key)));
}
}
}
let args = EcsArgs {
forward,
send,
preset,
mask4,
mask6,
};
let plugin = EcsPlugin::new(args)?;
Ok(Arc::new(plugin))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dns::Message;
use std::net::Ipv6Addr;
#[tokio::test]
async fn test_ecs_preset_v4() {
let args = EcsArgs {
forward: None,
send: None,
preset: Some("192.0.2.5".to_string()),
mask4: Some(24),
mask6: None,
};
let plugin = EcsPlugin::new(args).unwrap();
let req = Message::new();
let mut ctx = Context::new(req);
plugin.execute(&mut ctx).await.unwrap();
let opts = ctx.get_metadata::<Vec<(u16, Vec<u8>)>>("edns0_options");
assert!(opts.is_some());
}
#[tokio::test]
async fn test_ecs_forward_copies_client_options() {
let args = EcsArgs {
forward: Some(true),
send: None,
preset: None,
mask4: None,
mask6: None,
};
let plugin = EcsPlugin::new(args).unwrap();
let req = Message::new();
let mut ctx = Context::new(req);
let client_opts: Vec<(u16, Vec<u8>)> = vec![(8u16, vec![0, 1, 24, 0, 192, 0, 2])];
ctx.set_metadata("client_edns0_options", client_opts.clone());
plugin.execute(&mut ctx).await.unwrap();
let got = ctx.get_metadata::<Vec<(u16, Vec<u8>)>>("edns0_options");
assert!(got.is_some());
assert_eq!(got.unwrap(), &client_opts);
}
#[tokio::test]
async fn test_ecs_send_derives_from_client_addr() {
let args = EcsArgs {
forward: None,
send: Some(true),
preset: None,
mask4: Some(24),
mask6: Some(56),
};
let plugin = EcsPlugin::new(args).unwrap();
let req = Message::new();
let mut ctx = Context::new(req);
ctx.set_metadata("client_addr", "192.0.2.7".to_string());
plugin.execute(&mut ctx).await.unwrap();
let got = ctx.get_metadata::<Vec<(u16, Vec<u8>)>>("edns0_options");
assert!(got.is_some());
let _ = EcsPlugin::make_ecs_option(IpAddr::V6(Ipv6Addr::LOCALHOST), 24, 56);
}
#[test]
fn test_ecs_quick_setup() {
let plugin = EcsPlugin::quick_setup("ecs", "forward=true").unwrap();
assert_eq!(plugin.name(), "ecs");
let plugin = EcsPlugin::quick_setup("ecs", "send=true,mask4=20,mask6=40").unwrap();
assert_eq!(plugin.name(), "ecs");
let plugin = EcsPlugin::quick_setup("ecs", "preset=192.0.2.1").unwrap();
assert_eq!(plugin.name(), "ecs");
assert!(EcsPlugin::quick_setup("invalid", "forward=true").is_err());
assert!(EcsPlugin::quick_setup("ecs", "invalid=true").is_err());
}
}