astrid_uplink/
socket_client.rs1use anyhow::{Context, Result};
10use astrid_core::PrincipalId;
11use astrid_core::SessionId;
12use astrid_core::session_token::{
13 HandshakeRequest, HandshakeResponse, PROTOCOL_VERSION, SessionToken,
14};
15use astrid_types::ipc::{IpcMessage, IpcPayload};
16use tokio::io::{AsyncReadExt, AsyncWriteExt};
17use tokio::net::UnixStream;
18use tracing::warn;
19
20#[must_use]
25pub fn proxy_socket_path() -> std::path::PathBuf {
26 use astrid_core::dirs::AstridHome;
27 match AstridHome::resolve() {
28 Ok(home) => home.socket_path(),
29 Err(e) => {
30 warn!(error = %e, "Failed to resolve ASTRID_HOME; falling back to /tmp/.astrid/run/system.sock");
31 std::path::PathBuf::from("/tmp/.astrid/run/system.sock")
32 },
33 }
34}
35
36#[must_use]
43pub fn readiness_path() -> std::path::PathBuf {
44 use astrid_core::dirs::AstridHome;
45 match AstridHome::resolve() {
46 Ok(home) => home.ready_path(),
47 Err(e) => {
48 warn!(
49 error = %e,
50 "Failed to resolve ASTRID_HOME; falling back to /tmp/.astrid/run/system.ready"
51 );
52 std::path::PathBuf::from("/tmp/.astrid/run/system.ready")
53 },
54 }
55}
56
57pub fn token_path() -> Result<std::path::PathBuf> {
64 use astrid_core::dirs::AstridHome;
65 let home = AstridHome::resolve()
66 .map_err(|e| anyhow::anyhow!("Failed to resolve ASTRID_HOME for token path: {e}"))?;
67 Ok(home.token_path())
68}
69
70pub struct SocketClient {
72 read_half: tokio::net::unix::OwnedReadHalf,
73 write_half: tokio::net::unix::OwnedWriteHalf,
74 pub session_id: SessionId,
76}
77
78impl SocketClient {
79 pub async fn connect(session_id: SessionId) -> Result<Self> {
86 let path = proxy_socket_path();
87
88 if !path.exists() {
89 anyhow::bail!("Global OS Socket not found at {}", path.display());
90 }
91
92 let mut stream = UnixStream::connect(&path)
93 .await
94 .context("Failed to connect to IPC socket")?;
95
96 perform_handshake(&mut stream).await?;
97
98 let (read_half, write_half) = stream.into_split();
99
100 Ok(Self {
101 read_half,
102 write_half,
103 session_id,
104 })
105 }
106
107 pub async fn read_message(&mut self) -> Result<Option<IpcMessage>> {
121 loop {
122 let mut len_buf = [0u8; 4];
123 if self.read_half.read_exact(&mut len_buf).await.is_err() {
124 return Ok(None);
125 }
126 let len = u32::from_be_bytes(len_buf) as usize;
127
128 if len > 50 * 1024 * 1024 {
129 anyhow::bail!("Message too large from kernel: {len} bytes");
130 }
131
132 let mut payload = vec![0u8; len];
133 self.read_half.read_exact(&mut payload).await?;
134
135 if let Ok(message) = serde_json::from_slice::<IpcMessage>(&payload) {
136 return Ok(Some(message));
137 }
138 let preview = String::from_utf8_lossy(&payload[..payload.len().min(120)]);
139 tracing::debug!(
140 preview = %preview,
141 "skipping unparseable frame from daemon"
142 );
143 }
144 }
145
146 pub async fn read_raw_frame(&mut self) -> Result<Option<Vec<u8>>> {
154 let mut len_buf = [0u8; 4];
155 if self.read_half.read_exact(&mut len_buf).await.is_err() {
156 return Ok(None);
157 }
158 let len = u32::from_be_bytes(len_buf) as usize;
159 if len > 50 * 1024 * 1024 {
160 anyhow::bail!("Message too large from kernel: {len} bytes");
161 }
162 let mut payload = vec![0u8; len];
163 self.read_half.read_exact(&mut payload).await?;
164 Ok(Some(payload))
165 }
166
167 pub async fn read_until_topic(
175 &mut self,
176 want_topic: &str,
177 timeout: std::time::Duration,
178 ) -> Result<serde_json::Value> {
179 let deadline = tokio::time::Instant::now()
180 .checked_add(timeout)
181 .unwrap_or_else(tokio::time::Instant::now);
182 loop {
183 let remaining = deadline.saturating_duration_since(tokio::time::Instant::now());
184 if remaining.is_zero() {
185 anyhow::bail!("timed out waiting for {want_topic}");
186 }
187 let read = tokio::time::timeout(remaining, self.read_raw_frame()).await;
188 let frame = match read {
189 Ok(Ok(Some(bytes))) => bytes,
190 Ok(Ok(None)) => anyhow::bail!("connection closed before {want_topic}"),
191 Ok(Err(e)) => return Err(e),
192 Err(_) => anyhow::bail!("timed out waiting for {want_topic}"),
193 };
194 let raw: serde_json::Value = match serde_json::from_slice(&frame) {
195 Ok(v) => v,
196 Err(_) => continue,
197 };
198 if raw.get("topic").and_then(|t| t.as_str()) == Some(want_topic) {
199 return Ok(raw);
200 }
201 }
202 }
203
204 #[must_use]
222 pub fn extract_kernel_response(
223 raw: &serde_json::Value,
224 ) -> Option<astrid_core::kernel_api::KernelResponse> {
225 let payload = raw.get("payload")?.clone();
226 let value = if payload
227 .as_object()
228 .is_some_and(|m| m.contains_key("type") && m.contains_key("value"))
229 {
230 payload.get("value").cloned().unwrap_or(payload)
231 } else {
232 payload
233 };
234 serde_json::from_value::<astrid_core::kernel_api::KernelResponse>(value).ok()
235 }
236
237 pub async fn send_input(&mut self, text: String, caller: &PrincipalId) -> Result<()> {
247 let payload = IpcPayload::UserInput {
248 text,
249 session_id: self.session_id.0.to_string(),
250 context: None,
251 };
252
253 let msg = IpcMessage::new("user.v1.prompt", payload, self.session_id.0)
254 .with_principal(caller.to_string());
255
256 self.send_message(msg).await
257 }
258
259 pub async fn send_message(&mut self, msg: IpcMessage) -> Result<()> {
268 let bytes = serde_json::to_vec(&msg)?;
269 let len =
270 u32::try_from(bytes.len()).context("IPC message too large (exceeds 4 GiB limit)")?;
271
272 self.write_half.write_all(&len.to_be_bytes()).await?;
273 self.write_half.write_all(&bytes).await?;
274 self.write_half.flush().await?;
275 Ok(())
276 }
277}
278
279const HANDSHAKE_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10);
282
283const MAX_HANDSHAKE_RESPONSE_SIZE: usize = 4096;
285
286async fn perform_handshake(stream: &mut UnixStream) -> Result<()> {
289 let tok_path = token_path()?;
290 let token = SessionToken::read_from_file(&tok_path).with_context(|| {
291 format!(
292 "Failed to read session token from {}. Is the daemon running?",
293 tok_path.display()
294 )
295 })?;
296
297 let request = HandshakeRequest {
298 token: token.to_hex(),
299 protocol_version: PROTOCOL_VERSION,
300 client_version: env!("CARGO_PKG_VERSION").to_string(),
301 };
302
303 let request_bytes =
304 serde_json::to_vec(&request).context("Failed to serialize handshake request")?;
305 let len = u32::try_from(request_bytes.len()).context("Handshake request too large")?;
306
307 tokio::time::timeout(HANDSHAKE_TIMEOUT, async {
308 stream.write_all(&len.to_be_bytes()).await?;
309 stream.write_all(&request_bytes).await?;
310 stream.flush().await?;
311 Ok::<(), std::io::Error>(())
312 })
313 .await
314 .context("Handshake request write timed out")?
315 .context("Failed to send handshake request")?;
316
317 let mut len_buf = [0u8; 4];
318 tokio::time::timeout(HANDSHAKE_TIMEOUT, stream.read_exact(&mut len_buf))
319 .await
320 .context("Handshake response timed out")?
321 .context("Failed to read handshake response length")?;
322
323 let resp_len = u32::from_be_bytes(len_buf) as usize;
324 if resp_len > MAX_HANDSHAKE_RESPONSE_SIZE {
325 anyhow::bail!("Handshake response too large: {resp_len} bytes");
326 }
327
328 let mut resp_payload = vec![0u8; resp_len];
329 tokio::time::timeout(HANDSHAKE_TIMEOUT, stream.read_exact(&mut resp_payload))
330 .await
331 .context("Handshake response payload timed out")?
332 .context("Failed to read handshake response payload")?;
333
334 let response: HandshakeResponse =
335 serde_json::from_slice(&resp_payload).context("Failed to parse handshake response")?;
336
337 if !response.is_ok() {
338 let reason = response
339 .reason
340 .unwrap_or_else(|| "unknown error".to_string());
341 anyhow::bail!("Daemon rejected connection: {reason}");
342 }
343
344 Ok(())
345}