pub(super) mod key;
pub mod rendezvous;
mod server_messages;
#[cfg(test)]
mod test;
pub mod wordlist;
use serde_derive::{Deserialize, Serialize};
use std::{borrow::Cow, str::FromStr};
use thiserror::Error;
use crate::Wordlist;
use self::{rendezvous::*, server_messages::EncryptedMessage};
use crypto_secretbox as secretbox;
#[derive(Debug, thiserror::Error)]
#[non_exhaustive]
pub enum WormholeError {
#[error("Corrupt message received from peer")]
ProtocolJson(#[from] serde_json::Error),
#[error("Error with the rendezvous server connection")]
ServerError(#[from] rendezvous::RendezvousError),
#[error("Protocol error: {}", _0)]
Protocol(Box<str>),
#[error(
"Key confirmation failed. If you didn't mistype the code, \
this is a sign of an attacker guessing passwords. Please try \
again some time later."
)]
PakeFailed,
#[error("Cannot decrypt a received message")]
Crypto,
#[error("Nameplate is unclaimed: {}", _0)]
UnclaimedNameplate(Nameplate),
#[error("The provided code is invalid: {_0}")]
CodeInvalid(#[from] ParseCodeError),
}
impl WormholeError {
pub fn is_scared(&self) -> bool {
matches!(self, Self::PakeFailed)
}
}
impl From<std::convert::Infallible> for WormholeError {
fn from(_: std::convert::Infallible) -> Self {
unreachable!()
}
}
pub struct MailboxConnection<V: serde::Serialize + Send + Sync + 'static> {
config: AppConfig<V>,
server: RendezvousServer,
welcome: Option<String>,
mailbox: Mailbox,
code: Code,
}
impl<V: serde::Serialize + Send + Sync + 'static> MailboxConnection<V> {
pub async fn create(config: AppConfig<V>, code_length: usize) -> Result<Self, WormholeError> {
Self::create_with_validated_password(
config,
Wordlist::default_wordlist(code_length).choose_words(),
)
.await
}
pub async fn create_with_password(
config: AppConfig<V>,
password: Password,
) -> Result<Self, WormholeError> {
Self::create_with_validated_password(config, password).await
}
async fn create_with_validated_password(
config: AppConfig<V>,
password: Password,
) -> Result<Self, WormholeError> {
let (mut server, welcome) =
RendezvousServer::connect(&config.id, &config.rendezvous_url).await?;
let (nameplate, mailbox) = server.allocate_claim_open().await?;
let code = Code::from_components(nameplate, password);
Ok(MailboxConnection {
config,
server,
mailbox,
code,
welcome,
})
}
pub async fn connect(
config: AppConfig<V>,
code: Code,
allocate: bool,
) -> Result<Self, WormholeError> {
let (mut server, welcome) =
RendezvousServer::connect(&config.id, &config.rendezvous_url).await?;
let nameplate = code.nameplate();
if !allocate {
let nameplates = server.list_nameplates().await?;
if !nameplates.contains(&nameplate) {
server.shutdown(Mood::Errory).await?;
return Err(WormholeError::UnclaimedNameplate(nameplate));
}
}
let mailbox = server.claim_open(nameplate).await?;
Ok(MailboxConnection {
config,
server,
mailbox,
code,
welcome,
})
}
pub async fn shutdown(self, mood: Mood) -> Result<(), WormholeError> {
self.server
.shutdown(mood)
.await
.map_err(WormholeError::ServerError)
}
pub fn welcome(&self) -> Option<&str> {
self.welcome.as_deref()
}
pub fn code(&self) -> &Code {
&self.code
}
}
#[derive(Debug)]
pub struct Wormhole {
server: RendezvousServer,
phase: u64,
key: key::Key<key::WormholeKey>,
appid: AppID,
verifier: Box<secretbox::Key>,
our_version: Box<dyn std::any::Any + Send + Sync>,
peer_version: serde_json::Value,
}
impl Wormhole {
pub async fn connect(
mailbox_connection: MailboxConnection<impl serde::Serialize + Send + Sync + 'static>,
) -> Result<Self, WormholeError> {
let MailboxConnection {
config,
mut server,
mailbox: _mailbox,
code,
welcome: _welcome,
} = mailbox_connection;
let (pake_state, pake_msg_ser) = key::make_pake(code.as_str(), &config.id);
server.send_peer_message(Phase::PAKE, pake_msg_ser).await?;
let peer_pake = key::extract_pake_msg(&server.next_peer_message_some().await?.body)?;
let key = pake_state
.finish(&peer_pake)
.map_err(|_| WormholeError::PakeFailed)
.map(|key| *secretbox::Key::from_slice(&key))?;
let mut versions = key::VersionsMessage::new();
versions.set_app_versions(serde_json::to_value(&config.app_version).unwrap());
let (version_phase, version_msg) = key::build_version_msg(server.side(), &key, &versions);
server.send_peer_message(version_phase, version_msg).await?;
let peer_version = server.next_peer_message_some().await?;
let versions: key::VersionsMessage = peer_version
.decrypt(&key)
.ok_or(WormholeError::PakeFailed)
.and_then(|plaintext| {
serde_json::from_slice(&plaintext).map_err(WormholeError::ProtocolJson)
})?;
let peer_version = versions.app_versions;
if server.needs_nameplate_release() {
server.release_nameplate().await?;
}
tracing::info!("Found peer on the rendezvous server.");
Ok(Self {
server,
appid: config.id,
phase: 0,
key: key::Key::new(key.into()),
verifier: Box::new(key::derive_verifier(&key)),
our_version: Box::new(config.app_version),
peer_version,
})
}
pub async fn send(&mut self, plaintext: Vec<u8>) -> Result<(), WormholeError> {
let phase_string = Phase::numeric(self.phase);
self.phase += 1;
let data_key = key::derive_phase_key(self.server.side(), self.key.as_ref(), &phase_string);
let (_nonce, encrypted) = key::encrypt_data(&data_key, &plaintext);
self.server
.send_peer_message(phase_string, encrypted)
.await?;
Ok(())
}
pub async fn send_json<T: serde::Serialize>(
&mut self,
message: &T,
) -> Result<(), WormholeError> {
self.send(serde_json::to_vec(message).unwrap()).await
}
pub async fn receive(&mut self) -> Result<Vec<u8>, WormholeError> {
loop {
let peer_message = match self.server.next_peer_message().await? {
Some(peer_message) => peer_message,
None => continue,
};
if peer_message.phase.to_num().is_none() {
todo!("log and ignore, for future expansion");
}
let decrypted_message = peer_message
.decrypt(self.key.as_ref())
.ok_or(WormholeError::Crypto)?;
return Ok(decrypted_message);
}
}
pub async fn receive_json<T>(&mut self) -> Result<Result<T, serde_json::Error>, WormholeError>
where
T: for<'a> serde::Deserialize<'a>,
{
self.receive().await.map(|data: Vec<u8>| {
serde_json::from_slice(&data).inspect_err(|_| {
tracing::error!(
"Received invalid data from peer: '{}'",
String::from_utf8_lossy(&data)
);
})
})
}
pub async fn close(self) -> Result<(), WormholeError> {
tracing::debug!("Closing Wormhole…");
self.server.shutdown(Mood::Happy).await.map_err(Into::into)
}
pub fn appid(&self) -> &AppID {
&self.appid
}
pub fn key(&self) -> &key::Key<key::WormholeKey> {
&self.key
}
pub fn verifier(&self) -> &secretbox::Key {
&self.verifier
}
pub fn our_version(&self) -> &(dyn std::any::Any + Send + Sync) {
&*self.our_version
}
pub fn peer_version(&self) -> &serde_json::Value {
&self.peer_version
}
}
#[derive(Debug, PartialEq, Copy, Clone, Deserialize, Serialize, derive_more::Display)]
pub enum Mood {
#[serde(rename = "happy")]
Happy,
#[serde(rename = "lonely")]
Lonely,
#[serde(rename = "errory")]
Errory,
#[serde(rename = "scary")]
Scared,
#[serde(rename = "unwelcome")]
Unwelcome,
}
#[derive(PartialEq, Eq, Clone, Debug)]
pub struct AppConfig<V> {
pub id: AppID,
pub rendezvous_url: Cow<'static, str>,
pub app_version: V,
}
impl<V> AppConfig<V> {
pub fn id(mut self, id: AppID) -> Self {
self.id = id;
self
}
pub fn rendezvous_url(mut self, rendezvous_url: Cow<'static, str>) -> Self {
self.rendezvous_url = rendezvous_url;
self
}
}
impl<V: serde::Serialize> AppConfig<V> {
pub fn app_version(mut self, app_version: V) -> Self {
self.app_version = app_version;
self
}
}
#[derive(
PartialEq, Eq, Clone, Debug, Deserialize, Serialize, derive_more::Display, derive_more::Deref,
)]
#[deref(forward)]
pub struct AppID(#[deref] pub(crate) Cow<'static, str>);
impl AppID {
pub fn new(id: impl Into<Cow<'static, str>>) -> Self {
AppID(id.into())
}
}
impl From<String> for AppID {
fn from(s: String) -> Self {
Self::new(s)
}
}
impl AsRef<str> for AppID {
fn as_ref(&self) -> &str {
&self.0
}
}
#[derive(
PartialEq, Eq, Clone, Debug, Deserialize, Serialize, derive_more::Display, derive_more::Deref,
)]
#[serde(transparent)]
#[display("MySide({})", "&*_0")]
pub(crate) struct MySide(EitherSide);
impl MySide {
pub fn generate() -> MySide {
use rand::{RngCore, rngs::OsRng};
let mut bytes: [u8; 5] = [0; 5];
OsRng.fill_bytes(&mut bytes);
MySide(EitherSide(hex::encode(bytes)))
}
#[cfg(test)]
pub fn unchecked_from_string(s: String) -> MySide {
MySide(EitherSide(s))
}
}
#[derive(
PartialEq, Eq, Clone, Debug, Deserialize, Serialize, derive_more::Display, derive_more::Deref,
)]
#[serde(transparent)]
#[display("TheirSide({})", "&*_0")]
pub(crate) struct TheirSide(EitherSide);
impl<S: Into<String>> From<S> for TheirSide {
fn from(s: S) -> TheirSide {
TheirSide(EitherSide(s.into()))
}
}
#[derive(
PartialEq, Eq, Clone, Debug, Deserialize, Serialize, derive_more::Display, derive_more::Deref,
)]
#[serde(transparent)]
#[deref(forward)]
#[display("{}", "&*_0")]
pub(crate) struct EitherSide(pub String);
impl<S: Into<String>> From<S> for EitherSide {
fn from(s: S) -> EitherSide {
EitherSide(s.into())
}
}
#[derive(PartialEq, Eq, Clone, Debug, Hash, Deserialize, Serialize, derive_more::Display)]
#[serde(transparent)]
pub(crate) struct Phase(Cow<'static, str>);
impl Phase {
pub const VERSION: Self = Phase(Cow::Borrowed("version"));
pub const PAKE: Self = Phase(Cow::Borrowed("pake"));
pub fn numeric(phase: u64) -> Self {
Phase(phase.to_string().into())
}
#[allow(dead_code)]
pub fn is_version(&self) -> bool {
self == &Self::VERSION
}
#[allow(dead_code)]
pub fn is_pake(&self) -> bool {
self == &Self::PAKE
}
pub fn to_num(&self) -> Option<u64> {
self.0.parse().ok()
}
}
impl AsRef<str> for Phase {
fn as_ref(&self) -> &str {
&self.0
}
}
#[derive(PartialEq, Eq, Clone, Debug, Deserialize, Serialize, derive_more::Display)]
#[serde(transparent)]
pub(crate) struct Mailbox(pub String);
#[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Debug, Copy, derive_more::Display, Error)]
#[display("Nameplate is not a number. Nameplates must be a number >= 1.")]
#[non_exhaustive]
pub struct ParseNameplateError {}
#[derive(PartialEq, Eq, Clone, Debug, Deserialize, Serialize, derive_more::Display)]
#[serde(transparent)]
#[display("{}", _0)]
pub struct Nameplate(String);
impl Nameplate {
#[expect(unsafe_code)]
#[doc(hidden)]
pub unsafe fn new_unchecked(n: &str) -> Self {
Nameplate(n.into())
}
}
impl FromStr for Nameplate {
type Err = ParseNameplateError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
if !s.chars().all(|c| c.is_ascii_digit()) || u128::from_str(s) == Ok(0) {
Err(ParseNameplateError {})
} else {
Ok(Self(s.to_string()))
}
}
}
#[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Debug, Copy, derive_more::Display, Error)]
#[non_exhaustive]
pub enum ParsePasswordError {
#[display("Password too short. It is only {value} bytes, but must be at least {required}")]
TooShort {
value: usize,
required: usize,
},
#[display(
"Password too weak. It can be guessed with an average of {value} tries, but must be at least {required}"
)]
LittleEntropy {
value: u64,
required: u64,
},
}
#[derive(Clone, Debug, Serialize, derive_more::Display)]
#[serde(transparent)]
#[display("{password}")]
pub struct Password {
password: String,
#[serde(skip)]
entropy: zxcvbn::Entropy,
}
impl PartialEq for Password {
fn eq(&self, other: &Self) -> bool {
self.password == other.password
}
}
impl Eq for Password {}
impl Password {
#[expect(unsafe_code)]
#[doc(hidden)]
pub unsafe fn new_unchecked(n: impl Into<String>) -> Self {
let password = n.into();
let entropy = Self::calculate_entropy(&password);
Password { password, entropy }
}
fn calculate_entropy(password: &str) -> zxcvbn::Entropy {
static PGP_WORDLIST: std::sync::OnceLock<Vec<&str>> = std::sync::OnceLock::new();
let words = PGP_WORDLIST.get_or_init(|| {
Wordlist::default_wordlist(2)
.into_words()
.map(|s| &*s.leak())
.collect::<Vec<_>>()
});
zxcvbn::zxcvbn(password, &words[..])
}
}
impl From<Password> for String {
fn from(value: Password) -> Self {
value.password
}
}
impl AsRef<str> for Password {
fn as_ref(&self) -> &str {
&self.password
}
}
impl FromStr for Password {
type Err = ParsePasswordError;
fn from_str(password: &str) -> Result<Self, Self::Err> {
let password = password.to_string();
if password.len() < 4 {
Err(ParsePasswordError::TooShort {
value: password.len(),
required: 4,
})
} else {
let entropy = Self::calculate_entropy(&password);
if entropy.guesses() < 2_u64.pow(16) {
return Err(ParsePasswordError::LittleEntropy {
value: entropy.guesses(),
required: 2_u64.pow(16),
});
}
Ok(Self { password, entropy })
}
}
}
#[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Debug, Copy, derive_more::Display, Error)]
#[non_exhaustive]
pub enum ParseCodeError {
#[display("The code is empty")]
Empty,
#[display("A code must contain at least one '-' to separate nameplate from password")]
SeparatorMissing,
#[display("{_0}")]
Nameplate(#[from] ParseNameplateError),
#[display("{_0}")]
Password(#[from] ParsePasswordError),
}
#[derive(PartialEq, Eq, Clone, Debug, derive_more::Display)]
#[display("{}", _0)]
pub struct Code(String);
impl Code {
pub fn from_components(nameplate: Nameplate, password: Password) -> Self {
Self(format!("{nameplate}-{password}"))
}
pub fn nameplate(&self) -> Nameplate {
#[expect(unsafe_code)]
unsafe {
Nameplate::new_unchecked(self.0.split('-').next().unwrap())
}
}
pub fn password(&self) -> Password {
#[expect(unsafe_code)]
unsafe {
Password::new_unchecked(self.0.splitn(2, '-').last().unwrap())
}
}
pub(crate) fn as_str(&self) -> &str {
&self.0
}
}
impl FromStr for Code {
type Err = ParseCodeError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.split_once('-') {
Some((n, p)) => {
let password: Password = p.parse()?;
let nameplate: Nameplate = n.parse()?;
Ok(Self(format!("{nameplate}-{password}")))
},
None if s.is_empty() => Err(ParseCodeError::Empty),
None => Err(ParseCodeError::SeparatorMissing),
}
}
}