use std::net::IpAddr;
use crate::netlink::{
MsgBuffer, NFNL_SUBSYS_IPSET, NLA_F_NESTED, NLM_F_ACK, NLM_F_DUMP, NLM_F_REQUEST,
NetlinkSocket, NfGenMsg, NlAttr, NlMsgHdr, is_nlmsg_done, nla_align, parse_nlmsg_error,
};
use crate::{IpEntry, IpSetError, Result};
const IPSET_PROTOCOL: u8 = 7;
const IPSET_MAXNAMELEN: usize = 32;
const IPSET_CMD_CREATE: u8 = 2;
const IPSET_CMD_DESTROY: u8 = 3;
const IPSET_CMD_FLUSH: u8 = 4;
const IPSET_CMD_LIST: u8 = 7;
const IPSET_CMD_ADD: u8 = 9;
const IPSET_CMD_DEL: u8 = 10;
const IPSET_CMD_TEST: u8 = 11;
const IPSET_ATTR_PROTOCOL: u16 = 1;
const IPSET_ATTR_SETNAME: u16 = 2;
const IPSET_ATTR_TYPENAME: u16 = 3;
const IPSET_ATTR_REVISION: u16 = 4;
const IPSET_ATTR_FAMILY: u16 = 5;
const IPSET_ATTR_DATA: u16 = 7;
const IPSET_ATTR_LINENO: u16 = 9;
const IPSET_ATTR_IP: u16 = 1;
const IPSET_ATTR_TIMEOUT: u16 = 6;
const IPSET_ATTR_CADT_MAX: u16 = 16;
const IPSET_ATTR_HASHSIZE: u16 = IPSET_ATTR_CADT_MAX + 2; const IPSET_ATTR_MAXELEM: u16 = IPSET_ATTR_CADT_MAX + 3;
const IPSET_ATTR_ADT: u16 = 8;
const IPSET_ATTR_IPADDR_IPV4: u16 = 1;
const IPSET_ATTR_IPADDR_IPV6: u16 = 2;
const BUFF_SZ: usize = 1024;
fn ipset_msg_type(cmd: u8) -> u16 {
((NFNL_SUBSYS_IPSET as u16) << 8) | (cmd as u16)
}
fn ipset_operate(setname: &str, entry: &IpEntry, cmd: u8) -> Result<()> {
if setname.is_empty() || setname.len() >= IPSET_MAXNAMELEN {
return Err(IpSetError::InvalidSetName(setname.to_string()));
}
let (family, addr_type, addr_bytes): (u8, u16, Vec<u8>) = match entry.addr {
IpAddr::V4(v4) => (
libc::AF_INET as u8,
IPSET_ATTR_IPADDR_IPV4,
v4.octets().to_vec(),
),
IpAddr::V6(v6) => (
libc::AF_INET6 as u8,
IPSET_ATTR_IPADDR_IPV6,
v6.octets().to_vec(),
),
};
let mut buf = MsgBuffer::new(BUFF_SZ);
buf.put_nlmsghdr(ipset_msg_type(cmd), NLM_F_REQUEST | NLM_F_ACK, 0);
buf.put_nfgenmsg(family, 0, 0);
buf.put_attr_u8(IPSET_ATTR_PROTOCOL, IPSET_PROTOCOL);
buf.put_attr_str(IPSET_ATTR_SETNAME, setname);
let data_offset = buf.start_nested(IPSET_ATTR_DATA);
let ip_offset = buf.start_nested(IPSET_ATTR_IP);
let len = crate::netlink::NlAttr::SIZE + addr_bytes.len();
buf.put_u16(len as u16);
buf.put_u16(addr_type | crate::netlink::NLA_F_NET_BYTEORDER);
buf.put_bytes(&addr_bytes);
buf.align();
buf.end_nested(ip_offset);
if let Some(timeout) = entry.timeout {
buf.put_attr_u32_be(IPSET_ATTR_TIMEOUT, timeout);
}
buf.put_attr_u32(IPSET_ATTR_LINENO, 0);
buf.end_nested(data_offset);
buf.finalize_nlmsg();
let socket = NetlinkSocket::new()?;
let mut recv_buf = [0u8; BUFF_SZ];
let recv_len = socket.send_recv(buf.as_slice(), &mut recv_buf)?;
if recv_len < NlMsgHdr::SIZE {
return Err(IpSetError::ProtocolError);
}
if let Some(error) = parse_nlmsg_error(&recv_buf[..recv_len]) {
if error == 0 {
return Ok(());
}
match -error {
libc::ENOENT => {
if cmd == IPSET_CMD_TEST {
return Err(IpSetError::ElementNotFound);
}
return Err(IpSetError::SetNotFound(setname.to_string()));
}
libc::EEXIST => return Err(IpSetError::ElementExists),
libc::IPSET_ERR_EXIST => {
if cmd == IPSET_CMD_TEST {
return Err(IpSetError::ElementNotFound);
}
return Err(IpSetError::ElementExists);
}
_ => return Err(IpSetError::NetlinkError(-error)),
}
}
Err(IpSetError::ProtocolError)
}
mod libc {
pub use ::libc::*;
pub const IPSET_ERR_EXIST: i32 = 4103;
}
#[derive(Clone, Copy, Debug)]
pub enum IpSetType {
HashIp,
HashNet,
}
impl IpSetType {
fn as_str(&self) -> &'static str {
match self {
IpSetType::HashIp => "hash:ip",
IpSetType::HashNet => "hash:net",
}
}
fn revision(&self) -> u8 {
match self {
IpSetType::HashIp => 4,
IpSetType::HashNet => 4,
}
}
}
#[derive(Clone, Copy, Debug)]
pub enum IpSetFamily {
Inet,
Inet6,
}
impl IpSetFamily {
fn as_u8(&self) -> u8 {
match self {
IpSetFamily::Inet => libc::AF_INET as u8,
IpSetFamily::Inet6 => libc::AF_INET6 as u8,
}
}
}
#[derive(Clone, Debug)]
pub struct IpSetCreateOptions {
pub set_type: IpSetType,
pub family: IpSetFamily,
pub hashsize: Option<u32>,
pub maxelem: Option<u32>,
pub timeout: Option<u32>,
}
impl Default for IpSetCreateOptions {
fn default() -> Self {
Self {
set_type: IpSetType::HashIp,
family: IpSetFamily::Inet,
hashsize: None,
maxelem: None,
timeout: None,
}
}
}
pub fn ipset_create(setname: &str, options: &IpSetCreateOptions) -> Result<()> {
if setname.is_empty() || setname.len() >= IPSET_MAXNAMELEN {
return Err(IpSetError::InvalidSetName(setname.to_string()));
}
let mut buf = MsgBuffer::new(BUFF_SZ);
buf.put_nlmsghdr(
ipset_msg_type(IPSET_CMD_CREATE),
NLM_F_REQUEST | NLM_F_ACK,
0,
);
buf.put_nfgenmsg(options.family.as_u8(), 0, 0);
buf.put_attr_u8(IPSET_ATTR_PROTOCOL, IPSET_PROTOCOL);
buf.put_attr_str(IPSET_ATTR_SETNAME, setname);
buf.put_attr_str(IPSET_ATTR_TYPENAME, options.set_type.as_str());
buf.put_attr_u8(IPSET_ATTR_REVISION, options.set_type.revision());
buf.put_attr_u8(IPSET_ATTR_FAMILY, options.family.as_u8());
let data_offset = buf.start_nested(IPSET_ATTR_DATA);
if let Some(hashsize) = options.hashsize {
buf.put_attr_u32(IPSET_ATTR_HASHSIZE, hashsize);
}
if let Some(maxelem) = options.maxelem {
buf.put_attr_u32(IPSET_ATTR_MAXELEM, maxelem);
}
if let Some(timeout) = options.timeout {
buf.put_attr_u32_be(IPSET_ATTR_TIMEOUT, timeout);
}
buf.end_nested(data_offset);
buf.finalize_nlmsg();
let socket = NetlinkSocket::new()?;
let mut recv_buf = [0u8; BUFF_SZ];
let recv_len = socket.send_recv(buf.as_slice(), &mut recv_buf)?;
if recv_len < NlMsgHdr::SIZE {
return Err(IpSetError::ProtocolError);
}
if let Some(error) = parse_nlmsg_error(&recv_buf[..recv_len]) {
if error == 0 {
return Ok(());
}
match -error {
libc::EEXIST => return Err(IpSetError::ElementExists),
_ => return Err(IpSetError::NetlinkError(-error)),
}
}
Err(IpSetError::ProtocolError)
}
pub fn ipset_destroy(setname: &str) -> Result<()> {
if setname.is_empty() || setname.len() >= IPSET_MAXNAMELEN {
return Err(IpSetError::InvalidSetName(setname.to_string()));
}
let mut buf = MsgBuffer::new(BUFF_SZ);
buf.put_nlmsghdr(
ipset_msg_type(IPSET_CMD_DESTROY),
NLM_F_REQUEST | NLM_F_ACK,
0,
);
buf.put_nfgenmsg(libc::AF_INET as u8, 0, 0);
buf.put_attr_u8(IPSET_ATTR_PROTOCOL, IPSET_PROTOCOL);
buf.put_attr_str(IPSET_ATTR_SETNAME, setname);
buf.finalize_nlmsg();
let socket = NetlinkSocket::new()?;
let mut recv_buf = [0u8; BUFF_SZ];
let recv_len = socket.send_recv(buf.as_slice(), &mut recv_buf)?;
if recv_len < NlMsgHdr::SIZE {
return Err(IpSetError::ProtocolError);
}
if let Some(error) = parse_nlmsg_error(&recv_buf[..recv_len]) {
if error == 0 {
return Ok(());
}
match -error {
libc::ENOENT => return Err(IpSetError::SetNotFound(setname.to_string())),
libc::EBUSY => return Err(IpSetError::NetlinkError(-error)), _ => return Err(IpSetError::NetlinkError(-error)),
}
}
Err(IpSetError::ProtocolError)
}
pub fn ipset_flush(setname: &str) -> Result<()> {
if setname.is_empty() || setname.len() >= IPSET_MAXNAMELEN {
return Err(IpSetError::InvalidSetName(setname.to_string()));
}
let mut buf = MsgBuffer::new(BUFF_SZ);
buf.put_nlmsghdr(
ipset_msg_type(IPSET_CMD_FLUSH),
NLM_F_REQUEST | NLM_F_ACK,
0,
);
buf.put_nfgenmsg(libc::AF_INET as u8, 0, 0);
buf.put_attr_u8(IPSET_ATTR_PROTOCOL, IPSET_PROTOCOL);
buf.put_attr_str(IPSET_ATTR_SETNAME, setname);
buf.finalize_nlmsg();
let socket = NetlinkSocket::new()?;
let mut recv_buf = [0u8; BUFF_SZ];
let recv_len = socket.send_recv(buf.as_slice(), &mut recv_buf)?;
if recv_len < NlMsgHdr::SIZE {
return Err(IpSetError::ProtocolError);
}
if let Some(error) = parse_nlmsg_error(&recv_buf[..recv_len]) {
if error == 0 {
return Ok(());
}
match -error {
libc::ENOENT => return Err(IpSetError::SetNotFound(setname.to_string())),
_ => return Err(IpSetError::NetlinkError(-error)),
}
}
Err(IpSetError::ProtocolError)
}
pub fn ipset_add<E: Into<IpEntry>>(setname: &str, entry: E) -> Result<()> {
ipset_operate(setname, &entry.into(), IPSET_CMD_ADD)
}
pub fn ipset_del<E: Into<IpEntry>>(setname: &str, entry: E) -> Result<()> {
ipset_operate(setname, &entry.into(), IPSET_CMD_DEL)
}
pub fn ipset_test<E: Into<IpEntry>>(setname: &str, entry: E) -> Result<bool> {
match ipset_operate(setname, &entry.into(), IPSET_CMD_TEST) {
Ok(()) => Ok(true),
Err(IpSetError::ElementNotFound) => Ok(false),
Err(e) => Err(e),
}
}
pub fn ipset_list(setname: &str) -> Result<Vec<IpAddr>> {
if setname.is_empty() || setname.len() >= IPSET_MAXNAMELEN {
return Err(IpSetError::InvalidSetName(setname.to_string()));
}
let mut buf = MsgBuffer::new(BUFF_SZ);
buf.put_nlmsghdr(
ipset_msg_type(IPSET_CMD_LIST),
NLM_F_REQUEST | NLM_F_ACK | NLM_F_DUMP,
0,
);
buf.put_nfgenmsg(libc::AF_INET as u8, 0, 0);
buf.put_attr_u8(IPSET_ATTR_PROTOCOL, IPSET_PROTOCOL);
buf.put_attr_str(IPSET_ATTR_SETNAME, setname);
buf.finalize_nlmsg();
let socket = NetlinkSocket::new()?;
socket.send(buf.as_slice())?;
let mut result = Vec::new();
let mut recv_buf = [0u8; 8192];
loop {
let recv_len = socket.recv(&mut recv_buf)?;
if recv_len < NlMsgHdr::SIZE {
break;
}
let mut offset = 0;
while offset + NlMsgHdr::SIZE <= recv_len {
let hdr: NlMsgHdr =
unsafe { std::ptr::read_unaligned(recv_buf[offset..].as_ptr() as *const NlMsgHdr) };
if hdr.nlmsg_len as usize > recv_len - offset {
break;
}
if is_nlmsg_done(&recv_buf[offset..]) {
return Ok(result);
}
if let Some(error) =
parse_nlmsg_error(&recv_buf[offset..offset + hdr.nlmsg_len as usize])
{
if error != 0 {
match -error {
libc::ENOENT => return Err(IpSetError::SetNotFound(setname.to_string())),
_ => return Err(IpSetError::NetlinkError(-error)),
}
}
} else {
let msg_end = offset + hdr.nlmsg_len as usize;
let attr_start = offset + NlMsgHdr::SIZE + NfGenMsg::SIZE;
parse_ipset_list_attrs(&recv_buf[attr_start..msg_end], &mut result);
}
offset += nla_align(hdr.nlmsg_len as usize);
}
}
Ok(result)
}
fn parse_ipset_list_attrs(data: &[u8], result: &mut Vec<IpAddr>) {
let mut offset = 0;
while offset + NlAttr::SIZE <= data.len() {
let attr_len = u16::from_ne_bytes([data[offset], data[offset + 1]]) as usize;
let attr_type = u16::from_ne_bytes([data[offset + 2], data[offset + 3]]);
if attr_len < NlAttr::SIZE || offset + attr_len > data.len() {
break;
}
let attr_type_masked = attr_type & !NLA_F_NESTED;
if attr_type_masked == IPSET_ATTR_ADT && (attr_type & NLA_F_NESTED) != 0 {
parse_ipset_adt_attrs(&data[offset + NlAttr::SIZE..offset + attr_len], result);
}
offset += nla_align(attr_len);
}
}
fn parse_ipset_adt_attrs(data: &[u8], result: &mut Vec<IpAddr>) {
let mut offset = 0;
while offset + NlAttr::SIZE <= data.len() {
let attr_len = u16::from_ne_bytes([data[offset], data[offset + 1]]) as usize;
let attr_type = u16::from_ne_bytes([data[offset + 2], data[offset + 3]]);
if attr_len < NlAttr::SIZE || offset + attr_len > data.len() {
break;
}
if (attr_type & NLA_F_NESTED) != 0 {
parse_ipset_data_attrs(&data[offset + NlAttr::SIZE..offset + attr_len], result);
}
offset += nla_align(attr_len);
}
}
fn parse_ipset_data_attrs(data: &[u8], result: &mut Vec<IpAddr>) {
let mut offset = 0;
while offset + NlAttr::SIZE <= data.len() {
let attr_len = u16::from_ne_bytes([data[offset], data[offset + 1]]) as usize;
let attr_type = u16::from_ne_bytes([data[offset + 2], data[offset + 3]]);
if attr_len < NlAttr::SIZE || offset + attr_len > data.len() {
break;
}
let attr_type_masked = attr_type & !NLA_F_NESTED;
if attr_type_masked == IPSET_ATTR_IP
&& (attr_type & NLA_F_NESTED) != 0
&& let Some(addr) = parse_ipset_ip_attr(&data[offset + NlAttr::SIZE..offset + attr_len])
{
result.push(addr);
}
offset += nla_align(attr_len);
}
}
fn parse_ipset_ip_attr(data: &[u8]) -> Option<IpAddr> {
if data.len() < NlAttr::SIZE {
return None;
}
let attr_len = u16::from_ne_bytes([data[0], data[1]]) as usize;
let attr_type = u16::from_ne_bytes([data[2], data[3]])
& !NLA_F_NESTED
& !crate::netlink::NLA_F_NET_BYTEORDER;
if attr_len < NlAttr::SIZE {
return None;
}
let payload = &data[NlAttr::SIZE..attr_len.min(data.len())];
match attr_type {
IPSET_ATTR_IPADDR_IPV4 if payload.len() >= 4 => {
let octets: [u8; 4] = payload[..4].try_into().ok()?;
Some(IpAddr::V4(std::net::Ipv4Addr::from(octets)))
}
IPSET_ATTR_IPADDR_IPV6 if payload.len() >= 16 => {
let octets: [u8; 16] = payload[..16].try_into().ok()?;
Some(IpAddr::V6(std::net::Ipv6Addr::from(octets)))
}
_ => None,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ipset_msg_type() {
assert_eq!(ipset_msg_type(IPSET_CMD_ADD), (6 << 8) | 9);
assert_eq!(ipset_msg_type(IPSET_CMD_DEL), (6 << 8) | 10);
assert_eq!(ipset_msg_type(IPSET_CMD_TEST), (6 << 8) | 11);
}
#[test]
fn test_invalid_setname() {
let addr: IpAddr = "192.168.1.1".parse().unwrap();
assert!(matches!(
ipset_add("", addr),
Err(IpSetError::InvalidSetName(_))
));
let long_name = "a".repeat(IPSET_MAXNAMELEN);
assert!(matches!(
ipset_add(&long_name, addr),
Err(IpSetError::InvalidSetName(_))
));
}
#[test]
#[ignore]
fn test_ipset_add_ipv4() {
let addr: IpAddr = "10.0.0.1".parse().unwrap();
ipset_add("test_set", addr).expect("Failed to add IP to ipset");
}
#[test]
#[ignore]
fn test_ipset_test_ipv4() {
let addr: IpAddr = "10.0.0.1".parse().unwrap();
let exists = ipset_test("test_set", addr).expect("Failed to test IP in ipset");
println!("IP exists in set: {}", exists);
}
#[test]
#[ignore]
fn test_ipset_del_ipv4() {
let addr: IpAddr = "10.0.0.1".parse().unwrap();
ipset_del("test_set", addr).expect("Failed to delete IP from ipset");
}
#[test]
#[ignore]
fn test_ipset_add_ipv6() {
let addr: IpAddr = "2001:db8::1".parse().unwrap();
ipset_add("test_set6", addr).expect("Failed to add IPv6 to ipset");
}
#[test]
#[ignore]
fn test_ipset_with_timeout() {
let addr: IpAddr = "10.0.0.2".parse().unwrap();
let entry = IpEntry::with_timeout(addr, 60);
ipset_add("test_set_timeout", entry).expect("Failed to add IP with timeout");
}
}