use libc::c_char;
use crate::{backends, key::Key, Backend, KeyPair, PeerConfigBuilder};
use std::{
borrow::Cow,
ffi::CStr,
fmt, io,
net::{IpAddr, SocketAddr},
str::FromStr,
time::SystemTime,
};
#[derive(PartialEq, Eq, Clone)]
pub struct AllowedIp {
pub address: IpAddr,
pub cidr: u8,
}
impl fmt::Debug for AllowedIp {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}/{}", self.address, self.cidr)
}
}
impl std::str::FromStr for AllowedIp {
type Err = ();
fn from_str(s: &str) -> Result<Self, Self::Err> {
let parts: Vec<_> = s.split('/').collect();
if parts.len() != 2 {
return Err(());
}
Ok(AllowedIp {
address: parts[0].parse().map_err(|_| ())?,
cidr: parts[1].parse().map_err(|_| ())?,
})
}
}
#[derive(Debug, PartialEq, Eq, Clone)]
pub struct PeerConfig {
pub public_key: Key,
pub preshared_key: Option<Key>,
pub endpoint: Option<SocketAddr>,
pub persistent_keepalive_interval: Option<u16>,
pub allowed_ips: Vec<AllowedIp>,
pub(crate) __cant_construct_me: (),
}
#[derive(Debug, PartialEq, Eq, Clone, Default)]
pub struct PeerStats {
pub last_handshake_time: Option<SystemTime>,
pub rx_bytes: u64,
pub tx_bytes: u64,
}
#[derive(Debug, PartialEq, Eq, Clone)]
pub struct PeerInfo {
pub config: PeerConfig,
pub stats: PeerStats,
}
#[derive(Debug, PartialEq, Eq, Clone)]
pub struct Device {
pub name: InterfaceName,
pub public_key: Option<Key>,
pub private_key: Option<Key>,
pub fwmark: Option<u32>,
pub listen_port: Option<u16>,
pub peers: Vec<PeerInfo>,
pub linked_name: Option<String>,
pub backend: Backend,
pub(crate) __cant_construct_me: (),
}
type RawInterfaceName = [c_char; libc::IFNAMSIZ];
#[derive(PartialEq, Eq, Clone, Copy)]
pub struct InterfaceName(RawInterfaceName);
impl FromStr for InterfaceName {
type Err = InvalidInterfaceName;
fn from_str(name: &str) -> Result<Self, InvalidInterfaceName> {
let len = name.len();
if len == 0 {
return Err(InvalidInterfaceName::Empty);
}
if len > (libc::IFNAMSIZ - 1) {
return Err(InvalidInterfaceName::TooLong);
}
let mut buf = [c_char::default(); libc::IFNAMSIZ];
for (out, b) in buf.iter_mut().zip(name.as_bytes().iter()) {
if *b == 0 || *b == b'/' || b.is_ascii_whitespace() {
return Err(InvalidInterfaceName::InvalidChars);
}
*out = *b as c_char;
}
Ok(Self(buf))
}
}
impl InterfaceName {
pub fn as_str_lossy(&self) -> Cow<'_, str> {
unsafe { CStr::from_ptr(self.0.as_ptr()) }.to_string_lossy()
}
#[cfg(target_os = "linux")]
pub fn as_ptr(&self) -> *const c_char {
self.0.as_ptr()
}
}
impl fmt::Debug for InterfaceName {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(&self.as_str_lossy())
}
}
impl fmt::Display for InterfaceName {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(&self.as_str_lossy())
}
}
#[derive(Debug, PartialEq, Eq)]
pub enum InvalidInterfaceName {
TooLong,
Empty,
InvalidChars,
}
impl fmt::Display for InvalidInterfaceName {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::TooLong => write!(
f,
"interface name longer than system max of {} chars",
libc::IFNAMSIZ
),
Self::Empty => f.write_str("an empty interface name was provided"),
Self::InvalidChars => f.write_str("interface name contained slash or space characters"),
}
}
}
impl From<InvalidInterfaceName> for std::io::Error {
fn from(e: InvalidInterfaceName) -> Self {
std::io::Error::new(std::io::ErrorKind::InvalidData, e.to_string())
}
}
impl std::error::Error for InvalidInterfaceName {}
impl Device {
pub fn list(backend: Backend) -> Result<Vec<InterfaceName>, std::io::Error> {
match backend {
#[cfg(target_os = "linux")]
Backend::Kernel => backends::kernel::enumerate(),
Backend::Userspace => backends::userspace::enumerate(),
}
}
pub fn get(name: &InterfaceName, backend: Backend) -> Result<Self, std::io::Error> {
match backend {
#[cfg(target_os = "linux")]
Backend::Kernel => backends::kernel::get_by_name(name),
Backend::Userspace => backends::userspace::get_by_name(name),
}
}
pub fn delete(self) -> Result<(), std::io::Error> {
match self.backend {
#[cfg(target_os = "linux")]
Backend::Kernel => backends::kernel::delete_interface(&self.name),
Backend::Userspace => backends::userspace::delete_interface(&self.name),
}
}
}
#[derive(Debug, PartialEq, Eq, Clone)]
pub struct DeviceUpdate {
pub(crate) public_key: Option<Key>,
pub(crate) private_key: Option<Key>,
pub(crate) fwmark: Option<u32>,
pub(crate) listen_port: Option<u16>,
pub(crate) peers: Vec<PeerConfigBuilder>,
pub(crate) replace_peers: bool,
}
impl DeviceUpdate {
#[must_use]
pub fn new() -> Self {
DeviceUpdate {
public_key: None,
private_key: None,
fwmark: None,
listen_port: None,
peers: vec![],
replace_peers: false,
}
}
#[must_use]
pub fn set_keypair(self, keypair: KeyPair) -> Self {
self.set_public_key(keypair.public)
.set_private_key(keypair.private)
}
#[must_use]
pub fn set_public_key(mut self, key: Key) -> Self {
self.public_key = Some(key);
self
}
#[must_use]
pub fn unset_public_key(self) -> Self {
self.set_public_key(Key::zero())
}
#[must_use]
pub fn set_private_key(mut self, key: Key) -> Self {
self.private_key = Some(key);
self
}
#[must_use]
pub fn unset_private_key(self) -> Self {
self.set_private_key(Key::zero())
}
#[must_use]
pub fn set_fwmark(mut self, fwmark: u32) -> Self {
self.fwmark = Some(fwmark);
self
}
#[must_use]
pub fn unset_fwmark(self) -> Self {
self.set_fwmark(0)
}
#[must_use]
pub fn set_listen_port(mut self, port: u16) -> Self {
self.listen_port = Some(port);
self
}
#[must_use]
pub fn randomize_listen_port(self) -> Self {
self.set_listen_port(0)
}
#[must_use]
pub fn add_peer(mut self, peer: PeerConfigBuilder) -> Self {
self.peers.push(peer);
self
}
#[must_use]
pub fn add_peer_with(
self,
pubkey: &Key,
builder: impl Fn(PeerConfigBuilder) -> PeerConfigBuilder,
) -> Self {
self.add_peer(builder(PeerConfigBuilder::new(pubkey)))
}
#[must_use]
pub fn add_peers(mut self, peers: &[PeerConfigBuilder]) -> Self {
self.peers.extend_from_slice(peers);
self
}
#[must_use]
pub fn replace_peers(mut self) -> Self {
self.replace_peers = true;
self
}
#[must_use]
pub fn remove_peer_by_key(self, public_key: &Key) -> Self {
let mut peer = PeerConfigBuilder::new(public_key);
peer.remove_me = true;
self.add_peer(peer)
}
pub fn apply(self, iface: &InterfaceName, backend: Backend) -> io::Result<()> {
match backend {
#[cfg(target_os = "linux")]
Backend::Kernel => backends::kernel::apply(&self, iface),
Backend::Userspace => backends::userspace::apply(&self, iface),
}
}
}
impl Default for DeviceUpdate {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use crate::{DeviceUpdate, InterfaceName, InvalidInterfaceName, KeyPair, PeerConfigBuilder};
const TEST_INTERFACE: &str = "wgctrl-test";
use super::*;
#[test]
fn test_add_peers() {
if unsafe { libc::getuid() } != 0 {
return;
}
let keypairs: Vec<_> = (0..10).map(|_| KeyPair::generate()).collect();
let mut builder = DeviceUpdate::new();
for keypair in &keypairs {
builder = builder.add_peer(PeerConfigBuilder::new(&keypair.public))
}
let interface = TEST_INTERFACE.parse().unwrap();
builder.apply(&interface, Backend::Userspace).unwrap();
let device = Device::get(&interface, Backend::Userspace).unwrap();
for keypair in &keypairs {
assert!(device
.peers
.iter()
.any(|p| p.config.public_key == keypair.public));
}
device.delete().unwrap();
}
#[test]
fn test_interface_names() {
assert_eq!(
"wg-01".parse::<InterfaceName>().unwrap().as_str_lossy(),
"wg-01"
);
assert!("longer-nul\0".parse::<InterfaceName>().is_err());
let invalid_names = &[
("", InvalidInterfaceName::Empty), ("\0", InvalidInterfaceName::InvalidChars), ("ifname\0nul", InvalidInterfaceName::InvalidChars), ("if name", InvalidInterfaceName::InvalidChars), ("ifna/me", InvalidInterfaceName::InvalidChars), ("if na/me", InvalidInterfaceName::InvalidChars), ("interfacelongname", InvalidInterfaceName::TooLong), ];
for (name, expected) in invalid_names {
assert!(name.parse::<InterfaceName>().as_ref() == Err(expected))
}
}
}