use anyhow::{Context, Result};
use astrid_core::PrincipalId;
use astrid_core::SessionId;
use astrid_core::session_token::{
HandshakeRequest, HandshakeResponse, PROTOCOL_VERSION, SessionToken,
};
use astrid_types::ipc::{IpcMessage, IpcPayload};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::UnixStream;
use tracing::warn;
#[must_use]
pub fn proxy_socket_path() -> std::path::PathBuf {
use astrid_core::dirs::AstridHome;
match AstridHome::resolve() {
Ok(home) => home.socket_path(),
Err(e) => {
warn!(error = %e, "Failed to resolve ASTRID_HOME; falling back to /tmp/.astrid/run/system.sock");
std::path::PathBuf::from("/tmp/.astrid/run/system.sock")
},
}
}
#[must_use]
pub fn readiness_path() -> std::path::PathBuf {
use astrid_core::dirs::AstridHome;
match AstridHome::resolve() {
Ok(home) => home.ready_path(),
Err(e) => {
warn!(
error = %e,
"Failed to resolve ASTRID_HOME; falling back to /tmp/.astrid/run/system.ready"
);
std::path::PathBuf::from("/tmp/.astrid/run/system.ready")
},
}
}
pub fn token_path() -> Result<std::path::PathBuf> {
use astrid_core::dirs::AstridHome;
let home = AstridHome::resolve()
.map_err(|e| anyhow::anyhow!("Failed to resolve ASTRID_HOME for token path: {e}"))?;
Ok(home.token_path())
}
pub struct SocketClient {
read_half: tokio::net::unix::OwnedReadHalf,
write_half: tokio::net::unix::OwnedWriteHalf,
pub session_id: SessionId,
}
impl SocketClient {
pub async fn connect(session_id: SessionId) -> Result<Self> {
let path = proxy_socket_path();
if !path.exists() {
anyhow::bail!("Global OS Socket not found at {}", path.display());
}
let mut stream = UnixStream::connect(&path)
.await
.context("Failed to connect to IPC socket")?;
perform_handshake(&mut stream).await?;
let (read_half, write_half) = stream.into_split();
Ok(Self {
read_half,
write_half,
session_id,
})
}
pub async fn read_message(&mut self) -> Result<Option<IpcMessage>> {
loop {
let mut len_buf = [0u8; 4];
if self.read_half.read_exact(&mut len_buf).await.is_err() {
return Ok(None);
}
let len = u32::from_be_bytes(len_buf) as usize;
if len > 50 * 1024 * 1024 {
anyhow::bail!("Message too large from kernel: {len} bytes");
}
let mut payload = vec![0u8; len];
self.read_half.read_exact(&mut payload).await?;
if let Ok(message) = serde_json::from_slice::<IpcMessage>(&payload) {
return Ok(Some(message));
}
let preview = String::from_utf8_lossy(&payload[..payload.len().min(120)]);
tracing::debug!(
preview = %preview,
"skipping unparseable frame from daemon"
);
}
}
pub async fn read_raw_frame(&mut self) -> Result<Option<Vec<u8>>> {
let mut len_buf = [0u8; 4];
if self.read_half.read_exact(&mut len_buf).await.is_err() {
return Ok(None);
}
let len = u32::from_be_bytes(len_buf) as usize;
if len > 50 * 1024 * 1024 {
anyhow::bail!("Message too large from kernel: {len} bytes");
}
let mut payload = vec![0u8; len];
self.read_half.read_exact(&mut payload).await?;
Ok(Some(payload))
}
pub async fn read_until_topic(
&mut self,
want_topic: &str,
timeout: std::time::Duration,
) -> Result<serde_json::Value> {
let deadline = tokio::time::Instant::now()
.checked_add(timeout)
.unwrap_or_else(tokio::time::Instant::now);
loop {
let remaining = deadline.saturating_duration_since(tokio::time::Instant::now());
if remaining.is_zero() {
anyhow::bail!("timed out waiting for {want_topic}");
}
let read = tokio::time::timeout(remaining, self.read_raw_frame()).await;
let frame = match read {
Ok(Ok(Some(bytes))) => bytes,
Ok(Ok(None)) => anyhow::bail!("connection closed before {want_topic}"),
Ok(Err(e)) => return Err(e),
Err(_) => anyhow::bail!("timed out waiting for {want_topic}"),
};
let raw: serde_json::Value = match serde_json::from_slice(&frame) {
Ok(v) => v,
Err(_) => continue,
};
if raw.get("topic").and_then(|t| t.as_str()) == Some(want_topic) {
return Ok(raw);
}
}
}
#[must_use]
pub fn extract_kernel_response(
raw: &serde_json::Value,
) -> Option<astrid_core::kernel_api::KernelResponse> {
let payload = raw.get("payload")?.clone();
let value = if payload
.as_object()
.is_some_and(|m| m.contains_key("type") && m.contains_key("value"))
{
payload.get("value").cloned().unwrap_or(payload)
} else {
payload
};
serde_json::from_value::<astrid_core::kernel_api::KernelResponse>(value).ok()
}
pub async fn send_input(&mut self, text: String, caller: &PrincipalId) -> Result<()> {
let payload = IpcPayload::UserInput {
text,
session_id: self.session_id.0.to_string(),
context: None,
};
let msg = IpcMessage::new("user.v1.prompt", payload, self.session_id.0)
.with_principal(caller.to_string());
self.send_message(msg).await
}
pub async fn send_message(&mut self, msg: IpcMessage) -> Result<()> {
let bytes = serde_json::to_vec(&msg)?;
let len =
u32::try_from(bytes.len()).context("IPC message too large (exceeds 4 GiB limit)")?;
self.write_half.write_all(&len.to_be_bytes()).await?;
self.write_half.write_all(&bytes).await?;
self.write_half.flush().await?;
Ok(())
}
}
const HANDSHAKE_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10);
const MAX_HANDSHAKE_RESPONSE_SIZE: usize = 4096;
async fn perform_handshake(stream: &mut UnixStream) -> Result<()> {
let tok_path = token_path()?;
let token = SessionToken::read_from_file(&tok_path).with_context(|| {
format!(
"Failed to read session token from {}. Is the daemon running?",
tok_path.display()
)
})?;
let request = HandshakeRequest {
token: token.to_hex(),
protocol_version: PROTOCOL_VERSION,
client_version: env!("CARGO_PKG_VERSION").to_string(),
};
let request_bytes =
serde_json::to_vec(&request).context("Failed to serialize handshake request")?;
let len = u32::try_from(request_bytes.len()).context("Handshake request too large")?;
tokio::time::timeout(HANDSHAKE_TIMEOUT, async {
stream.write_all(&len.to_be_bytes()).await?;
stream.write_all(&request_bytes).await?;
stream.flush().await?;
Ok::<(), std::io::Error>(())
})
.await
.context("Handshake request write timed out")?
.context("Failed to send handshake request")?;
let mut len_buf = [0u8; 4];
tokio::time::timeout(HANDSHAKE_TIMEOUT, stream.read_exact(&mut len_buf))
.await
.context("Handshake response timed out")?
.context("Failed to read handshake response length")?;
let resp_len = u32::from_be_bytes(len_buf) as usize;
if resp_len > MAX_HANDSHAKE_RESPONSE_SIZE {
anyhow::bail!("Handshake response too large: {resp_len} bytes");
}
let mut resp_payload = vec![0u8; resp_len];
tokio::time::timeout(HANDSHAKE_TIMEOUT, stream.read_exact(&mut resp_payload))
.await
.context("Handshake response payload timed out")?
.context("Failed to read handshake response payload")?;
let response: HandshakeResponse =
serde_json::from_slice(&resp_payload).context("Failed to parse handshake response")?;
if !response.is_ok() {
let reason = response
.reason
.unwrap_or_else(|| "unknown error".to_string());
anyhow::bail!("Daemon rejected connection: {reason}");
}
Ok(())
}