use std::{
fmt,
net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6},
ops::{Deref, DerefMut},
time::{Duration, SystemTime, UNIX_EPOCH},
};
use anyhow::{anyhow, Context, Result};
use base64::{engine::general_purpose, Engine as _};
use ipnet::IpNet;
use rand::Rng;
use rsln::types::message::RouteAttrIter;
use x25519_dalek::{PublicKey, StaticSecret};
use crate::constants::{WgAllowedIpAttr, WgDeviceAttr, WgPeerAttr};
const KEY_LEN: usize = 32;
const ATTR_TYPE_MAST: u16 = 0x3fff;
#[derive(Debug)]
pub enum DeviceType {
LinuxKernel,
Userspace,
Unknown,
}
impl fmt::Display for DeviceType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
DeviceType::LinuxKernel => write!(f, "linux-kernel"),
DeviceType::Userspace => write!(f, "userspace"),
DeviceType::Unknown => write!(f, "unknown"),
}
}
}
impl Default for DeviceType {
fn default() -> Self {
Self::LinuxKernel
}
}
#[derive(Default, PartialEq, Eq, Hash, Clone, Copy)]
pub struct Key([u8; KEY_LEN]);
impl fmt::Debug for Key {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Key([REDACTED])")
}
}
impl TryFrom<&[u8]> for Key {
type Error = anyhow::Error;
fn try_from(bytes: &[u8]) -> Result<Self> {
if bytes.len() != KEY_LEN {
return Err(anyhow!("Incorrect key size: {}", bytes.len()));
}
let mut key = [0; KEY_LEN];
key.copy_from_slice(bytes);
Ok(Self(key))
}
}
impl TryFrom<&str> for Key {
type Error = anyhow::Error;
fn try_from(s: &str) -> Result<Self> {
let bytes = general_purpose::STANDARD.decode(s)?;
Self::try_from(bytes.as_slice())
}
}
impl Deref for Key {
type Target = [u8; KEY_LEN];
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl DerefMut for Key {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
impl From<Key> for String {
fn from(val: Key) -> Self {
general_purpose::STANDARD.encode(*val)
}
}
impl Key {
pub fn generate_key() -> Result<Self> {
let mut key = [0; KEY_LEN];
let mut rng = rand::thread_rng();
rng.fill(&mut key);
Key::try_from(key.as_ref())
}
pub fn generate_private_key() -> Result<Self> {
let mut key = Key::generate_key()?;
key[0] &= 248;
key[31] &= 127;
key[31] |= 64;
Ok(key)
}
pub fn public_key(&self) -> Key {
let secret = StaticSecret::from(self.0);
let public: PublicKey = (&secret).into();
Self(*public.as_bytes())
}
pub fn exchange(&self, public_key: &Key) -> Key {
let secret = StaticSecret::from(self.0);
let public: PublicKey = (public_key.0).into();
let shared = secret.diffie_hellman(&public);
Self(*shared.as_bytes())
}
}
#[derive(Default, Debug)]
pub struct Peer {
pub public_key: Key,
pub preshared_key: Option<Key>,
pub endpoint: Option<SocketAddr>,
pub persistent_keepalive_interval: Option<Duration>,
pub last_handshake_time: Option<SystemTime>,
pub rx_bytes: i64,
pub tx_bytes: i64,
pub allowed_ips: Vec<IpNet>,
pub protocol_version: Option<u16>,
}
impl TryFrom<&[u8]> for Peer {
type Error = anyhow::Error;
fn try_from(data: &[u8]) -> Result<Self> {
let mut peer = Peer::default();
for attr in RouteAttrIter::new(data) {
let (kind, value) = attr?;
let kind = kind & ATTR_TYPE_MAST;
match kind {
k if k == WgPeerAttr::PublicKey as u16 => {
peer.public_key = Key::try_from(value)?;
}
k if k == WgPeerAttr::PresharedKey as u16 => {
peer.preshared_key = Some(Key::try_from(value)?);
}
k if k == WgPeerAttr::Endpoint as u16 => {
peer.endpoint = Some(parse_sockaddr(value)?);
}
k if k == WgPeerAttr::PersistentKeepalive as u16 => {
let secs = parse_u16(value)?;
peer.persistent_keepalive_interval = if secs > 0 {
Some(Duration::from_secs(secs as u64))
} else {
None
};
}
k if k == WgPeerAttr::LastHandshakeTime as u16 => {
peer.last_handshake_time = parse_timespec(value);
}
k if k == WgPeerAttr::RxBytes as u16 => {
peer.rx_bytes = parse_u64(value)? as i64;
}
k if k == WgPeerAttr::TxBytes as u16 => {
peer.tx_bytes = parse_u64(value)? as i64;
}
k if k == WgPeerAttr::AllowedIps as u16 => {
for ip_attr in RouteAttrIter::new(value) {
let (_, ip_payload) = ip_attr?;
let allowed_ip = parse_allowed_ip(ip_payload)?;
peer.allowed_ips.push(allowed_ip);
}
}
k if k == WgPeerAttr::ProtocolVersion as u16 => {
peer.protocol_version = Some(parse_u32(value)? as u16);
}
_ => {}
}
}
Ok(peer)
}
}
#[derive(Debug, Clone)]
pub struct PeerConfig {
pub public_key: Key,
pub remove: bool,
pub update_only: bool,
pub preshared_key: Option<Key>,
pub endpoint: Option<SocketAddr>,
pub persistent_keepalive_interval: Option<Duration>,
pub replace_allowed_ips: bool,
pub allowed_ips: Vec<IpNet>,
}
#[derive(Default, Debug)]
pub struct Device {
pub name: String,
pub device_type: DeviceType,
pub private_key: Key,
pub public_key: Key,
pub listen_port: u16,
pub firewall_mark: u32,
pub peers: Vec<Peer>,
}
impl TryFrom<&[u8]> for Device {
type Error = anyhow::Error;
fn try_from(payload: &[u8]) -> Result<Self> {
if payload.len() < 4 {
return Err(anyhow!("Short payload"));
}
let buf = &payload[4..]; let mut dev = Device::default();
for attr in RouteAttrIter::new(buf) {
let (kind, value) = attr?;
let kind = kind & ATTR_TYPE_MAST;
match kind {
k if k == WgDeviceAttr::IfName as u16 => {
dev.name = parse_string(value);
}
k if k == WgDeviceAttr::PrivateKey as u16 => {
dev.private_key = Key::try_from(value)?;
}
k if k == WgDeviceAttr::PublicKey as u16 => {
dev.public_key = Key::try_from(value)?;
}
k if k == WgDeviceAttr::ListenPort as u16 => {
dev.listen_port = parse_u16(value)?;
}
k if k == WgDeviceAttr::Fwmark as u16 => {
dev.firewall_mark = parse_u32(value)?;
}
k if k == WgDeviceAttr::Peers as u16 => {
for peer_attr in RouteAttrIter::new(value) {
let (_, peer_payload) = peer_attr?;
let peer = Peer::try_from(peer_payload)?;
dev.peers.push(peer);
}
}
_ => {}
}
}
Ok(dev)
}
}
#[derive(Debug, Clone)]
pub struct Config {
pub private_key: Option<Key>,
pub listen_port: Option<u16>,
pub firewall_mark: Option<u32>,
pub replace_peers: bool,
pub peers: Vec<PeerConfig>,
}
fn parse_string(data: &[u8]) -> String {
let trimmed = match data.iter().position(|&c| c == 0) {
Some(pos) => &data[..pos],
None => data,
};
String::from_utf8_lossy(trimmed).into_owned()
}
fn parse_sockaddr(data: &[u8]) -> Result<SocketAddr> {
if data.len() < 2 {
return Err(anyhow!("Sockaddr data too short"));
}
let family = u16::from_ne_bytes([data[0], data[1]]);
match family as i32 {
libc::AF_INET => {
if data.len() < 8 {
return Err(anyhow!("IPv4 sockaddr too short"));
}
let port = u16::from_be_bytes([data[2], data[3]]);
let ip_bytes: [u8; 4] = data[4..8]
.try_into()
.map_err(|_| anyhow!("Failed to convert IPv4 address bytes"))?;
Ok(SocketAddr::V4(SocketAddrV4::new(
Ipv4Addr::from(ip_bytes),
port,
)))
}
libc::AF_INET6 => {
if data.len() < 28 {
return Err(anyhow!("IPv6 sockaddr too short"));
}
let port = u16::from_be_bytes([data[2], data[3]]);
let ip_bytes: [u8; 16] = data[8..24]
.try_into()
.map_err(|_| anyhow!("Failed to convert IPv6 address bytes"))?;
let scope_id = u32::from_ne_bytes([data[24], data[25], data[26], data[27]]);
Ok(SocketAddr::V6(SocketAddrV6::new(
Ipv6Addr::from(ip_bytes),
port,
0, scope_id,
)))
}
_ => Err(anyhow!("Unsupported address family: {}", family)),
}
}
fn parse_allowed_ip(data: &[u8]) -> Result<IpNet> {
let mut ip_addr = None;
let mut cidr = 0u8;
for attr in RouteAttrIter::new(data) {
let (raw_kind, value) = attr?;
let kind = raw_kind & 0x3fff;
match kind {
k if k == WgAllowedIpAttr::IpAddr as u16 => {
if value.len() == 4 {
let b: [u8; 4] = value
.try_into()
.map_err(|_| anyhow!("Invalid IPv4 length inside AllowedIPs"))?;
ip_addr = Some(std::net::IpAddr::V4(Ipv4Addr::from(b)));
} else if value.len() == 16 {
let b: [u8; 16] = value
.try_into()
.map_err(|_| anyhow!("Invalid IPv6 length inside AllowedIPs"))?;
ip_addr = Some(std::net::IpAddr::V6(Ipv6Addr::from(b)));
}
}
k if k == WgAllowedIpAttr::CidrMask as u16 => {
if !value.is_empty() {
cidr = value[0];
}
}
_ => {}
}
}
let ip = ip_addr.ok_or_else(|| anyhow!("Missing IP address in AllowedIPs"))?;
IpNet::new(ip, cidr).context("Invalid CIDR")
}
fn parse_timespec(data: &[u8]) -> Option<SystemTime> {
if data.len() < 16 {
return None;
}
let mut sec_bytes = [0u8; 8];
sec_bytes.copy_from_slice(&data[0..8]);
let sec = i64::from_ne_bytes(sec_bytes);
let mut nsec_bytes = [0u8; 8];
nsec_bytes.copy_from_slice(&data[8..16]);
let nsec = i64::from_ne_bytes(nsec_bytes);
if sec == 0 && nsec == 0 {
return None;
}
Some(UNIX_EPOCH + Duration::new(sec as u64, nsec as u32))
}
fn parse_u16(data: &[u8]) -> Result<u16> {
if data.len() < 2 {
return Err(anyhow!("u16 data too short"));
}
Ok(u16::from_ne_bytes([data[0], data[1]]))
}
fn parse_u32(data: &[u8]) -> Result<u32> {
if data.len() < 4 {
return Err(anyhow!("u32 data too short"));
}
Ok(u32::from_ne_bytes([data[0], data[1], data[2], data[3]]))
}
fn parse_u64(data: &[u8]) -> Result<u64> {
if data.len() < 8 {
return Err(anyhow!("u64 data too short"));
}
Ok(u64::from_ne_bytes(
data[0..8]
.try_into()
.map_err(|_| anyhow!("Failed to parse u64"))?,
))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_prepared_keys() {
let private = "GHuMwljFfqd2a7cs6BaUOmHflK23zME8VNvC5B37S3k=";
let public = "aPxGwq8zERHQ3Q1cOZFdJ+cvJX5Ka4mLN38AyYKYF10=";
let priv_key = Key::try_from(private).unwrap();
let public_key = priv_key.public_key();
assert_eq!(private, Into::<String>::into(priv_key));
assert_eq!(public, Into::<String>::into(public_key));
}
#[test]
fn test_key_exchange() {
let alice = Key::generate_private_key().unwrap();
let bob = Key::generate_private_key().unwrap();
let alice_pub = alice.public_key();
let bob_pub = bob.public_key();
let alice_shared = alice.exchange(&bob_pub);
let bob_shared = bob.exchange(&alice_pub);
assert_eq!(*alice_shared, *bob_shared);
}
#[cfg(target_os = "linux")]
#[test]
#[ignore]
fn test_linux_integration_full() {
use std::process::Command;
let euid = unsafe { libc::geteuid() };
if euid != 0 {
eprintln!("SKIPPING: Root privileges required for integration test");
return;
}
let test_ifname = "wg_test_tmp";
let _ = Command::new("ip").args(["link", "del", test_ifname]).output();
let status = Command::new("ip")
.args(["link", "add", test_ifname, "type", "wireguard"])
.status();
if status.is_err() || !status.unwrap().success() {
eprintln!("SKIPPING: Could not create wireguard interface (kernel module missing?)");
return;
}
struct InterfaceGuard<'a>(&'a str);
impl<'a> Drop for InterfaceGuard<'a> {
fn drop(&mut self) {
let _ = Command::new("ip")
.args(["link", "del", self.0])
.output();
}
}
let _guard = InterfaceGuard(test_ifname);
let mut client = crate::client::Client::new().expect("Failed to create Netlink Client");
let device = client
.get_device(test_ifname)
.expect("Should find created device");
assert_eq!(device.name, test_ifname);
println!("Success: Found existing device '{}'", device.name);
let err = client.get_device("wg_imaginary_99");
assert!(err.is_err());
println!("Success: Correctly failed to find non-existent device");
let err_loopback = client.get_device("lo");
assert!(err_loopback.is_err());
println!("Success: Correctly refused non-WireGuard interface 'lo'");
}
#[cfg(target_os = "linux")]
#[test]
#[ignore]
fn test_linux_client_is_permission() {
let euid = unsafe { libc::geteuid() };
if euid == 0 {
println!("SKIPPING: Test must be run without elevated privileges (uid != 0)");
return;
}
let mut client = match crate::client::Client::new() {
Ok(c) => c,
Err(e) => {
println!("Skipping: Failed to create client (generic netlink not available?): {}", e);
return;
}
};
let err = client.get_device("wgnotexist0");
match err {
Ok(_) => panic!("Expected error, got success"),
Err(e) => {
let msg = e.to_string();
if msg.to_lowercase().contains("permission denied") || msg.to_lowercase().contains("operation not permitted") {
println!("Success: Got permission error: {}", msg);
} else {
panic!("expected permission denied, but got: {}", msg);
}
}
}
}
#[cfg(target_os = "linux")]
#[test]
#[ignore]
fn test_linux_client_devices_empty() {
use std::process::Command;
let euid = unsafe { libc::geteuid() };
if euid != 0 {
eprintln!("SKIPPING: Root privileges required");
return;
}
let output = Command::new("ip")
.args(["-o", "link", "show", "type", "wireguard"])
.output()
.expect("Failed to run ip link");
if !output.stdout.is_empty() {
eprintln!("SKIPPING: Existing WireGuard interfaces found. Cannot strictly test 'Empty' case without potentially affecting system state.");
return;
}
let mut client = match crate::client::Client::new() {
Ok(c) => c,
Err(e) => {
eprintln!("Skipping: Failed to create client: {}", e);
return;
}
};
let devices = client.list_devices().expect("Failed to list devices");
assert!(devices.is_empty(), "Expected no devices, got {}", devices.len());
println!("Success: list_devices returned empty list as expected");
}
}