#![allow(dead_code)]
use std::{
fmt::{self, Display},
iter::Peekable,
net::SocketAddr,
str::FromStr,
};
use eyre::{bail, ensure, eyre};
use typed_builder::TypedBuilder;
use crate::{
device::{daita::api::DaitaSettings, peer::AllowedIP},
serialization::KeyBytes,
};
#[derive(Debug)]
pub enum Request {
Get(Get),
Set(Set),
}
#[derive(Debug)]
pub enum Response {
Get(GetResponse),
Set(SetResponse),
}
#[derive(Default, Debug)]
#[non_exhaustive]
pub struct Get;
#[derive(Debug, TypedBuilder)]
#[non_exhaustive]
pub struct GetPeer {
pub peer: Peer,
#[builder(default, setter(strip_option, into))]
pub last_handshake_time_sec: Option<u64>,
#[builder(default, setter(strip_option, into))]
pub last_handshake_time_nsec: Option<u32>,
#[builder(default, setter(strip_option, into))]
pub rx_bytes: Option<u64>,
#[builder(default, setter(strip_option, into))]
pub tx_bytes: Option<u64>,
#[builder(default, setter(strip_option, into))]
pub tx_padding_bytes: Option<u64>,
#[builder(default, setter(strip_option, into))]
pub tx_padding_packet_bytes: Option<u64>,
#[builder(default, setter(strip_option, into))]
pub rx_padding_bytes: Option<u64>,
#[builder(default, setter(strip_option, into))]
pub rx_padding_packet_bytes: Option<u64>,
}
#[derive(TypedBuilder, Default, Debug)]
#[non_exhaustive]
pub struct GetResponse {
#[builder(default, setter(strip_option, into))]
pub private_key: Option<KeyBytes>,
#[builder(default, setter(strip_option, into))]
pub listen_port: Option<u16>,
#[builder(default, setter(strip_option, into))]
pub fwmark: Option<u32>,
#[builder(default, setter(skip))]
pub peers: Vec<GetPeer>,
pub errno: i32,
}
#[derive(TypedBuilder, Default, Debug)]
#[non_exhaustive]
pub struct Set {
#[builder(default, setter(strip_option, into))]
pub private_key: Option<KeyBytes>,
#[builder(default, setter(strip_option, into))]
pub listen_port: Option<u16>,
#[builder(default, setter(strip_option, into))]
pub fwmark: Option<u32>,
#[builder(setter(strip_bool))]
pub replace_peers: bool,
#[builder(default, setter(strip_option, into))]
pub protocol_version: Option<String>,
#[builder(default, setter(skip))]
pub peers: Vec<SetPeer>,
}
#[derive(TypedBuilder, Debug)]
#[non_exhaustive]
pub struct SetPeer {
pub peer: Peer,
#[builder(setter(strip_bool))]
pub remove: bool,
#[builder(setter(strip_bool))]
pub update_only: bool,
#[builder(setter(strip_bool))]
pub replace_allowed_ips: bool,
#[builder(default, setter(strip_option, into))]
pub daita_settings: Option<DaitaSettings>,
}
#[derive(Debug)]
#[non_exhaustive]
pub struct SetResponse {
pub errno: i32,
}
#[derive(Debug)]
pub enum SetUnset<T> {
Set(T),
Unset,
}
#[derive(TypedBuilder, Debug)]
#[non_exhaustive]
pub struct Peer {
#[builder(setter(into))]
pub public_key: KeyBytes,
#[builder(default, setter(strip_option, into))]
pub preshared_key: Option<SetUnset<KeyBytes>>,
#[builder(default, setter(strip_option, into))]
pub endpoint: Option<SocketAddr>,
#[builder(default, setter(strip_option, into))]
pub persistent_keepalive_interval: Option<u16>,
#[builder(default)]
pub allowed_ip: Vec<AllowedIP>,
}
impl From<Set> for Request {
fn from(set: Set) -> Self {
Self::Set(set)
}
}
impl From<Get> for Request {
fn from(get: Get) -> Self {
Self::Get(get)
}
}
impl Set {
pub fn peer(mut self, peer: SetPeer) -> Self {
self.peers.push(peer);
self
}
}
impl Peer {
pub fn new(public_key: impl Into<KeyBytes>) -> Self {
Self {
public_key: public_key.into(),
preshared_key: None,
endpoint: None,
persistent_keepalive_interval: None,
allowed_ip: vec![],
}
}
pub fn with_endpoint(mut self, endpoint: impl Into<SocketAddr>) -> Self {
self.endpoint = Some(endpoint.into());
self
}
}
impl SetPeer {
pub fn new(public_key: impl Into<KeyBytes>) -> Self {
Self {
peer: Peer::new(public_key),
remove: false,
update_only: false,
replace_allowed_ips: false,
daita_settings: None,
}
}
pub fn with_endpoint(mut self, endpoint: impl Into<SocketAddr>) -> Self {
self.peer.endpoint = Some(endpoint.into());
self
}
}
impl GetPeer {
pub fn new(public_key: impl Into<KeyBytes>) -> Self {
Self {
peer: Peer::new(public_key),
last_handshake_time_sec: None,
last_handshake_time_nsec: None,
rx_bytes: None,
tx_bytes: None,
tx_padding_bytes: None,
tx_padding_packet_bytes: None,
rx_padding_bytes: None,
rx_padding_packet_bytes: None,
}
}
}
impl GetResponse {
pub fn peer(mut self, peer: GetPeer) -> Self {
self.peers.push(peer);
self
}
}
impl Display for Response {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Response::Get(get) => get.fmt(f),
Response::Set(set) => set.fmt(f),
}
}
}
macro_rules! opt_to_key_and_display {
($i:ident) => {
$i.as_ref().map(|r| (stringify!($i), r as &dyn Display))
};
}
impl Display for GetResponse {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let GetResponse {
private_key,
listen_port,
fwmark,
peers,
errno,
} = self;
let fields = [
opt_to_key_and_display!(private_key),
opt_to_key_and_display!(listen_port),
opt_to_key_and_display!(fwmark),
]
.into_iter()
.flatten();
for (key, value) in fields {
writeln!(f, "{key}={value}")?;
}
for peer in peers {
write!(f, "{peer}")?;
}
writeln!(f, "errno={errno}")?;
Ok(())
}
}
impl Display for GetPeer {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let GetPeer {
peer:
Peer {
public_key,
preshared_key,
endpoint,
persistent_keepalive_interval,
allowed_ip,
},
last_handshake_time_sec,
last_handshake_time_nsec,
rx_bytes,
tx_bytes,
tx_padding_bytes: _,
tx_padding_packet_bytes: _,
rx_padding_bytes: _,
rx_padding_packet_bytes: _,
} = self;
let public_key = Some(&public_key);
let fields = [
opt_to_key_and_display!(public_key),
opt_to_key_and_display!(preshared_key),
opt_to_key_and_display!(endpoint),
opt_to_key_and_display!(persistent_keepalive_interval),
opt_to_key_and_display!(last_handshake_time_sec),
opt_to_key_and_display!(last_handshake_time_nsec),
opt_to_key_and_display!(rx_bytes),
opt_to_key_and_display!(tx_bytes),
]
.into_iter()
.flatten();
for (key, value) in fields {
writeln!(f, "{key}={value}")?;
}
for AllowedIP { addr, cidr } in allowed_ip {
writeln!(f, "allowed_ip={addr}/{cidr}")?;
}
Ok(())
}
}
impl Display for SetResponse {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(f, "errno={}", self.errno)
}
}
impl<T: Display> Display for SetUnset<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
SetUnset::Set(t) => Display::fmt(t, f),
SetUnset::Unset => Ok(()),
}
}
}
impl Display for KeyBytes {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
for b in &self.0 {
write!(f, "{b:02x}")?;
}
Ok(())
}
}
macro_rules! parse_opt {
($key:expr, $value:expr, $field:ident) => {{
ensure!(
$field.is_none(),
"Key {:?} may not be specified twice",
$key
);
*$field = Some($value.parse().unwrap());
}};
}
macro_rules! parse_bool {
($key:expr, $value:expr, $field:ident) => {{
ensure!(
$value == "true",
"The only valid value for key {:?} is \"true\"",
$key
);
*$field = true;
}};
}
impl FromStr for Get {
type Err = eyre::Report;
fn from_str(s: &str) -> Result<Self, Self::Err> {
if s != "get=1\n" {
bail!("Not a valid `get` command. Expected `get=1\\n`");
}
Ok(Get {})
}
}
impl FromStr for Set {
type Err = eyre::Report;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let mut lines = s.lines().peekable();
ensure!(
lines.next() == Some("set=1"),
"Set commands must start with 'set=1'"
);
let mut set = Set::default();
let Set {
private_key,
listen_port,
fwmark,
replace_peers,
protocol_version,
peers,
} = &mut set;
while let Some(line) = lines.next() {
if line.is_empty() {
break;
}
let (k, v) = to_key_value(line)?;
match k {
"private_key" => parse_opt!(k, v, private_key),
"listen_port" => parse_opt!(k, v, listen_port),
"fwmark" => parse_opt!(k, v, fwmark),
"replace_peers" => parse_bool!(k, v, replace_peers),
"protocol_version" => parse_opt!(k, v, protocol_version),
"public_key" => {
let public_key = KeyBytes::from_str(v).map_err(|err| eyre!("{err}"))?;
peers.push(SetPeer::from_lines(public_key, &mut lines)?);
}
_ => bail!("Key {k:?} in {line:?} is not allowed in command set"),
}
}
Ok(set)
}
}
impl<T: FromStr> FromStr for SetUnset<T> {
type Err = T::Err;
fn from_str(s: &str) -> Result<Self, Self::Err> {
Ok(if s.is_empty() {
SetUnset::Unset
} else {
SetUnset::Set(T::from_str(s)?)
})
}
}
impl SetPeer {
fn from_lines<'a>(
public_key: impl Into<KeyBytes>,
lines: &mut Peekable<impl Iterator<Item = &'a str>>,
) -> eyre::Result<Self> {
let mut set_peer = SetPeer::new(public_key);
let SetPeer {
peer:
Peer {
public_key: _,
preshared_key,
endpoint,
persistent_keepalive_interval,
allowed_ip,
},
remove,
update_only,
replace_allowed_ips,
daita_settings: _, } = &mut set_peer;
loop {
let Some(line) = lines.peek() else {
break;
};
if line.is_empty() {
break;
}
let (k, v) = to_key_value(line)?;
match k {
"public_key" => break,
"preshared_key" => parse_opt!(k, v, preshared_key),
"endpoint" => parse_opt!(k, v, endpoint),
"persistent_keepalive_interval" => parse_opt!(k, v, persistent_keepalive_interval),
"remove" => parse_bool!(k, v, remove),
"update_only" => parse_bool!(k, v, update_only),
"replace_allowed_ips" => parse_bool!(k, v, replace_allowed_ips),
"allowed_ip" => allowed_ip.push(v.parse().map_err(|err| eyre!("{err}"))?),
_ => bail!("Key {k:?} in {line:?} is not allowed in command set/peer"),
}
lines.next();
}
Ok(set_peer)
}
}
impl FromStr for Request {
type Err = eyre::Report;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let Some((first_line, ..)) = s.split_once('\n') else {
bail!("Missing newline: {s:?}");
};
Ok(match first_line {
"set=1" => Set::from_str(s)?.into(),
"get=1" => Get::from_str(s)?.into(),
_ => bail!("Unknown command: {s:?}"),
})
}
}
fn to_key_value(line: &str) -> eyre::Result<(&str, &str)> {
line.split_once('=')
.ok_or(eyre!("expected {line:?} to be `<key>=<value>`"))
}
fn testy() {
let public_key = [0x77u8; 32];
let get = Peer::builder().public_key(public_key).build();
let get = GetPeer::builder().peer(get).build();
let _get = GetResponse::builder()
.fwmark(123u32)
.listen_port(18u16)
.errno(0)
.build()
.peer(get);
let _set = Set::builder()
.fwmark(1234u32)
.private_key(public_key)
.build()
.peer(
SetPeer::builder()
.peer(Peer::builder().public_key(public_key).build())
.remove()
.update_only()
.build(),
)
.peer(
SetPeer::builder()
.peer(
Peer::builder()
.public_key(public_key)
.endpoint(([127, 0, 0, 1], 1234u16))
.build(),
)
.build(),
);
}