use core::fmt;
use base64::prelude::*;
use rand::RngCore;
use x25519_dalek::{PublicKey as XPublicKey, StaticSecret};
use zeroize::{Zeroize, ZeroizeOnDrop};
use crate::WireguardError;
#[derive(Clone, Zeroize, ZeroizeOnDrop)]
pub struct PrivateKey(StaticSecret);
impl PrivateKey {
#[must_use]
pub fn random() -> PrivateKey {
Self(StaticSecret::random())
}
}
impl PrivateKey {
#[inline]
#[must_use]
pub fn as_bytes(&self) -> &[u8; 32] {
self.0.as_bytes()
}
#[inline]
#[must_use]
pub fn to_bytes(&self) -> [u8; 32] {
self.0.to_bytes()
}
}
impl fmt::Debug for PrivateKey {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_tuple("PrivateKey")
.field(&self.to_string())
.finish()
}
}
impl fmt::Display for PrivateKey {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", BASE64_STANDARD.encode(self.as_bytes()))
}
}
impl PartialEq for PrivateKey {
fn eq(&self, other: &Self) -> bool {
self.as_bytes() == other.as_bytes()
}
}
impl TryFrom<&str> for PrivateKey {
type Error = WireguardError;
fn try_from(value: &str) -> Result<Self, Self::Error> {
let bytes: [u8; 32] = BASE64_STANDARD
.decode(value)
.map_err(|_| WireguardError::InvalidPrivateKey)?
.try_into()
.map_err(|_| WireguardError::InvalidPrivateKey)?;
Ok(Self(StaticSecret::from(bytes)))
}
}
impl TryFrom<String> for PrivateKey {
type Error = WireguardError;
fn try_from(value: String) -> Result<Self, Self::Error> {
Self::try_from(value.as_str())
}
}
impl From<[u8; 32]> for PrivateKey {
fn from(value: [u8; 32]) -> Self {
Self(StaticSecret::from(value))
}
}
#[derive(Clone, PartialEq, Zeroize, ZeroizeOnDrop)]
pub struct PublicKey(XPublicKey);
impl PublicKey {
#[inline]
#[must_use]
pub fn to_bytes(&self) -> [u8; 32] {
self.0.to_bytes()
}
#[inline]
#[must_use]
pub fn as_bytes(&self) -> &[u8; 32] {
self.0.as_bytes()
}
}
impl fmt::Debug for PublicKey {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_tuple("PublicKey").field(&self.to_string()).finish()
}
}
impl fmt::Display for PublicKey {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", BASE64_STANDARD.encode(self.as_bytes()))
}
}
impl TryFrom<&str> for PublicKey {
type Error = WireguardError;
fn try_from(value: &str) -> Result<Self, Self::Error> {
let bytes: [u8; 32] = BASE64_STANDARD
.decode(value)
.map_err(|_| WireguardError::InvalidPublicKey)?
.try_into()
.map_err(|_| WireguardError::InvalidPublicKey)?;
Ok(Self(XPublicKey::from(bytes)))
}
}
impl TryFrom<String> for PublicKey {
type Error = WireguardError;
fn try_from(value: String) -> Result<Self, Self::Error> {
Self::try_from(value.as_str())
}
}
impl From<[u8; 32]> for PublicKey {
fn from(value: [u8; 32]) -> Self {
Self(XPublicKey::from(value))
}
}
impl From<&PrivateKey> for PublicKey {
fn from(value: &PrivateKey) -> Self {
Self(XPublicKey::from(&value.0))
}
}
#[derive(Clone, PartialEq, Zeroize, ZeroizeOnDrop)]
pub struct PresharedKey([u8; 32]);
impl PresharedKey {
#[must_use]
pub fn random() -> Self {
let mut key = [0u8; 32];
rand::rng().fill_bytes(&mut key);
Self(key)
}
}
impl PresharedKey {
#[inline]
#[must_use]
pub fn to_bytes(&self) -> [u8; 32] {
self.0
}
#[inline]
#[must_use]
pub fn as_bytes(&self) -> &[u8; 32] {
&self.0
}
}
impl fmt::Debug for PresharedKey {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_tuple("PresharedKey")
.field(&self.to_string())
.finish()
}
}
impl fmt::Display for PresharedKey {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", BASE64_STANDARD.encode(self.as_bytes()))
}
}
impl From<[u8; 32]> for PresharedKey {
fn from(value: [u8; 32]) -> Self {
Self(value)
}
}
impl TryFrom<&str> for PresharedKey {
type Error = WireguardError;
fn try_from(value: &str) -> Result<Self, Self::Error> {
let bytes: [u8; 32] = BASE64_STANDARD
.decode(value)
.map_err(|_| WireguardError::InvalidPresharedKey)?
.try_into()
.map_err(|_| WireguardError::InvalidPresharedKey)?;
Ok(Self(bytes))
}
}
impl TryFrom<String> for PresharedKey {
type Error = WireguardError;
fn try_from(value: String) -> Result<Self, Self::Error> {
Self::try_from(value.as_str())
}
}