1use std::path::PathBuf;
43use std::time::{SystemTime, UNIX_EPOCH};
44
45use astrid_core::dirs::AstridHome;
46use base64::Engine;
47use serde::{Deserialize, Serialize};
48use sha2::{Digest, Sha256};
49use subtle::ConstantTimeEq;
50
51pub const TOKEN_RAW_LEN: usize = 24;
56
57pub const MAX_EXPIRY_SECS: u64 = 60 * 60 * 24 * 30;
61
62#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
65pub struct Invite {
66 pub token_hash: String,
68 pub group: String,
70 pub remaining_uses: u32,
72 #[serde(default, skip_serializing_if = "Option::is_none")]
75 pub expires_at_epoch: Option<u64>,
76 pub issued_at_epoch: u64,
78 #[serde(default, skip_serializing_if = "Option::is_none")]
80 pub metadata: Option<String>,
81}
82
83#[derive(Debug)]
87pub struct InviteStore {
88 path: PathBuf,
89}
90
91impl InviteStore {
92 #[must_use]
95 pub const fn new(path: PathBuf) -> Self {
96 Self { path }
97 }
98
99 #[must_use]
101 pub fn path_for(home: &AstridHome) -> PathBuf {
102 home.etc_dir().join("invites.toml")
103 }
104
105 pub fn load(&self) -> Result<Vec<Invite>, InviteStoreError> {
111 let bytes = match std::fs::read(&self.path) {
112 Ok(b) => b,
113 Err(e) if e.kind() == std::io::ErrorKind::NotFound => return Ok(Vec::new()),
114 Err(e) => return Err(InviteStoreError::Io(e)),
115 };
116 let text = std::str::from_utf8(&bytes).map_err(|e| {
117 InviteStoreError::Io(std::io::Error::new(std::io::ErrorKind::InvalidData, e))
118 })?;
119 if text.trim().is_empty() {
120 return Ok(Vec::new());
121 }
122 let parsed: PersistedFile = toml::from_str(text).map_err(InviteStoreError::Toml)?;
123 Ok(parsed.invite)
124 }
125
126 pub fn save(&self, invites: &[Invite]) -> Result<(), InviteStoreError> {
134 if let Some(parent) = self.path.parent() {
135 std::fs::create_dir_all(parent).map_err(InviteStoreError::Io)?;
136 }
137 let body = PersistedFile {
138 invite: invites.to_vec(),
139 };
140 let text = toml::to_string_pretty(&body).map_err(InviteStoreError::TomlSer)?;
141
142 #[cfg(unix)]
143 {
144 use std::io::Write;
145 use std::os::unix::fs::OpenOptionsExt;
146 let tmp_path = self
147 .path
148 .with_extension(format!("{}.tmp", std::process::id()));
149 let mut f = std::fs::OpenOptions::new()
150 .write(true)
151 .create(true)
152 .truncate(true)
153 .mode(0o600)
154 .open(&tmp_path)
155 .map_err(InviteStoreError::Io)?;
156 f.write_all(text.as_bytes()).map_err(InviteStoreError::Io)?;
157 f.sync_all().map_err(InviteStoreError::Io)?;
158 drop(f);
159 if let Err(e) = std::fs::rename(&tmp_path, &self.path) {
160 let _ = std::fs::remove_file(&tmp_path);
161 return Err(InviteStoreError::Io(e));
162 }
163 }
164 #[cfg(not(unix))]
165 {
166 std::fs::write(&self.path, text.as_bytes()).map_err(InviteStoreError::Io)?;
167 }
168 Ok(())
169 }
170}
171
172#[derive(Debug)]
174pub enum InviteStoreError {
175 Io(std::io::Error),
177 Toml(toml::de::Error),
179 TomlSer(toml::ser::Error),
181}
182
183impl std::fmt::Display for InviteStoreError {
184 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
185 match self {
186 Self::Io(e) => write!(f, "invite store io: {e}"),
187 Self::Toml(e) => write!(f, "invite store parse: {e}"),
188 Self::TomlSer(e) => write!(f, "invite store serialise: {e}"),
189 }
190 }
191}
192
193impl std::error::Error for InviteStoreError {}
194
195#[derive(Debug, Default, Serialize, Deserialize)]
196struct PersistedFile {
197 #[serde(default)]
198 invite: Vec<Invite>,
199}
200
201#[must_use]
203pub fn generate_token() -> String {
204 use rand::RngCore;
205 let mut bytes = [0u8; TOKEN_RAW_LEN];
206 rand::rngs::OsRng.fill_bytes(&mut bytes);
207 base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes)
208}
209
210#[must_use]
212pub fn hash_token(token: &str) -> String {
213 let mut hasher = Sha256::new();
214 hasher.update(token.as_bytes());
215 hex::encode(hasher.finalize())
216}
217
218#[must_use]
222pub fn ct_hash_eq(a: &str, b: &str) -> bool {
223 if a.len() != b.len() {
224 return false;
225 }
226 a.as_bytes().ct_eq(b.as_bytes()).into()
227}
228
229#[must_use]
232pub fn now_epoch() -> u64 {
233 SystemTime::now()
234 .duration_since(UNIX_EPOCH)
235 .map_or(0, |d| d.as_secs())
236}
237
238pub fn prune_expired(invites: &mut Vec<Invite>) -> usize {
242 let now = now_epoch();
243 let before = invites.len();
244 invites.retain(|i| {
245 if i.remaining_uses == 0 {
246 return false;
247 }
248 i.expires_at_epoch.is_none_or(|exp| exp > now)
249 });
250 before.saturating_sub(invites.len())
251}
252
253pub fn prune_file(store: &InviteStore) -> Result<usize, InviteStoreError> {
259 let mut invites = store.load()?;
260 let removed = prune_expired(&mut invites);
261 if removed > 0 {
262 store.save(&invites)?;
263 }
264 Ok(removed)
265}
266
267#[cfg(test)]
268mod tests {
269 use super::*;
270
271 #[test]
272 fn token_round_trip_is_random_and_url_safe() {
273 let a = generate_token();
274 let b = generate_token();
275 assert_ne!(a, b, "two tokens must differ");
276 assert!(
277 a.chars()
278 .all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_')
279 );
280 assert_eq!(a.len(), 32);
282 }
283
284 #[test]
285 fn hash_token_is_deterministic_hex_sha256() {
286 let h = hash_token("hello");
287 assert_eq!(h.len(), 64);
288 assert!(h.chars().all(|c| c.is_ascii_hexdigit()));
289 assert_eq!(h, hash_token("hello"));
290 assert_ne!(h, hash_token("world"));
291 }
292
293 #[test]
294 fn ct_hash_eq_rejects_length_mismatch() {
295 assert!(!ct_hash_eq("abc", "abcd"));
296 assert!(ct_hash_eq("abc", "abc"));
297 assert!(!ct_hash_eq("abc", "abd"));
298 }
299
300 #[test]
301 fn prune_removes_expired_and_consumed() {
302 let mut v = vec![
303 Invite {
304 token_hash: "a".into(),
305 group: "agent".into(),
306 remaining_uses: 1,
307 expires_at_epoch: Some(now_epoch().saturating_add(60)),
308 issued_at_epoch: 0,
309 metadata: None,
310 },
311 Invite {
312 token_hash: "b".into(),
313 group: "agent".into(),
314 remaining_uses: 0,
315 expires_at_epoch: None,
316 issued_at_epoch: 0,
317 metadata: None,
318 },
319 Invite {
320 token_hash: "c".into(),
321 group: "agent".into(),
322 remaining_uses: 1,
323 expires_at_epoch: Some(now_epoch().saturating_sub(60)),
324 issued_at_epoch: 0,
325 metadata: None,
326 },
327 ];
328 let removed = prune_expired(&mut v);
329 assert_eq!(removed, 2);
330 assert_eq!(v.len(), 1);
331 assert_eq!(v[0].token_hash, "a");
332 }
333
334 #[test]
335 fn save_round_trips() {
336 let dir = tempfile::tempdir().unwrap();
337 let store = InviteStore::new(dir.path().join("invites.toml"));
338 let now = now_epoch();
339 let invite = Invite {
340 token_hash: "deadbeef".into(),
341 group: "agent".into(),
342 remaining_uses: 2,
343 expires_at_epoch: Some(now.saturating_add(3600)),
344 issued_at_epoch: now,
345 metadata: Some("alice".into()),
346 };
347 store.save(&[invite.clone()]).unwrap();
348 let loaded = store.load().unwrap();
349 assert_eq!(loaded, vec![invite]);
350 }
351
352 #[test]
353 fn empty_file_loads_as_empty_vec() {
354 let dir = tempfile::tempdir().unwrap();
355 let store = InviteStore::new(dir.path().join("invites.toml"));
356 assert_eq!(store.load().unwrap(), Vec::<Invite>::new());
358 std::fs::write(&store.path, "").unwrap();
360 assert_eq!(store.load().unwrap(), Vec::<Invite>::new());
361 }
362
363 #[cfg(unix)]
364 #[test]
365 fn save_writes_0600_perms() {
366 use std::os::unix::fs::PermissionsExt;
367 let dir = tempfile::tempdir().unwrap();
368 let store = InviteStore::new(dir.path().join("invites.toml"));
369 store.save(&[]).unwrap();
370 let perms = std::fs::metadata(&store.path).unwrap().permissions();
371 assert_eq!(perms.mode() & 0o777, 0o600);
372 }
373}