use bytes::Bytes;
use std::collections::HashSet;
use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader};
use tokio::net::unix::{OwnedReadHalf, OwnedWriteHalf};
use tokio::net::UnixStream;
use tracing::debug;
use super::{Error, Request, Response, SetDevice, SetPeer};
use crate::noise::crypto;
pub struct Connection {
reader: BufReader<OwnedReadHalf>,
writer: OwnedWriteHalf,
}
impl Connection {
pub(super) fn new(socket: UnixStream) -> Self {
let (rh, wh) = socket.into_split();
Self {
reader: BufReader::new(rh),
writer: wh,
}
}
pub async fn next(&mut self) -> Result<Request, Error> {
let mut op = vec![];
self.reader.read_until(b'\n', &mut op).await?;
match op.as_slice() {
b"get=1\n" => {
if self.reader.read_u8().await? != b'\n' {
return Err(Error::InvalidProtocol);
}
Ok(Request::Get)
}
b"set=1\n" => {
let mut buf = vec![];
while self.reader.read_until(b'\n', &mut buf).await? > 1 {}
let s = unsafe { String::from_utf8_unchecked(buf).trim_end().to_owned() };
Ok(Request::Set(parse_set_request(&s)?))
}
_ => Err(Error::InvalidProtocol),
}
}
pub async fn write(&mut self, resp: Response) {
match resp {
Response::Ok => {
debug!("UAPI: writing ok response");
self.writer.write_all(b"errno=0\n\n").await.unwrap();
}
Response::Get(info) => {
let buf: Bytes = info.into();
self.writer.write_all(buf.as_ref()).await.unwrap();
}
_ => {}
}
}
}
#[allow(clippy::too_many_lines)]
fn parse_set_request(s: &str) -> Result<SetDevice, Error> {
debug!("UAPI: parsing set request: {:?}", s);
let mut set_device = SetDevice {
private_key: None,
listen_port: None,
fwmark: None,
replace_peers: false,
peers: vec![],
};
for line in s.split('\n') {
let (k, v) = line.split_once('=').ok_or(Error::InvalidProtocol)?;
match k {
"private_key" => {
let mut private_key = [0u8; 32];
private_key.copy_from_slice(crypto::decode_from_hex(v).as_slice());
set_device.private_key = Some(private_key);
}
"listen_port" => {
set_device.listen_port = Some(v.parse().map_err(|_| Error::InvalidProtocol)?);
}
"fwmark" => {
set_device.fwmark = Some(v.parse().map_err(|_| Error::InvalidProtocol)?);
}
"replace_peers" => {
if v != "true" {
return Err(Error::InvalidProtocol);
}
set_device.replace_peers = true;
}
"public_key" => {
set_device.peers.push(SetPeer {
public_key: [0u8; 32],
remove: false,
update_only: false,
psk: None,
endpoint: None,
persistent_keepalive_interval: None,
replace_allowed_ips: false,
allowed_ips: HashSet::new(),
});
set_device
.peers
.last_mut()
.ok_or(Error::InvalidProtocol)?
.public_key = crypto::decode_from_hex(v)
.as_slice()
.try_into()
.map_err(|_| Error::InvalidProtocol)?;
}
"remove" => {
if v != "true" {
return Err(Error::InvalidProtocol);
}
set_device
.peers
.last_mut()
.ok_or(Error::InvalidProtocol)?
.remove = true;
}
"update_only" => {
if v != "true" {
return Err(Error::InvalidProtocol);
}
set_device
.peers
.last_mut()
.ok_or(Error::InvalidProtocol)?
.update_only = true;
}
"preshared_key" => {
set_device
.peers
.last_mut()
.ok_or(Error::InvalidProtocol)?
.psk = Some(
crypto::decode_from_hex(v)
.as_slice()
.try_into()
.map_err(|_| Error::InvalidProtocol)?,
);
}
"endpoint" => {
set_device
.peers
.last_mut()
.ok_or(Error::InvalidProtocol)?
.endpoint = Some(v.parse().map_err(|_| Error::InvalidProtocol)?);
}
"persistent_keepalive_interval" => {
set_device
.peers
.last_mut()
.ok_or(Error::InvalidProtocol)?
.persistent_keepalive_interval =
Some(v.parse().map_err(|_| Error::InvalidProtocol)?);
}
"replace_allowed_ips" => {
if v != "true" {
return Err(Error::InvalidProtocol);
}
set_device
.peers
.last_mut()
.ok_or(Error::InvalidProtocol)?
.replace_allowed_ips = true;
}
"allowed_ip" => {
set_device
.peers
.last_mut()
.ok_or(Error::InvalidProtocol)?
.allowed_ips
.insert(v.parse().map_err(|_| Error::InvalidProtocol)?);
}
_ => return Err(Error::InvalidProtocol),
}
}
Ok(set_device)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
#[allow(clippy::too_many_lines)]
fn test_parse_set_request() {
let rv = parse_set_request(
"private_key=e84b5a6d2717c1003a13b431570353dbaca9146cf150c5f8575680feba52027a
fwmark=0
listen_port=12912
replace_peers=true
public_key=b85996fecc9c7f1fc6d2572a76eda11d59bcd20be8e543b15ce4bd85a8e75a33
preshared_key=188515093e952f5f22e865cef3012e72f8b5f0b598ac0309d5dacce3b70fcf52
replace_allowed_ips=true
allowed_ip=192.168.4.4/32
endpoint=[abcd:23::33%2]:51820
public_key=58402e695ba1772b1cc9309755f043251ea77fdcf10fbe63989ceb7e19321376
replace_allowed_ips=true
allowed_ip=192.168.4.6/32
persistent_keepalive_interval=111
endpoint=182.122.22.19:3233
public_key=662e14fd594556f522604703340351258903b64f35553763f19426ab2a515c58
endpoint=5.152.198.39:51820
replace_allowed_ips=true
allowed_ip=192.168.4.10/32
allowed_ip=192.168.4.11/32
public_key=e818b58db5274087fcc1be5dc728cf53d3b5726b4cef6b9bab8f8f8c2452c25c
remove=true",
);
assert!(rv.is_ok());
let rv = rv.unwrap();
assert_eq!(
rv,
SetDevice {
private_key: Some(
crypto::decode_from_hex(
"e84b5a6d2717c1003a13b431570353dbaca9146cf150c5f8575680feba52027a"
)
.try_into()
.unwrap()
),
listen_port: Some(12912),
fwmark: Some(0),
replace_peers: true,
peers: vec![
SetPeer {
public_key: crypto::decode_from_hex(
"b85996fecc9c7f1fc6d2572a76eda11d59bcd20be8e543b15ce4bd85a8e75a33"
)
.try_into()
.unwrap(),
remove: false,
update_only: false,
psk: Some(
crypto::decode_from_hex(
"188515093e952f5f22e865cef3012e72f8b5f0b598ac0309d5dacce3b70fcf52"
)
.try_into()
.unwrap()
),
endpoint: Some("[abcd:23::33%2]:51820".parse().unwrap()),
persistent_keepalive_interval: None,
replace_allowed_ips: true,
allowed_ips: ["192.168.4.4/32".parse().unwrap()].into_iter().collect(),
},
SetPeer {
public_key: crypto::decode_from_hex(
"58402e695ba1772b1cc9309755f043251ea77fdcf10fbe63989ceb7e19321376"
)
.try_into()
.unwrap(),
remove: false,
update_only: false,
psk: None,
endpoint: Some("182.122.22.19:3233".parse().unwrap()),
persistent_keepalive_interval: Some(111),
replace_allowed_ips: true,
allowed_ips: ["192.168.4.6/32".parse().unwrap()].into_iter().collect(),
},
SetPeer {
public_key: crypto::decode_from_hex(
"662e14fd594556f522604703340351258903b64f35553763f19426ab2a515c58"
)
.try_into()
.unwrap(),
remove: false,
update_only: false,
psk: None,
endpoint: Some("5.152.198.39:51820".parse().unwrap()),
persistent_keepalive_interval: None,
replace_allowed_ips: true,
allowed_ips: [
"192.168.4.10/32".parse().unwrap(),
"192.168.4.11/32".parse().unwrap(),
]
.into_iter()
.collect(),
},
SetPeer {
public_key: crypto::decode_from_hex(
"e818b58db5274087fcc1be5dc728cf53d3b5726b4cef6b9bab8f8f8c2452c25c"
)
.try_into()
.unwrap(),
remove: true,
update_only: false,
psk: None,
endpoint: None,
persistent_keepalive_interval: None,
replace_allowed_ips: false,
allowed_ips: [].into_iter().collect(),
}
],
}
)
}
}