use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
use rand::{RngCore, rngs::OsRng};
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use thiserror::Error;
const THREAD_BYTES: usize = 20;
const INVITE_SECRET_BYTES: usize = 32;
#[derive(Debug, Error)]
pub enum InviteError {
#[error("invite must have format f8s_invite_v1.<thread_id>.<secret>")]
BadFormat,
#[error("thread id is invalid")]
BadThreadId,
#[error("invite secret is invalid")]
BadSecret,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct ThreadId(pub String);
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct Invite {
pub thread_id: ThreadId,
pub secret: String,
}
impl ThreadId {
pub fn new() -> Self {
let mut bytes = [0_u8; THREAD_BYTES];
OsRng.fill_bytes(&mut bytes);
Self(format!("thd_{}", URL_SAFE_NO_PAD.encode(bytes)))
}
pub fn parse(value: &str) -> Result<Self, InviteError> {
let encoded = value.strip_prefix("thd_").ok_or(InviteError::BadThreadId)?;
let bytes = URL_SAFE_NO_PAD
.decode(encoded)
.map_err(|_| InviteError::BadThreadId)?;
if bytes.len() != THREAD_BYTES {
return Err(InviteError::BadThreadId);
}
Ok(Self(value.to_string()))
}
}
impl Default for ThreadId {
fn default() -> Self {
Self::new()
}
}
impl Invite {
pub fn new(thread_id: ThreadId) -> Self {
let mut bytes = [0_u8; INVITE_SECRET_BYTES];
OsRng.fill_bytes(&mut bytes);
Self {
thread_id,
secret: URL_SAFE_NO_PAD.encode(bytes),
}
}
pub fn parse(value: &str) -> Result<Self, InviteError> {
let mut parts = value.split('.');
if parts.next() != Some("f8s_invite_v1") {
return Err(InviteError::BadFormat);
}
let thread_id = parts.next().ok_or(InviteError::BadFormat)?;
let secret = parts.next().ok_or(InviteError::BadFormat)?;
if parts.next().is_some() {
return Err(InviteError::BadFormat);
}
let thread_id = ThreadId::parse(thread_id)?;
let bytes = URL_SAFE_NO_PAD
.decode(secret)
.map_err(|_| InviteError::BadSecret)?;
if bytes.len() != INVITE_SECRET_BYTES {
return Err(InviteError::BadSecret);
}
Ok(Self {
thread_id,
secret: secret.to_string(),
})
}
pub fn expose(&self) -> String {
format!("f8s_invite_v1.{}.{}", self.thread_id.0, self.secret)
}
pub fn verifier(&self) -> String {
let mut hasher = Sha256::new();
hasher.update(self.thread_id.0.as_bytes());
hasher.update(b":");
hasher.update(self.secret.as_bytes());
URL_SAFE_NO_PAD.encode(hasher.finalize())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn invite_round_trip() {
let invite = Invite::new(ThreadId::new());
let parsed = Invite::parse(&invite.expose()).unwrap();
assert_eq!(invite.thread_id, parsed.thread_id);
assert_eq!(invite.secret, parsed.secret);
assert_eq!(invite.verifier(), parsed.verifier());
}
#[test]
fn rejects_low_entropy_invite() {
let value = "f8s_invite_v1.thd_abc.short";
assert!(Invite::parse(value).is_err());
}
}