Skip to main content

ai_session/ipc/
mod.rs

1//! Native IPC implementation using Unix domain sockets
2
3use anyhow::Result;
4use serde::{Deserialize, Serialize};
5use std::path::PathBuf;
6use std::sync::Arc;
7use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
8use tokio::net::{UnixListener, UnixStream};
9use uuid::Uuid;
10
11/// IPC message envelope
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct IpcMessage {
14    /// Message ID
15    pub id: String,
16    /// Message type
17    pub msg_type: IpcMessageType,
18    /// Payload
19    pub payload: serde_json::Value,
20    /// Timestamp
21    pub timestamp: chrono::DateTime<chrono::Utc>,
22}
23
24/// IPC message types
25#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
26pub enum IpcMessageType {
27    /// Create session
28    CreateSession,
29    /// Execute command
30    ExecuteCommand,
31    /// Get output
32    GetOutput,
33    /// Get status
34    GetStatus,
35    /// List sessions
36    ListSessions,
37    /// Delete session
38    DeleteSession,
39    /// Response
40    Response,
41    /// Error
42    Error,
43    /// Event notification
44    Event,
45}
46
47/// IPC server for native communication
48pub struct IpcServer {
49    /// Socket path
50    socket_path: PathBuf,
51    /// Session manager
52    session_manager: Arc<crate::SessionManager>,
53}
54
55impl IpcServer {
56    /// Create new IPC server
57    pub fn new(socket_path: PathBuf, session_manager: Arc<crate::SessionManager>) -> Self {
58        Self {
59            socket_path,
60            session_manager,
61        }
62    }
63
64    /// Start the IPC server
65    pub async fn start(&self) -> Result<()> {
66        // Remove existing socket if it exists
67        if self.socket_path.exists() {
68            tokio::fs::remove_file(&self.socket_path).await?;
69        }
70
71        // Create parent directory if needed
72        if let Some(parent) = self.socket_path.parent() {
73            tokio::fs::create_dir_all(parent).await?;
74        }
75
76        // Bind to socket
77        let listener = UnixListener::bind(&self.socket_path)?;
78        tracing::info!("IPC server listening on {:?}", self.socket_path);
79
80        loop {
81            let (stream, _) = listener.accept().await?;
82            let session_manager = self.session_manager.clone();
83
84            // Handle connection in separate task
85            tokio::spawn(async move {
86                if let Err(e) = handle_client(stream, session_manager).await {
87                    tracing::error!("Client handler error: {}", e);
88                }
89            });
90        }
91    }
92}
93
94/// Handle client connection
95async fn handle_client(
96    stream: UnixStream,
97    session_manager: Arc<crate::SessionManager>,
98) -> Result<()> {
99    let (reader, mut writer) = stream.into_split();
100    let mut reader = BufReader::new(reader);
101    let mut line = String::new();
102
103    loop {
104        line.clear();
105        match reader.read_line(&mut line).await {
106            Ok(0) => break, // EOF
107            Ok(_) => {
108                // Parse message
109                let msg: IpcMessage = match serde_json::from_str(&line) {
110                    Ok(msg) => msg,
111                    Err(e) => {
112                        let error_response = IpcMessage {
113                            id: Uuid::new_v4().to_string(),
114                            msg_type: IpcMessageType::Error,
115                            payload: serde_json::json!({
116                                "error": format!("Invalid message format: {}", e)
117                            }),
118                            timestamp: chrono::Utc::now(),
119                        };
120                        writer
121                            .write_all(serde_json::to_string(&error_response)?.as_bytes())
122                            .await?;
123                        writer.write_all(b"\n").await?;
124                        writer.flush().await?;
125                        continue;
126                    }
127                };
128
129                // Process message
130                let response = process_message(msg, &session_manager).await?;
131
132                // Send response
133                writer
134                    .write_all(serde_json::to_string(&response)?.as_bytes())
135                    .await?;
136                writer.write_all(b"\n").await?;
137                writer.flush().await?;
138            }
139            Err(e) => {
140                tracing::error!("Read error: {}", e);
141                break;
142            }
143        }
144    }
145
146    Ok(())
147}
148
149/// Process IPC message
150async fn process_message(
151    msg: IpcMessage,
152    session_manager: &Arc<crate::SessionManager>,
153) -> Result<IpcMessage> {
154    match msg.msg_type {
155        IpcMessageType::CreateSession => {
156            let ai_features = msg.payload["enable_ai_features"].as_bool().unwrap_or(false);
157
158            let mut config = crate::core::SessionConfig::default();
159            config.enable_ai_features = ai_features;
160
161            let session = session_manager.create_session_with_config(config).await?;
162
163            Ok(IpcMessage {
164                id: msg.id,
165                msg_type: IpcMessageType::Response,
166                payload: serde_json::json!({
167                    "success": true,
168                    "session_id": session.id.to_string(),
169                }),
170                timestamp: chrono::Utc::now(),
171            })
172        }
173
174        IpcMessageType::ExecuteCommand => {
175            let session_id = msg.payload["session"].as_str().unwrap_or("");
176            let command = msg.payload["command"].as_str().unwrap_or("");
177
178            let session_id = crate::core::SessionId::parse_str(session_id)?;
179
180            if let Some(session) = session_manager.get_session(&session_id) {
181                session.send_input(&format!("{}\n", command)).await?;
182
183                Ok(IpcMessage {
184                    id: msg.id,
185                    msg_type: IpcMessageType::Response,
186                    payload: serde_json::json!({
187                        "success": true,
188                    }),
189                    timestamp: chrono::Utc::now(),
190                })
191            } else {
192                Ok(IpcMessage {
193                    id: msg.id,
194                    msg_type: IpcMessageType::Error,
195                    payload: serde_json::json!({
196                        "error": "Session not found"
197                    }),
198                    timestamp: chrono::Utc::now(),
199                })
200            }
201        }
202
203        IpcMessageType::GetOutput => {
204            let session_id = msg.payload["session"].as_str().unwrap_or("");
205            let last_n = msg.payload["last_n"].as_u64().unwrap_or(100) as usize;
206
207            let session_id = crate::core::SessionId::parse_str(session_id)?;
208
209            if let Some(session) = session_manager.get_session(&session_id) {
210                let output = session.read_output().await?;
211                let output_str = String::from_utf8_lossy(&output);
212                let all_lines: Vec<&str> = output_str.lines().collect();
213                let lines: Vec<String> = all_lines
214                    .iter()
215                    .rev()
216                    .take(last_n)
217                    .rev()
218                    .map(|s| s.to_string())
219                    .collect();
220
221                Ok(IpcMessage {
222                    id: msg.id,
223                    msg_type: IpcMessageType::Response,
224                    payload: serde_json::json!({
225                        "output": lines,
226                    }),
227                    timestamp: chrono::Utc::now(),
228                })
229            } else {
230                Ok(IpcMessage {
231                    id: msg.id,
232                    msg_type: IpcMessageType::Error,
233                    payload: serde_json::json!({
234                        "error": "Session not found"
235                    }),
236                    timestamp: chrono::Utc::now(),
237                })
238            }
239        }
240
241        IpcMessageType::GetStatus => {
242            let session_id = msg.payload["session"].as_str().unwrap_or("");
243            let session_id = crate::core::SessionId::parse_str(session_id)?;
244
245            if let Some(session) = session_manager.get_session(&session_id) {
246                let status = session.status().await;
247                Ok(IpcMessage {
248                    id: msg.id,
249                    msg_type: IpcMessageType::Response,
250                    payload: serde_json::to_value(status)?,
251                    timestamp: chrono::Utc::now(),
252                })
253            } else {
254                Ok(IpcMessage {
255                    id: msg.id,
256                    msg_type: IpcMessageType::Error,
257                    payload: serde_json::json!({
258                        "error": "Session not found"
259                    }),
260                    timestamp: chrono::Utc::now(),
261                })
262            }
263        }
264
265        IpcMessageType::ListSessions => {
266            let sessions = session_manager.list_sessions();
267            let session_ids: Vec<String> = sessions.iter().map(|id| id.to_string()).collect();
268            Ok(IpcMessage {
269                id: msg.id,
270                msg_type: IpcMessageType::Response,
271                payload: serde_json::json!({
272                    "sessions": session_ids,
273                }),
274                timestamp: chrono::Utc::now(),
275            })
276        }
277
278        IpcMessageType::DeleteSession => {
279            let session_id = msg.payload["session"].as_str().unwrap_or("");
280            let session_id = crate::core::SessionId::parse_str(session_id)?;
281
282            session_manager.remove_session(&session_id).await?;
283
284            Ok(IpcMessage {
285                id: msg.id,
286                msg_type: IpcMessageType::Response,
287                payload: serde_json::json!({
288                    "success": true,
289                }),
290                timestamp: chrono::Utc::now(),
291            })
292        }
293
294        _ => Ok(IpcMessage {
295            id: msg.id,
296            msg_type: IpcMessageType::Error,
297            payload: serde_json::json!({
298                "error": "Unsupported message type"
299            }),
300            timestamp: chrono::Utc::now(),
301        }),
302    }
303}
304
305/// IPC client for native communication
306pub struct IpcClient {
307    /// Socket path
308    socket_path: PathBuf,
309}
310
311impl IpcClient {
312    /// Create new IPC client
313    pub fn new(socket_path: PathBuf) -> Self {
314        Self { socket_path }
315    }
316
317    /// Send message and get response
318    pub async fn send_message(&self, msg: IpcMessage) -> Result<IpcMessage> {
319        let stream = UnixStream::connect(&self.socket_path).await?;
320        let (reader, mut writer) = stream.into_split();
321        let mut reader = BufReader::new(reader);
322
323        // Send message
324        writer
325            .write_all(serde_json::to_string(&msg)?.as_bytes())
326            .await?;
327        writer.write_all(b"\n").await?;
328        writer.flush().await?;
329
330        // Read response
331        let mut line = String::new();
332        reader.read_line(&mut line).await?;
333
334        let response: IpcMessage = serde_json::from_str(&line)?;
335        Ok(response)
336    }
337
338    /// Create session
339    pub async fn create_session(&self, name: &str, enable_ai_features: bool) -> Result<()> {
340        let msg = IpcMessage {
341            id: Uuid::new_v4().to_string(),
342            msg_type: IpcMessageType::CreateSession,
343            payload: serde_json::json!({
344                "name": name,
345                "enable_ai_features": enable_ai_features,
346            }),
347            timestamp: chrono::Utc::now(),
348        };
349
350        self.send_message(msg).await?;
351        Ok(())
352    }
353
354    /// Execute command
355    pub async fn execute_command(&self, session: &str, command: &str) -> Result<()> {
356        let msg = IpcMessage {
357            id: Uuid::new_v4().to_string(),
358            msg_type: IpcMessageType::ExecuteCommand,
359            payload: serde_json::json!({
360                "session": session,
361                "command": command,
362            }),
363            timestamp: chrono::Utc::now(),
364        };
365
366        self.send_message(msg).await?;
367        Ok(())
368    }
369
370    /// Get output
371    pub async fn get_output(&self, session: &str, last_n: usize) -> Result<Vec<String>> {
372        let msg = IpcMessage {
373            id: Uuid::new_v4().to_string(),
374            msg_type: IpcMessageType::GetOutput,
375            payload: serde_json::json!({
376                "session": session,
377                "last_n": last_n,
378            }),
379            timestamp: chrono::Utc::now(),
380        };
381
382        let response = self.send_message(msg).await?;
383        let output = response.payload["output"]
384            .as_array()
385            .ok_or_else(|| anyhow::anyhow!("Invalid output format"))?
386            .iter()
387            .filter_map(|v| v.as_str())
388            .map(|s| s.to_string())
389            .collect();
390
391        Ok(output)
392    }
393}
394
395/// Get default socket path
396pub fn default_socket_path() -> PathBuf {
397    let runtime_dir = std::env::var("XDG_RUNTIME_DIR").unwrap_or_else(|_| "/tmp".to_string());
398    PathBuf::from(runtime_dir).join("ai-session.sock")
399}