1use anyhow::{Context, Result, bail};
16use serde::{Serialize, de::DeserializeOwned};
17use std::path::{Path, PathBuf};
18use std::sync::Arc;
19use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader, BufWriter};
20use tokio::net::UnixStream;
21use tokio::net::unix::{OwnedReadHalf, OwnedWriteHalf};
22
23use super::crypto::IpcCipher;
24
25const MAX_MESSAGE_SIZE: usize = 16 * 1024 * 1024;
27
28pub struct IpcReader {
30 reader: BufReader<OwnedReadHalf>,
31}
32
33impl IpcReader {
34 pub fn new(read_half: OwnedReadHalf) -> Self {
36 Self {
37 reader: BufReader::new(read_half),
38 }
39 }
40
41 pub async fn read<T: DeserializeOwned>(&mut self) -> Result<Option<T>> {
46 let mut line = String::new();
47 let bytes_read = self.reader.read_line(&mut line).await?;
48
49 if bytes_read == 0 {
50 return Ok(None); }
52
53 if line.len() > MAX_MESSAGE_SIZE {
54 bail!("Message exceeds maximum size of {} bytes", MAX_MESSAGE_SIZE);
55 }
56
57 let message: T = serde_json::from_str(line.trim())
58 .with_context(|| format!("Failed to parse IPC message: {}", line.trim()))?;
59
60 Ok(Some(message))
61 }
62}
63
64pub struct IpcWriter {
66 writer: BufWriter<OwnedWriteHalf>,
67}
68
69impl IpcWriter {
70 pub fn new(write_half: OwnedWriteHalf) -> Self {
72 Self {
73 writer: BufWriter::new(write_half),
74 }
75 }
76
77 pub async fn write<T: Serialize>(&mut self, message: &T) -> Result<()> {
81 let json = serde_json::to_string(message).context("Failed to serialize IPC message")?;
82
83 if json.len() > MAX_MESSAGE_SIZE {
84 bail!("Message exceeds maximum size of {} bytes", MAX_MESSAGE_SIZE);
85 }
86
87 self.writer.write_all(json.as_bytes()).await?;
88 self.writer.write_all(b"\n").await?;
89 self.writer.flush().await?;
90
91 Ok(())
92 }
93}
94
95pub struct IpcConnection {
97 pub reader: IpcReader,
99 pub writer: IpcWriter,
101}
102
103impl IpcConnection {
104 pub fn from_stream(stream: UnixStream) -> Self {
106 let (read_half, write_half) = stream.into_split();
107 Self {
108 reader: IpcReader::new(read_half),
109 writer: IpcWriter::new(write_half),
110 }
111 }
112
113 pub async fn connect(socket_path: &Path) -> Result<Self> {
115 if !socket_path.exists() {
116 bail!("Agent socket not found: {}", socket_path.display());
117 }
118
119 let stream = UnixStream::connect(socket_path).await.with_context(|| {
120 format!(
121 "Failed to connect to agent socket: {}",
122 socket_path.display()
123 )
124 })?;
125
126 Ok(Self::from_stream(stream))
127 }
128
129 pub async fn connect_to_agent(sessions_dir: &Path, session_id: &str) -> Result<Self> {
131 let socket_path = get_agent_socket_path(sessions_dir, session_id);
132 Self::connect(&socket_path).await
133 }
134
135 pub fn split(self) -> (IpcReader, IpcWriter) {
137 (self.reader, self.writer)
138 }
139
140 pub fn upgrade_to_encrypted(self, session_token: &str) -> EncryptedIpcConnection {
145 let cipher = Arc::new(IpcCipher::from_session_token(session_token));
146 let (read_half, write_half) = (self.reader, self.writer);
147 EncryptedIpcConnection {
148 reader: EncryptedIpcReader::new(read_half, Arc::clone(&cipher)),
149 writer: EncryptedIpcWriter::new(write_half, cipher),
150 }
151 }
152}
153
154pub struct EncryptedIpcReader {
163 inner: IpcReader,
164 cipher: Arc<IpcCipher>,
165}
166
167impl EncryptedIpcReader {
168 pub fn new(reader: IpcReader, cipher: Arc<IpcCipher>) -> Self {
170 Self {
171 inner: reader,
172 cipher,
173 }
174 }
175
176 pub async fn read<T: DeserializeOwned>(&mut self) -> Result<Option<T>> {
181 let mut len_buf = [0u8; 4];
183 match self.inner.reader.read_exact(&mut len_buf).await {
184 Ok(_) => {}
185 Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(None),
186 Err(e) => return Err(e.into()),
187 }
188
189 let msg_len = u32::from_be_bytes(len_buf) as usize;
190
191 if msg_len > MAX_MESSAGE_SIZE {
192 bail!(
193 "Encrypted message exceeds maximum size of {} bytes",
194 MAX_MESSAGE_SIZE
195 );
196 }
197
198 let mut encrypted = vec![0u8; msg_len];
200 self.inner.reader.read_exact(&mut encrypted).await?;
201
202 let plaintext = self
204 .cipher
205 .decrypt(&encrypted)
206 .context("Failed to decrypt IPC message")?;
207
208 let message: T =
210 serde_json::from_slice(&plaintext).context("Failed to parse decrypted IPC message")?;
211
212 Ok(Some(message))
213 }
214}
215
216pub struct EncryptedIpcWriter {
221 inner: IpcWriter,
222 cipher: Arc<IpcCipher>,
223}
224
225impl EncryptedIpcWriter {
226 pub fn new(writer: IpcWriter, cipher: Arc<IpcCipher>) -> Self {
228 Self {
229 inner: writer,
230 cipher,
231 }
232 }
233
234 pub async fn write<T: Serialize>(&mut self, message: &T) -> Result<()> {
238 let json = serde_json::to_vec(message).context("Failed to serialize IPC message")?;
240
241 if json.len() > MAX_MESSAGE_SIZE {
242 bail!("Message exceeds maximum size of {} bytes", MAX_MESSAGE_SIZE);
243 }
244
245 let encrypted = self
247 .cipher
248 .encrypt(&json)
249 .context("Failed to encrypt IPC message")?;
250
251 let len_buf = (encrypted.len() as u32).to_be_bytes();
253 self.inner.writer.write_all(&len_buf).await?;
254
255 self.inner.writer.write_all(&encrypted).await?;
257 self.inner.writer.flush().await?;
258
259 Ok(())
260 }
261}
262
263pub struct EncryptedIpcConnection {
265 pub reader: EncryptedIpcReader,
267 pub writer: EncryptedIpcWriter,
269}
270
271impl EncryptedIpcConnection {
272 pub fn from_stream(stream: UnixStream, session_token: &str) -> Self {
274 let cipher = Arc::new(IpcCipher::from_session_token(session_token));
275 let (read_half, write_half) = stream.into_split();
276 Self {
277 reader: EncryptedIpcReader::new(IpcReader::new(read_half), Arc::clone(&cipher)),
278 writer: EncryptedIpcWriter::new(IpcWriter::new(write_half), cipher),
279 }
280 }
281
282 pub fn split(self) -> (EncryptedIpcReader, EncryptedIpcWriter) {
284 (self.reader, self.writer)
285 }
286}
287
288pub fn get_agent_socket_path(sessions_dir: &Path, session_id: &str) -> PathBuf {
294 sessions_dir.join(format!("{}.sock", session_id))
295}
296
297pub fn get_session_token_path(sessions_dir: &Path, session_id: &str) -> PathBuf {
299 sessions_dir.join(format!("{}.token", session_id))
300}
301
302pub fn generate_session_token() -> String {
308 use rand::Rng;
309 let mut bytes = [0u8; 32];
310 rand::rng().fill_bytes(&mut bytes);
311 hex::encode(bytes)
312}
313
314pub fn write_session_token(sessions_dir: &Path, session_id: &str, token: &str) -> Result<()> {
317 let token_path = get_session_token_path(sessions_dir, session_id);
318
319 if let Some(parent) = token_path.parent() {
321 std::fs::create_dir_all(parent)?;
322 }
323
324 std::fs::write(&token_path, token)?;
326
327 #[cfg(unix)]
329 {
330 use std::os::unix::fs::PermissionsExt;
331 std::fs::set_permissions(&token_path, std::fs::Permissions::from_mode(0o600))?;
332 }
333
334 tracing::debug!(
335 "Wrote session token: {} (0600 permissions)",
336 token_path.display()
337 );
338 Ok(())
339}
340
341pub fn read_session_token(sessions_dir: &Path, session_id: &str) -> Result<Option<String>> {
344 let token_path = get_session_token_path(sessions_dir, session_id);
345
346 if !token_path.exists() {
347 return Ok(None);
348 }
349
350 let token = std::fs::read_to_string(&token_path)
351 .with_context(|| format!("Failed to read session token from {}", token_path.display()))?;
352
353 Ok(Some(token.trim().to_string()))
354}
355
356pub fn delete_session_token(sessions_dir: &Path, session_id: &str) -> Result<()> {
358 let token_path = get_session_token_path(sessions_dir, session_id);
359
360 if token_path.exists() {
361 std::fs::remove_file(&token_path)
362 .with_context(|| format!("Failed to delete session token: {}", token_path.display()))?;
363 tracing::debug!("Deleted session token: {}", token_path.display());
364 }
365
366 Ok(())
367}
368
369pub fn validate_session_token(sessions_dir: &Path, session_id: &str, provided_token: &str) -> bool {
372 match read_session_token(sessions_dir, session_id) {
373 Ok(Some(stored_token)) => {
374 use subtle::ConstantTimeEq;
376 provided_token
377 .as_bytes()
378 .ct_eq(stored_token.as_bytes())
379 .into()
380 }
381 Ok(None) => {
382 tracing::warn!("No session token found for session {}", session_id);
383 false
384 }
385 Err(e) => {
386 tracing::error!("Failed to read session token for {}: {}", session_id, e);
387 false
388 }
389 }
390}
391
392#[cfg(test)]
393mod tests {
394 use super::super::protocol::{AgentMessage, ViewerMessage};
395 use super::*;
396 use tokio::net::UnixListener;
397
398 #[tokio::test]
399 async fn test_ipc_roundtrip() {
400 let temp_dir = tempfile::tempdir().unwrap();
402 let socket_path = temp_dir.path().join("test.sock");
403
404 let listener = UnixListener::bind(&socket_path).unwrap();
406
407 let server_task = tokio::spawn(async move {
409 let (stream, _) = listener.accept().await.unwrap();
410 let mut conn = IpcConnection::from_stream(stream);
411
412 let msg: ViewerMessage = conn.reader.read().await.unwrap().unwrap();
414 match msg {
415 ViewerMessage::UserInput { content, .. } => {
416 assert_eq!(content, "Hello");
417 }
418 _ => panic!("Unexpected message type"),
419 }
420
421 let response = AgentMessage::Ack {
423 command: "user_input".to_string(),
424 };
425 conn.writer.write(&response).await.unwrap();
426 });
427
428 let stream = UnixStream::connect(&socket_path).await.unwrap();
430 let mut conn = IpcConnection::from_stream(stream);
431
432 let msg = ViewerMessage::UserInput {
433 content: "Hello".to_string(),
434 context_files: vec![],
435 };
436 conn.writer.write(&msg).await.unwrap();
437
438 let response: AgentMessage = conn.reader.read().await.unwrap().unwrap();
440 match response {
441 AgentMessage::Ack { command } => {
442 assert_eq!(command, "user_input");
443 }
444 _ => panic!("Unexpected response type"),
445 }
446
447 server_task.await.unwrap();
448 }
449
450 #[tokio::test]
451 async fn test_encrypted_ipc_roundtrip() {
452 let temp_dir = tempfile::tempdir().unwrap();
454 let socket_path = temp_dir.path().join("encrypted_test.sock");
455
456 let listener = UnixListener::bind(&socket_path).unwrap();
458
459 let session_token = "test-session-token-for-encrypted-ipc";
461
462 let server_token = session_token.to_string();
463 let server_task = tokio::spawn(async move {
465 let (stream, _) = listener.accept().await.unwrap();
466
467 let conn = IpcConnection::from_stream(stream);
469 let mut encrypted_conn = conn.upgrade_to_encrypted(&server_token);
470
471 let msg: ViewerMessage = encrypted_conn.reader.read().await.unwrap().unwrap();
473 match msg {
474 ViewerMessage::UserInput { content, .. } => {
475 assert_eq!(content, "Encrypted Hello!");
476 }
477 _ => panic!("Unexpected message type"),
478 }
479
480 let response = AgentMessage::Ack {
482 command: "encrypted_user_input".to_string(),
483 };
484 encrypted_conn.writer.write(&response).await.unwrap();
485 });
486
487 let stream = UnixStream::connect(&socket_path).await.unwrap();
489 let conn = IpcConnection::from_stream(stream);
490 let mut encrypted_conn = conn.upgrade_to_encrypted(session_token);
491
492 let msg = ViewerMessage::UserInput {
493 content: "Encrypted Hello!".to_string(),
494 context_files: vec![],
495 };
496 encrypted_conn.writer.write(&msg).await.unwrap();
497
498 let response: AgentMessage = encrypted_conn.reader.read().await.unwrap().unwrap();
500 match response {
501 AgentMessage::Ack { command } => {
502 assert_eq!(command, "encrypted_user_input");
503 }
504 _ => panic!("Unexpected response type"),
505 }
506
507 server_task.await.unwrap();
508 }
509
510 #[tokio::test]
511 async fn test_encrypted_ipc_wrong_key_fails() {
512 let temp_dir = tempfile::tempdir().unwrap();
514 let socket_path = temp_dir.path().join("wrong_key_test.sock");
515
516 let listener = UnixListener::bind(&socket_path).unwrap();
518
519 let server_task = tokio::spawn(async move {
521 let (stream, _) = listener.accept().await.unwrap();
522
523 let conn = IpcConnection::from_stream(stream);
525 let mut encrypted_conn = conn.upgrade_to_encrypted("server-token-different");
526
527 let result: Result<Option<ViewerMessage>> = encrypted_conn.reader.read().await;
529 assert!(result.is_err(), "Should fail to decrypt with wrong key");
530 });
531
532 let stream = UnixStream::connect(&socket_path).await.unwrap();
534 let conn = IpcConnection::from_stream(stream);
535 let mut encrypted_conn = conn.upgrade_to_encrypted("client-token-different");
536
537 let msg = ViewerMessage::UserInput {
538 content: "This will fail".to_string(),
539 context_files: vec![],
540 };
541 encrypted_conn.writer.write(&msg).await.unwrap();
542
543 server_task.await.unwrap();
544 }
545
546 #[tokio::test]
547 async fn test_encrypted_multiple_messages() {
548 let temp_dir = tempfile::tempdir().unwrap();
550 let socket_path = temp_dir.path().join("multi_msg_test.sock");
551
552 let listener = UnixListener::bind(&socket_path).unwrap();
554 let session_token = "multi-message-token";
555
556 let server_token = session_token.to_string();
557 let server_task = tokio::spawn(async move {
558 let (stream, _) = listener.accept().await.unwrap();
559 let conn = IpcConnection::from_stream(stream);
560 let mut encrypted_conn = conn.upgrade_to_encrypted(&server_token);
561
562 for i in 0..5 {
564 let msg: ViewerMessage = encrypted_conn.reader.read().await.unwrap().unwrap();
565 match msg {
566 ViewerMessage::UserInput { content, .. } => {
567 assert_eq!(content, format!("Message {}", i));
568 }
569 _ => panic!("Unexpected message type"),
570 }
571
572 let response = AgentMessage::Ack {
573 command: format!("ack_{}", i),
574 };
575 encrypted_conn.writer.write(&response).await.unwrap();
576 }
577 });
578
579 let stream = UnixStream::connect(&socket_path).await.unwrap();
581 let conn = IpcConnection::from_stream(stream);
582 let mut encrypted_conn = conn.upgrade_to_encrypted(session_token);
583
584 for i in 0..5 {
585 let msg = ViewerMessage::UserInput {
586 content: format!("Message {}", i),
587 context_files: vec![],
588 };
589 encrypted_conn.writer.write(&msg).await.unwrap();
590
591 let response: AgentMessage = encrypted_conn.reader.read().await.unwrap().unwrap();
592 match response {
593 AgentMessage::Ack { command } => {
594 assert_eq!(command, format!("ack_{}", i));
595 }
596 _ => panic!("Unexpected response type"),
597 }
598 }
599
600 server_task.await.unwrap();
601 }
602
603 #[test]
604 fn test_session_token_roundtrip() {
605 let temp_dir = tempfile::tempdir().unwrap();
606 let sessions_dir = temp_dir.path();
607
608 let token = generate_session_token();
609 assert_eq!(token.len(), 64); write_session_token(sessions_dir, "test-session", &token).unwrap();
612 let read_token = read_session_token(sessions_dir, "test-session").unwrap();
613 assert_eq!(read_token, Some(token.clone()));
614
615 assert!(validate_session_token(sessions_dir, "test-session", &token));
616 assert!(!validate_session_token(
617 sessions_dir,
618 "test-session",
619 "wrong-token"
620 ));
621
622 delete_session_token(sessions_dir, "test-session").unwrap();
623 let read_after_delete = read_session_token(sessions_dir, "test-session").unwrap();
624 assert_eq!(read_after_delete, None);
625 }
626}