Skip to main content

cnctd_service_ssh/sessions/
mod.rs

1//! Interactive shell session management.
2//!
3//! This module provides functionality for creating and managing persistent
4//! interactive shell sessions over SSH. Sessions support:
5//!
6//! - Full TUI applications (vim, htop, tmux)
7//! - Simple REPL-style interactions
8//! - Persistence across LLM conversations via remote tmux
9//! - Screen state capture for AI agents
10
11pub mod connection;
12pub mod registry;
13pub mod terminal;
14pub mod types;
15
16use crate::operations::lookup_target;
17use crate::service_error::ServiceError;
18use connection::SshConnection;
19use once_cell::sync::Lazy;
20use registry::{ShellSession, ShellSessionRegistry};
21use std::sync::Arc;
22use tracing::info;
23use types::*;
24
25/// Global session registry for MCP server usage
26static GLOBAL_SESSION_REGISTRY: Lazy<Arc<ShellSessionRegistry>> =
27    Lazy::new(|| Arc::new(ShellSessionRegistry::new()));
28
29/// Service for managing interactive shell sessions
30pub struct ShellSessionService {
31    registry: Arc<ShellSessionRegistry>,
32}
33
34impl ShellSessionService {
35    /// Create a new shell session service with its own registry
36    pub fn new() -> Self {
37        Self {
38            registry: Arc::new(ShellSessionRegistry::new()),
39        }
40    }
41
42    /// Create a new shell session service using the global registry
43    pub fn global() -> Self {
44        Self {
45            registry: Arc::clone(&GLOBAL_SESSION_REGISTRY),
46        }
47    }
48
49    /// Create a new interactive shell session
50    pub async fn create(&self, args: ShellSessionCreateArgs) -> Result<ShellSessionCreateResult, ServiceError> {
51        create_session_impl(&self.registry, args).await
52    }
53
54    /// Write input to a session
55    pub async fn write(&self, args: ShellSessionWriteArgs) -> Result<ShellSessionWriteResult, ServiceError> {
56        write_session_impl(&self.registry, args).await
57    }
58
59    /// Read output from a session
60    pub async fn read(&self, args: ShellSessionReadArgs) -> Result<ShellSessionReadResult, ServiceError> {
61        read_session_impl(&self.registry, args).await
62    }
63
64    /// List sessions
65    pub async fn list(&self, args: ShellSessionListArgs) -> Result<ShellSessionListResult, ServiceError> {
66        list_sessions_impl(&self.registry, args).await
67    }
68
69    /// Reconnect to a disconnected session
70    pub async fn reconnect(&self, args: ShellSessionReconnectArgs) -> Result<ShellSessionReconnectResult, ServiceError> {
71        reconnect_session_impl(&self.registry, args).await
72    }
73
74    /// Resize a session's terminal
75    pub async fn resize(&self, args: ShellSessionResizeArgs) -> Result<ShellSessionResizeResult, ServiceError> {
76        resize_session_impl(&self.registry, args).await
77    }
78
79    /// Close a session
80    pub async fn close(&self, args: ShellSessionCloseArgs) -> Result<ShellSessionCloseResult, ServiceError> {
81        close_session_impl(&self.registry, args).await
82    }
83}
84
85impl Default for ShellSessionService {
86    fn default() -> Self {
87        Self::new()
88    }
89}
90
91// ============================================================================
92// Global API functions (for MCP server)
93// ============================================================================
94
95/// Create a new interactive shell session (global registry)
96pub async fn shell_session_create(args: ShellSessionCreateArgs) -> Result<ShellSessionCreateResult, ServiceError> {
97    create_session_impl(&GLOBAL_SESSION_REGISTRY, args).await
98}
99
100/// Write input to a session (global registry)
101pub async fn shell_session_write(args: ShellSessionWriteArgs) -> Result<ShellSessionWriteResult, ServiceError> {
102    write_session_impl(&GLOBAL_SESSION_REGISTRY, args).await
103}
104
105/// Read output from a session (global registry)
106pub async fn shell_session_read(args: ShellSessionReadArgs) -> Result<ShellSessionReadResult, ServiceError> {
107    read_session_impl(&GLOBAL_SESSION_REGISTRY, args).await
108}
109
110/// List sessions (global registry)
111pub async fn shell_session_list(args: ShellSessionListArgs) -> Result<ShellSessionListResult, ServiceError> {
112    list_sessions_impl(&GLOBAL_SESSION_REGISTRY, args).await
113}
114
115/// Reconnect to a session (global registry)
116pub async fn shell_session_reconnect(args: ShellSessionReconnectArgs) -> Result<ShellSessionReconnectResult, ServiceError> {
117    reconnect_session_impl(&GLOBAL_SESSION_REGISTRY, args).await
118}
119
120/// Resize a session (global registry)
121pub async fn shell_session_resize(args: ShellSessionResizeArgs) -> Result<ShellSessionResizeResult, ServiceError> {
122    resize_session_impl(&GLOBAL_SESSION_REGISTRY, args).await
123}
124
125/// Close a session (global registry)
126pub async fn shell_session_close(args: ShellSessionCloseArgs) -> Result<ShellSessionCloseResult, ServiceError> {
127    close_session_impl(&GLOBAL_SESSION_REGISTRY, args).await
128}
129
130// ============================================================================
131// Implementation functions
132// ============================================================================
133
134async fn create_session_impl(
135    registry: &ShellSessionRegistry,
136    args: ShellSessionCreateArgs,
137) -> Result<ShellSessionCreateResult, ServiceError> {
138    // Look up the target configuration
139    let target = lookup_target(&args.target_id).await?;
140
141    // Generate session ID
142    let session_id = uuid::Uuid::new_v4().to_string();
143    let tmux_session = format!("cnctd-ssh-{}", &session_id[..8]);
144
145    info!(
146        "Creating shell session {} for target {} (tmux: {})",
147        session_id, args.target_id, tmux_session
148    );
149
150    // Build the shell command that will run inside tmux
151    // We create a tmux session and then attach to it
152    let shell_cmd = args.shell.as_deref();
153
154    // Connect via SSH with PTY
155    let connection = SshConnection::connect(
156        &target.host,
157        target.port,
158        &target.user,
159        &target.key_path,
160        target.key_passphrase.as_deref(),
161        args.cols,
162        args.rows,
163        shell_cmd,
164    )
165    .await?;
166
167    // Create the session
168    let mut session = ShellSession::new(
169        session_id.clone(),
170        args.target_id.clone(),
171        args.name.clone(),
172        args.client_id.clone(),
173        tmux_session,
174        args.cols,
175        args.rows,
176        connection,
177    );
178
179    // Start output reader
180    session.start_output_reader();
181
182    // Wait briefly for initial output
183    tokio::time::sleep(std::time::Duration::from_millis(100)).await;
184
185    // Get initial screen state
186    let screen = session.screen_state().await;
187    let info = session.info().await;
188
189    // Add to registry
190    registry.add(session).await;
191
192    Ok(ShellSessionCreateResult {
193        session_id,
194        info,
195        screen,
196    })
197}
198
199async fn write_session_impl(
200    registry: &ShellSessionRegistry,
201    args: ShellSessionWriteArgs,
202) -> Result<ShellSessionWriteResult, ServiceError> {
203    let session_lock = registry
204        .get(&args.session_id)
205        .await
206        .ok_or_else(|| ServiceError::NotFound(format!("Session not found: {}", args.session_id)))?;
207
208    let session = session_lock.read().await;
209
210    // Process input - handle escape sequences
211    let mut data = process_escape_sequences(&args.input);
212    if args.newline {
213        data.push(b'\n');
214    }
215
216    let bytes_sent = session.write(&data).await?;
217
218    Ok(ShellSessionWriteResult {
219        session_id: args.session_id,
220        bytes_sent,
221    })
222}
223
224async fn read_session_impl(
225    registry: &ShellSessionRegistry,
226    args: ShellSessionReadArgs,
227) -> Result<ShellSessionReadResult, ServiceError> {
228    let session_lock = registry
229        .get(&args.session_id)
230        .await
231        .ok_or_else(|| ServiceError::NotFound(format!("Session not found: {}", args.session_id)))?;
232
233    let session = session_lock.read().await;
234
235    // Determine effective timeout (use wait_ms as the max timeout for pattern/stable waits)
236    let timeout_ms = if args.wait_ms > 0 { args.wait_ms } else { 30000 }; // Default 30s max
237
238    // Wait for pattern if requested (takes precedence)
239    let pattern_matched = if let Some(ref pattern) = args.wait_for_pattern {
240        Some(session.wait_for_pattern(pattern, timeout_ms).await)
241    } else {
242        None
243    };
244
245    // Wait for stable output if requested (and pattern wasn't requested or already matched)
246    let stabilized = if let Some(stable_ms) = args.wait_for_stable_ms {
247        // Only wait for stable if we're not waiting for pattern, or pattern was found
248        if args.wait_for_pattern.is_none() || pattern_matched == Some(true) {
249            Some(session.wait_for_stable(stable_ms, timeout_ms).await)
250        } else {
251            Some(false)
252        }
253    } else {
254        None
255    };
256
257    // If no special waits, use the basic wait_ms/min_bytes
258    if args.wait_for_pattern.is_none() && args.wait_for_stable_ms.is_none() && args.wait_ms > 0 {
259        session.wait_for_output(args.wait_ms, args.min_bytes).await;
260    }
261
262    let state = *session.state.read().await;
263
264    // Get output based on format
265    let (raw, screen, buffer_size, truncated) = match args.format {
266        OutputFormat::Raw => {
267            let (text, remaining, truncated) = session.read(args.consume).await;
268            (Some(text), None, remaining, truncated)
269        }
270        OutputFormat::Stripped => {
271            let (text, remaining, truncated) = session.read(args.consume).await;
272            let stripped = strip_ansi_codes(&text);
273            (Some(stripped), None, remaining, truncated)
274        }
275        OutputFormat::Screen => {
276            let screen = session.screen_state().await;
277            // Don't consume buffer when only getting screen
278            let (_, remaining, truncated) = session.read(false).await;
279            (None, Some(screen), remaining, truncated)
280        }
281        OutputFormat::Both => {
282            let (text, remaining, truncated) = session.read(args.consume).await;
283            let screen = session.screen_state().await;
284            (Some(text), Some(screen), remaining, truncated)
285        }
286    };
287
288    Ok(ShellSessionReadResult {
289        session_id: args.session_id,
290        raw,
291        screen,
292        buffer_size,
293        truncated,
294        state,
295        pattern_matched,
296        stabilized,
297    })
298}
299
300async fn list_sessions_impl(
301    registry: &ShellSessionRegistry,
302    args: ShellSessionListArgs,
303) -> Result<ShellSessionListResult, ServiceError> {
304    let sessions = registry
305        .list(
306            args.target_id.as_deref(),
307            args.client_id.as_deref(),
308            args.include_disconnected,
309        )
310        .await;
311
312    Ok(ShellSessionListResult { sessions })
313}
314
315async fn reconnect_session_impl(
316    _registry: &ShellSessionRegistry,
317    _args: ShellSessionReconnectArgs,
318) -> Result<ShellSessionReconnectResult, ServiceError> {
319    // TODO: Implement reconnection
320    // This would involve:
321    // 1. Finding the session in the registry
322    // 2. Re-establishing the SSH connection
323    // 3. Attaching to the existing tmux session
324    // 4. Restarting the output reader
325
326    Err(ServiceError::Internal(
327        "Reconnection not yet implemented - coming soon".to_string(),
328    ))
329}
330
331async fn resize_session_impl(
332    registry: &ShellSessionRegistry,
333    args: ShellSessionResizeArgs,
334) -> Result<ShellSessionResizeResult, ServiceError> {
335    let session_lock = registry
336        .get(&args.session_id)
337        .await
338        .ok_or_else(|| ServiceError::NotFound(format!("Session not found: {}", args.session_id)))?;
339
340    let mut session = session_lock.write().await;
341    session.resize(args.cols, args.rows).await?;
342
343    Ok(ShellSessionResizeResult {
344        session_id: args.session_id,
345        size: (args.cols, args.rows),
346    })
347}
348
349async fn close_session_impl(
350    registry: &ShellSessionRegistry,
351    args: ShellSessionCloseArgs,
352) -> Result<ShellSessionCloseResult, ServiceError> {
353    let session_lock = registry
354        .remove(&args.session_id)
355        .await
356        .ok_or_else(|| ServiceError::NotFound(format!("Session not found: {}", args.session_id)))?;
357
358    // Take ownership of the session
359    let session = match Arc::try_unwrap(session_lock) {
360        Ok(rwlock) => rwlock.into_inner(),
361        Err(_) => {
362            return Err(ServiceError::Internal(
363                "Session is still in use".to_string(),
364            ))
365        }
366    };
367
368    let closed = session.close(args.force).await?;
369
370    Ok(ShellSessionCloseResult {
371        session_id: args.session_id,
372        closed,
373    })
374}
375
376/// Process escape sequences in input string
377/// Converts \xNN hex escapes and \n, \r, \t, etc.
378fn process_escape_sequences(input: &str) -> Vec<u8> {
379    let mut result = Vec::with_capacity(input.len());
380    let mut chars = input.chars().peekable();
381
382    while let Some(c) = chars.next() {
383        if c == '\\' {
384            match chars.peek() {
385                Some('x') => {
386                    chars.next(); // consume 'x'
387                    let hex: String = chars.by_ref().take(2).collect();
388                    if let Ok(byte) = u8::from_str_radix(&hex, 16) {
389                        result.push(byte);
390                    } else {
391                        // Invalid hex, output as-is
392                        result.extend_from_slice(b"\\x");
393                        result.extend_from_slice(hex.as_bytes());
394                    }
395                }
396                Some('n') => {
397                    chars.next();
398                    result.push(b'\n');
399                }
400                Some('r') => {
401                    chars.next();
402                    result.push(b'\r');
403                }
404                Some('t') => {
405                    chars.next();
406                    result.push(b'\t');
407                }
408                Some('\\') => {
409                    chars.next();
410                    result.push(b'\\');
411                }
412                Some('0') => {
413                    chars.next();
414                    result.push(0);
415                }
416                _ => {
417                    result.push(b'\\');
418                }
419            }
420        } else {
421            let mut buf = [0u8; 4];
422            let s = c.encode_utf8(&mut buf);
423            result.extend_from_slice(s.as_bytes());
424        }
425    }
426
427    result
428}
429
430/// Strip ANSI escape sequences from a string
431/// Removes CSI sequences like colors, cursor movement, etc.
432fn strip_ansi_codes(input: &str) -> String {
433    let mut result = String::with_capacity(input.len());
434    let mut chars = input.chars().peekable();
435
436    while let Some(c) = chars.next() {
437        if c == '\x1b' {
438            // ESC character - start of escape sequence
439            if chars.peek() == Some(&'[') {
440                chars.next(); // consume '['
441                // Skip until we hit a letter (end of CSI sequence)
442                while let Some(&ch) = chars.peek() {
443                    chars.next();
444                    if ch.is_ascii_alphabetic() || ch == '~' {
445                        break;
446                    }
447                }
448            } else if chars.peek() == Some(&']') {
449                // OSC sequence (operating system command) - skip until BEL or ST
450                chars.next(); // consume ']'
451                while let Some(&ch) = chars.peek() {
452                    chars.next();
453                    if ch == '\x07' {
454                        break; // BEL
455                    }
456                    if ch == '\x1b' {
457                        if chars.peek() == Some(&'\\') {
458                            chars.next(); // consume '\\' for ST
459                            break;
460                        }
461                    }
462                }
463            } else {
464                // Other escape sequences - skip the next character
465                chars.next();
466            }
467        } else if c == '\x0f' || c == '\x0e' {
468            // SI/SO - shift in/out, skip
469        } else if c.is_control() && c != '\n' && c != '\r' && c != '\t' {
470            // Skip other control characters (except newline, carriage return, tab)
471        } else {
472            result.push(c);
473        }
474    }
475
476    result
477}
478
479/// Get tool definitions for shell session tools
480pub fn get_shell_session_tool_definitions() -> Vec<crate::operations::ToolDefinition> {
481    use crate::operations::ToolDefinition;
482    use schemars::schema_for;
483
484    vec![
485        ToolDefinition {
486            name: "shell_session_create".to_string(),
487            description: "Create a new interactive shell session on a registered SSH target. Sessions are persistent and survive disconnections.".to_string(),
488            input_schema: serde_json::to_value(schema_for!(ShellSessionCreateArgs)).unwrap_or_default(),
489        },
490        ToolDefinition {
491            name: "shell_session_write".to_string(),
492            description: "Send input to an interactive shell session. Use for commands, keystrokes (\\x03 for Ctrl+C), or any terminal input.".to_string(),
493            input_schema: serde_json::to_value(schema_for!(ShellSessionWriteArgs)).unwrap_or_default(),
494        },
495        ToolDefinition {
496            name: "shell_session_read".to_string(),
497            description: "Read output from an interactive shell session. Supports raw output, screen state (for TUI apps), or both.".to_string(),
498            input_schema: serde_json::to_value(schema_for!(ShellSessionReadArgs)).unwrap_or_default(),
499        },
500        ToolDefinition {
501            name: "shell_session_list".to_string(),
502            description: "List interactive shell sessions. Can filter by target or client ID.".to_string(),
503            input_schema: serde_json::to_value(schema_for!(ShellSessionListArgs)).unwrap_or_default(),
504        },
505        ToolDefinition {
506            name: "shell_session_reconnect".to_string(),
507            description: "Reconnect to a disconnected shell session. The session must still exist on the remote server.".to_string(),
508            input_schema: serde_json::to_value(schema_for!(ShellSessionReconnectArgs)).unwrap_or_default(),
509        },
510        ToolDefinition {
511            name: "shell_session_resize".to_string(),
512            description: "Resize a shell session's terminal dimensions. Important for TUI applications.".to_string(),
513            input_schema: serde_json::to_value(schema_for!(ShellSessionResizeArgs)).unwrap_or_default(),
514        },
515        ToolDefinition {
516            name: "shell_session_close".to_string(),
517            description: "Close an interactive shell session. Terminates the remote shell.".to_string(),
518            input_schema: serde_json::to_value(schema_for!(ShellSessionCloseArgs)).unwrap_or_default(),
519        },
520    ]
521}