1use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
2use rand::{RngCore, rngs::OsRng};
3use serde::{Deserialize, Serialize};
4use sha2::{Digest, Sha256};
5use thiserror::Error;
6
7const THREAD_BYTES: usize = 20;
8const INVITE_SECRET_BYTES: usize = 32;
9
10#[derive(Debug, Error)]
11pub enum InviteError {
12 #[error("invite must have format f8s_invite_v1.<thread_id>.<secret>")]
13 BadFormat,
14 #[error("thread id is invalid")]
15 BadThreadId,
16 #[error("invite secret is invalid")]
17 BadSecret,
18}
19
20#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
21pub struct ThreadId(pub String);
22
23#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
24pub struct Invite {
25 pub thread_id: ThreadId,
26 pub secret: String,
27}
28
29impl ThreadId {
30 pub fn new() -> Self {
31 let mut bytes = [0_u8; THREAD_BYTES];
32 OsRng.fill_bytes(&mut bytes);
33 Self(format!("thd_{}", URL_SAFE_NO_PAD.encode(bytes)))
34 }
35
36 pub fn parse(value: &str) -> Result<Self, InviteError> {
37 let encoded = value.strip_prefix("thd_").ok_or(InviteError::BadThreadId)?;
38 let bytes = URL_SAFE_NO_PAD
39 .decode(encoded)
40 .map_err(|_| InviteError::BadThreadId)?;
41 if bytes.len() != THREAD_BYTES {
42 return Err(InviteError::BadThreadId);
43 }
44 Ok(Self(value.to_string()))
45 }
46}
47
48impl Default for ThreadId {
49 fn default() -> Self {
50 Self::new()
51 }
52}
53
54impl Invite {
55 pub fn new(thread_id: ThreadId) -> Self {
56 let mut bytes = [0_u8; INVITE_SECRET_BYTES];
57 OsRng.fill_bytes(&mut bytes);
58 Self {
59 thread_id,
60 secret: URL_SAFE_NO_PAD.encode(bytes),
61 }
62 }
63
64 pub fn parse(value: &str) -> Result<Self, InviteError> {
65 let mut parts = value.split('.');
66 if parts.next() != Some("f8s_invite_v1") {
67 return Err(InviteError::BadFormat);
68 }
69 let thread_id = parts.next().ok_or(InviteError::BadFormat)?;
70 let secret = parts.next().ok_or(InviteError::BadFormat)?;
71 if parts.next().is_some() {
72 return Err(InviteError::BadFormat);
73 }
74
75 let thread_id = ThreadId::parse(thread_id)?;
76 let bytes = URL_SAFE_NO_PAD
77 .decode(secret)
78 .map_err(|_| InviteError::BadSecret)?;
79 if bytes.len() != INVITE_SECRET_BYTES {
80 return Err(InviteError::BadSecret);
81 }
82
83 Ok(Self {
84 thread_id,
85 secret: secret.to_string(),
86 })
87 }
88
89 pub fn expose(&self) -> String {
90 format!("f8s_invite_v1.{}.{}", self.thread_id.0, self.secret)
91 }
92
93 pub fn verifier(&self) -> String {
94 let mut hasher = Sha256::new();
95 hasher.update(self.thread_id.0.as_bytes());
96 hasher.update(b":");
97 hasher.update(self.secret.as_bytes());
98 URL_SAFE_NO_PAD.encode(hasher.finalize())
99 }
100}
101
102#[cfg(test)]
103mod tests {
104 use super::*;
105
106 #[test]
107 fn invite_round_trip() {
108 let invite = Invite::new(ThreadId::new());
109 let parsed = Invite::parse(&invite.expose()).unwrap();
110 assert_eq!(invite.thread_id, parsed.thread_id);
111 assert_eq!(invite.secret, parsed.secret);
112 assert_eq!(invite.verifier(), parsed.verifier());
113 }
114
115 #[test]
116 fn rejects_low_entropy_invite() {
117 let value = "f8s_invite_v1.thd_abc.short";
118 assert!(Invite::parse(value).is_err());
119 }
120}