use crate::{Backend, Device, DeviceUpdate, InterfaceName, Key, PeerInfo};
use std::{
fmt::Write as _,
fs,
io::{self, prelude::*, BufReader},
os::unix::net::UnixStream,
path::{Path, PathBuf},
process::{Command, Output},
time::{Duration, SystemTime},
};
static VAR_RUN_PATH: &str = "/var/run/wireguard";
static RUN_PATH: &str = "/run/wireguard";
fn get_base_folder() -> io::Result<PathBuf> {
if Path::new(VAR_RUN_PATH).exists() {
Ok(Path::new(VAR_RUN_PATH).to_path_buf())
} else if Path::new(RUN_PATH).exists() {
Ok(Path::new(RUN_PATH).to_path_buf())
} else {
Err(io::Error::new(
io::ErrorKind::NotFound,
"WireGuard socket directory not found.",
))
}
}
fn get_namefile(name: &InterfaceName) -> io::Result<PathBuf> {
Ok(get_base_folder()?.join(format!("{}.name", name.as_str_lossy())))
}
fn get_socketfile(name: &InterfaceName) -> io::Result<PathBuf> {
if cfg!(target_os = "linux") {
Ok(get_base_folder()?.join(format!("{name}.sock")))
} else {
Ok(get_base_folder()?.join(format!("{}.sock", resolve_tun(name)?)))
}
}
fn open_socket(name: &InterfaceName) -> io::Result<UnixStream> {
UnixStream::connect(get_socketfile(name)?)
}
pub fn resolve_tun(name: &InterfaceName) -> io::Result<String> {
let namefile = get_namefile(name)?;
Ok(fs::read_to_string(namefile)
.map_err(|_| io::Error::new(io::ErrorKind::NotFound, "WireGuard name file can't be read"))?
.trim()
.to_string())
}
pub fn delete_interface(name: &InterfaceName) -> io::Result<()> {
fs::remove_file(get_socketfile(name)?).ok();
fs::remove_file(get_namefile(name)?).ok();
Ok(())
}
pub fn enumerate() -> Result<Vec<InterfaceName>, io::Error> {
use std::ffi::OsStr;
let mut interfaces = vec![];
for entry in fs::read_dir(get_base_folder()?)? {
let path = entry?.path();
if path.extension() == Some(OsStr::new("name")) {
let stem = path
.file_stem()
.and_then(|stem| stem.to_str())
.and_then(|name| name.parse::<InterfaceName>().ok())
.filter(|iface| open_socket(iface).is_ok());
if let Some(iface) = stem {
interfaces.push(iface);
}
}
}
Ok(interfaces)
}
struct ConfigParser {
device_info: Device,
current_peer: Option<PeerInfo>,
}
impl From<ConfigParser> for Device {
fn from(parser: ConfigParser) -> Self {
parser.device_info
}
}
impl ConfigParser {
fn new(name: &InterfaceName) -> Self {
let device_info = Device {
name: *name,
public_key: None,
private_key: None,
fwmark: None,
listen_port: None,
peers: vec![],
linked_name: resolve_tun(name).ok(),
backend: Backend::Userspace,
};
Self {
device_info,
current_peer: None,
}
}
fn add_line(&mut self, line: &str) -> Result<(), std::io::Error> {
use io::ErrorKind::InvalidData;
let split: Vec<&str> = line.splitn(2, '=').collect();
match &split[..] {
[key, value] => self.add_pair(key, value),
_ => Err(InvalidData.into()),
}
}
fn add_pair(&mut self, key: &str, value: &str) -> Result<(), std::io::Error> {
use io::ErrorKind::InvalidData;
match key {
"private_key" => {
self.device_info.private_key = Some(Key::from_hex(value).map_err(|_| InvalidData)?);
self.device_info.public_key = self
.device_info
.private_key
.as_ref()
.map(|k| k.get_public());
},
"listen_port" => {
self.device_info.listen_port = Some(value.parse().map_err(|_| InvalidData)?)
},
"fwmark" => self.device_info.fwmark = Some(value.parse().map_err(|_| InvalidData)?),
"public_key" => {
let new_peer =
PeerInfo::from_public_key(Key::from_hex(value).map_err(|_| InvalidData)?);
if let Some(finished_peer) = self.current_peer.replace(new_peer) {
self.device_info.peers.push(finished_peer);
}
},
"preshared_key" => {
self.current_peer
.as_mut()
.ok_or(InvalidData)?
.config
.preshared_key = Some(Key::from_hex(value).map_err(|_| InvalidData)?);
},
"tx_bytes" => {
self.current_peer
.as_mut()
.ok_or(InvalidData)?
.stats
.tx_bytes = value.parse().map_err(|_| InvalidData)?
},
"rx_bytes" => {
self.current_peer
.as_mut()
.ok_or(InvalidData)?
.stats
.rx_bytes = value.parse().map_err(|_| InvalidData)?
},
"last_handshake_time_sec" => {
let handshake_seconds: u64 = value.parse().map_err(|_| InvalidData)?;
if handshake_seconds > 0 {
self.current_peer
.as_mut()
.ok_or(InvalidData)?
.stats
.last_handshake_time =
Some(SystemTime::UNIX_EPOCH + Duration::from_secs(handshake_seconds));
}
},
"allowed_ip" => {
self.current_peer
.as_mut()
.ok_or(InvalidData)?
.config
.allowed_ips
.push(value.parse().map_err(|_| InvalidData)?);
},
"persistent_keepalive_interval" => {
self.current_peer
.as_mut()
.ok_or(InvalidData)?
.config
.persistent_keepalive_interval = Some(value.parse().map_err(|_| InvalidData)?);
},
"endpoint" => {
self.current_peer
.as_mut()
.ok_or(InvalidData)?
.config
.endpoint = Some(value.parse().map_err(|_| InvalidData)?);
},
"errno" => {
if value != "0" {
return Err(std::io::Error::from_raw_os_error(
value
.parse()
.expect("Unable to parse userspace wg errno return code"),
));
}
if let Some(finished_peer) = self.current_peer.take() {
self.device_info.peers.push(finished_peer);
}
},
"protocol_version" | "last_handshake_time_nsec" => {},
_ => println!("got unsupported info: {key}={value}"),
}
Ok(())
}
}
pub fn get_by_name(name: &InterfaceName) -> Result<Device, io::Error> {
let mut sock = open_socket(name)?;
sock.write_all(b"get=1\n\n")?;
let mut reader = BufReader::new(sock);
let mut buf = String::new();
let mut parser = ConfigParser::new(name);
loop {
match reader.read_line(&mut buf)? {
0 | 1 if buf == "\n" => break,
_ => {
parser.add_line(buf.trim_end())?;
buf.clear();
},
};
}
Ok(parser.into())
}
fn get_userspace_implementation() -> String {
std::env::var("WG_USERSPACE_IMPLEMENTATION")
.or_else(|_| std::env::var("WG_QUICK_USERSPACE_IMPLEMENTATION"))
.unwrap_or_else(|_| "wireguard-go".to_string())
}
fn start_userspace_wireguard(iface: &InterfaceName) -> io::Result<Output> {
let userspace_implementation = get_userspace_implementation();
let mut command = Command::new(&userspace_implementation);
let output_res = if cfg!(target_os = "linux") {
command.args(&[iface.to_string()]).output()
} else {
command
.env("WG_TUN_NAME_FILE", format!("{VAR_RUN_PATH}/{iface}.name"))
.args(["utun"])
.output()
};
match output_res {
Ok(output) => {
if output.status.success() {
Ok(output)
} else {
Err(io::ErrorKind::AddrNotAvailable.into())
}
},
Err(e) if e.kind() == io::ErrorKind::NotFound => Err(io::Error::new(
io::ErrorKind::NotFound,
format!("Cannot find \"{userspace_implementation}\". Specify a custom path with WG_USERSPACE_IMPLEMENTATION."),
)),
Err(e) => Err(e),
}
}
pub fn apply(builder: &DeviceUpdate, iface: &InterfaceName) -> io::Result<()> {
let mut sock = match open_socket(iface) {
Err(_) => {
fs::create_dir_all(VAR_RUN_PATH)?;
let _ = fs::remove_file(get_namefile(iface)?);
start_userspace_wireguard(iface)?;
std::thread::sleep(Duration::from_millis(100));
open_socket(iface)
.map_err(|e| io::Error::new(e.kind(), format!("failed to open socket ({e})")))?
},
Ok(sock) => sock,
};
let mut request = String::from("set=1\n");
if let Some(ref k) = builder.private_key {
writeln!(request, "private_key={}", hex::encode(k.as_bytes())).ok();
}
if let Some(f) = builder.fwmark {
writeln!(request, "fwmark={f}").ok();
}
if let Some(f) = builder.listen_port {
writeln!(request, "listen_port={f}").ok();
}
if builder.replace_peers {
writeln!(request, "replace_peers=true").ok();
}
for peer in &builder.peers {
writeln!(
request,
"public_key={}",
hex::encode(peer.public_key.as_bytes())
)
.ok();
if peer.replace_allowed_ips {
writeln!(request, "replace_allowed_ips=true").ok();
}
if peer.remove_me {
writeln!(request, "remove=true").ok();
}
if let Some(ref k) = peer.preshared_key {
writeln!(request, "preshared_key={}", hex::encode(k.as_bytes())).ok();
}
if let Some(endpoint) = peer.endpoint {
writeln!(request, "endpoint={endpoint}").ok();
}
if let Some(keepalive_interval) = peer.persistent_keepalive_interval {
writeln!(
request,
"persistent_keepalive_interval={keepalive_interval}"
)
.ok();
}
for allowed_ip in &peer.allowed_ips {
writeln!(
request,
"allowed_ip={}/{}",
allowed_ip.address, allowed_ip.cidr
)
.ok();
}
}
request.push('\n');
sock.write_all(request.as_bytes())?;
let mut reader = BufReader::new(sock);
let mut line = String::new();
reader.read_line(&mut line)?;
let split: Vec<&str> = line.trim_end().splitn(2, '=').collect();
match &split[..] {
["errno", "0"] => Ok(()),
["errno", val] => {
println!("ERROR {val}");
Err(io::ErrorKind::InvalidInput.into())
},
_ => Err(io::ErrorKind::Other.into()),
}
}