wgctrl 0.1.0

wgctrl is a crate that enables control over wireguard interfaces
Documentation
use anyhow::Result;
use rsln::{
    core::message::Message,
    handle::{handle::SocketHandle, zero_terminated},
    netlink::Netlink,
    types::{
        link::Link,
        message::{Attribute, GenlMessage, RouteAttr},
    },
};

use crate::types::Device;
use crate::constants::WgCmd;

pub struct Client {
    netlink: Netlink,
    family_id: u16,
}

impl Client {
    pub fn new() -> Result<Self> {
        let mut netlink = Netlink::new();

        let family = netlink.genl_family_get(crate::constants::WG_GENL_NAME)?;

        Ok(Self {
            netlink,
            family_id: family.id,
        })
    }

    pub fn get_device(&mut self, name: &str) -> Result<Device> {
        if name.is_empty() {
            return Err(anyhow::anyhow!("Device name cannot be empty"));
        }

        let mut handle = self
            .netlink
            .sockets
            .entry(libc::NETLINK_GENERIC)
            .or_insert(SocketHandle::new(libc::NETLINK_GENERIC))
            .handle_generic();

        let mut req = Message::new(self.family_id, libc::NLM_F_REQUEST | libc::NLM_F_DUMP);

        let genl_hdr = GenlMessage {
            command: WgCmd::GetDevice as u8,
            version: 1,
            reserved: 0,
        };
        req.add(&genl_hdr.serialize()?);

        let name_attr = RouteAttr::new(2, &zero_terminated(name));
        req.add(&name_attr.serialize()?);

        let resp = handle.request(&mut req, 0)?;

        if resp.is_empty() {
            return Err(anyhow::anyhow!("Device not found"));
        }

        let mut device: Option<Device> = None;
        let mut known_peers: std::collections::HashMap<crate::types::Key, usize> =
            std::collections::HashMap::new();

        for payload in resp {
            let partial = Device::try_from(payload.as_slice())?;

            if let Some(dev) = &mut device {
                for peer in partial.peers {
                    if let Some(&idx) = known_peers.get(&peer.public_key) {
                        dev.peers[idx].allowed_ips.extend(peer.allowed_ips);
                    } else {
                        known_peers.insert(peer.public_key, dev.peers.len());
                        dev.peers.push(peer);
                    }
                }
            } else {
                for (i, peer) in partial.peers.iter().enumerate() {
                    known_peers.insert(peer.public_key, i);
                }
                device = Some(partial);
            }
        }

        device.ok_or_else(|| anyhow::anyhow!("Device not found"))
    }

    pub fn list_devices(&mut self) -> Result<Vec<Device>> {
        let links = self.netlink.link_list()?;
        let mut devices = Vec::new();

        for link in links {
            if link.link_type() == "wireguard" {
                let name = &link.attrs().name;
                let device = self.get_device(name)?;
                devices.push(device);
            }
        }

        Ok(devices)
    }

    pub fn configure_device(&mut self, name: &str, config: &crate::types::Config) -> Result<()> {
        use crate::config::{build_batches, config_attrs};

        let batches = build_batches(config);

        for batch in batches {
            let attrs = config_attrs(name, &batch)?;

            let mut handle = self
                .netlink
                .sockets
                .entry(libc::NETLINK_GENERIC)
                .or_insert(SocketHandle::new(libc::NETLINK_GENERIC))
                .handle_generic();

            let flags = libc::NLM_F_REQUEST | libc::NLM_F_ACK;
            let mut req = Message::new(self.family_id, flags);

            let genl_hdr = GenlMessage {
                command: WgCmd::SetDevice as u8,
                version: 1,
                reserved: 0,
            };
            req.add(&genl_hdr.serialize()?);
            req.add(&attrs);

            handle.request(&mut req, 0)?;
        }

        Ok(())
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::types::{Config, PeerConfig, Key};
    use std::process::Command;
    use std::time::Duration;
    use ipnet::IpNet;
    use std::str::FromStr;

    #[test]
    #[ignore]
    fn test_configure_device() {
        // Check for root privileges
        if unsafe { libc::geteuid() } != 0 {
            eprintln!("SKIPPING: Root privileges required for integration test");
            return;
        }

        let ifname = "wg_test_cfg";

        // Cleanup existing interface if any
        let _ = Command::new("ip").args(["link", "del", ifname]).output();

        // Create WireGuard interface
        let status = Command::new("ip")
            .args(["link", "add", ifname, "type", "wireguard"])
            .status();

        if status.is_err() || !status.unwrap().success() {
            eprintln!("SKIPPING: Could not create wireguard interface (kernel module missing?)");
            return;
        }

        // RAII Guard for cleanup
        struct Guard<'a>(&'a str);
        impl<'a> Drop for Guard<'a> {
            fn drop(&mut self) {
                let _ = Command::new("ip").args(["link", "del", self.0]).output();
            }
        }
        let _guard = Guard(ifname);

        let mut client = Client::new().expect("Failed to create Netlink Client");

        // Verify device exists
        let initial_dev = client.get_device(ifname).expect("Failed to get device initially");
        assert_eq!(initial_dev.name, ifname);

        // Generate keys
        let private_key = Key::generate_private_key().expect("Failed to generate private key");
        let peer_key = Key::generate_private_key().expect("Failed to generate peer key");
        let peer_pub = peer_key.public_key();
        let psk = Key::generate_key().expect("Failed to generate PSK");

        // Define Peer Config
        let peer = PeerConfig {
            public_key: peer_pub,
            remove: false,
            update_only: false,
            preshared_key: Some(psk),
            endpoint: Some("127.0.0.1:51820".parse().unwrap()),
            persistent_keepalive_interval: Some(Duration::from_secs(25)),
            replace_allowed_ips: true,
            allowed_ips: vec![IpNet::from_str("10.0.0.2/32").unwrap()],
        };

        // Define Device Config
        let config = Config {
            private_key: Some(private_key),
            listen_port: Some(51821),
            firewall_mark: Some(1234),
            replace_peers: true,
            peers: vec![peer],
        };

        // Apply configuration
        client.configure_device(ifname, &config).expect("Failed to configure device");

        // Verify configuration
        let device = client.get_device(ifname).expect("Failed to get device after config");

        assert_eq!(device.private_key, private_key, "Private key mismatch");
        assert_eq!(device.listen_port, 51821, "Listen port mismatch");
        assert_eq!(device.firewall_mark, 1234, "Firewall mark mismatch");
        assert_eq!(device.peers.len(), 1, "Expected 1 peer");
        
        let d_peer = &device.peers[0];
        assert_eq!(d_peer.public_key, peer_pub, "Peer public key mismatch");
        assert_eq!(d_peer.preshared_key, Some(psk), "PSK mismatch");
        assert_eq!(d_peer.persistent_keepalive_interval, Some(Duration::from_secs(25)), "Keepalive mismatch");
        
        // Allowed IPs check
        assert_eq!(d_peer.allowed_ips.len(), 1, "Expected 1 allowed IP");
        assert_eq!(d_peer.allowed_ips[0].to_string(), "10.0.0.2/32", "Allowed IP mismatch");
    }
}