use anyhow::Result;
use rsln::{
core::message::Message,
handle::{handle::SocketHandle, zero_terminated},
netlink::Netlink,
types::{
link::Link,
message::{Attribute, GenlMessage, RouteAttr},
},
};
use crate::types::Device;
use crate::constants::WgCmd;
pub struct Client {
netlink: Netlink,
family_id: u16,
}
impl Client {
pub fn new() -> Result<Self> {
let mut netlink = Netlink::new();
let family = netlink.genl_family_get(crate::constants::WG_GENL_NAME)?;
Ok(Self {
netlink,
family_id: family.id,
})
}
pub fn get_device(&mut self, name: &str) -> Result<Device> {
if name.is_empty() {
return Err(anyhow::anyhow!("Device name cannot be empty"));
}
let mut handle = self
.netlink
.sockets
.entry(libc::NETLINK_GENERIC)
.or_insert(SocketHandle::new(libc::NETLINK_GENERIC))
.handle_generic();
let mut req = Message::new(self.family_id, libc::NLM_F_REQUEST | libc::NLM_F_DUMP);
let genl_hdr = GenlMessage {
command: WgCmd::GetDevice as u8,
version: 1,
reserved: 0,
};
req.add(&genl_hdr.serialize()?);
let name_attr = RouteAttr::new(2, &zero_terminated(name));
req.add(&name_attr.serialize()?);
let resp = handle.request(&mut req, 0)?;
if resp.is_empty() {
return Err(anyhow::anyhow!("Device not found"));
}
let mut device: Option<Device> = None;
let mut known_peers: std::collections::HashMap<crate::types::Key, usize> =
std::collections::HashMap::new();
for payload in resp {
let partial = Device::try_from(payload.as_slice())?;
if let Some(dev) = &mut device {
for peer in partial.peers {
if let Some(&idx) = known_peers.get(&peer.public_key) {
dev.peers[idx].allowed_ips.extend(peer.allowed_ips);
} else {
known_peers.insert(peer.public_key, dev.peers.len());
dev.peers.push(peer);
}
}
} else {
for (i, peer) in partial.peers.iter().enumerate() {
known_peers.insert(peer.public_key, i);
}
device = Some(partial);
}
}
device.ok_or_else(|| anyhow::anyhow!("Device not found"))
}
pub fn list_devices(&mut self) -> Result<Vec<Device>> {
let links = self.netlink.link_list()?;
let mut devices = Vec::new();
for link in links {
if link.link_type() == "wireguard" {
let name = &link.attrs().name;
let device = self.get_device(name)?;
devices.push(device);
}
}
Ok(devices)
}
pub fn configure_device(&mut self, name: &str, config: &crate::types::Config) -> Result<()> {
use crate::config::{build_batches, config_attrs};
let batches = build_batches(config);
for batch in batches {
let attrs = config_attrs(name, &batch)?;
let mut handle = self
.netlink
.sockets
.entry(libc::NETLINK_GENERIC)
.or_insert(SocketHandle::new(libc::NETLINK_GENERIC))
.handle_generic();
let flags = libc::NLM_F_REQUEST | libc::NLM_F_ACK;
let mut req = Message::new(self.family_id, flags);
let genl_hdr = GenlMessage {
command: WgCmd::SetDevice as u8,
version: 1,
reserved: 0,
};
req.add(&genl_hdr.serialize()?);
req.add(&attrs);
handle.request(&mut req, 0)?;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::{Config, PeerConfig, Key};
use std::process::Command;
use std::time::Duration;
use ipnet::IpNet;
use std::str::FromStr;
#[test]
#[ignore]
fn test_configure_device() {
if unsafe { libc::geteuid() } != 0 {
eprintln!("SKIPPING: Root privileges required for integration test");
return;
}
let ifname = "wg_test_cfg";
let _ = Command::new("ip").args(["link", "del", ifname]).output();
let status = Command::new("ip")
.args(["link", "add", ifname, "type", "wireguard"])
.status();
if status.is_err() || !status.unwrap().success() {
eprintln!("SKIPPING: Could not create wireguard interface (kernel module missing?)");
return;
}
struct Guard<'a>(&'a str);
impl<'a> Drop for Guard<'a> {
fn drop(&mut self) {
let _ = Command::new("ip").args(["link", "del", self.0]).output();
}
}
let _guard = Guard(ifname);
let mut client = Client::new().expect("Failed to create Netlink Client");
let initial_dev = client.get_device(ifname).expect("Failed to get device initially");
assert_eq!(initial_dev.name, ifname);
let private_key = Key::generate_private_key().expect("Failed to generate private key");
let peer_key = Key::generate_private_key().expect("Failed to generate peer key");
let peer_pub = peer_key.public_key();
let psk = Key::generate_key().expect("Failed to generate PSK");
let peer = PeerConfig {
public_key: peer_pub,
remove: false,
update_only: false,
preshared_key: Some(psk),
endpoint: Some("127.0.0.1:51820".parse().unwrap()),
persistent_keepalive_interval: Some(Duration::from_secs(25)),
replace_allowed_ips: true,
allowed_ips: vec![IpNet::from_str("10.0.0.2/32").unwrap()],
};
let config = Config {
private_key: Some(private_key),
listen_port: Some(51821),
firewall_mark: Some(1234),
replace_peers: true,
peers: vec![peer],
};
client.configure_device(ifname, &config).expect("Failed to configure device");
let device = client.get_device(ifname).expect("Failed to get device after config");
assert_eq!(device.private_key, private_key, "Private key mismatch");
assert_eq!(device.listen_port, 51821, "Listen port mismatch");
assert_eq!(device.firewall_mark, 1234, "Firewall mark mismatch");
assert_eq!(device.peers.len(), 1, "Expected 1 peer");
let d_peer = &device.peers[0];
assert_eq!(d_peer.public_key, peer_pub, "Peer public key mismatch");
assert_eq!(d_peer.preshared_key, Some(psk), "PSK mismatch");
assert_eq!(d_peer.persistent_keepalive_interval, Some(Duration::from_secs(25)), "Keepalive mismatch");
assert_eq!(d_peer.allowed_ips.len(), 1, "Expected 1 allowed IP");
assert_eq!(d_peer.allowed_ips[0].to_string(), "10.0.0.2/32", "Allowed IP mismatch");
}
}