use std::fmt;
use std::str::FromStr;
use data_encoding::BASE64;
use serde::{
Serialize,
Deserialize,
de::{self, Visitor}
};
use rand::Rng;
use crate::peer::NodeID;
#[derive(Clone, PartialEq, Eq)]
pub struct KeyArg {
key: [u8; 32]
}
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub struct NodeIDArg {
id: NodeID
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum NodeArg {
ID(NodeIDArg),
Name(String)
}
impl KeyArg {
pub fn get(self) -> [u8; 32] { self.key }
pub fn random() -> Self {
Self {key: rand::rng().random()}
}
}
impl AsRef<[u8; 32]> for KeyArg {
fn as_ref(&self) -> &[u8; 32] { &self.key }
}
impl From<[u8; 32]> for KeyArg {
fn from(key: [u8; 32]) -> Self {
Self {key}
}
}
impl FromStr for KeyArg {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match BASE64.decode(s.as_bytes()) {
Ok(key_vec) if key_vec.len() == 32 => {
Ok(Self {
key: TryFrom::try_from(key_vec).unwrap()
})
},
Ok(_) => Err("Key is not exactly 32 bytes long!".to_owned()),
Err(e) => Err(format!("Failed to decode base64 key: {e}!"))
}
}
}
impl fmt::Debug for KeyArg {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
BASE64.encode_write(&self.key, f)
}
}
impl fmt::Display for KeyArg {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
BASE64.encode_write(&self.key, f)
}
}
impl Serialize for KeyArg {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where S: serde::Serializer
{
serializer.serialize_str(&BASE64.encode(&self.key))
}
}
impl<'de> Deserialize<'de> for KeyArg {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where D: serde::Deserializer<'de>
{
deserializer.deserialize_str(KeyVisitor)
}
}
struct KeyVisitor;
impl Visitor<'_> for KeyVisitor {
type Value = KeyArg;
fn expecting(&self, formatter: &mut std::fmt::Formatter)
-> std::fmt::Result
{
formatter.write_str("a base64-encoded 256bit key")
}
fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
where E: de::Error,
{
value.parse::<KeyArg>().map_err(E::custom)
}
}
impl NodeIDArg {
pub fn get(self) -> NodeID { self.id }
pub fn random() -> Self {Self {id: NodeID::random()}}
}
impl From<NodeIDArg> for NodeID {
fn from(id: NodeIDArg) -> Self {id.get()}
}
impl From<NodeID> for NodeIDArg {
fn from(id: NodeID) -> Self {Self{id}}
}
impl FromStr for NodeIDArg {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
Ok(NodeID::from_str(s)?.into())
}
}
impl std::fmt::Display for NodeIDArg {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
std::fmt::Display::fmt(&self.id, f)
}
}
impl Serialize for NodeIDArg {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where S: serde::Serializer
{
serializer.serialize_str(&self.id.to_string())
}
}
impl<'de> Deserialize<'de> for NodeIDArg {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where D: serde::Deserializer<'de>
{
deserializer.deserialize_str(NodeIDVisitor)
}
}
struct NodeIDVisitor;
impl Visitor<'_> for NodeIDVisitor {
type Value = NodeIDArg;
fn expecting(&self, formatter: &mut std::fmt::Formatter)
-> std::fmt::Result
{
formatter.write_str("a URL-safe base64-encoded 6byte ID")
}
fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
where E: de::Error,
{
match value.parse::<NodeID>() {
Ok(k) => Ok(k.into()),
Err(e) => Err(E::custom(e))
}
}
}
impl From<NodeIDArg> for NodeArg {
fn from(id: NodeIDArg) -> Self {Self::ID(id)}
}
impl From<String> for NodeArg {
fn from(name: String) -> Self {Self::Name(name)}
}
impl From<NodeID> for NodeArg {
fn from(id: NodeID) -> Self {Self::ID(id.into())}
}
impl FromStr for NodeArg {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
if s.starts_with('[') && s.ends_with(']') {
s.get(1..9)
.ok_or_else(|| "node ID must be 8 characters (6B)".to_string())
.and_then(FromStr::from_str)
.map(Self::ID)
}
else if s.is_empty() {
Err("invalid empty node argument".to_string())
}
else {
Ok(Self::Name(s.to_owned()))
}
}
}
impl std::fmt::Display for NodeArg {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::ID(id) => write!(f, "[{id}]"),
Self::Name(name) => write!(f, "[{name}]"),
}
}
}
impl Serialize for NodeArg {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where S: serde::Serializer
{
match self {
Self::ID(id) => serializer.serialize_str(&format!("[{id}]")),
Self::Name(name) => serializer.serialize_str(name),
}
}
}
impl<'de> Deserialize<'de> for NodeArg {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where D: serde::Deserializer<'de>
{
deserializer.deserialize_str(NodeArgVisitor)
}
}
struct NodeArgVisitor;
impl Visitor<'_> for NodeArgVisitor {
type Value = NodeArg;
fn expecting(&self, formatter: &mut std::fmt::Formatter)
-> std::fmt::Result
{
formatter.write_str(
"a name or a 6byte ID encoded as URL-safe\
base64 surrounded by \"[\" and \"]\""
)
}
fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
where E: de::Error,
{
value.parse::<NodeArg>().map_err(E::custom)
}
}