use anyhow::{bail, Context, Result};
use serde::{de::DeserializeOwned, Serialize};
use std::path::{Path, PathBuf};
use std::sync::Arc;
use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader, BufWriter};
use tokio::net::unix::{OwnedReadHalf, OwnedWriteHalf};
use tokio::net::UnixStream;
use super::crypto::IpcCipher;
const MAX_MESSAGE_SIZE: usize = 16 * 1024 * 1024;
pub struct IpcReader {
reader: BufReader<OwnedReadHalf>,
}
impl IpcReader {
pub fn new(read_half: OwnedReadHalf) -> Self {
Self {
reader: BufReader::new(read_half),
}
}
pub async fn read<T: DeserializeOwned>(&mut self) -> Result<Option<T>> {
let mut line = String::new();
let bytes_read = self.reader.read_line(&mut line).await?;
if bytes_read == 0 {
return Ok(None); }
if line.len() > MAX_MESSAGE_SIZE {
bail!("Message exceeds maximum size of {} bytes", MAX_MESSAGE_SIZE);
}
let message: T = serde_json::from_str(line.trim())
.with_context(|| format!("Failed to parse IPC message: {}", line.trim()))?;
Ok(Some(message))
}
}
pub struct IpcWriter {
writer: BufWriter<OwnedWriteHalf>,
}
impl IpcWriter {
pub fn new(write_half: OwnedWriteHalf) -> Self {
Self {
writer: BufWriter::new(write_half),
}
}
pub async fn write<T: Serialize>(&mut self, message: &T) -> Result<()> {
let json = serde_json::to_string(message)
.context("Failed to serialize IPC message")?;
if json.len() > MAX_MESSAGE_SIZE {
bail!("Message exceeds maximum size of {} bytes", MAX_MESSAGE_SIZE);
}
self.writer.write_all(json.as_bytes()).await?;
self.writer.write_all(b"\n").await?;
self.writer.flush().await?;
Ok(())
}
}
pub struct IpcConnection {
pub reader: IpcReader,
pub writer: IpcWriter,
}
impl IpcConnection {
pub fn from_stream(stream: UnixStream) -> Self {
let (read_half, write_half) = stream.into_split();
Self {
reader: IpcReader::new(read_half),
writer: IpcWriter::new(write_half),
}
}
pub async fn connect(socket_path: &Path) -> Result<Self> {
if !socket_path.exists() {
bail!("Agent socket not found: {}", socket_path.display());
}
let stream = UnixStream::connect(socket_path)
.await
.with_context(|| format!("Failed to connect to agent socket: {}", socket_path.display()))?;
Ok(Self::from_stream(stream))
}
pub async fn connect_to_agent(sessions_dir: &Path, session_id: &str) -> Result<Self> {
let socket_path = get_agent_socket_path(sessions_dir, session_id);
Self::connect(&socket_path).await
}
pub fn split(self) -> (IpcReader, IpcWriter) {
(self.reader, self.writer)
}
pub fn upgrade_to_encrypted(self, session_token: &str) -> EncryptedIpcConnection {
let cipher = Arc::new(IpcCipher::from_session_token(session_token));
let (read_half, write_half) = (self.reader, self.writer);
EncryptedIpcConnection {
reader: EncryptedIpcReader::new(read_half, Arc::clone(&cipher)),
writer: EncryptedIpcWriter::new(write_half, cipher),
}
}
}
pub struct EncryptedIpcReader {
inner: IpcReader,
cipher: Arc<IpcCipher>,
}
impl EncryptedIpcReader {
pub fn new(reader: IpcReader, cipher: Arc<IpcCipher>) -> Self {
Self {
inner: reader,
cipher,
}
}
pub async fn read<T: DeserializeOwned>(&mut self) -> Result<Option<T>> {
let mut len_buf = [0u8; 4];
match self.inner.reader.read_exact(&mut len_buf).await {
Ok(_) => {}
Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(None),
Err(e) => return Err(e.into()),
}
let msg_len = u32::from_be_bytes(len_buf) as usize;
if msg_len > MAX_MESSAGE_SIZE {
bail!("Encrypted message exceeds maximum size of {} bytes", MAX_MESSAGE_SIZE);
}
let mut encrypted = vec![0u8; msg_len];
self.inner.reader.read_exact(&mut encrypted).await?;
let plaintext = self.cipher.decrypt(&encrypted)
.context("Failed to decrypt IPC message")?;
let message: T = serde_json::from_slice(&plaintext)
.context("Failed to parse decrypted IPC message")?;
Ok(Some(message))
}
}
pub struct EncryptedIpcWriter {
inner: IpcWriter,
cipher: Arc<IpcCipher>,
}
impl EncryptedIpcWriter {
pub fn new(writer: IpcWriter, cipher: Arc<IpcCipher>) -> Self {
Self {
inner: writer,
cipher,
}
}
pub async fn write<T: Serialize>(&mut self, message: &T) -> Result<()> {
let json = serde_json::to_vec(message)
.context("Failed to serialize IPC message")?;
if json.len() > MAX_MESSAGE_SIZE {
bail!("Message exceeds maximum size of {} bytes", MAX_MESSAGE_SIZE);
}
let encrypted = self.cipher.encrypt(&json)
.context("Failed to encrypt IPC message")?;
let len_buf = (encrypted.len() as u32).to_be_bytes();
self.inner.writer.write_all(&len_buf).await?;
self.inner.writer.write_all(&encrypted).await?;
self.inner.writer.flush().await?;
Ok(())
}
}
pub struct EncryptedIpcConnection {
pub reader: EncryptedIpcReader,
pub writer: EncryptedIpcWriter,
}
impl EncryptedIpcConnection {
pub fn from_stream(stream: UnixStream, session_token: &str) -> Self {
let cipher = Arc::new(IpcCipher::from_session_token(session_token));
let (read_half, write_half) = stream.into_split();
Self {
reader: EncryptedIpcReader::new(
IpcReader::new(read_half),
Arc::clone(&cipher),
),
writer: EncryptedIpcWriter::new(
IpcWriter::new(write_half),
cipher,
),
}
}
pub fn split(self) -> (EncryptedIpcReader, EncryptedIpcWriter) {
(self.reader, self.writer)
}
}
pub fn get_agent_socket_path(sessions_dir: &Path, session_id: &str) -> PathBuf {
sessions_dir.join(format!("{}.sock", session_id))
}
pub fn get_session_token_path(sessions_dir: &Path, session_id: &str) -> PathBuf {
sessions_dir.join(format!("{}.token", session_id))
}
pub fn generate_session_token() -> String {
use rand::RngCore;
let mut bytes = [0u8; 32];
rand::thread_rng().fill_bytes(&mut bytes);
hex::encode(bytes)
}
pub fn write_session_token(sessions_dir: &Path, session_id: &str, token: &str) -> Result<()> {
let token_path = get_session_token_path(sessions_dir, session_id);
if let Some(parent) = token_path.parent() {
std::fs::create_dir_all(parent)?;
}
std::fs::write(&token_path, token)?;
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
std::fs::set_permissions(&token_path, std::fs::Permissions::from_mode(0o600))?;
}
tracing::debug!("Wrote session token: {} (0600 permissions)", token_path.display());
Ok(())
}
pub fn read_session_token(sessions_dir: &Path, session_id: &str) -> Result<Option<String>> {
let token_path = get_session_token_path(sessions_dir, session_id);
if !token_path.exists() {
return Ok(None);
}
let token = std::fs::read_to_string(&token_path)
.with_context(|| format!("Failed to read session token from {}", token_path.display()))?;
Ok(Some(token.trim().to_string()))
}
pub fn delete_session_token(sessions_dir: &Path, session_id: &str) -> Result<()> {
let token_path = get_session_token_path(sessions_dir, session_id);
if token_path.exists() {
std::fs::remove_file(&token_path)
.with_context(|| format!("Failed to delete session token: {}", token_path.display()))?;
tracing::debug!("Deleted session token: {}", token_path.display());
}
Ok(())
}
pub fn validate_session_token(sessions_dir: &Path, session_id: &str, provided_token: &str) -> bool {
match read_session_token(sessions_dir, session_id) {
Ok(Some(stored_token)) => {
use subtle::ConstantTimeEq;
provided_token.as_bytes().ct_eq(stored_token.as_bytes()).into()
}
Ok(None) => {
tracing::warn!("No session token found for session {}", session_id);
false
}
Err(e) => {
tracing::error!("Failed to read session token for {}: {}", session_id, e);
false
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use super::super::protocol::{AgentMessage, ViewerMessage};
use tokio::net::UnixListener;
#[tokio::test]
async fn test_ipc_roundtrip() {
let temp_dir = tempfile::tempdir().unwrap();
let socket_path = temp_dir.path().join("test.sock");
let listener = UnixListener::bind(&socket_path).unwrap();
let server_task = tokio::spawn(async move {
let (stream, _) = listener.accept().await.unwrap();
let mut conn = IpcConnection::from_stream(stream);
let msg: ViewerMessage = conn.reader.read().await.unwrap().unwrap();
match msg {
ViewerMessage::UserInput { content, .. } => {
assert_eq!(content, "Hello");
}
_ => panic!("Unexpected message type"),
}
let response = AgentMessage::Ack {
command: "user_input".to_string(),
};
conn.writer.write(&response).await.unwrap();
});
let stream = UnixStream::connect(&socket_path).await.unwrap();
let mut conn = IpcConnection::from_stream(stream);
let msg = ViewerMessage::UserInput {
content: "Hello".to_string(),
context_files: vec![],
};
conn.writer.write(&msg).await.unwrap();
let response: AgentMessage = conn.reader.read().await.unwrap().unwrap();
match response {
AgentMessage::Ack { command } => {
assert_eq!(command, "user_input");
}
_ => panic!("Unexpected response type"),
}
server_task.await.unwrap();
}
#[tokio::test]
async fn test_encrypted_ipc_roundtrip() {
let temp_dir = tempfile::tempdir().unwrap();
let socket_path = temp_dir.path().join("encrypted_test.sock");
let listener = UnixListener::bind(&socket_path).unwrap();
let session_token = "test-session-token-for-encrypted-ipc";
let server_token = session_token.to_string();
let server_task = tokio::spawn(async move {
let (stream, _) = listener.accept().await.unwrap();
let conn = IpcConnection::from_stream(stream);
let mut encrypted_conn = conn.upgrade_to_encrypted(&server_token);
let msg: ViewerMessage = encrypted_conn.reader.read().await.unwrap().unwrap();
match msg {
ViewerMessage::UserInput { content, .. } => {
assert_eq!(content, "Encrypted Hello!");
}
_ => panic!("Unexpected message type"),
}
let response = AgentMessage::Ack {
command: "encrypted_user_input".to_string(),
};
encrypted_conn.writer.write(&response).await.unwrap();
});
let stream = UnixStream::connect(&socket_path).await.unwrap();
let conn = IpcConnection::from_stream(stream);
let mut encrypted_conn = conn.upgrade_to_encrypted(session_token);
let msg = ViewerMessage::UserInput {
content: "Encrypted Hello!".to_string(),
context_files: vec![],
};
encrypted_conn.writer.write(&msg).await.unwrap();
let response: AgentMessage = encrypted_conn.reader.read().await.unwrap().unwrap();
match response {
AgentMessage::Ack { command } => {
assert_eq!(command, "encrypted_user_input");
}
_ => panic!("Unexpected response type"),
}
server_task.await.unwrap();
}
#[tokio::test]
async fn test_encrypted_ipc_wrong_key_fails() {
let temp_dir = tempfile::tempdir().unwrap();
let socket_path = temp_dir.path().join("wrong_key_test.sock");
let listener = UnixListener::bind(&socket_path).unwrap();
let server_task = tokio::spawn(async move {
let (stream, _) = listener.accept().await.unwrap();
let conn = IpcConnection::from_stream(stream);
let mut encrypted_conn = conn.upgrade_to_encrypted("server-token-different");
let result: Result<Option<ViewerMessage>> = encrypted_conn.reader.read().await;
assert!(result.is_err(), "Should fail to decrypt with wrong key");
});
let stream = UnixStream::connect(&socket_path).await.unwrap();
let conn = IpcConnection::from_stream(stream);
let mut encrypted_conn = conn.upgrade_to_encrypted("client-token-different");
let msg = ViewerMessage::UserInput {
content: "This will fail".to_string(),
context_files: vec![],
};
encrypted_conn.writer.write(&msg).await.unwrap();
server_task.await.unwrap();
}
#[tokio::test]
async fn test_encrypted_multiple_messages() {
let temp_dir = tempfile::tempdir().unwrap();
let socket_path = temp_dir.path().join("multi_msg_test.sock");
let listener = UnixListener::bind(&socket_path).unwrap();
let session_token = "multi-message-token";
let server_token = session_token.to_string();
let server_task = tokio::spawn(async move {
let (stream, _) = listener.accept().await.unwrap();
let conn = IpcConnection::from_stream(stream);
let mut encrypted_conn = conn.upgrade_to_encrypted(&server_token);
for i in 0..5 {
let msg: ViewerMessage = encrypted_conn.reader.read().await.unwrap().unwrap();
match msg {
ViewerMessage::UserInput { content, .. } => {
assert_eq!(content, format!("Message {}", i));
}
_ => panic!("Unexpected message type"),
}
let response = AgentMessage::Ack {
command: format!("ack_{}", i),
};
encrypted_conn.writer.write(&response).await.unwrap();
}
});
let stream = UnixStream::connect(&socket_path).await.unwrap();
let conn = IpcConnection::from_stream(stream);
let mut encrypted_conn = conn.upgrade_to_encrypted(session_token);
for i in 0..5 {
let msg = ViewerMessage::UserInput {
content: format!("Message {}", i),
context_files: vec![],
};
encrypted_conn.writer.write(&msg).await.unwrap();
let response: AgentMessage = encrypted_conn.reader.read().await.unwrap().unwrap();
match response {
AgentMessage::Ack { command } => {
assert_eq!(command, format!("ack_{}", i));
}
_ => panic!("Unexpected response type"),
}
}
server_task.await.unwrap();
}
#[test]
fn test_session_token_roundtrip() {
let temp_dir = tempfile::tempdir().unwrap();
let sessions_dir = temp_dir.path();
let token = generate_session_token();
assert_eq!(token.len(), 64);
write_session_token(sessions_dir, "test-session", &token).unwrap();
let read_token = read_session_token(sessions_dir, "test-session").unwrap();
assert_eq!(read_token, Some(token.clone()));
assert!(validate_session_token(sessions_dir, "test-session", &token));
assert!(!validate_session_token(sessions_dir, "test-session", "wrong-token"));
delete_session_token(sessions_dir, "test-session").unwrap();
let read_after_delete = read_session_token(sessions_dir, "test-session").unwrap();
assert_eq!(read_after_delete, None);
}
}