use crate::config::PluginConfig;
use crate::plugin::{Context, Plugin};
use crate::{RegisterPlugin, Result};
use async_trait::async_trait;
use std::sync::Arc;
use reqwest::StatusCode;
use serde_json::json;
use std::fmt;
use std::net::IpAddr;
use std::time::Duration;
use tracing::{debug, warn};
#[derive(RegisterPlugin)]
pub struct RosAddrlistPlugin {
list_name: String,
track_responses: bool,
server: Option<String>,
user: Option<String>,
passwd: Option<String>,
mask4: Option<u8>,
mask6: Option<u8>,
}
impl RosAddrlistPlugin {
pub fn new(list_name: impl Into<String>) -> Self {
Self {
list_name: list_name.into(),
track_responses: true,
server: None,
user: None,
passwd: None,
mask4: None,
mask6: None,
}
}
pub fn track_responses(mut self, enabled: bool) -> Self {
self.track_responses = enabled;
self
}
pub fn with_server(mut self, server: impl Into<String>) -> Self {
self.server = Some(server.into());
self
}
pub fn with_auth(mut self, user: impl Into<String>, passwd: impl Into<String>) -> Self {
self.user = Some(user.into());
self.passwd = Some(passwd.into());
self
}
pub fn with_masks(mut self, mask4: Option<u8>, mask6: Option<u8>) -> Self {
self.mask4 = mask4;
self.mask6 = mask6;
self
}
pub fn with_mask4(mut self, mask4: u8) -> Self {
self.mask4 = Some(mask4);
self
}
pub fn with_mask6(mut self, mask6: u8) -> Self {
self.mask6 = Some(mask6);
self
}
fn extract_ips(&self, ctx: &Context) -> Vec<IpAddr> {
let mut ips = Vec::new();
if let Some(response) = ctx.response() {
for answer in response.answers() {
if let Some(ip) = self.extract_ip_from_rdata(answer.rdata()) {
ips.push(ip);
}
}
}
ips
}
fn extract_ip_from_rdata(&self, rdata: &crate::dns::RData) -> Option<IpAddr> {
use crate::dns::RData;
match rdata {
RData::A(addr) => Some(IpAddr::V4(*addr)),
RData::AAAA(addr) => Some(IpAddr::V6(*addr)),
_ => None,
}
}
async fn notify_server(&self, ips: &[IpAddr], domain: &str) -> Result<()> {
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(2))
.danger_accept_invalid_certs(true)
.build()
.map_err(|e| crate::Error::Other(format!("failed build http client: {}", e)))?;
if let Some(srv) = &self.server {
for ip in ips {
let v6 = matches!(ip, IpAddr::V6(_));
let kind = if v6 { "ipv6" } else { "ip" };
let router_url = format!(
"{}/rest/{}/firewall/address-list/add",
srv.trim_end_matches('/'),
kind
);
let payload = json!({
"address": ip.to_string(),
"list": self.list_name,
"comment": format!("[lazydns] domain: {}", domain),
});
let mut req = client.post(&router_url).json(&payload);
if let (Some(user), Some(pass)) = (&self.user, &self.passwd) {
req = req.basic_auth(user, Some(pass));
}
let resp = req
.send()
.await
.map_err(|e| crate::Error::Other(format!("http request failed: {}", e)))?;
match resp.status() {
StatusCode::OK => {
debug!(ip = %ip, list = %self.list_name, domain = %domain, "added ip to ros addrlist")
}
StatusCode::BAD_REQUEST => {
debug!(ip = %ip, list = %self.list_name, domain = %domain, "likely ip already exists")
}
StatusCode::UNAUTHORIZED => {
return Err(crate::Error::Other(format!(
"unauthorized when adding {}",
ip
)));
}
StatusCode::INTERNAL_SERVER_ERROR => {
return Err(crate::Error::Other(format!(
"internal server error when adding {}",
ip
)));
}
s => {
return Err(crate::Error::Other(format!(
"unexpected status code {} when adding {}",
s, ip
)));
}
}
}
}
Ok(())
}
}
impl fmt::Debug for RosAddrlistPlugin {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("RosAddrListPlugin")
.field("list_name", &self.list_name)
.field("track_responses", &self.track_responses)
.field("server", &self.server)
.field("user", &self.user)
.field("mask4", &self.mask4)
.field("mask6", &self.mask6)
.finish()
}
}
#[async_trait]
impl Plugin for RosAddrlistPlugin {
fn name(&self) -> &str {
"ros_addrlist"
}
async fn execute(&self, ctx: &mut Context) -> Result<()> {
if !self.track_responses {
return Ok(());
}
let ips = self.extract_ips(ctx);
if !ips.is_empty() {
let domain = if let Some(question) = ctx.request().questions().first() {
question.qname().trim_end_matches('.').to_string()
} else {
"".to_string()
};
debug!(
list_name = %self.list_name,
domain = %domain,
ip_count = ips.len(),
ips = ?ips,
"RouterOS address list: add IPs"
);
if let Err(e) = self.notify_server(&ips, &domain).await {
warn!(error = %e, domain = %domain, "Failed to notify RouterOS helper server");
}
}
Ok(())
}
fn init(config: &PluginConfig) -> Result<Arc<dyn Plugin>> {
let args = config.effective_args();
let addrlist = args
.get("addrlist")
.and_then(|v| v.as_str())
.unwrap_or("default");
let mut plugin = RosAddrlistPlugin::new(addrlist);
if let Some(server) = args.get("server").and_then(|v| v.as_str()) {
plugin = plugin.with_server(server.to_string());
}
if let Some(user) = args.get("user").and_then(|v| v.as_str())
&& let Some(pass) = args.get("passwd").and_then(|v| v.as_str())
{
plugin = plugin.with_auth(user.to_string(), pass.to_string());
}
if let Some(mask4) = args.get("mask4").and_then(|v| v.as_i64()) {
plugin = plugin.with_mask4(mask4 as u8);
}
if let Some(mask6) = args.get("mask6").and_then(|v| v.as_i64()) {
plugin = plugin.with_mask6(mask6 as u8);
}
Ok(Arc::new(plugin))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dns::types::{RecordClass, RecordType};
use crate::dns::{Message, RData, ResourceRecord};
#[tokio::test]
async fn test_ros_addrlist_extract_ips() {
let plugin = RosAddrlistPlugin::new("test_list");
let mut ctx = Context::new(Message::new());
let mut response = Message::new();
response.add_answer(ResourceRecord::new(
"example.com".to_string(),
RecordType::A,
RecordClass::IN,
300,
RData::A("192.0.2.1".parse().unwrap()),
));
response.add_answer(ResourceRecord::new(
"example.com".to_string(),
RecordType::AAAA,
RecordClass::IN,
300,
RData::AAAA("2001:db8::1".parse().unwrap()),
));
ctx.set_response(Some(response));
plugin.execute(&mut ctx).await.unwrap();
}
#[tokio::test]
async fn test_ros_addrlist_disabled() {
let plugin = RosAddrlistPlugin::new("test_list").track_responses(false);
let mut ctx = Context::new(Message::new());
let mut response = Message::new();
response.add_answer(ResourceRecord::new(
"example.com".to_string(),
RecordType::A,
RecordClass::IN,
300,
RData::A("192.0.2.1".parse().unwrap()),
));
ctx.set_response(Some(response));
plugin.execute(&mut ctx).await.unwrap();
}
#[tokio::test]
async fn test_ros_addrlist_no_ips() {
let plugin = RosAddrlistPlugin::new("test_list");
let mut ctx = Context::new(Message::new());
let mut response = Message::new();
response.add_answer(ResourceRecord::new(
"example.com".to_string(),
RecordType::CNAME,
RecordClass::IN,
300,
RData::CNAME("target.example.com".to_string()),
));
ctx.set_response(Some(response));
plugin.execute(&mut ctx).await.unwrap();
}
}