use crate::interface::platform_ops;
use crate::{config::OverlayConfig, PeerInfo};
#[cfg(not(windows))]
use boringtun::device::{DeviceConfig, DeviceHandle};
use std::fmt::Write;
#[cfg(not(windows))]
use tokio::io::{AsyncReadExt, AsyncWriteExt};
#[cfg(not(windows))]
use tokio::net::UnixStream;
#[cfg(windows)]
use crate::tun::WindowsTun;
#[cfg(windows)]
use boringtun::noise::{Tunn, TunnResult};
#[cfg(windows)]
use dashmap::DashMap;
#[cfg(windows)]
use parking_lot::RwLock;
#[cfg(windows)]
use std::net::{IpAddr, SocketAddr};
#[cfg(windows)]
use std::sync::atomic::{AtomicU64, Ordering};
#[cfg(windows)]
use std::sync::Arc;
#[cfg(windows)]
use std::time::{Duration, SystemTime, UNIX_EPOCH};
#[cfg(windows)]
use tokio::net::UdpSocket;
#[cfg(windows)]
use tokio::sync::Mutex as AsyncMutex;
#[cfg(windows)]
use tokio::task::JoinHandle;
#[cfg(not(windows))]
fn key_to_hex(base64_key: &str) -> Result<String, Box<dyn std::error::Error>> {
use base64::{engine::general_purpose::STANDARD, Engine as _};
let bytes = STANDARD.decode(base64_key)?;
if bytes.len() != 32 {
return Err(format!("Invalid key length: expected 32 bytes, got {}", bytes.len()).into());
}
Ok(hex::encode(bytes))
}
#[cfg(not(windows))]
async fn uapi_set(sock_path: &str, body: &str) -> Result<(), Box<dyn std::error::Error>> {
let mut stream = UnixStream::connect(sock_path).await?;
let msg = format!("set=1\n{body}\n");
stream.write_all(msg.as_bytes()).await?;
stream.shutdown().await?;
let mut response = String::new();
stream.read_to_string(&mut response).await?;
if response.contains("errno=0") {
Ok(())
} else {
Err(format!("UAPI set failed: {}", response.trim()).into())
}
}
#[cfg(not(windows))]
async fn uapi_get(sock_path: &str) -> Result<String, Box<dyn std::error::Error>> {
let mut stream = UnixStream::connect(sock_path).await?;
stream.write_all(b"get=1\n\n").await?;
stream.shutdown().await?;
let mut response = String::new();
stream.read_to_string(&mut response).await?;
Ok(response)
}
#[cfg(windows)]
#[derive(Clone)]
struct WindowsPeerState {
tunn: Arc<AsyncMutex<Tunn>>,
endpoint: Arc<RwLock<Option<SocketAddr>>>,
last_handshake_sec: Arc<AtomicU64>,
allowed_ips: Arc<Vec<ipnet::IpNet>>,
persistent_keepalive: Option<u16>,
}
#[cfg(windows)]
fn decode_key_b64(b64: &str) -> Result<[u8; 32], Box<dyn std::error::Error>> {
use base64::{engine::general_purpose::STANDARD, Engine as _};
let bytes = STANDARD.decode(b64)?;
if bytes.len() != 32 {
return Err(format!(
"invalid WireGuard key length: expected 32 bytes, got {}",
bytes.len()
)
.into());
}
let mut out = [0u8; 32];
out.copy_from_slice(&bytes);
Ok(out)
}
#[cfg(windows)]
fn parse_dst_ip(packet: &[u8]) -> Option<IpAddr> {
if packet.is_empty() {
return None;
}
match packet[0] >> 4 {
4 if packet.len() >= 20 => {
let b: [u8; 4] = packet[16..20].try_into().ok()?;
Some(IpAddr::from(b))
}
6 if packet.len() >= 40 => {
let b: [u8; 16] = packet[24..40].try_into().ok()?;
Some(IpAddr::from(b))
}
_ => None,
}
}
#[cfg(windows)]
fn build_tunn(
our_priv: &[u8; 32],
peer_pub: &[u8; 32],
preshared: Option<[u8; 32]>,
persistent_keepalive: Option<u16>,
) -> Tunn {
let priv_secret = boringtun::x25519::StaticSecret::from(*our_priv);
let peer_pub_key = boringtun::x25519::PublicKey::from(*peer_pub);
Tunn::new(
priv_secret,
peer_pub_key,
preshared,
persistent_keepalive,
0,
None,
)
}
pub struct OverlayTransport {
config: OverlayConfig,
interface_name: String,
#[cfg(not(windows))]
device: Option<DeviceHandle>,
#[cfg(windows)]
wintun_dev: Option<Arc<WindowsTun>>,
#[cfg(windows)]
udp: Option<Arc<UdpSocket>>,
#[cfg(windows)]
peers: Arc<DashMap<[u8; 32], WindowsPeerState>>,
#[cfg(windows)]
ingress_task: Option<JoinHandle<()>>,
#[cfg(windows)]
egress_task: Option<JoinHandle<()>>,
#[cfg(windows)]
timers_task: Option<JoinHandle<()>>,
}
impl OverlayTransport {
#[must_use]
pub fn new(config: OverlayConfig, interface_name: String) -> Self {
Self {
config,
interface_name,
#[cfg(not(windows))]
device: None,
#[cfg(windows)]
wintun_dev: None,
#[cfg(windows)]
udp: None,
#[cfg(windows)]
peers: Arc::new(DashMap::new()),
#[cfg(windows)]
ingress_task: None,
#[cfg(windows)]
egress_task: None,
#[cfg(windows)]
timers_task: None,
}
}
#[must_use]
pub fn interface_name(&self) -> &str {
&self.interface_name
}
#[cfg(not(windows))]
fn uapi_sock_path(&self) -> String {
format!("/var/run/wireguard/{}.sock", self.interface_name)
}
pub async fn create_interface(&mut self) -> Result<(), Box<dyn std::error::Error>> {
#[cfg(windows)]
{
self.create_interface_windows().await
}
#[cfg(not(windows))]
{
self.create_interface_unix().await
}
}
#[cfg(not(windows))]
async fn create_interface_unix(&mut self) -> Result<(), Box<dyn std::error::Error>> {
#[cfg(not(target_os = "macos"))]
if self.interface_name.len() > 15 {
return Err(format!(
"Interface name '{}' exceeds 15 character limit",
self.interface_name
)
.into());
}
tokio::fs::create_dir_all("/var/run/wireguard").await?;
#[cfg(target_os = "linux")]
{
let iface_ops = platform_ops();
match iface_ops.link_exists(&self.interface_name).await {
Ok(true) => {
return Err(format!(
"Kernel link '{}' already exists; refusing to delete it. \
If this is a stale interface from a previous crash, restart \
the daemon (its boot-time sweep clears stale zl-* / veth-* \
links). If this fires during normal operation, there is a \
duplicate-name bug somewhere in the overlay setup path.",
self.interface_name
)
.into());
}
Ok(false) => {}
Err(e) => {
tracing::warn!(
interface = %self.interface_name,
error = %e,
"failed to probe for existing overlay interface; proceeding"
);
}
}
}
let sock_path = format!("/var/run/wireguard/{}.sock", self.interface_name);
if tokio::fs::try_exists(&sock_path).await.unwrap_or(false) {
tracing::warn!(path = %sock_path, "removing stale UAPI socket");
let _ = tokio::fs::remove_file(&sock_path).await;
}
#[cfg(target_os = "macos")]
let existing_socks = {
let mut set = std::collections::HashSet::new();
if let Ok(mut entries) = tokio::fs::read_dir("/var/run/wireguard").await {
while let Ok(Some(entry)) = entries.next_entry().await {
set.insert(entry.file_name().to_string_lossy().to_string());
}
}
set
};
#[cfg(target_os = "macos")]
let name = "utun".to_string();
#[cfg(not(target_os = "macos"))]
let name = self.interface_name.clone();
let cfg = DeviceConfig {
n_threads: 2,
use_connected_socket: true,
#[cfg(target_os = "linux")]
use_multi_queue: false,
#[cfg(target_os = "linux")]
uapi_fd: -1,
};
let iface_name_for_err = self.interface_name.clone();
let handle = tokio::task::spawn_blocking(move || DeviceHandle::new(&name, cfg))
.await
.map_err(|e| format!("spawn_blocking join error: {e}"))?
.map_err(|e| {
#[cfg(target_os = "macos")]
let hint = "Requires root. Run with sudo or install as a system service (zlayer daemon install).";
#[cfg(not(target_os = "macos"))]
let hint = "Ensure CAP_NET_ADMIN capability is available.";
format!("Failed to create boringtun device '{iface_name_for_err}': {e}. {hint}")
})?;
self.device = Some(handle);
#[cfg(target_os = "macos")]
{
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
if let Ok(mut entries) = tokio::fs::read_dir("/var/run/wireguard").await {
while let Ok(Some(entry)) = entries.next_entry().await {
let fname = entry.file_name().to_string_lossy().to_string();
if !existing_socks.contains(&fname)
&& fname.starts_with("utun")
&& std::path::Path::new(&fname)
.extension()
.is_some_and(|ext| ext.eq_ignore_ascii_case("sock"))
{
self.interface_name = fname.trim_end_matches(".sock").to_string();
break;
}
}
}
}
tracing::info!(
interface = %self.interface_name,
"Created boringtun overlay transport"
);
Ok(())
}
#[cfg(windows)]
async fn create_interface_windows(&mut self) -> Result<(), Box<dyn std::error::Error>> {
if self.interface_name.len() > 64 {
return Err(format!(
"Wintun adapter name '{}' exceeds 64 character limit",
self.interface_name
)
.into());
}
let iface_name = self.interface_name.clone();
let mtu = 1420;
let dev = tokio::task::spawn_blocking(move || WindowsTun::new(&iface_name, mtu))
.await
.map_err(|e| format!("spawn_blocking join error: {e}"))??;
tracing::info!(
interface = %self.interface_name,
luid = dev.luid_value(),
"Created Wintun overlay adapter"
);
self.wintun_dev = Some(Arc::new(dev));
Ok(())
}
pub async fn configure(
&mut self,
peers: &[PeerInfo],
) -> Result<(), Box<dyn std::error::Error>> {
#[cfg(not(windows))]
{
let sock = self.uapi_sock_path();
let private_key_hex = key_to_hex(&self.config.private_key)?;
let mut body = format!(
"private_key={}\nlisten_port={}\n",
private_key_hex,
self.config.local_endpoint.port(),
);
for peer in peers {
let pub_hex = key_to_hex(&peer.public_key)?;
let _ = writeln!(body, "public_key={pub_hex}");
let _ = writeln!(body, "endpoint={}", peer.endpoint);
let _ = writeln!(body, "allowed_ip={}", peer.allowed_ips);
let _ = writeln!(
body,
"persistent_keepalive_interval={}",
peer.persistent_keepalive_interval.as_secs()
);
}
uapi_set(&sock, &body).await?;
tracing::debug!(interface = %self.interface_name, "Applied UAPI configuration");
self.configure_interface().await?;
tracing::info!(interface = %self.interface_name, "Overlay transport configured and up");
Ok(())
}
#[cfg(windows)]
{
self.configure_windows(peers).await
}
}
#[cfg(windows)]
async fn configure_windows(
&mut self,
peers: &[PeerInfo],
) -> Result<(), Box<dyn std::error::Error>> {
self.configure_interface().await?;
if let Some(ref cluster_cidr_str) = self.config.cluster_cidr {
match cluster_cidr_str.parse::<ipnet::IpNet>() {
Ok(net) => {
use crate::interface::windows::WindowsIpHelperOps;
use crate::interface::InterfaceOps;
let ops = WindowsIpHelperOps::new();
let adapter_name = self.interface_name.clone();
match ops
.add_route_via_dev(net.network(), net.prefix_len(), &adapter_name)
.await
{
Ok(()) => {
tracing::info!(
cidr = %net,
adapter = %adapter_name,
"Installed cluster-CIDR host route via Wintun adapter"
);
}
Err(e) => {
tracing::warn!(
error = %e,
cidr = %net,
adapter = %adapter_name,
"Failed to install cluster-CIDR host route via Wintun (overlay traffic may not route across nodes); route may already exist"
);
}
}
}
Err(e) => {
tracing::warn!(
error = %e,
cidr = %cluster_cidr_str,
"cluster_cidr unparseable; skipping Wintun route install"
);
}
}
} else {
tracing::warn!(
"cluster_cidr not set in OverlayConfig; skipping Wintun route install (cross-node overlay traffic may not route)"
);
}
let tun = self
.wintun_dev
.as_ref()
.ok_or("Wintun adapter not initialized — call create_interface first")?
.clone();
let listen = self.config.local_endpoint;
let udp = Arc::new(
UdpSocket::bind(listen)
.await
.map_err(|e| format!("failed to bind WireGuard UDP socket on {listen}: {e}"))?,
);
self.udp = Some(udp.clone());
let priv_bytes = decode_key_b64(&self.config.private_key)?;
for peer in peers {
self.add_peer_windows(&priv_bytes, peer)?;
}
let peers_ingress = self.peers.clone();
let udp_ingress = udp.clone();
let tun_ingress = tun.clone();
self.ingress_task = Some(tokio::spawn(async move {
Self::ingress_loop(udp_ingress, tun_ingress, peers_ingress).await;
}));
let peers_egress = self.peers.clone();
let udp_egress = udp.clone();
let tun_egress = tun.clone();
self.egress_task = Some(tokio::spawn(async move {
Self::egress_loop(tun_egress, udp_egress, peers_egress).await;
}));
let peers_timers = self.peers.clone();
let udp_timers = udp.clone();
self.timers_task = Some(tokio::spawn(async move {
Self::timers_loop(udp_timers, peers_timers).await;
}));
tracing::info!(
interface = %self.interface_name,
peer_count = peers.len(),
listen = %listen,
"Windows overlay transport configured (Tunn pipeline online)"
);
Ok(())
}
#[cfg(windows)]
fn add_peer_windows(
&self,
our_priv: &[u8; 32],
peer: &PeerInfo,
) -> Result<(), Box<dyn std::error::Error>> {
let peer_pub = decode_key_b64(&peer.public_key)?;
let allowed: ipnet::IpNet = peer
.allowed_ips
.parse()
.map_err(|e| format!("invalid allowed_ips '{}': {e}", peer.allowed_ips))?;
let keepalive = {
let secs = peer.persistent_keepalive_interval.as_secs();
if secs == 0 {
None
} else {
u16::try_from(secs).ok()
}
};
let tunn = build_tunn(our_priv, &peer_pub, None, keepalive);
let state = WindowsPeerState {
tunn: Arc::new(AsyncMutex::new(tunn)),
endpoint: Arc::new(RwLock::new(Some(peer.endpoint))),
last_handshake_sec: Arc::new(AtomicU64::new(0)),
allowed_ips: Arc::new(vec![allowed]),
persistent_keepalive: keepalive,
};
self.peers.insert(peer_pub, state);
tracing::debug!(
peer_key = %peer.public_key,
endpoint = %peer.endpoint,
allowed = %peer.allowed_ips,
"Added peer to Windows overlay peer map"
);
Ok(())
}
#[cfg(windows)]
async fn ingress_loop(
udp: Arc<UdpSocket>,
tun: Arc<WindowsTun>,
peers: Arc<DashMap<[u8; 32], WindowsPeerState>>,
) {
let mut inbuf = vec![0u8; 65536];
loop {
let (n, src) = match udp.recv_from(&mut inbuf).await {
Ok(p) => p,
Err(e) => {
tracing::error!(error = %e, "UDP recv failed; ingress loop exiting");
break;
}
};
let snapshot: Vec<([u8; 32], WindowsPeerState)> = peers
.iter()
.map(|e| (*e.key(), e.value().clone()))
.collect();
for (pk, state) in snapshot {
let mut out = vec![0u8; 65536];
let mut handled = false;
{
let mut tunn = state.tunn.lock().await;
match tunn.decapsulate(Some(src.ip()), &inbuf[..n], &mut out) {
TunnResult::WriteToTunnelV4(pkt, _)
| TunnResult::WriteToTunnelV6(pkt, _) => {
let pkt_owned = pkt.to_vec();
drop(tunn);
if let Err(e) = tun.send(&pkt_owned).await {
tracing::warn!(error = %e, "Wintun send failed");
}
*state.endpoint.write() = Some(src);
state.last_handshake_sec.store(
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
Ordering::Relaxed,
);
handled = true;
}
TunnResult::WriteToNetwork(resp) => {
let resp_owned = resp.to_vec();
drop(tunn);
if let Err(e) = udp.send_to(&resp_owned, src).await {
tracing::warn!(error = %e, "UDP reply send failed");
}
*state.endpoint.write() = Some(src);
handled = true;
}
TunnResult::Done | TunnResult::Err(_) => {
}
}
}
if handled {
loop {
let mut drain = vec![0u8; 65536];
let mut tunn = state.tunn.lock().await;
match tunn.decapsulate(None, &[], &mut drain) {
TunnResult::WriteToNetwork(resp) => {
let resp_owned = resp.to_vec();
drop(tunn);
if let Err(e) = udp.send_to(&resp_owned, src).await {
tracing::warn!(error = %e, "UDP drain send failed");
}
}
TunnResult::WriteToTunnelV4(pkt, _)
| TunnResult::WriteToTunnelV6(pkt, _) => {
let pkt_owned = pkt.to_vec();
drop(tunn);
if let Err(e) = tun.send(&pkt_owned).await {
tracing::warn!(error = %e, "Wintun drain send failed");
}
}
TunnResult::Done | TunnResult::Err(_) => break,
}
}
let _ = pk; break;
}
}
}
}
#[cfg(windows)]
async fn egress_loop(
tun: Arc<WindowsTun>,
udp: Arc<UdpSocket>,
peers: Arc<DashMap<[u8; 32], WindowsPeerState>>,
) {
let mut buf = vec![0u8; 65536];
loop {
let n = match tun.recv(&mut buf).await {
Ok(n) => n,
Err(e) => {
tracing::error!(error = %e, "Wintun recv failed; egress loop exiting");
break;
}
};
let Some(dst_ip) = parse_dst_ip(&buf[..n]) else {
continue;
};
let state = peers.iter().find_map(|entry| {
if entry
.value()
.allowed_ips
.iter()
.any(|net| net.contains(&dst_ip))
{
Some(entry.value().clone())
} else {
None
}
});
let Some(state) = state else {
tracing::trace!(%dst_ip, "no matching overlay peer");
continue;
};
let endpoint = *state.endpoint.read();
let Some(endpoint) = endpoint else {
tracing::trace!(%dst_ip, "peer has no endpoint yet; dropping");
continue;
};
let mut out = vec![0u8; 65536 + 32];
let mut tunn = state.tunn.lock().await;
match tunn.encapsulate(&buf[..n], &mut out) {
TunnResult::WriteToNetwork(pkt) => {
let pkt_owned = pkt.to_vec();
drop(tunn);
if let Err(e) = udp.send_to(&pkt_owned, endpoint).await {
tracing::warn!(error = %e, "UDP send failed");
}
}
TunnResult::Done
| TunnResult::WriteToTunnelV4(_, _)
| TunnResult::WriteToTunnelV6(_, _) => {
}
TunnResult::Err(e) => {
tracing::warn!(?e, "encapsulate error");
}
}
}
}
#[cfg(windows)]
async fn timers_loop(udp: Arc<UdpSocket>, peers: Arc<DashMap<[u8; 32], WindowsPeerState>>) {
let mut interval = tokio::time::interval(Duration::from_millis(250));
interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay);
loop {
interval.tick().await;
let snapshot: Vec<WindowsPeerState> = peers.iter().map(|e| e.value().clone()).collect();
for state in snapshot {
let endpoint = *state.endpoint.read();
let mut out = vec![0u8; 148];
let mut tunn = state.tunn.lock().await;
match tunn.update_timers(&mut out) {
TunnResult::WriteToNetwork(pkt) => {
let pkt_owned = pkt.to_vec();
drop(tunn);
if let Some(ep) = endpoint {
if let Err(e) = udp.send_to(&pkt_owned, ep).await {
tracing::debug!(error = %e, "timers UDP send failed");
}
}
}
TunnResult::Done
| TunnResult::WriteToTunnelV4(_, _)
| TunnResult::WriteToTunnelV6(_, _) => {}
TunnResult::Err(e) => {
tracing::debug!(?e, "update_timers error");
}
}
}
}
}
async fn configure_interface(&self) -> Result<(), Box<dyn std::error::Error>> {
let cidr: ipnet::IpNet = self.config.overlay_cidr.parse().map_err(|e| {
format!(
"Failed to parse overlay CIDR '{}': {e}",
self.config.overlay_cidr
)
})?;
let overlay_addr = cidr.addr();
let prefix_len = cidr.prefix_len();
let net_addr = cidr.network();
let iface_ops = platform_ops();
if let Err(e) = iface_ops
.add_address(&self.interface_name, overlay_addr, prefix_len)
.await
{
let msg = e.to_string();
if !msg.contains("File exists") && !msg.contains("EEXIST") {
return Err(format!("Failed to assign IP: {msg}").into());
}
}
iface_ops
.set_link_up(&self.interface_name)
.await
.map_err(|e| format!("Failed to bring up interface: {e}"))?;
if let Err(e) = iface_ops
.add_route_via_dev(net_addr, prefix_len, &self.interface_name)
.await
{
let msg = e.to_string();
if !msg.contains("File exists")
&& !msg.contains("EEXIST")
&& !msg.contains("already in table")
{
return Err(format!("Failed to add route: {msg}").into());
}
}
Ok(())
}
#[cfg_attr(windows, allow(clippy::unused_async))]
pub async fn add_peer(&self, peer: &PeerInfo) -> Result<(), Box<dyn std::error::Error>> {
#[cfg(not(windows))]
{
let sock = self.uapi_sock_path();
let pub_hex = key_to_hex(&peer.public_key)?;
let body = format!(
"public_key={}\nendpoint={}\nallowed_ip={}\npersistent_keepalive_interval={}\n",
pub_hex,
peer.endpoint,
peer.allowed_ips,
peer.persistent_keepalive_interval.as_secs(),
);
uapi_set(&sock, &body).await?;
tracing::debug!(
peer_key = %peer.public_key,
interface = %self.interface_name,
"Added peer via UAPI"
);
Ok(())
}
#[cfg(windows)]
{
let priv_bytes = decode_key_b64(&self.config.private_key)?;
self.add_peer_windows(&priv_bytes, peer)?;
Ok(())
}
}
#[cfg_attr(windows, allow(clippy::unused_async))]
pub async fn remove_peer(&self, public_key: &str) -> Result<(), Box<dyn std::error::Error>> {
#[cfg(not(windows))]
{
let sock = self.uapi_sock_path();
let pub_hex = key_to_hex(public_key)?;
let body = format!("public_key={pub_hex}\nremove=true\n");
uapi_set(&sock, &body).await?;
tracing::debug!(
peer_key = %public_key,
interface = %self.interface_name,
"Removed peer via UAPI"
);
Ok(())
}
#[cfg(windows)]
{
let pk = decode_key_b64(public_key)?;
self.peers.remove(&pk);
tracing::debug!(
peer_key = %public_key,
interface = %self.interface_name,
"Removed peer from Windows overlay"
);
Ok(())
}
}
#[cfg_attr(windows, allow(clippy::unused_async))]
pub async fn status(&self) -> Result<String, Box<dyn std::error::Error>> {
#[cfg(not(windows))]
{
let sock = self.uapi_sock_path();
let response = uapi_get(&sock).await?;
Ok(response)
}
#[cfg(windows)]
{
use base64::{engine::general_purpose::STANDARD, Engine as _};
let mut out = String::new();
let priv_bytes = decode_key_b64(&self.config.private_key).unwrap_or([0u8; 32]);
let _ = writeln!(out, "private_key={}", hex::encode(priv_bytes));
let _ = writeln!(out, "listen_port={}", self.config.local_endpoint.port());
for entry in self.peers.iter() {
let pk_b64 = STANDARD.encode(entry.key());
let _ = writeln!(out, "public_key={}", hex::encode(entry.key()));
let _ = writeln!(out, "public_key_b64={pk_b64}");
if let Some(ep) = *entry.value().endpoint.read() {
let _ = writeln!(out, "endpoint={ep}");
}
for net in entry.value().allowed_ips.iter() {
let _ = writeln!(out, "allowed_ip={net}");
}
if let Some(k) = entry.value().persistent_keepalive {
let _ = writeln!(out, "persistent_keepalive_interval={k}");
}
let last = entry.value().last_handshake_sec.load(Ordering::Relaxed);
let _ = writeln!(out, "last_handshake_time_sec={last}");
}
let _ = writeln!(out, "errno=0");
Ok(out)
}
}
#[allow(clippy::unused_async)]
pub async fn generate_keys() -> Result<(String, String), Box<dyn std::error::Error>> {
use base64::{engine::general_purpose::STANDARD, Engine as _};
use x25519_dalek::{PublicKey, StaticSecret};
let secret = StaticSecret::random();
let public = PublicKey::from(&secret);
let private_key = STANDARD.encode(secret.to_bytes());
let public_key = STANDARD.encode(public.as_bytes());
Ok((private_key, public_key))
}
#[cfg(feature = "nat")]
#[cfg_attr(windows, allow(clippy::unused_async))]
pub async fn update_peer_endpoint(
&self,
public_key: &str,
new_endpoint: std::net::SocketAddr,
) -> Result<(), Box<dyn std::error::Error>> {
#[cfg(not(windows))]
{
let sock = self.uapi_sock_path();
let pub_hex = key_to_hex(public_key)?;
let body = format!("public_key={pub_hex}\nendpoint={new_endpoint}\n");
uapi_set(&sock, &body).await?;
tracing::debug!(
peer_key = %public_key,
endpoint = %new_endpoint,
"Updated peer endpoint"
);
Ok(())
}
#[cfg(windows)]
{
let pk = decode_key_b64(public_key)?;
let entry = self
.peers
.get(&pk)
.ok_or_else(|| format!("peer not found: {public_key}"))?;
*entry.value().endpoint.write() = Some(new_endpoint);
tracing::debug!(
peer_key = %public_key,
endpoint = %new_endpoint,
"Updated peer endpoint (Windows)"
);
Ok(())
}
}
#[cfg(feature = "nat")]
#[cfg_attr(windows, allow(clippy::unused_async))]
pub async fn check_peer_handshake(
&self,
public_key: &str,
since: u64,
) -> Result<bool, Box<dyn std::error::Error>> {
#[cfg(not(windows))]
{
let sock = self.uapi_sock_path();
let response = uapi_get(&sock).await?;
let target_hex = key_to_hex(public_key)?;
let mut in_target = false;
for line in response.lines() {
let line = line.trim();
if line.is_empty() || line.starts_with("errno=") {
continue;
}
let Some((key, value)) = line.split_once('=') else {
continue;
};
match key {
"public_key" => {
in_target = value == target_hex;
}
"last_handshake_time_sec" if in_target => {
if let Ok(t) = value.parse::<u64>() {
return Ok(t > 0 && t >= since);
}
}
_ => {}
}
}
Ok(false)
}
#[cfg(windows)]
{
let pk = decode_key_b64(public_key)?;
let entry = self
.peers
.get(&pk)
.ok_or_else(|| format!("peer not found: {public_key}"))?;
let last = entry.value().last_handshake_sec.load(Ordering::Relaxed);
Ok(last > 0 && last >= since)
}
}
pub fn shutdown(&mut self) {
#[cfg(not(windows))]
if let Some(device) = self.device.take() {
tracing::info!(
interface = %self.interface_name,
"Shutting down overlay transport"
);
drop(device);
}
#[cfg(windows)]
{
if let Some(h) = self.ingress_task.take() {
h.abort();
}
if let Some(h) = self.egress_task.take() {
h.abort();
}
if let Some(h) = self.timers_task.take() {
h.abort();
}
self.udp.take();
self.peers.clear();
if let Some(dev) = self.wintun_dev.take() {
tracing::info!(
interface = %self.interface_name,
"Shutting down Wintun overlay transport"
);
drop(dev);
}
}
}
}
impl Drop for OverlayTransport {
fn drop(&mut self) {
self.shutdown();
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
use std::time::Duration;
#[test]
fn test_peer_info_to_config() {
let peer = PeerInfo::new(
"test_public_key".to_string(),
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1)), 51820),
"10.0.0.2/32",
Duration::from_secs(25),
);
let config = peer.to_peer_config();
assert!(config.contains("PublicKey = test_public_key"));
assert!(config.contains("Endpoint = 10.0.0.1:51820"));
}
#[cfg(windows)]
#[test]
fn test_parse_dst_ip_v4() {
let mut pkt = vec![0u8; 20];
pkt[0] = 0x45;
pkt[16..20].copy_from_slice(&[10, 0, 0, 7]);
assert_eq!(
super::parse_dst_ip(&pkt),
Some(IpAddr::V4(Ipv4Addr::new(10, 0, 0, 7)))
);
}
#[cfg(windows)]
#[test]
fn test_parse_dst_ip_v6() {
let mut pkt = vec![0u8; 40];
pkt[0] = 0x60;
pkt[24] = 0xfd;
pkt[25] = 0x00;
pkt[39] = 0x01;
let expected = IpAddr::V6(Ipv6Addr::new(0xfd00, 0, 0, 0, 0, 0, 0, 1));
assert_eq!(super::parse_dst_ip(&pkt), Some(expected));
}
#[cfg(windows)]
#[test]
fn test_parse_dst_ip_truncated_returns_none() {
let pkt = vec![0x45u8; 10];
assert_eq!(super::parse_dst_ip(&pkt), None);
assert_eq!(super::parse_dst_ip(&[]), None);
}
#[cfg(windows)]
#[test]
fn test_parse_dst_ip_unknown_version_returns_none() {
let pkt = vec![0x70u8; 64];
assert_eq!(super::parse_dst_ip(&pkt), None);
}
#[cfg(windows)]
#[test]
fn test_decode_key_b64_roundtrip() {
use base64::{engine::general_purpose::STANDARD, Engine as _};
let raw = [0x42u8; 32];
let b64 = STANDARD.encode(raw);
let decoded = super::decode_key_b64(&b64).expect("decode");
assert_eq!(decoded, raw);
}
#[cfg(windows)]
#[test]
fn test_decode_key_b64_wrong_length_errors() {
use base64::{engine::general_purpose::STANDARD, Engine as _};
let short = STANDARD.encode([0u8; 16]);
assert!(super::decode_key_b64(&short).is_err());
}
#[test]
fn test_peer_info_ipv6_to_config() {
let peer = PeerInfo::new(
"test_public_key_v6".to_string(),
SocketAddr::new(
IpAddr::V6(Ipv6Addr::new(0xfd00, 0, 0, 0, 0, 0, 0, 1)),
51820,
),
"fd00::2/128",
Duration::from_secs(25),
);
let config = peer.to_peer_config();
assert!(config.contains("PublicKey = test_public_key_v6"));
assert!(
config.contains("Endpoint = [fd00::1]:51820"),
"IPv6 endpoint should use bracket notation, got: {config}"
);
assert!(config.contains("AllowedIPs = fd00::2/128"));
}
#[test]
fn test_overlay_cidr_parses_ipv4() {
let cidr: ipnet::IpNet = "10.200.0.1/24".parse().unwrap();
assert!(cidr.addr().is_ipv4());
assert_eq!(cidr.prefix_len(), 24);
assert_eq!(cidr.network().to_string(), "10.200.0.0");
}
#[test]
fn test_overlay_cidr_parses_ipv6() {
let cidr: ipnet::IpNet = "fd00::1/48".parse().unwrap();
assert!(cidr.addr().is_ipv6());
assert_eq!(cidr.prefix_len(), 48);
assert_eq!(cidr.network().to_string(), "fd00::");
}
#[test]
fn test_overlay_cidr_ipv6_host_address() {
let cidr: ipnet::IpNet = "fd00::5/128".parse().unwrap();
assert!(cidr.addr().is_ipv6());
assert_eq!(cidr.prefix_len(), 128);
assert_eq!(cidr.addr().to_string(), "fd00::5");
}
#[test]
fn test_peer_info_ipv6_allowed_ips_format() {
let peer_v4 = PeerInfo::new(
"key_v4".to_string(),
SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 51820),
"10.200.0.5/32",
Duration::from_secs(25),
);
assert_eq!(peer_v4.allowed_ips, "10.200.0.5/32");
let peer_v6 = PeerInfo::new(
"key_v6".to_string(),
SocketAddr::new(
IpAddr::V6(Ipv6Addr::new(0xfd00, 0, 0, 0, 0, 0, 0, 5)),
51820,
),
"fd00::5/128",
Duration::from_secs(25),
);
assert_eq!(peer_v6.allowed_ips, "fd00::5/128");
}
#[test]
fn test_uapi_body_format_ipv6_peer() {
let endpoint = SocketAddr::new(
IpAddr::V6(Ipv6Addr::new(0xfd00, 0, 0, 0, 0, 0, 0, 1)),
51820,
);
let formatted = format!("endpoint={endpoint}");
assert_eq!(formatted, "endpoint=[fd00::1]:51820");
}
#[tokio::test]
async fn test_generate_keys_native() {
use base64::{engine::general_purpose::STANDARD, Engine as _};
use x25519_dalek::{PublicKey, StaticSecret};
let (private_key, public_key) = OverlayTransport::generate_keys().await.unwrap();
assert_eq!(
private_key.len(),
44,
"Private key should be 44 chars base64"
);
assert_eq!(public_key.len(), 44, "Public key should be 44 chars base64");
let priv_bytes = STANDARD.decode(&private_key).unwrap();
let pub_bytes = STANDARD.decode(&public_key).unwrap();
assert_eq!(priv_bytes.len(), 32);
assert_eq!(pub_bytes.len(), 32);
let secret = StaticSecret::from(<[u8; 32]>::try_from(priv_bytes.as_slice()).unwrap());
let expected_public = PublicKey::from(&secret);
assert_eq!(pub_bytes.as_slice(), expected_public.as_bytes());
}
#[tokio::test]
async fn test_generate_keys_unique() {
let (key1, _) = OverlayTransport::generate_keys().await.unwrap();
let (key2, _) = OverlayTransport::generate_keys().await.unwrap();
assert_ne!(
key1, key2,
"Sequential key generation should produce unique keys"
);
}
#[cfg(not(windows))]
#[test]
fn test_key_to_hex() {
use base64::{engine::general_purpose::STANDARD, Engine as _};
let key_bytes = [0xABu8; 32];
let base64_key = STANDARD.encode(key_bytes);
let hex_key = key_to_hex(&base64_key).unwrap();
assert_eq!(hex_key, "ab".repeat(32));
assert_eq!(hex_key.len(), 64, "Hex key should be 64 chars");
}
#[cfg(not(windows))]
#[test]
fn test_key_to_hex_invalid_length() {
use base64::{engine::general_purpose::STANDARD, Engine as _};
let short_bytes = [0xABu8; 16];
let base64_key = STANDARD.encode(short_bytes);
let result = key_to_hex(&base64_key);
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("Invalid key length"));
}
#[tokio::test]
#[ignore = "Requires root/CAP_NET_ADMIN"]
async fn test_create_interface_boringtun() {
let config = OverlayConfig {
overlay_cidr: "10.42.0.1/24".to_string(),
cluster_cidr: None,
private_key: "test_key".to_string(),
public_key: "test_pub".to_string(),
local_endpoint: SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 51820),
peer_discovery_interval: Duration::from_secs(30),
#[cfg(feature = "nat")]
nat: crate::nat::NatConfig::default(),
};
#[cfg(target_os = "macos")]
let iface_name = "utun".to_string();
#[cfg(not(target_os = "macos"))]
let iface_name = "zl-bt-test0".to_string();
let mut transport = OverlayTransport::new(config, iface_name);
let result = transport.create_interface().await;
match result {
Ok(()) => {
#[cfg(target_os = "macos")]
assert!(
transport.interface_name().starts_with("utun"),
"macOS interface should be utunN, got: {}",
transport.interface_name()
);
transport.shutdown();
}
Err(e) => {
let msg = e.to_string();
assert!(
!msg.contains("Attribute failed policy validation"),
"create_interface should not produce kernel WireGuard errors. Got: {msg}",
);
assert!(
msg.contains("boringtun")
|| msg.contains("CAP_NET_ADMIN")
|| msg.contains("sudo"),
"Error should mention boringtun, CAP_NET_ADMIN, or sudo. Got: {msg}",
);
}
}
}
#[tokio::test]
#[ignore = "Requires root/CAP_NET_ADMIN"]
async fn test_create_interface_boringtun_ipv6() {
let config = OverlayConfig {
overlay_cidr: "fd00::1/48".to_string(),
cluster_cidr: None,
private_key: "test_key".to_string(),
public_key: "test_pub".to_string(),
local_endpoint: SocketAddr::new(IpAddr::V6(Ipv6Addr::LOCALHOST), 51820),
peer_discovery_interval: Duration::from_secs(30),
#[cfg(feature = "nat")]
nat: crate::nat::NatConfig::default(),
};
#[cfg(target_os = "macos")]
let iface_name = "utun".to_string();
#[cfg(not(target_os = "macos"))]
let iface_name = "zl-bt6-test0".to_string();
let mut transport = OverlayTransport::new(config, iface_name);
let result = transport.create_interface().await;
match result {
Ok(()) => {
#[cfg(target_os = "macos")]
assert!(
transport.interface_name().starts_with("utun"),
"macOS interface should be utunN, got: {}",
transport.interface_name()
);
transport.shutdown();
}
Err(e) => {
let msg = e.to_string();
assert!(
!msg.contains("Attribute failed policy validation"),
"create_interface should not produce kernel WireGuard errors. Got: {msg}",
);
assert!(
msg.contains("boringtun")
|| msg.contains("CAP_NET_ADMIN")
|| msg.contains("sudo"),
"Error should mention boringtun, CAP_NET_ADMIN, or sudo. Got: {msg}",
);
}
}
}
}