Skip to main content

astrid_core/
session_token.rs

1//! Session token for Unix socket authentication.
2//!
3//! The daemon generates a random 256-bit token at startup and writes it to
4//! `~/.astrid/run/system.token` with 0o600 permissions. The CLI reads
5//! this token and sends it as the first message after connecting. The daemon
6//! validates the token with constant-time comparison and rejects connections
7//! that fail.
8//!
9//! This follows the same pattern used by Jupyter notebooks and Docker.
10
11use std::fmt;
12use std::io;
13use std::path::Path;
14
15use rand::RngCore;
16use serde::{Deserialize, Serialize};
17use subtle::ConstantTimeEq;
18
19/// Current wire protocol version. Bumped when the handshake or IPC message
20/// format changes in a backwards-incompatible way.
21pub const PROTOCOL_VERSION: u8 = 1;
22
23/// A 256-bit random session token for socket authentication.
24pub struct SessionToken([u8; 32]);
25
26impl SessionToken {
27    /// Generate a new random session token from the OS CSPRNG.
28    #[must_use]
29    pub fn generate() -> Self {
30        let mut bytes = [0u8; 32];
31        rand::rngs::OsRng.fill_bytes(&mut bytes);
32        Self(bytes)
33    }
34
35    /// Hex-encode the token for file storage and wire transmission.
36    #[must_use]
37    pub fn to_hex(&self) -> String {
38        let mut hex = String::with_capacity(64);
39        for byte in &self.0 {
40            use fmt::Write;
41            let _ = write!(hex, "{byte:02x}");
42        }
43        hex
44    }
45
46    /// Decode a hex-encoded token string.
47    ///
48    /// # Errors
49    ///
50    /// Returns an error if the hex string is not exactly 64 characters or
51    /// contains invalid hex digits.
52    pub fn from_hex(hex: &str) -> Result<Self, io::Error> {
53        if hex.len() != 64 {
54            return Err(io::Error::new(
55                io::ErrorKind::InvalidData,
56                format!("session token hex must be 64 chars, got {}", hex.len()),
57            ));
58        }
59        let mut bytes = [0u8; 32];
60        for (i, chunk) in hex.as_bytes().chunks(2).enumerate() {
61            let hi = hex_digit(chunk[0])?;
62            let lo = hex_digit(chunk[1])?;
63            bytes[i] = (hi << 4) | lo;
64        }
65        Ok(Self(bytes))
66    }
67
68    /// Write the token to a file with owner-only permissions (0o600).
69    ///
70    /// On Unix, this uses write-then-rename atomicity: writes to a temporary
71    /// file at 0o600 (via `OpenOptions::mode` to avoid a TOCTOU permissions
72    /// window), then atomically renames it to the target path. This prevents
73    /// a racing `read_from_file` from seeing a truncated/empty file during
74    /// daemon restarts.
75    ///
76    /// # Errors
77    ///
78    /// Returns an error if the file cannot be written or permissions cannot be set.
79    pub fn write_to_file(&self, path: &Path) -> io::Result<()> {
80        let hex = self.to_hex();
81
82        #[cfg(unix)]
83        {
84            use io::Write;
85            use std::os::unix::fs::OpenOptionsExt;
86
87            let tmp_path = path.with_extension(format!("{}.tmp", std::process::id()));
88            let mut f = std::fs::OpenOptions::new()
89                .write(true)
90                .create(true)
91                .truncate(true)
92                .mode(0o600)
93                .open(&tmp_path)?;
94            f.write_all(hex.as_bytes())?;
95            f.sync_all()?;
96            drop(f);
97
98            // Atomic rename on the same filesystem. Clean up temp file on
99            // failure to avoid orphaned secret-containing files.
100            if let Err(e) = std::fs::rename(&tmp_path, path) {
101                let _ = std::fs::remove_file(&tmp_path);
102                return Err(e);
103            }
104        }
105
106        // Non-Unix fallback: no atomic rename, no explicit permissions.
107        // The token file will inherit the process umask (likely 0o644).
108        // Windows is not a supported daemon platform; this exists only
109        // for compilation and test compatibility.
110        #[cfg(not(unix))]
111        {
112            std::fs::write(path, hex.as_bytes())?;
113        }
114
115        Ok(())
116    }
117
118    /// Read and decode a token from a file.
119    ///
120    /// # Errors
121    ///
122    /// Returns an error if the file cannot be read or contains invalid hex.
123    pub fn read_from_file(path: &Path) -> io::Result<Self> {
124        let contents = std::fs::read_to_string(path)?;
125        Self::from_hex(contents.trim())
126    }
127
128    /// Constant-time comparison. Returns `true` if the tokens are equal.
129    #[must_use]
130    pub fn ct_eq(&self, other: &Self) -> bool {
131        self.0.ct_eq(&other.0).into()
132    }
133}
134
135impl fmt::Debug for SessionToken {
136    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
137        f.write_str("SessionToken([REDACTED])")
138    }
139}
140
141/// Decode a single hex digit, returning an error for invalid characters.
142///
143/// The match arms guarantee the subtraction cannot overflow.
144#[expect(clippy::arithmetic_side_effects)]
145fn hex_digit(byte: u8) -> io::Result<u8> {
146    match byte {
147        b'0'..=b'9' => Ok(byte - b'0'),
148        b'a'..=b'f' => Ok(byte - b'a' + 10),
149        b'A'..=b'F' => Ok(byte - b'A' + 10),
150        _ => Err(io::Error::new(
151            io::ErrorKind::InvalidData,
152            format!("invalid hex digit: {byte:#04x}"),
153        )),
154    }
155}
156
157/// First message sent by the CLI after connecting to the daemon socket.
158#[derive(Debug, Serialize, Deserialize)]
159pub struct HandshakeRequest {
160    /// Hex-encoded session token.
161    pub token: String,
162    /// Wire protocol version supported by this client.
163    pub protocol_version: u8,
164    /// Semantic version of the client binary (e.g. "0.1.1").
165    pub client_version: String,
166}
167
168/// Typed status for handshake responses. Using an enum instead of a raw
169/// string prevents typo-induced mismatches between client and server.
170#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
171#[serde(rename_all = "lowercase")]
172pub enum HandshakeStatus {
173    /// Handshake succeeded.
174    Ok,
175    /// Handshake failed.
176    Error,
177}
178
179/// Response sent by the daemon after validating the handshake.
180#[derive(Debug, Serialize, Deserialize)]
181pub struct HandshakeResponse {
182    /// Whether the handshake succeeded or failed.
183    pub status: HandshakeStatus,
184    /// Wire protocol version of the daemon.
185    pub protocol_version: u8,
186    /// Semantic version of the daemon binary.
187    pub server_version: String,
188    /// Human-readable reason for rejection (only set when status is `Error`).
189    #[serde(skip_serializing_if = "Option::is_none")]
190    pub reason: Option<String>,
191}
192
193impl HandshakeResponse {
194    /// Create a successful handshake response.
195    #[must_use]
196    pub fn ok() -> Self {
197        Self {
198            status: HandshakeStatus::Ok,
199            protocol_version: PROTOCOL_VERSION,
200            server_version: env!("CARGO_PKG_VERSION").to_string(),
201            reason: None,
202        }
203    }
204
205    /// Create an error handshake response.
206    #[must_use]
207    pub fn error(reason: impl Into<String>) -> Self {
208        Self {
209            status: HandshakeStatus::Error,
210            protocol_version: PROTOCOL_VERSION,
211            server_version: env!("CARGO_PKG_VERSION").to_string(),
212            reason: Some(reason.into()),
213        }
214    }
215
216    /// Returns `true` if the handshake succeeded.
217    #[must_use]
218    pub fn is_ok(&self) -> bool {
219        self.status == HandshakeStatus::Ok
220    }
221}
222
223#[cfg(test)]
224mod tests {
225    use super::*;
226
227    #[test]
228    fn generate_produces_unique_tokens() {
229        let a = SessionToken::generate();
230        let b = SessionToken::generate();
231        assert!(!a.ct_eq(&b), "two generated tokens must differ");
232    }
233
234    #[test]
235    fn hex_round_trip() {
236        let token = SessionToken::generate();
237        let hex = token.to_hex();
238        assert_eq!(hex.len(), 64);
239        let decoded = SessionToken::from_hex(&hex).expect("valid hex");
240        assert!(token.ct_eq(&decoded));
241    }
242
243    #[test]
244    fn from_hex_rejects_wrong_length() {
245        let err = SessionToken::from_hex("abcd").unwrap_err();
246        assert_eq!(err.kind(), io::ErrorKind::InvalidData);
247        assert!(err.to_string().contains("64 chars"));
248    }
249
250    #[test]
251    fn from_hex_rejects_invalid_chars() {
252        let bad = "zz".repeat(32);
253        let err = SessionToken::from_hex(&bad).unwrap_err();
254        assert_eq!(err.kind(), io::ErrorKind::InvalidData);
255        assert!(err.to_string().contains("invalid hex digit"));
256    }
257
258    #[test]
259    fn constant_time_eq_matches() {
260        let token = SessionToken::generate();
261        let same = SessionToken::from_hex(&token.to_hex()).expect("valid");
262        assert!(token.ct_eq(&same));
263    }
264
265    #[test]
266    fn constant_time_eq_rejects_different() {
267        let a = SessionToken::generate();
268        let b = SessionToken::generate();
269        assert!(!a.ct_eq(&b));
270    }
271
272    #[test]
273    fn file_round_trip() {
274        let dir = tempfile::tempdir().expect("tmpdir");
275        let path = dir.path().join("test.token");
276
277        let token = SessionToken::generate();
278        token.write_to_file(&path).expect("write");
279
280        let loaded = SessionToken::read_from_file(&path).expect("read");
281        assert!(token.ct_eq(&loaded));
282
283        // Verify 0600 permissions on Unix
284        #[cfg(unix)]
285        {
286            use std::os::unix::fs::PermissionsExt;
287            let perms = std::fs::metadata(&path).expect("metadata").permissions();
288            assert_eq!(perms.mode() & 0o777, 0o600);
289        }
290    }
291
292    #[test]
293    fn debug_redacts_token() {
294        let token = SessionToken::generate();
295        let debug = format!("{token:?}");
296        assert_eq!(debug, "SessionToken([REDACTED])");
297        assert!(!debug.contains(&token.to_hex()));
298    }
299
300    #[test]
301    fn handshake_response_ok_serializes() {
302        let resp = HandshakeResponse::ok();
303        assert_eq!(resp.status, HandshakeStatus::Ok);
304        assert!(resp.is_ok());
305        assert_eq!(resp.protocol_version, PROTOCOL_VERSION);
306        assert!(resp.reason.is_none());
307
308        let json = serde_json::to_value(&resp).expect("serialize");
309        assert_eq!(json["status"], "ok");
310        assert!(json.get("reason").is_none(), "reason should be skipped");
311    }
312
313    #[test]
314    fn handshake_response_error_serializes() {
315        let resp = HandshakeResponse::error("bad token");
316        assert_eq!(resp.status, HandshakeStatus::Error);
317        assert!(!resp.is_ok());
318        assert_eq!(resp.reason.as_deref(), Some("bad token"));
319
320        let json = serde_json::to_value(&resp).expect("serialize");
321        assert_eq!(json["status"], "error");
322        assert_eq!(json["reason"], "bad token");
323    }
324}