use std::collections::HashSet;
use std::mem::size_of;
use std::net::SocketAddr;
use anyhow::Result;
use ipnet::IpNet;
use rsln::types::message::{Attribute, RouteAttr};
use crate::constants::*;
use crate::types::{Config, PeerConfig};
const IP_BATCH_CHUNK: usize = 256;
const PEER_BATCH_CHUNK: usize = 32;
pub fn config_attrs(name: &str, cfg: &Config) -> Result<Vec<u8>> {
let mut attrs = Vec::new();
let name_attr = RouteAttr::new(WgDeviceAttr::IfName as u16, zero_terminated(name).as_slice());
attrs.extend(name_attr.serialize()?);
if let Some(key) = &cfg.private_key {
let attr = RouteAttr::new(WgDeviceAttr::PrivateKey as u16, key.as_ref());
attrs.extend(attr.serialize()?);
}
if let Some(port) = cfg.listen_port {
let attr = RouteAttr::new(WgDeviceAttr::ListenPort as u16, &port.to_ne_bytes());
attrs.extend(attr.serialize()?);
}
if let Some(fwmark) = cfg.firewall_mark {
let attr = RouteAttr::new(WgDeviceAttr::Fwmark as u16, &(fwmark as u32).to_ne_bytes());
attrs.extend(attr.serialize()?);
}
if cfg.replace_peers {
let attr = RouteAttr::new(WgDeviceAttr::Flags as u16, &WGDEVICE_F_REPLACE_PEERS.to_ne_bytes());
attrs.extend(attr.serialize()?);
}
if !cfg.peers.is_empty() {
let mut peers_payload = Vec::new();
for (i, peer) in cfg.peers.iter().enumerate() {
let peer_attr_payload = encode_peer(peer)?;
let peer_attr = RouteAttr::new(i as u16 | NLA_F_NESTED, &peer_attr_payload);
peers_payload.extend(peer_attr.serialize()?);
}
let attr = RouteAttr::new(WgDeviceAttr::Peers as u16 | NLA_F_NESTED, &peers_payload);
attrs.extend(attr.serialize()?);
}
Ok(attrs)
}
fn encode_peer(p: &PeerConfig) -> Result<Vec<u8>> {
let mut attrs = Vec::new();
let pk_attr = RouteAttr::new(WgPeerAttr::PublicKey as u16, p.public_key.as_ref());
attrs.extend(pk_attr.serialize()?);
let mut flags = 0u32;
if p.remove {
flags |= WGPEER_F_REMOVE_ME;
}
if p.replace_allowed_ips {
flags |= WGPEER_F_REPLACE_ALLOWEDIPS;
}
if p.update_only {
flags |= WGPEER_F_UPDATE_ONLY;
}
if flags != 0 {
let attr = RouteAttr::new(WgPeerAttr::Flags as u16, &flags.to_ne_bytes());
attrs.extend(attr.serialize()?);
}
if let Some(psk) = &p.preshared_key {
let attr = RouteAttr::new(WgPeerAttr::PresharedKey as u16, psk.as_ref());
attrs.extend(attr.serialize()?);
}
if let Some(endpoint) = &p.endpoint {
let endpoint_bytes = encode_sockaddr(endpoint)?;
let attr = RouteAttr::new(WgPeerAttr::Endpoint as u16, &endpoint_bytes);
attrs.extend(attr.serialize()?);
}
if let Some(interval) = p.persistent_keepalive_interval {
let secs = interval.as_secs() as u16;
let attr = RouteAttr::new(WgPeerAttr::PersistentKeepalive as u16, &secs.to_ne_bytes());
attrs.extend(attr.serialize()?);
}
if !p.allowed_ips.is_empty() {
let mut ips_payload = Vec::new();
for (i, ip) in p.allowed_ips.iter().enumerate() {
let ip_attr_payload = encode_allowed_ip(ip)?;
let ip_attr = RouteAttr::new(i as u16 | NLA_F_NESTED, &ip_attr_payload);
ips_payload.extend(ip_attr.serialize()?);
}
let attr = RouteAttr::new(WgPeerAttr::AllowedIps as u16 | NLA_F_NESTED, &ips_payload);
attrs.extend(attr.serialize()?);
}
Ok(attrs)
}
fn encode_sockaddr(endpoint: &SocketAddr) -> Result<Vec<u8>> {
match endpoint {
SocketAddr::V4(addr) => {
let sa = libc::sockaddr_in {
sin_family: libc::AF_INET as u16,
sin_port: addr.port().to_be(), sin_addr: libc::in_addr {
s_addr: u32::from_ne_bytes(addr.ip().octets()),
},
sin_zero: [0; 8],
};
let ptr = &sa as *const libc::sockaddr_in as *const u8;
let bytes = unsafe { std::slice::from_raw_parts(ptr, size_of::<libc::sockaddr_in>()) };
Ok(bytes.to_vec())
}
SocketAddr::V6(addr) => {
let sa = libc::sockaddr_in6 {
sin6_family: libc::AF_INET6 as u16,
sin6_port: addr.port().to_be(),
sin6_flowinfo: addr.flowinfo().to_be(),
sin6_addr: libc::in6_addr {
s6_addr: addr.ip().octets(),
},
sin6_scope_id: addr.scope_id(),
};
let ptr = &sa as *const libc::sockaddr_in6 as *const u8;
let bytes = unsafe { std::slice::from_raw_parts(ptr, size_of::<libc::sockaddr_in6>()) };
Ok(bytes.to_vec())
}
}
}
fn encode_allowed_ip(ipn: &IpNet) -> Result<Vec<u8>> {
let mut attrs = Vec::new();
let family = match ipn {
IpNet::V4(_) => libc::AF_INET as u16,
IpNet::V6(_) => libc::AF_INET6 as u16,
};
let fam_attr = RouteAttr::new(WgAllowedIpAttr::Family as u16, &family.to_ne_bytes());
attrs.extend(fam_attr.serialize()?);
let ip_bytes = match ipn {
IpNet::V4(net) => net.addr().octets().to_vec(),
IpNet::V6(net) => net.addr().octets().to_vec(),
};
let ip_attr = RouteAttr::new(WgAllowedIpAttr::IpAddr as u16, &ip_bytes);
attrs.extend(ip_attr.serialize()?);
let cidr = ipn.prefix_len();
let cidr_attr = RouteAttr::new(WgAllowedIpAttr::CidrMask as u16, &[cidr]);
attrs.extend(cidr_attr.serialize()?);
Ok(attrs)
}
fn zero_terminated(s: &str) -> Vec<u8> {
let mut v = s.as_bytes().to_vec();
v.push(0);
v
}
pub fn build_batches(cfg: &Config) -> Vec<Config> {
if !should_batch(cfg) {
return vec![cfg.clone()];
}
let mut base = cfg.clone();
base.peers.clear();
let mut batches = Vec::new();
let mut current_batch = base.clone();
let mut current_ip_count = 0;
let mut known_peers = HashSet::new();
for p in &cfg.peers {
let mut current_allowed_ips = p.allowed_ips.clone();
if current_allowed_ips.is_empty() {
let mut pcfg = p.clone();
if known_peers.contains(&p.public_key) {
pcfg.preshared_key = None;
pcfg.endpoint = None;
pcfg.persistent_keepalive_interval = None;
pcfg.replace_allowed_ips = false;
} else {
known_peers.insert(p.public_key);
}
if current_batch.peers.len() >= PEER_BATCH_CHUNK {
batches.push(current_batch);
current_batch = base.clone();
current_ip_count = 0;
}
current_batch.peers.push(pcfg);
continue;
}
while !current_allowed_ips.is_empty() {
let space = IP_BATCH_CHUNK.saturating_sub(current_ip_count);
if space == 0 || current_batch.peers.len() >= PEER_BATCH_CHUNK {
batches.push(current_batch);
current_batch = base.clone();
current_ip_count = 0;
}
let space = IP_BATCH_CHUNK.saturating_sub(current_ip_count);
let take = std::cmp::min(current_allowed_ips.len(), space);
let chunk: Vec<IpNet> = current_allowed_ips.drain(0..take).collect();
let mut pcfg = p.clone();
pcfg.allowed_ips = chunk;
if known_peers.contains(&p.public_key) {
pcfg.preshared_key = None;
pcfg.endpoint = None;
pcfg.persistent_keepalive_interval = None;
pcfg.replace_allowed_ips = false;
} else {
known_peers.insert(p.public_key);
}
current_batch.peers.push(pcfg);
current_ip_count += take;
}
}
if !current_batch.peers.is_empty() {
batches.push(current_batch);
}
for (i, batch) in batches.iter_mut().enumerate() {
if i > 0 {
batch.replace_peers = false;
}
}
batches
}
fn should_batch(cfg: &Config) -> bool {
if cfg.peers.len() > PEER_BATCH_CHUNK {
return true;
}
let mut ips = 0;
for p in &cfg.peers {
ips += p.allowed_ips.len();
}
ips > IP_BATCH_CHUNK
}