astrid_kernel/
pair_token.rs1use std::path::PathBuf;
34use std::time::{SystemTime, UNIX_EPOCH};
35
36use astrid_core::PrincipalId;
37use astrid_core::dirs::AstridHome;
38use base64::Engine;
39use rand::RngCore;
40use serde::{Deserialize, Serialize};
41use sha2::{Digest, Sha256};
42use subtle::ConstantTimeEq;
43
44pub const TOKEN_RAW_LEN: usize = 24;
47
48pub const MAX_EXPIRY_SECS: u64 = 60 * 60;
52
53#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
56pub struct PairToken {
57 pub token_hash: String,
59 pub principal: PrincipalId,
61 pub expires_at_epoch: u64,
63 pub issued_at_epoch: u64,
65 #[serde(default, skip_serializing_if = "Option::is_none")]
68 pub label: Option<String>,
69}
70
71#[derive(Debug)]
75pub struct PairTokenStore {
76 path: PathBuf,
77}
78
79impl PairTokenStore {
80 #[must_use]
82 pub const fn new(path: PathBuf) -> Self {
83 Self { path }
84 }
85
86 #[must_use]
88 pub fn path_for(home: &AstridHome) -> PathBuf {
89 home.etc_dir().join("pair-tokens.toml")
90 }
91
92 pub fn load(&self) -> Result<Vec<PairToken>, PairTokenStoreError> {
98 let bytes = match std::fs::read(&self.path) {
99 Ok(b) => b,
100 Err(e) if e.kind() == std::io::ErrorKind::NotFound => return Ok(Vec::new()),
101 Err(e) => return Err(PairTokenStoreError::Io(e)),
102 };
103 let text = std::str::from_utf8(&bytes).map_err(|e| {
104 PairTokenStoreError::Io(std::io::Error::new(std::io::ErrorKind::InvalidData, e))
105 })?;
106 if text.trim().is_empty() {
107 return Ok(Vec::new());
108 }
109 let parsed: PersistedFile = toml::from_str(text).map_err(PairTokenStoreError::Toml)?;
110 Ok(parsed.pair_token)
111 }
112
113 pub fn save(&self, tokens: &[PairToken]) -> Result<(), PairTokenStoreError> {
119 if let Some(parent) = self.path.parent() {
120 std::fs::create_dir_all(parent).map_err(PairTokenStoreError::Io)?;
121 }
122 let body = PersistedFile {
123 pair_token: tokens.to_vec(),
124 };
125 let text = toml::to_string_pretty(&body).map_err(PairTokenStoreError::TomlSer)?;
126
127 #[cfg(unix)]
128 {
129 use std::io::Write;
130 use std::os::unix::fs::OpenOptionsExt;
131 let tmp_path = self
132 .path
133 .with_extension(format!("{}.tmp", std::process::id()));
134 let mut f = std::fs::OpenOptions::new()
135 .write(true)
136 .create(true)
137 .truncate(true)
138 .mode(0o600)
139 .open(&tmp_path)
140 .map_err(PairTokenStoreError::Io)?;
141 f.write_all(text.as_bytes())
142 .map_err(PairTokenStoreError::Io)?;
143 f.sync_all().map_err(PairTokenStoreError::Io)?;
144 drop(f);
145 if let Err(e) = std::fs::rename(&tmp_path, &self.path) {
146 let _ = std::fs::remove_file(&tmp_path);
147 return Err(PairTokenStoreError::Io(e));
148 }
149 }
150 #[cfg(not(unix))]
151 {
152 std::fs::write(&self.path, text.as_bytes()).map_err(PairTokenStoreError::Io)?;
153 }
154 Ok(())
155 }
156}
157
158#[derive(Debug)]
160pub enum PairTokenStoreError {
161 Io(std::io::Error),
163 Toml(toml::de::Error),
165 TomlSer(toml::ser::Error),
167}
168
169impl std::fmt::Display for PairTokenStoreError {
170 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
171 match self {
172 Self::Io(e) => write!(f, "pair-token store io: {e}"),
173 Self::Toml(e) => write!(f, "pair-token store parse: {e}"),
174 Self::TomlSer(e) => write!(f, "pair-token store serialise: {e}"),
175 }
176 }
177}
178
179impl std::error::Error for PairTokenStoreError {}
180
181#[derive(Debug, Default, Serialize, Deserialize)]
182struct PersistedFile {
183 #[serde(default)]
184 pair_token: Vec<PairToken>,
185}
186
187#[must_use]
189pub fn generate_token() -> String {
190 let mut bytes = [0u8; TOKEN_RAW_LEN];
191 rand::rngs::OsRng.fill_bytes(&mut bytes);
192 base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes)
193}
194
195#[must_use]
197pub fn hash_token(token: &str) -> String {
198 let mut hasher = Sha256::new();
199 hasher.update(token.as_bytes());
200 hex::encode(hasher.finalize())
201}
202
203#[must_use]
205pub fn ct_hash_eq(a: &str, b: &str) -> bool {
206 if a.len() != b.len() {
207 return false;
208 }
209 a.as_bytes().ct_eq(b.as_bytes()).into()
210}
211
212#[must_use]
214pub fn now_epoch() -> u64 {
215 SystemTime::now()
216 .duration_since(UNIX_EPOCH)
217 .map_or(0, |d| d.as_secs())
218}
219
220pub fn prune_expired(tokens: &mut Vec<PairToken>) -> usize {
222 let now = now_epoch();
223 let before = tokens.len();
224 tokens.retain(|t| t.expires_at_epoch > now);
225 before.saturating_sub(tokens.len())
226}
227
228#[cfg(test)]
229mod tests {
230 use super::*;
231
232 #[test]
233 fn token_is_random_and_short() {
234 let a = generate_token();
235 let b = generate_token();
236 assert_ne!(a, b);
237 assert_eq!(a.len(), 32);
238 }
239
240 #[test]
241 fn hash_is_deterministic_hex() {
242 let h = hash_token("hello");
243 assert_eq!(h.len(), 64);
244 assert_eq!(h, hash_token("hello"));
245 assert_ne!(h, hash_token("world"));
246 }
247
248 #[test]
249 fn round_trip_save_load() {
250 let dir = tempfile::tempdir().unwrap();
251 let store = PairTokenStore::new(dir.path().join("pair-tokens.toml"));
252 let token = PairToken {
253 token_hash: "abc".into(),
254 principal: PrincipalId::new("alice").unwrap(),
255 expires_at_epoch: 9_999_999_999,
256 issued_at_epoch: 1,
257 label: Some("phone".into()),
258 };
259 store.save(&[token.clone()]).unwrap();
260 let loaded = store.load().unwrap();
261 assert_eq!(loaded, vec![token]);
262 }
263
264 #[test]
265 fn prune_drops_expired() {
266 let now = now_epoch();
267 let mut v = vec![
268 PairToken {
269 token_hash: "a".into(),
270 principal: PrincipalId::default(),
271 expires_at_epoch: now.saturating_add(60),
272 issued_at_epoch: now,
273 label: None,
274 },
275 PairToken {
276 token_hash: "b".into(),
277 principal: PrincipalId::default(),
278 expires_at_epoch: now.saturating_sub(60),
279 issued_at_epoch: now.saturating_sub(120),
280 label: None,
281 },
282 ];
283 assert_eq!(prune_expired(&mut v), 1);
284 assert_eq!(v.len(), 1);
285 }
286}