wgctrl 0.1.0

wgctrl is a crate that enables control over wireguard interfaces
Documentation
use std::collections::HashSet;
use std::mem::size_of;
use std::net::SocketAddr;

use anyhow::Result;
use ipnet::IpNet;
use rsln::types::message::{Attribute, RouteAttr};

use crate::constants::*;
use crate::types::{Config, PeerConfig};

// ipBatchChunk is a tunable allowed IP batch limit per peer.
const IP_BATCH_CHUNK: usize = 256;

// peerBatchChunk specifies the number of peers that can appear in a
// configuration before we start splitting it into chunks.
const PEER_BATCH_CHUNK: usize = 32;

pub fn config_attrs(name: &str, cfg: &Config) -> Result<Vec<u8>> {
    let mut attrs = Vec::new();

    // Device Name
    let name_attr = RouteAttr::new(WgDeviceAttr::IfName as u16, zero_terminated(name).as_slice());
    attrs.extend(name_attr.serialize()?);

    // Private Key
    if let Some(key) = &cfg.private_key {
        let attr = RouteAttr::new(WgDeviceAttr::PrivateKey as u16, key.as_ref());
        attrs.extend(attr.serialize()?);
    }

    // Listen Port
    if let Some(port) = cfg.listen_port {
        let attr = RouteAttr::new(WgDeviceAttr::ListenPort as u16, &port.to_ne_bytes());
        attrs.extend(attr.serialize()?);
    }

    // Firewall Mark
    if let Some(fwmark) = cfg.firewall_mark {
        let attr = RouteAttr::new(WgDeviceAttr::Fwmark as u16, &(fwmark as u32).to_ne_bytes());
        attrs.extend(attr.serialize()?);
    }

    // Replace Peers Flag
    if cfg.replace_peers {
        let attr = RouteAttr::new(WgDeviceAttr::Flags as u16, &WGDEVICE_F_REPLACE_PEERS.to_ne_bytes());
        attrs.extend(attr.serialize()?);
    }

    // Peers
    if !cfg.peers.is_empty() {
        let mut peers_payload = Vec::new();
        for (i, peer) in cfg.peers.iter().enumerate() {
            // Netlink arrays use type as an array index (i as u16)
            let peer_attr_payload = encode_peer(peer)?;
            let peer_attr = RouteAttr::new(i as u16 | NLA_F_NESTED, &peer_attr_payload);
            peers_payload.extend(peer_attr.serialize()?);
        }
        
        let attr = RouteAttr::new(WgDeviceAttr::Peers as u16 | NLA_F_NESTED, &peers_payload);
        attrs.extend(attr.serialize()?);
    }

    Ok(attrs)
}

fn encode_peer(p: &PeerConfig) -> Result<Vec<u8>> {
    let mut attrs = Vec::new();

    // Public Key
    let pk_attr = RouteAttr::new(WgPeerAttr::PublicKey as u16, p.public_key.as_ref());
    attrs.extend(pk_attr.serialize()?);

    // Flags
    let mut flags = 0u32;
    if p.remove {
        flags |= WGPEER_F_REMOVE_ME;
    }
    if p.replace_allowed_ips {
        flags |= WGPEER_F_REPLACE_ALLOWEDIPS;
    }
    if p.update_only {
        flags |= WGPEER_F_UPDATE_ONLY;
    }
    if flags != 0 {
        let attr = RouteAttr::new(WgPeerAttr::Flags as u16, &flags.to_ne_bytes());
        attrs.extend(attr.serialize()?);
    }

    // Preshared Key
    if let Some(psk) = &p.preshared_key {
        let attr = RouteAttr::new(WgPeerAttr::PresharedKey as u16, psk.as_ref());
        attrs.extend(attr.serialize()?);
    }

    // Endpoint
    if let Some(endpoint) = &p.endpoint {
        let endpoint_bytes = encode_sockaddr(endpoint)?;
        let attr = RouteAttr::new(WgPeerAttr::Endpoint as u16, &endpoint_bytes);
        attrs.extend(attr.serialize()?);
    }

    // Persistent Keepalive
    if let Some(interval) = p.persistent_keepalive_interval {
        let secs = interval.as_secs() as u16;
        let attr = RouteAttr::new(WgPeerAttr::PersistentKeepalive as u16, &secs.to_ne_bytes());
        attrs.extend(attr.serialize()?);
    }

    // Allowed IPs
    if !p.allowed_ips.is_empty() {
        let mut ips_payload = Vec::new();
        for (i, ip) in p.allowed_ips.iter().enumerate() {
            let ip_attr_payload = encode_allowed_ip(ip)?;
            let ip_attr = RouteAttr::new(i as u16 | NLA_F_NESTED, &ip_attr_payload);
            ips_payload.extend(ip_attr.serialize()?);
        }
        let attr = RouteAttr::new(WgPeerAttr::AllowedIps as u16 | NLA_F_NESTED, &ips_payload);
        attrs.extend(attr.serialize()?);
    }

    Ok(attrs)
}

