Skip to main content

brainwires_network/ipc/
socket.rs

1//! Unix Socket IPC Utilities
2//!
3//! Provides async read/write helpers for Unix domain socket communication
4//! between the TUI viewer and Agent process.
5//!
6//! # Encryption
7//!
8//! This module provides both plaintext and encrypted IPC:
9//! - `IpcReader`/`IpcWriter` - Plaintext JSON over newlines (legacy, for handshake)
10//! - `EncryptedIpcReader`/`EncryptedIpcWriter` - ChaCha20-Poly1305 encrypted messages
11//!
12//! The encrypted variants use the session token to derive the encryption key,
13//! providing confidentiality and integrity for all messages after handshake.
14
15use anyhow::{Context, Result, bail};
16use serde::{Serialize, de::DeserializeOwned};
17use std::path::{Path, PathBuf};
18use std::sync::Arc;
19use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader, BufWriter};
20use tokio::net::UnixStream;
21use tokio::net::unix::{OwnedReadHalf, OwnedWriteHalf};
22
23use super::crypto::IpcCipher;
24
25/// Maximum message size (16 MB)
26const MAX_MESSAGE_SIZE: usize = 16 * 1024 * 1024;
27
28/// Reader for receiving IPC messages
29pub struct IpcReader {
30    reader: BufReader<OwnedReadHalf>,
31}
32
33impl IpcReader {
34    /// Create a new IPC reader from a Unix stream read half
35    pub fn new(read_half: OwnedReadHalf) -> Self {
36        Self {
37            reader: BufReader::new(read_half),
38        }
39    }
40
41    /// Read a message from the socket
42    ///
43    /// Messages are newline-delimited JSON.
44    /// Returns None on EOF.
45    pub async fn read<T: DeserializeOwned>(&mut self) -> Result<Option<T>> {
46        let mut line = String::new();
47        let bytes_read = self.reader.read_line(&mut line).await?;
48
49        if bytes_read == 0 {
50            return Ok(None); // EOF
51        }
52
53        if line.len() > MAX_MESSAGE_SIZE {
54            bail!("Message exceeds maximum size of {} bytes", MAX_MESSAGE_SIZE);
55        }
56
57        let message: T = serde_json::from_str(line.trim())
58            .with_context(|| format!("Failed to parse IPC message: {}", line.trim()))?;
59
60        Ok(Some(message))
61    }
62}
63
64/// Writer for sending IPC messages
65pub struct IpcWriter {
66    writer: BufWriter<OwnedWriteHalf>,
67}
68
69impl IpcWriter {
70    /// Create a new IPC writer from a Unix stream write half
71    pub fn new(write_half: OwnedWriteHalf) -> Self {
72        Self {
73            writer: BufWriter::new(write_half),
74        }
75    }
76
77    /// Write a message to the socket
78    ///
79    /// Messages are serialized as newline-delimited JSON.
80    pub async fn write<T: Serialize>(&mut self, message: &T) -> Result<()> {
81        let json = serde_json::to_string(message).context("Failed to serialize IPC message")?;
82
83        if json.len() > MAX_MESSAGE_SIZE {
84            bail!("Message exceeds maximum size of {} bytes", MAX_MESSAGE_SIZE);
85        }
86
87        self.writer.write_all(json.as_bytes()).await?;
88        self.writer.write_all(b"\n").await?;
89        self.writer.flush().await?;
90
91        Ok(())
92    }
93}
94
95/// IPC connection handle (combines reader and writer)
96pub struct IpcConnection {
97    /// The reader half of the connection.
98    pub reader: IpcReader,
99    /// The writer half of the connection.
100    pub writer: IpcWriter,
101}
102
103impl IpcConnection {
104    /// Create an IPC connection from a Unix stream
105    pub fn from_stream(stream: UnixStream) -> Self {
106        let (read_half, write_half) = stream.into_split();
107        Self {
108            reader: IpcReader::new(read_half),
109            writer: IpcWriter::new(write_half),
110        }
111    }
112
113    /// Connect to an agent socket by path
114    pub async fn connect(socket_path: &Path) -> Result<Self> {
115        if !socket_path.exists() {
116            bail!("Agent socket not found: {}", socket_path.display());
117        }
118
119        let stream = UnixStream::connect(socket_path).await.with_context(|| {
120            format!(
121                "Failed to connect to agent socket: {}",
122                socket_path.display()
123            )
124        })?;
125
126        Ok(Self::from_stream(stream))
127    }
128
129    /// Connect to an agent by session ID, looking up the socket in sessions_dir
130    pub async fn connect_to_agent(sessions_dir: &Path, session_id: &str) -> Result<Self> {
131        let socket_path = get_agent_socket_path(sessions_dir, session_id);
132        Self::connect(&socket_path).await
133    }
134
135    /// Split into reader and writer
136    pub fn split(self) -> (IpcReader, IpcWriter) {
137        (self.reader, self.writer)
138    }
139
140    /// Upgrade to encrypted connection using the session token
141    ///
142    /// This should be called after the handshake is complete and both
143    /// sides have agreed on the session token.
144    pub fn upgrade_to_encrypted(self, session_token: &str) -> EncryptedIpcConnection {
145        let cipher = Arc::new(IpcCipher::from_session_token(session_token));
146        let (read_half, write_half) = (self.reader, self.writer);
147        EncryptedIpcConnection {
148            reader: EncryptedIpcReader::new(read_half, Arc::clone(&cipher)),
149            writer: EncryptedIpcWriter::new(write_half, cipher),
150        }
151    }
152}
153
154// ============================================================================
155// Encrypted IPC (ChaCha20-Poly1305)
156// ============================================================================
157
158/// Encrypted reader for receiving IPC messages
159///
160/// Uses ChaCha20-Poly1305 authenticated encryption.
161/// Message format: [4-byte length][encrypted data]
162pub struct EncryptedIpcReader {
163    inner: IpcReader,
164    cipher: Arc<IpcCipher>,
165}
166
167impl EncryptedIpcReader {
168    /// Create a new encrypted IPC reader
169    pub fn new(reader: IpcReader, cipher: Arc<IpcCipher>) -> Self {
170        Self {
171            inner: reader,
172            cipher,
173        }
174    }
175
176    /// Read an encrypted message from the socket
177    ///
178    /// Messages are length-prefixed encrypted blobs.
179    /// Returns None on EOF.
180    pub async fn read<T: DeserializeOwned>(&mut self) -> Result<Option<T>> {
181        // Read length prefix (4 bytes, big-endian)
182        let mut len_buf = [0u8; 4];
183        match self.inner.reader.read_exact(&mut len_buf).await {
184            Ok(_) => {}
185            Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => return Ok(None),
186            Err(e) => return Err(e.into()),
187        }
188
189        let msg_len = u32::from_be_bytes(len_buf) as usize;
190
191        if msg_len > MAX_MESSAGE_SIZE {
192            bail!(
193                "Encrypted message exceeds maximum size of {} bytes",
194                MAX_MESSAGE_SIZE
195            );
196        }
197
198        // Read encrypted message
199        let mut encrypted = vec![0u8; msg_len];
200        self.inner.reader.read_exact(&mut encrypted).await?;
201
202        // Decrypt
203        let plaintext = self
204            .cipher
205            .decrypt(&encrypted)
206            .context("Failed to decrypt IPC message")?;
207
208        // Deserialize
209        let message: T =
210            serde_json::from_slice(&plaintext).context("Failed to parse decrypted IPC message")?;
211
212        Ok(Some(message))
213    }
214}
215
216/// Encrypted writer for sending IPC messages
217///
218/// Uses ChaCha20-Poly1305 authenticated encryption.
219/// Message format: [4-byte length][encrypted data]
220pub struct EncryptedIpcWriter {
221    inner: IpcWriter,
222    cipher: Arc<IpcCipher>,
223}
224
225impl EncryptedIpcWriter {
226    /// Create a new encrypted IPC writer
227    pub fn new(writer: IpcWriter, cipher: Arc<IpcCipher>) -> Self {
228        Self {
229            inner: writer,
230            cipher,
231        }
232    }
233
234    /// Write an encrypted message to the socket
235    ///
236    /// Messages are serialized, encrypted, then length-prefixed.
237    pub async fn write<T: Serialize>(&mut self, message: &T) -> Result<()> {
238        // Serialize
239        let json = serde_json::to_vec(message).context("Failed to serialize IPC message")?;
240
241        if json.len() > MAX_MESSAGE_SIZE {
242            bail!("Message exceeds maximum size of {} bytes", MAX_MESSAGE_SIZE);
243        }
244
245        // Encrypt
246        let encrypted = self
247            .cipher
248            .encrypt(&json)
249            .context("Failed to encrypt IPC message")?;
250
251        // Write length prefix (4 bytes, big-endian)
252        let len_buf = (encrypted.len() as u32).to_be_bytes();
253        self.inner.writer.write_all(&len_buf).await?;
254
255        // Write encrypted message
256        self.inner.writer.write_all(&encrypted).await?;
257        self.inner.writer.flush().await?;
258
259        Ok(())
260    }
261}
262
263/// Encrypted IPC connection handle
264pub struct EncryptedIpcConnection {
265    /// The encrypted reader half of the connection.
266    pub reader: EncryptedIpcReader,
267    /// The encrypted writer half of the connection.
268    pub writer: EncryptedIpcWriter,
269}
270
271impl EncryptedIpcConnection {
272    /// Create an encrypted IPC connection from a Unix stream and session token
273    pub fn from_stream(stream: UnixStream, session_token: &str) -> Self {
274        let cipher = Arc::new(IpcCipher::from_session_token(session_token));
275        let (read_half, write_half) = stream.into_split();
276        Self {
277            reader: EncryptedIpcReader::new(IpcReader::new(read_half), Arc::clone(&cipher)),
278            writer: EncryptedIpcWriter::new(IpcWriter::new(write_half), cipher),
279        }
280    }
281
282    /// Split into encrypted reader and writer
283    pub fn split(self) -> (EncryptedIpcReader, EncryptedIpcWriter) {
284        (self.reader, self.writer)
285    }
286}
287
288// ============================================================================
289// Path Helpers
290// ============================================================================
291
292/// Get the socket path for an agent session
293pub fn get_agent_socket_path(sessions_dir: &Path, session_id: &str) -> PathBuf {
294    sessions_dir.join(format!("{}.sock", session_id))
295}
296
297/// Get the token file path for an agent session
298pub fn get_session_token_path(sessions_dir: &Path, session_id: &str) -> PathBuf {
299    sessions_dir.join(format!("{}.token", session_id))
300}
301
302// ============================================================================
303// Session Token Management (for secure IPC authentication)
304// ============================================================================
305
306/// Generate a cryptographically secure session token (64 hex characters = 256 bits)
307pub fn generate_session_token() -> String {
308    use rand::Rng;
309    let mut bytes = [0u8; 32];
310    rand::rng().fill_bytes(&mut bytes);
311    hex::encode(bytes)
312}
313
314/// Write session token to disk with secure permissions (0600)
315/// This should only be called by the agent process that owns the session
316pub fn write_session_token(sessions_dir: &Path, session_id: &str, token: &str) -> Result<()> {
317    let token_path = get_session_token_path(sessions_dir, session_id);
318
319    // Ensure parent directory exists
320    if let Some(parent) = token_path.parent() {
321        std::fs::create_dir_all(parent)?;
322    }
323
324    // Write token
325    std::fs::write(&token_path, token)?;
326
327    // Set secure permissions (0600 = owner read/write only)
328    #[cfg(unix)]
329    {
330        use std::os::unix::fs::PermissionsExt;
331        std::fs::set_permissions(&token_path, std::fs::Permissions::from_mode(0o600))?;
332    }
333
334    tracing::debug!(
335        "Wrote session token: {} (0600 permissions)",
336        token_path.display()
337    );
338    Ok(())
339}
340
341/// Read session token from disk
342/// This is used by clients that need to reattach to a session
343pub fn read_session_token(sessions_dir: &Path, session_id: &str) -> Result<Option<String>> {
344    let token_path = get_session_token_path(sessions_dir, session_id);
345
346    if !token_path.exists() {
347        return Ok(None);
348    }
349
350    let token = std::fs::read_to_string(&token_path)
351        .with_context(|| format!("Failed to read session token from {}", token_path.display()))?;
352
353    Ok(Some(token.trim().to_string()))
354}
355
356/// Delete session token file
357pub fn delete_session_token(sessions_dir: &Path, session_id: &str) -> Result<()> {
358    let token_path = get_session_token_path(sessions_dir, session_id);
359
360    if token_path.exists() {
361        std::fs::remove_file(&token_path)
362            .with_context(|| format!("Failed to delete session token: {}", token_path.display()))?;
363        tracing::debug!("Deleted session token: {}", token_path.display());
364    }
365
366    Ok(())
367}
368
369/// Validate that a provided token matches the stored token for a session
370/// Returns true if tokens match, false if they don't match or no token exists
371pub fn validate_session_token(sessions_dir: &Path, session_id: &str, provided_token: &str) -> bool {
372    match read_session_token(sessions_dir, session_id) {
373        Ok(Some(stored_token)) => {
374            // Use constant-time comparison to prevent timing attacks
375            use subtle::ConstantTimeEq;
376            provided_token
377                .as_bytes()
378                .ct_eq(stored_token.as_bytes())
379                .into()
380        }
381        Ok(None) => {
382            tracing::warn!("No session token found for session {}", session_id);
383            false
384        }
385        Err(e) => {
386            tracing::error!("Failed to read session token for {}: {}", session_id, e);
387            false
388        }
389    }
390}
391
392#[cfg(test)]
393mod tests {
394    use super::super::protocol::{AgentMessage, ViewerMessage};
395    use super::*;
396    use tokio::net::UnixListener;
397
398    #[tokio::test]
399    async fn test_ipc_roundtrip() {
400        // Create a temporary socket
401        let temp_dir = tempfile::tempdir().unwrap();
402        let socket_path = temp_dir.path().join("test.sock");
403
404        // Start listener
405        let listener = UnixListener::bind(&socket_path).unwrap();
406
407        // Spawn server task
408        let server_task = tokio::spawn(async move {
409            let (stream, _) = listener.accept().await.unwrap();
410            let mut conn = IpcConnection::from_stream(stream);
411
412            // Read message from client
413            let msg: ViewerMessage = conn.reader.read().await.unwrap().unwrap();
414            match msg {
415                ViewerMessage::UserInput { content, .. } => {
416                    assert_eq!(content, "Hello");
417                }
418                _ => panic!("Unexpected message type"),
419            }
420
421            // Send response
422            let response = AgentMessage::Ack {
423                command: "user_input".to_string(),
424            };
425            conn.writer.write(&response).await.unwrap();
426        });
427
428        // Client connects and sends message
429        let stream = UnixStream::connect(&socket_path).await.unwrap();
430        let mut conn = IpcConnection::from_stream(stream);
431
432        let msg = ViewerMessage::UserInput {
433            content: "Hello".to_string(),
434            context_files: vec![],
435        };
436        conn.writer.write(&msg).await.unwrap();
437
438        // Read response
439        let response: AgentMessage = conn.reader.read().await.unwrap().unwrap();
440        match response {
441            AgentMessage::Ack { command } => {
442                assert_eq!(command, "user_input");
443            }
444            _ => panic!("Unexpected response type"),
445        }
446
447        server_task.await.unwrap();
448    }
449
450    #[tokio::test]
451    async fn test_encrypted_ipc_roundtrip() {
452        // Create a temporary socket
453        let temp_dir = tempfile::tempdir().unwrap();
454        let socket_path = temp_dir.path().join("encrypted_test.sock");
455
456        // Start listener
457        let listener = UnixListener::bind(&socket_path).unwrap();
458
459        // Shared session token for encryption
460        let session_token = "test-session-token-for-encrypted-ipc";
461
462        let server_token = session_token.to_string();
463        // Spawn server task
464        let server_task = tokio::spawn(async move {
465            let (stream, _) = listener.accept().await.unwrap();
466
467            // Upgrade to encrypted connection
468            let conn = IpcConnection::from_stream(stream);
469            let mut encrypted_conn = conn.upgrade_to_encrypted(&server_token);
470
471            // Read encrypted message from client
472            let msg: ViewerMessage = encrypted_conn.reader.read().await.unwrap().unwrap();
473            match msg {
474                ViewerMessage::UserInput { content, .. } => {
475                    assert_eq!(content, "Encrypted Hello!");
476                }
477                _ => panic!("Unexpected message type"),
478            }
479
480            // Send encrypted response
481            let response = AgentMessage::Ack {
482                command: "encrypted_user_input".to_string(),
483            };
484            encrypted_conn.writer.write(&response).await.unwrap();
485        });
486
487        // Client connects and sends encrypted message
488        let stream = UnixStream::connect(&socket_path).await.unwrap();
489        let conn = IpcConnection::from_stream(stream);
490        let mut encrypted_conn = conn.upgrade_to_encrypted(session_token);
491
492        let msg = ViewerMessage::UserInput {
493            content: "Encrypted Hello!".to_string(),
494            context_files: vec![],
495        };
496        encrypted_conn.writer.write(&msg).await.unwrap();
497
498        // Read encrypted response
499        let response: AgentMessage = encrypted_conn.reader.read().await.unwrap().unwrap();
500        match response {
501            AgentMessage::Ack { command } => {
502                assert_eq!(command, "encrypted_user_input");
503            }
504            _ => panic!("Unexpected response type"),
505        }
506
507        server_task.await.unwrap();
508    }
509
510    #[tokio::test]
511    async fn test_encrypted_ipc_wrong_key_fails() {
512        // Create a temporary socket
513        let temp_dir = tempfile::tempdir().unwrap();
514        let socket_path = temp_dir.path().join("wrong_key_test.sock");
515
516        // Start listener
517        let listener = UnixListener::bind(&socket_path).unwrap();
518
519        // Spawn server task with DIFFERENT token
520        let server_task = tokio::spawn(async move {
521            let (stream, _) = listener.accept().await.unwrap();
522
523            // Server uses different token
524            let conn = IpcConnection::from_stream(stream);
525            let mut encrypted_conn = conn.upgrade_to_encrypted("server-token-different");
526
527            // Read should fail due to wrong key
528            let result: Result<Option<ViewerMessage>> = encrypted_conn.reader.read().await;
529            assert!(result.is_err(), "Should fail to decrypt with wrong key");
530        });
531
532        // Client connects with different token
533        let stream = UnixStream::connect(&socket_path).await.unwrap();
534        let conn = IpcConnection::from_stream(stream);
535        let mut encrypted_conn = conn.upgrade_to_encrypted("client-token-different");
536
537        let msg = ViewerMessage::UserInput {
538            content: "This will fail".to_string(),
539            context_files: vec![],
540        };
541        encrypted_conn.writer.write(&msg).await.unwrap();
542
543        server_task.await.unwrap();
544    }
545
546    #[tokio::test]
547    async fn test_encrypted_multiple_messages() {
548        // Create a temporary socket
549        let temp_dir = tempfile::tempdir().unwrap();
550        let socket_path = temp_dir.path().join("multi_msg_test.sock");
551
552        // Start listener
553        let listener = UnixListener::bind(&socket_path).unwrap();
554        let session_token = "multi-message-token";
555
556        let server_token = session_token.to_string();
557        let server_task = tokio::spawn(async move {
558            let (stream, _) = listener.accept().await.unwrap();
559            let conn = IpcConnection::from_stream(stream);
560            let mut encrypted_conn = conn.upgrade_to_encrypted(&server_token);
561
562            // Read and respond to multiple messages
563            for i in 0..5 {
564                let msg: ViewerMessage = encrypted_conn.reader.read().await.unwrap().unwrap();
565                match msg {
566                    ViewerMessage::UserInput { content, .. } => {
567                        assert_eq!(content, format!("Message {}", i));
568                    }
569                    _ => panic!("Unexpected message type"),
570                }
571
572                let response = AgentMessage::Ack {
573                    command: format!("ack_{}", i),
574                };
575                encrypted_conn.writer.write(&response).await.unwrap();
576            }
577        });
578
579        // Client sends multiple encrypted messages
580        let stream = UnixStream::connect(&socket_path).await.unwrap();
581        let conn = IpcConnection::from_stream(stream);
582        let mut encrypted_conn = conn.upgrade_to_encrypted(session_token);
583
584        for i in 0..5 {
585            let msg = ViewerMessage::UserInput {
586                content: format!("Message {}", i),
587                context_files: vec![],
588            };
589            encrypted_conn.writer.write(&msg).await.unwrap();
590
591            let response: AgentMessage = encrypted_conn.reader.read().await.unwrap().unwrap();
592            match response {
593                AgentMessage::Ack { command } => {
594                    assert_eq!(command, format!("ack_{}", i));
595                }
596                _ => panic!("Unexpected response type"),
597            }
598        }
599
600        server_task.await.unwrap();
601    }
602
603    #[test]
604    fn test_session_token_roundtrip() {
605        let temp_dir = tempfile::tempdir().unwrap();
606        let sessions_dir = temp_dir.path();
607
608        let token = generate_session_token();
609        assert_eq!(token.len(), 64); // 32 bytes = 64 hex chars
610
611        write_session_token(sessions_dir, "test-session", &token).unwrap();
612        let read_token = read_session_token(sessions_dir, "test-session").unwrap();
613        assert_eq!(read_token, Some(token.clone()));
614
615        assert!(validate_session_token(sessions_dir, "test-session", &token));
616        assert!(!validate_session_token(
617            sessions_dir,
618            "test-session",
619            "wrong-token"
620        ));
621
622        delete_session_token(sessions_dir, "test-session").unwrap();
623        let read_after_delete = read_session_token(sessions_dir, "test-session").unwrap();
624        assert_eq!(read_after_delete, None);
625    }
626}