Skip to main content

astrid_uplink/
socket_client.rs

1//! Unix-domain socket client for the kernel.
2//!
3//! Performs the session-token handshake and exposes length-prefixed
4//! JSON framing for [`IpcMessage`](astrid_types::ipc::IpcMessage).
5//! Callers are responsible for stamping `principal` on outbound
6//! messages — this crate has no opinion on how a consumer resolves
7//! the caller (CLI active-agent context vs gateway-verified bearer).
8
9use anyhow::{Context, Result};
10use astrid_core::PrincipalId;
11use astrid_core::SessionId;
12use astrid_core::session_token::{
13    HandshakeRequest, HandshakeResponse, PROTOCOL_VERSION, SessionToken,
14};
15use astrid_types::ipc::{IpcMessage, IpcPayload};
16use tokio::io::{AsyncReadExt, AsyncWriteExt};
17use tokio::net::UnixStream;
18use tracing::warn;
19
20/// Path to the kernel's Unix-domain socket. Falls back to
21/// `/tmp/.astrid/run/system.sock` if `ASTRID_HOME` can't be resolved
22/// — matches the pre-existing CLI behaviour so single-host development
23/// continues to work without env setup.
24#[must_use]
25pub fn proxy_socket_path() -> std::path::PathBuf {
26    use astrid_core::dirs::AstridHome;
27    match AstridHome::resolve() {
28        Ok(home) => home.socket_path(),
29        Err(e) => {
30            warn!(error = %e, "Failed to resolve ASTRID_HOME; falling back to /tmp/.astrid/run/system.sock");
31            std::path::PathBuf::from("/tmp/.astrid/run/system.sock")
32        },
33    }
34}
35
36/// Path to the daemon readiness sentinel.
37///
38/// Polled by uplinks after spawning the daemon to determine when it is
39/// fully initialized. NOTE: also duplicated in
40/// `astrid-kernel/src/socket.rs` because the kernel cannot depend on
41/// this crate; the canonical path is `AstridHome::ready_path()`.
42#[must_use]
43pub fn readiness_path() -> std::path::PathBuf {
44    use astrid_core::dirs::AstridHome;
45    match AstridHome::resolve() {
46        Ok(home) => home.ready_path(),
47        Err(e) => {
48            warn!(
49                error = %e,
50                "Failed to resolve ASTRID_HOME; falling back to /tmp/.astrid/run/system.ready"
51            );
52            std::path::PathBuf::from("/tmp/.astrid/run/system.ready")
53        },
54    }
55}
56
57/// Path to the session-authentication token file.
58///
59/// # Errors
60/// Returns an error if `ASTRID_HOME` cannot be resolved. No `/tmp`
61/// fallback — the daemon refuses to write its token under
62/// world-listable directories.
63pub fn token_path() -> Result<std::path::PathBuf> {
64    use astrid_core::dirs::AstridHome;
65    let home = AstridHome::resolve()
66        .map_err(|e| anyhow::anyhow!("Failed to resolve ASTRID_HOME for token path: {e}"))?;
67    Ok(home.token_path())
68}
69
70/// A client connection to the kernel's Unix-domain socket.
71pub struct SocketClient {
72    read_half: tokio::net::unix::OwnedReadHalf,
73    write_half: tokio::net::unix::OwnedWriteHalf,
74    /// The unique identifier for this session.
75    pub session_id: SessionId,
76}
77
78impl SocketClient {
79    /// Connect to an existing session socket and perform the
80    /// authenticated handshake.
81    ///
82    /// # Errors
83    /// Returns an error if the socket file does not exist, connection
84    /// fails, or the handshake is rejected.
85    pub async fn connect(session_id: SessionId) -> Result<Self> {
86        let path = proxy_socket_path();
87
88        if !path.exists() {
89            anyhow::bail!("Global OS Socket not found at {}", path.display());
90        }
91
92        let mut stream = UnixStream::connect(&path)
93            .await
94            .context("Failed to connect to IPC socket")?;
95
96        perform_handshake(&mut stream).await?;
97
98        let (read_half, write_half) = stream.into_split();
99
100        Ok(Self {
101            read_half,
102            write_half,
103            session_id,
104        })
105    }
106
107    /// Read the next IPC message from the daemon.
108    ///
109    /// Frames that don't deserialize cleanly as
110    /// [`IpcMessage`](astrid_types::ipc::IpcMessage) (notably the
111    /// kernel's `astrid.v1.capsules_loaded` broadcast, whose
112    /// [`IpcPayload::RawJson`] inner value is emitted without the
113    /// `type` discriminator) are logged at `debug` and skipped. Without
114    /// this tolerance interactive clients would die on the first
115    /// broadcast.
116    ///
117    /// # Errors
118    /// Returns an error if the connection is unrecoverable (over-large
119    /// frame, IO failure mid-read).
120    pub async fn read_message(&mut self) -> Result<Option<IpcMessage>> {
121        loop {
122            let mut len_buf = [0u8; 4];
123            if self.read_half.read_exact(&mut len_buf).await.is_err() {
124                return Ok(None);
125            }
126            let len = u32::from_be_bytes(len_buf) as usize;
127
128            if len > 50 * 1024 * 1024 {
129                anyhow::bail!("Message too large from kernel: {len} bytes");
130            }
131
132            let mut payload = vec![0u8; len];
133            self.read_half.read_exact(&mut payload).await?;
134
135            if let Ok(message) = serde_json::from_slice::<IpcMessage>(&payload) {
136                return Ok(Some(message));
137            }
138            let preview = String::from_utf8_lossy(&payload[..payload.len().min(120)]);
139            tracing::debug!(
140                preview = %preview,
141                "skipping unparseable frame from daemon"
142            );
143        }
144    }
145
146    /// Read the next length-prefixed frame as raw bytes, without
147    /// attempting to deserialize. Used by [`crate::admin_client`] when
148    /// it needs to tolerate broadcast messages that don't deserialize
149    /// cleanly into [`IpcMessage`].
150    ///
151    /// # Errors
152    /// Returns an error if the frame cannot be read.
153    pub async fn read_raw_frame(&mut self) -> Result<Option<Vec<u8>>> {
154        let mut len_buf = [0u8; 4];
155        if self.read_half.read_exact(&mut len_buf).await.is_err() {
156            return Ok(None);
157        }
158        let len = u32::from_be_bytes(len_buf) as usize;
159        if len > 50 * 1024 * 1024 {
160            anyhow::bail!("Message too large from kernel: {len} bytes");
161        }
162        let mut payload = vec![0u8; len];
163        self.read_half.read_exact(&mut payload).await?;
164        Ok(Some(payload))
165    }
166
167    /// Read frames until one arrives on `want_topic` or `timeout`
168    /// elapses. Frames that fail to deserialize as JSON or carry a
169    /// different topic are silently skipped.
170    ///
171    /// # Errors
172    /// Returns an error if the deadline elapses, the connection
173    /// closes, or a read fails.
174    pub async fn read_until_topic(
175        &mut self,
176        want_topic: &str,
177        timeout: std::time::Duration,
178    ) -> Result<serde_json::Value> {
179        let deadline = tokio::time::Instant::now()
180            .checked_add(timeout)
181            .unwrap_or_else(tokio::time::Instant::now);
182        loop {
183            let remaining = deadline.saturating_duration_since(tokio::time::Instant::now());
184            if remaining.is_zero() {
185                anyhow::bail!("timed out waiting for {want_topic}");
186            }
187            let read = tokio::time::timeout(remaining, self.read_raw_frame()).await;
188            let frame = match read {
189                Ok(Ok(Some(bytes))) => bytes,
190                Ok(Ok(None)) => anyhow::bail!("connection closed before {want_topic}"),
191                Ok(Err(e)) => return Err(e),
192                Err(_) => anyhow::bail!("timed out waiting for {want_topic}"),
193            };
194            let raw: serde_json::Value = match serde_json::from_slice(&frame) {
195                Ok(v) => v,
196                Err(_) => continue,
197            };
198            if raw.get("topic").and_then(|t| t.as_str()) == Some(want_topic) {
199                return Ok(raw);
200            }
201        }
202    }
203
204    /// Extract the inner kernel response from a raw frame previously
205    /// returned by [`read_until_topic`](Self::read_until_topic).
206    ///
207    /// The kernel emits one of two on-wire shapes depending on which
208    /// router branch produced the response:
209    ///
210    /// * Bare typed payload — `{ "type": "...", ... }`, already a
211    ///   `KernelResponse`-shaped object that `serde_json::from_value`
212    ///   can deserialize directly.
213    /// * `RawJson`-wrapped payload — `{ "type": "raw_json", "value":
214    ///   { "type": "...", ... } }` (the older router branch wraps the
215    ///   typed body in `IpcPayload::RawJson`).
216    ///
217    /// Both have to be tolerated by every consumer of the bare verbs.
218    /// Returns `None` when the frame has no `payload` field or the
219    /// deserialization fails — callers fall back to an empty display
220    /// rather than crashing.
221    #[must_use]
222    pub fn extract_kernel_response(
223        raw: &serde_json::Value,
224    ) -> Option<astrid_core::kernel_api::KernelResponse> {
225        let payload = raw.get("payload")?.clone();
226        let value = if payload
227            .as_object()
228            .is_some_and(|m| m.contains_key("type") && m.contains_key("value"))
229        {
230            payload.get("value").cloned().unwrap_or(payload)
231        } else {
232            payload
233        };
234        serde_json::from_value::<astrid_core::kernel_api::KernelResponse>(value).ok()
235    }
236
237    /// Send a user-prompt message on behalf of `caller`.
238    ///
239    /// Convenience helper for chat-style uplinks. Stamps
240    /// `IpcMessage.principal` from the caller so the kernel's
241    /// `resolve_caller` sees the right principal for session, KV,
242    /// home, secret, and quota scoping.
243    ///
244    /// # Errors
245    /// Returns an error if the message cannot be sent.
246    pub async fn send_input(&mut self, text: String, caller: &PrincipalId) -> Result<()> {
247        let payload = IpcPayload::UserInput {
248            text,
249            session_id: self.session_id.0.to_string(),
250            context: None,
251        };
252
253        let msg = IpcMessage::new("user.v1.prompt", payload, self.session_id.0)
254            .with_principal(caller.to_string());
255
256        self.send_message(msg).await
257    }
258
259    /// Send a raw IPC message to the kernel.
260    ///
261    /// The caller is responsible for stamping
262    /// [`IpcMessage::principal`](astrid_types::ipc::IpcMessage::principal)
263    /// before calling — this transport does not infer it.
264    ///
265    /// # Errors
266    /// Returns an error if the message cannot be serialized or sent.
267    pub async fn send_message(&mut self, msg: IpcMessage) -> Result<()> {
268        let bytes = serde_json::to_vec(&msg)?;
269        let len =
270            u32::try_from(bytes.len()).context("IPC message too large (exceeds 4 GiB limit)")?;
271
272        self.write_half.write_all(&len.to_be_bytes()).await?;
273        self.write_half.write_all(&bytes).await?;
274        self.write_half.flush().await?;
275        Ok(())
276    }
277}
278
279/// Timeout for individual handshake read/write operations (client-side).
280/// Slightly longer than the server-side timeout to absorb daemon load.
281const HANDSHAKE_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10);
282
283/// Maximum allowed size of a handshake response payload (bytes).
284const MAX_HANDSHAKE_RESPONSE_SIZE: usize = 4096;
285
286/// Read the session token from disk and execute the authentication
287/// handshake.
288async fn perform_handshake(stream: &mut UnixStream) -> Result<()> {
289    let tok_path = token_path()?;
290    let token = SessionToken::read_from_file(&tok_path).with_context(|| {
291        format!(
292            "Failed to read session token from {}. Is the daemon running?",
293            tok_path.display()
294        )
295    })?;
296
297    let request = HandshakeRequest {
298        token: token.to_hex(),
299        protocol_version: PROTOCOL_VERSION,
300        client_version: env!("CARGO_PKG_VERSION").to_string(),
301    };
302
303    let request_bytes =
304        serde_json::to_vec(&request).context("Failed to serialize handshake request")?;
305    let len = u32::try_from(request_bytes.len()).context("Handshake request too large")?;
306
307    tokio::time::timeout(HANDSHAKE_TIMEOUT, async {
308        stream.write_all(&len.to_be_bytes()).await?;
309        stream.write_all(&request_bytes).await?;
310        stream.flush().await?;
311        Ok::<(), std::io::Error>(())
312    })
313    .await
314    .context("Handshake request write timed out")?
315    .context("Failed to send handshake request")?;
316
317    let mut len_buf = [0u8; 4];
318    tokio::time::timeout(HANDSHAKE_TIMEOUT, stream.read_exact(&mut len_buf))
319        .await
320        .context("Handshake response timed out")?
321        .context("Failed to read handshake response length")?;
322
323    let resp_len = u32::from_be_bytes(len_buf) as usize;
324    if resp_len > MAX_HANDSHAKE_RESPONSE_SIZE {
325        anyhow::bail!("Handshake response too large: {resp_len} bytes");
326    }
327
328    let mut resp_payload = vec![0u8; resp_len];
329    tokio::time::timeout(HANDSHAKE_TIMEOUT, stream.read_exact(&mut resp_payload))
330        .await
331        .context("Handshake response payload timed out")?
332        .context("Failed to read handshake response payload")?;
333
334    let response: HandshakeResponse =
335        serde_json::from_slice(&resp_payload).context("Failed to parse handshake response")?;
336
337    if !response.is_ok() {
338        let reason = response
339            .reason
340            .unwrap_or_else(|| "unknown error".to_string());
341        anyhow::bail!("Daemon rejected connection: {reason}");
342    }
343
344    Ok(())
345}