fn encode_sockaddr(endpoint: &SocketAddr) -> Result<Vec<u8>> {
    match endpoint {
        SocketAddr::V4(addr) => {
            let sa = libc::sockaddr_in {
                sin_family: libc::AF_INET as u16,
                sin_port: addr.port().to_be(), // Network Byte Order
                sin_addr: libc::in_addr {
                    s_addr: u32::from_ne_bytes(addr.ip().octets()),
                },
                sin_zero: [0; 8],
            };
            let ptr = &sa as *const libc::sockaddr_in as *const u8;
            let bytes = unsafe { std::slice::from_raw_parts(ptr, size_of::<libc::sockaddr_in>()) };
            Ok(bytes.to_vec())
        }
        SocketAddr::V6(addr) => {
             let sa = libc::sockaddr_in6 {
                sin6_family: libc::AF_INET6 as u16,
                sin6_port: addr.port().to_be(),
                sin6_flowinfo: addr.flowinfo().to_be(),
                sin6_addr: libc::in6_addr {
                    s6_addr: addr.ip().octets(),
                },
                sin6_scope_id: addr.scope_id(),
            };
            let ptr = &sa as *const libc::sockaddr_in6 as *const u8;
            let bytes = unsafe { std::slice::from_raw_parts(ptr, size_of::<libc::sockaddr_in6>()) };
            Ok(bytes.to_vec())
        }
    }
}

fn encode_allowed_ip(ipn: &IpNet) -> Result<Vec<u8>> {
    let mut attrs = Vec::new();
    
    let family = match ipn {
        IpNet::V4(_) => libc::AF_INET as u16,
        IpNet::V6(_) => libc::AF_INET6 as u16,
    };
    
    let fam_attr = RouteAttr::new(WgAllowedIpAttr::Family as u16, &family.to_ne_bytes());
    attrs.extend(fam_attr.serialize()?);

    let ip_bytes = match ipn {
        IpNet::V4(net) => net.addr().octets().to_vec(),
        IpNet::V6(net) => net.addr().octets().to_vec(),
    };
    let ip_attr = RouteAttr::new(WgAllowedIpAttr::IpAddr as u16, &ip_bytes);
    attrs.extend(ip_attr.serialize()?);

    let cidr = ipn.prefix_len();
    let cidr_attr = RouteAttr::new(WgAllowedIpAttr::CidrMask as u16, &[cidr]);
    attrs.extend(cidr_attr.serialize()?);

    Ok(attrs)
}

fn zero_terminated(s: &str) -> Vec<u8> {
    let mut v = s.as_bytes().to_vec();
    v.push(0);
    v
}

// buildBatches produces a batch of configs from a single config, if needed.
pub fn build_batches(cfg: &Config) -> Vec<Config> {
    if !should_batch(cfg) {
        // We can't clone Config directly if it doesn't derive Clone.
        // Assuming Config implies Clone or we implement manual cloning.
        // For now, let's assume we need to clone manually or verify if Config is Clone.
        // Since Config owns data (Vec, String), it should be Clone.
        // I will add #[derive(Clone)] to Config in types.rs if needed.
        return vec![cfg.clone()];
    }

    let mut base = cfg.clone();
    base.peers.clear();

    let mut batches = Vec::new();
    let mut current_batch = base.clone();
    let mut current_ip_count = 0;
    
    let mut known_peers = HashSet::new();

    for p in &cfg.peers {
        let mut current_allowed_ips = p.allowed_ips.clone();

        // Handle peer with no AllowedIPs
        if current_allowed_ips.is_empty() {
             let mut pcfg = p.clone();
             if known_peers.contains(&p.public_key) {
                 pcfg.preshared_key = None;
                 pcfg.endpoint = None;
                 pcfg.persistent_keepalive_interval = None;
                 pcfg.replace_allowed_ips = false;
             } else {
                 known_peers.insert(p.public_key);
             }
             
             if current_batch.peers.len() >= PEER_BATCH_CHUNK {
                 batches.push(current_batch);
                 current_batch = base.clone();
                 current_ip_count = 0;
             }
             
             current_batch.peers.push(pcfg);
             continue;
        }

        while !current_allowed_ips.is_empty() {
            let space = IP_BATCH_CHUNK.saturating_sub(current_ip_count);
            // If no space or max peers, flush
            if space == 0 || current_batch.peers.len() >= PEER_BATCH_CHUNK {
                 batches.push(current_batch);
                 current_batch = base.clone();
                 current_ip_count = 0;
            }
            
            // Recalculate space after flush
            let space = IP_BATCH_CHUNK.saturating_sub(current_ip_count);
            let take = std::cmp::min(current_allowed_ips.len(), space);
            let chunk: Vec<IpNet> = current_allowed_ips.drain(0..take).collect();

            let mut pcfg = p.clone();
            pcfg.allowed_ips = chunk;

            if known_peers.contains(&p.public_key) {
                 pcfg.preshared_key = None;
                 pcfg.endpoint = None;
                 pcfg.persistent_keepalive_interval = None;
                 pcfg.replace_allowed_ips = false;
            } else {
                 known_peers.insert(p.public_key);
            }

            current_batch.peers.push(pcfg);
            current_ip_count += take;
        }
    }

    if !current_batch.peers.is_empty() {
        batches.push(current_batch);
    }

    // Do not allow peer replacement beyond the first message in a batch
    for (i, batch) in batches.iter_mut().enumerate() {
        if i > 0 {
            batch.replace_peers = false;
        }
    }

    batches
}

fn should_batch(cfg: &Config) -> bool {
    if cfg.peers.len() > PEER_BATCH_CHUNK {
        return true;
    }

    let mut ips = 0;
    for p in &cfg.peers {
        ips += p.allowed_ips.len();
    }

    ips > IP_BATCH_CHUNK
}