astrid_core/
session_token.rs1use std::fmt;
12use std::io;
13use std::path::Path;
14
15use rand::RngCore;
16use serde::{Deserialize, Serialize};
17use subtle::ConstantTimeEq;
18
19pub const PROTOCOL_VERSION: u8 = 1;
22
23pub struct SessionToken([u8; 32]);
25
26impl SessionToken {
27 #[must_use]
29 pub fn generate() -> Self {
30 let mut bytes = [0u8; 32];
31 rand::rngs::OsRng.fill_bytes(&mut bytes);
32 Self(bytes)
33 }
34
35 #[must_use]
37 pub fn to_hex(&self) -> String {
38 let mut hex = String::with_capacity(64);
39 for byte in &self.0 {
40 use fmt::Write;
41 let _ = write!(hex, "{byte:02x}");
42 }
43 hex
44 }
45
46 pub fn from_hex(hex: &str) -> Result<Self, io::Error> {
53 if hex.len() != 64 {
54 return Err(io::Error::new(
55 io::ErrorKind::InvalidData,
56 format!("session token hex must be 64 chars, got {}", hex.len()),
57 ));
58 }
59 let mut bytes = [0u8; 32];
60 for (i, chunk) in hex.as_bytes().chunks(2).enumerate() {
61 let hi = hex_digit(chunk[0])?;
62 let lo = hex_digit(chunk[1])?;
63 bytes[i] = (hi << 4) | lo;
64 }
65 Ok(Self(bytes))
66 }
67
68 pub fn write_to_file(&self, path: &Path) -> io::Result<()> {
80 let hex = self.to_hex();
81
82 #[cfg(unix)]
83 {
84 use io::Write;
85 use std::os::unix::fs::OpenOptionsExt;
86
87 let tmp_path = path.with_extension(format!("{}.tmp", std::process::id()));
88 let mut f = std::fs::OpenOptions::new()
89 .write(true)
90 .create(true)
91 .truncate(true)
92 .mode(0o600)
93 .open(&tmp_path)?;
94 f.write_all(hex.as_bytes())?;
95 f.sync_all()?;
96 drop(f);
97
98 if let Err(e) = std::fs::rename(&tmp_path, path) {
101 let _ = std::fs::remove_file(&tmp_path);
102 return Err(e);
103 }
104 }
105
106 #[cfg(not(unix))]
111 {
112 std::fs::write(path, hex.as_bytes())?;
113 }
114
115 Ok(())
116 }
117
118 pub fn read_from_file(path: &Path) -> io::Result<Self> {
124 let contents = std::fs::read_to_string(path)?;
125 Self::from_hex(contents.trim())
126 }
127
128 #[must_use]
130 pub fn ct_eq(&self, other: &Self) -> bool {
131 self.0.ct_eq(&other.0).into()
132 }
133}
134
135impl fmt::Debug for SessionToken {
136 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
137 f.write_str("SessionToken([REDACTED])")
138 }
139}
140
141#[expect(clippy::arithmetic_side_effects)]
145fn hex_digit(byte: u8) -> io::Result<u8> {
146 match byte {
147 b'0'..=b'9' => Ok(byte - b'0'),
148 b'a'..=b'f' => Ok(byte - b'a' + 10),
149 b'A'..=b'F' => Ok(byte - b'A' + 10),
150 _ => Err(io::Error::new(
151 io::ErrorKind::InvalidData,
152 format!("invalid hex digit: {byte:#04x}"),
153 )),
154 }
155}
156
157#[derive(Debug, Serialize, Deserialize)]
159pub struct HandshakeRequest {
160 pub token: String,
162 pub protocol_version: u8,
164 pub client_version: String,
166}
167
168#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
171#[serde(rename_all = "lowercase")]
172pub enum HandshakeStatus {
173 Ok,
175 Error,
177}
178
179#[derive(Debug, Serialize, Deserialize)]
181pub struct HandshakeResponse {
182 pub status: HandshakeStatus,
184 pub protocol_version: u8,
186 pub server_version: String,
188 #[serde(skip_serializing_if = "Option::is_none")]
190 pub reason: Option<String>,
191}
192
193impl HandshakeResponse {
194 #[must_use]
196 pub fn ok() -> Self {
197 Self {
198 status: HandshakeStatus::Ok,
199 protocol_version: PROTOCOL_VERSION,
200 server_version: env!("CARGO_PKG_VERSION").to_string(),
201 reason: None,
202 }
203 }
204
205 #[must_use]
207 pub fn error(reason: impl Into<String>) -> Self {
208 Self {
209 status: HandshakeStatus::Error,
210 protocol_version: PROTOCOL_VERSION,
211 server_version: env!("CARGO_PKG_VERSION").to_string(),
212 reason: Some(reason.into()),
213 }
214 }
215
216 #[must_use]
218 pub fn is_ok(&self) -> bool {
219 self.status == HandshakeStatus::Ok
220 }
221}
222
223#[cfg(test)]
224mod tests {
225 use super::*;
226
227 #[test]
228 fn generate_produces_unique_tokens() {
229 let a = SessionToken::generate();
230 let b = SessionToken::generate();
231 assert!(!a.ct_eq(&b), "two generated tokens must differ");
232 }
233
234 #[test]
235 fn hex_round_trip() {
236 let token = SessionToken::generate();
237 let hex = token.to_hex();
238 assert_eq!(hex.len(), 64);
239 let decoded = SessionToken::from_hex(&hex).expect("valid hex");
240 assert!(token.ct_eq(&decoded));
241 }
242
243 #[test]
244 fn from_hex_rejects_wrong_length() {
245 let err = SessionToken::from_hex("abcd").unwrap_err();
246 assert_eq!(err.kind(), io::ErrorKind::InvalidData);
247 assert!(err.to_string().contains("64 chars"));
248 }
249
250 #[test]
251 fn from_hex_rejects_invalid_chars() {
252 let bad = "zz".repeat(32);
253 let err = SessionToken::from_hex(&bad).unwrap_err();
254 assert_eq!(err.kind(), io::ErrorKind::InvalidData);
255 assert!(err.to_string().contains("invalid hex digit"));
256 }
257
258 #[test]
259 fn constant_time_eq_matches() {
260 let token = SessionToken::generate();
261 let same = SessionToken::from_hex(&token.to_hex()).expect("valid");
262 assert!(token.ct_eq(&same));
263 }
264
265 #[test]
266 fn constant_time_eq_rejects_different() {
267 let a = SessionToken::generate();
268 let b = SessionToken::generate();
269 assert!(!a.ct_eq(&b));
270 }
271
272 #[test]
273 fn file_round_trip() {
274 let dir = tempfile::tempdir().expect("tmpdir");
275 let path = dir.path().join("test.token");
276
277 let token = SessionToken::generate();
278 token.write_to_file(&path).expect("write");
279
280 let loaded = SessionToken::read_from_file(&path).expect("read");
281 assert!(token.ct_eq(&loaded));
282
283 #[cfg(unix)]
285 {
286 use std::os::unix::fs::PermissionsExt;
287 let perms = std::fs::metadata(&path).expect("metadata").permissions();
288 assert_eq!(perms.mode() & 0o777, 0o600);
289 }
290 }
291
292 #[test]
293 fn debug_redacts_token() {
294 let token = SessionToken::generate();
295 let debug = format!("{token:?}");
296 assert_eq!(debug, "SessionToken([REDACTED])");
297 assert!(!debug.contains(&token.to_hex()));
298 }
299
300 #[test]
301 fn handshake_response_ok_serializes() {
302 let resp = HandshakeResponse::ok();
303 assert_eq!(resp.status, HandshakeStatus::Ok);
304 assert!(resp.is_ok());
305 assert_eq!(resp.protocol_version, PROTOCOL_VERSION);
306 assert!(resp.reason.is_none());
307
308 let json = serde_json::to_value(&resp).expect("serialize");
309 assert_eq!(json["status"], "ok");
310 assert!(json.get("reason").is_none(), "reason should be skipped");
311 }
312
313 #[test]
314 fn handshake_response_error_serializes() {
315 let resp = HandshakeResponse::error("bad token");
316 assert_eq!(resp.status, HandshakeStatus::Error);
317 assert!(!resp.is_ok());
318 assert_eq!(resp.reason.as_deref(), Some("bad token"));
319
320 let json = serde_json::to_value(&resp).expect("serialize");
321 assert_eq!(json["status"], "error");
322 assert_eq!(json["reason"], "bad token");
323 }
324}