use std::collections::HashMap;
use std::net::SocketAddr;
use std::time::{Duration, SystemTime};
use super::types::{AllowedIp, WG_KEY_LEN, WgDevice, WgPeer};
use crate::netlink::protocol::Wireguard;
use crate::{Connection, Error, Result};
#[derive(Debug, Clone)]
#[non_exhaustive]
pub enum WireguardEvent {
PeerAdded {
ifname: String,
peer: WgPeer,
},
PeerRemoved {
ifname: String,
public_key: [u8; WG_KEY_LEN],
},
PeerEndpointChanged {
ifname: String,
public_key: [u8; WG_KEY_LEN],
from: Option<SocketAddr>,
to: Option<SocketAddr>,
},
PeerHandshakeRefreshed {
ifname: String,
public_key: [u8; WG_KEY_LEN],
at: SystemTime,
},
PeerAllowedIpsChanged {
ifname: String,
public_key: [u8; WG_KEY_LEN],
previous: Vec<AllowedIp>,
current: Vec<AllowedIp>,
},
}
#[derive(Debug, Clone)]
#[non_exhaustive]
#[must_use = "options do nothing unless passed to WireguardWatcher::new"]
pub struct WireguardWatchOptions {
pub interval: Duration,
pub interfaces: Vec<String>,
}
impl Default for WireguardWatchOptions {
fn default() -> Self {
Self {
interval: Duration::from_secs(1),
interfaces: Vec::new(),
}
}
}
impl WireguardWatchOptions {
pub fn interval(mut self, d: Duration) -> Self {
self.interval = d;
self
}
pub fn interface(mut self, ifname: impl Into<String>) -> Self {
self.interfaces.push(ifname.into());
self
}
}
pub fn diff_device_states(
ifname: &str,
previous: Option<&WgDevice>,
current: &WgDevice,
) -> Vec<WireguardEvent> {
let mut out = Vec::new();
let prev_peers: HashMap<&[u8; WG_KEY_LEN], &WgPeer> = previous
.map(|d| d.peers.iter().map(|p| (&p.public_key, p)).collect())
.unwrap_or_default();
let curr_peers: HashMap<&[u8; WG_KEY_LEN], &WgPeer> =
current.peers.iter().map(|p| (&p.public_key, p)).collect();
for (pk, peer) in &curr_peers {
if !prev_peers.contains_key(pk) {
out.push(WireguardEvent::PeerAdded {
ifname: ifname.to_string(),
peer: (*peer).clone(),
});
}
}
for pk in prev_peers.keys() {
if !curr_peers.contains_key(pk) {
out.push(WireguardEvent::PeerRemoved {
ifname: ifname.to_string(),
public_key: **pk,
});
}
}
for (pk, curr) in &curr_peers {
let Some(prev) = prev_peers.get(pk) else {
continue;
};
if prev.endpoint != curr.endpoint {
out.push(WireguardEvent::PeerEndpointChanged {
ifname: ifname.to_string(),
public_key: **pk,
from: prev.endpoint,
to: curr.endpoint,
});
}
match (prev.last_handshake, curr.last_handshake) {
(Some(p), Some(c)) if c > p => {
out.push(WireguardEvent::PeerHandshakeRefreshed {
ifname: ifname.to_string(),
public_key: **pk,
at: c,
});
}
(None, Some(c)) => {
out.push(WireguardEvent::PeerHandshakeRefreshed {
ifname: ifname.to_string(),
public_key: **pk,
at: c,
});
}
_ => {}
}
if prev.allowed_ips != curr.allowed_ips {
out.push(WireguardEvent::PeerAllowedIpsChanged {
ifname: ifname.to_string(),
public_key: **pk,
previous: prev.allowed_ips.clone(),
current: curr.allowed_ips.clone(),
});
}
}
out
}
#[must_use = "WireguardWatcher does nothing unless next_events() is called"]
pub struct WireguardWatcher {
conn: Connection<Wireguard>,
opts: WireguardWatchOptions,
previous: HashMap<String, WgDevice>,
first_poll: bool,
}
impl WireguardWatcher {
pub fn new(conn: Connection<Wireguard>, opts: WireguardWatchOptions) -> Result<Self> {
if opts.interfaces.is_empty() {
return Err(Error::InvalidMessage(
"WireguardWatchOptions::interfaces is empty — \
specify at least one interface to watch"
.to_string(),
));
}
Ok(Self {
conn,
opts,
previous: HashMap::new(),
first_poll: true,
})
}
pub async fn next_events(&mut self) -> Result<Vec<WireguardEvent>> {
if !self.first_poll {
tokio::time::sleep(self.opts.interval).await;
}
self.first_poll = false;
let mut all_events = Vec::new();
for ifname in self.opts.interfaces.clone() {
match self.conn.get_device_by_name(&ifname).await {
Ok(device) => {
let prev = self.previous.get(&ifname);
let events = diff_device_states(&ifname, prev, &device);
all_events.extend(events);
self.previous.insert(ifname, device);
}
Err(e) => {
tracing::warn!(
ifname = %ifname,
error = %e,
"WireguardWatcher: failed to poll interface; emitting PeerRemoved for any tracked peers and continuing",
);
if let Some(prev_device) = self.previous.remove(&ifname) {
for peer in &prev_device.peers {
all_events.push(WireguardEvent::PeerRemoved {
ifname: ifname.clone(),
public_key: peer.public_key,
});
}
}
}
}
}
Ok(all_events)
}
pub fn connection(&self) -> &Connection<Wireguard> {
&self.conn
}
pub fn into_connection(self) -> Connection<Wireguard> {
self.conn
}
}
#[cfg(test)]
mod tests {
use std::net::{IpAddr, Ipv4Addr};
use std::time::{Duration, UNIX_EPOCH};
use super::*;
fn key(byte: u8) -> [u8; WG_KEY_LEN] {
[byte; WG_KEY_LEN]
}
fn peer(pk: [u8; WG_KEY_LEN]) -> WgPeer {
WgPeer::new(pk)
}
fn device(peers: Vec<WgPeer>) -> WgDevice {
let mut d = WgDevice::new();
d.ifname = Some("wg0".to_string());
d.peers = peers;
d
}
#[test]
fn first_poll_emits_peer_added_for_every_existing_peer() {
let current = device(vec![peer(key(1)), peer(key(2))]);
let events = diff_device_states("wg0", None, ¤t);
assert_eq!(events.len(), 2);
assert!(matches!(
events[0],
WireguardEvent::PeerAdded { .. }
));
}
#[test]
fn no_change_emits_empty() {
let snap = device(vec![peer(key(1))]);
let events = diff_device_states("wg0", Some(&snap), &snap);
assert!(events.is_empty());
}
#[test]
fn new_peer_emits_peer_added() {
let prev = device(vec![peer(key(1))]);
let curr = device(vec![peer(key(1)), peer(key(2))]);
let events = diff_device_states("wg0", Some(&prev), &curr);
assert_eq!(events.len(), 1);
match &events[0] {
WireguardEvent::PeerAdded { peer, .. } => assert_eq!(peer.public_key, key(2)),
other => panic!("expected PeerAdded, got {other:?}"),
}
}
#[test]
fn removed_peer_emits_peer_removed() {
let prev = device(vec![peer(key(1)), peer(key(2))]);
let curr = device(vec![peer(key(1))]);
let events = diff_device_states("wg0", Some(&prev), &curr);
assert_eq!(events.len(), 1);
match &events[0] {
WireguardEvent::PeerRemoved { public_key, .. } => assert_eq!(*public_key, key(2)),
other => panic!("expected PeerRemoved, got {other:?}"),
}
}
#[test]
fn endpoint_change_emits_endpoint_changed() {
let mut prev_peer = peer(key(1));
prev_peer.endpoint = Some(SocketAddr::new(IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1)), 51820));
let mut curr_peer = peer(key(1));
curr_peer.endpoint = Some(SocketAddr::new(IpAddr::V4(Ipv4Addr::new(2, 2, 2, 2)), 51820));
let prev = device(vec![prev_peer]);
let curr = device(vec![curr_peer]);
let events = diff_device_states("wg0", Some(&prev), &curr);
assert_eq!(events.len(), 1);
assert!(matches!(
events[0],
WireguardEvent::PeerEndpointChanged { .. }
));
}
#[test]
fn handshake_advance_emits_refresh() {
let t0 = UNIX_EPOCH + Duration::from_secs(1_000);
let t1 = UNIX_EPOCH + Duration::from_secs(2_000);
let mut prev_peer = peer(key(1));
prev_peer.last_handshake = Some(t0);
let mut curr_peer = peer(key(1));
curr_peer.last_handshake = Some(t1);
let prev = device(vec![prev_peer]);
let curr = device(vec![curr_peer]);
let events = diff_device_states("wg0", Some(&prev), &curr);
assert_eq!(events.len(), 1);
match &events[0] {
WireguardEvent::PeerHandshakeRefreshed { at, .. } => assert_eq!(*at, t1),
other => panic!("expected PeerHandshakeRefreshed, got {other:?}"),
}
}
#[test]
fn handshake_same_emits_nothing() {
let t = UNIX_EPOCH + Duration::from_secs(1_000);
let mut p = peer(key(1));
p.last_handshake = Some(t);
let snap = device(vec![p]);
let events = diff_device_states("wg0", Some(&snap), &snap);
assert!(events.is_empty());
}
#[test]
fn first_ever_handshake_emits_refresh() {
let t = UNIX_EPOCH + Duration::from_secs(1_000);
let prev_peer = peer(key(1));
let mut curr_peer = peer(key(1));
curr_peer.last_handshake = Some(t);
let prev = device(vec![prev_peer]);
let curr = device(vec![curr_peer]);
let events = diff_device_states("wg0", Some(&prev), &curr);
assert_eq!(events.len(), 1);
assert!(matches!(
events[0],
WireguardEvent::PeerHandshakeRefreshed { .. }
));
}
#[test]
fn allowed_ips_change_emits_event() {
let mut prev_peer = peer(key(1));
prev_peer.allowed_ips = vec![AllowedIp::v4(Ipv4Addr::new(10, 0, 0, 0), 24)];
let mut curr_peer = peer(key(1));
curr_peer.allowed_ips = vec![
AllowedIp::v4(Ipv4Addr::new(10, 0, 0, 0), 24),
AllowedIp::v4(Ipv4Addr::new(10, 0, 1, 0), 24),
];
let prev = device(vec![prev_peer]);
let curr = device(vec![curr_peer]);
let events = diff_device_states("wg0", Some(&prev), &curr);
assert_eq!(events.len(), 1);
assert!(matches!(
events[0],
WireguardEvent::PeerAllowedIpsChanged { .. }
));
}
#[test]
fn watch_options_defaults() {
let opts = WireguardWatchOptions::default();
assert_eq!(opts.interval, Duration::from_secs(1));
assert!(opts.interfaces.is_empty());
}
#[test]
fn watch_options_builder() {
let opts = WireguardWatchOptions::default()
.interval(Duration::from_secs(30))
.interface("wg0")
.interface("wg1");
assert_eq!(opts.interval, Duration::from_secs(30));
assert_eq!(opts.interfaces, vec!["wg0", "wg1"]);
}
}