Skip to main content

pty_mcp/app/
local_sessions.rs

1use anyhow::Result;
2use chrono::Utc;
3
4use crate::{
5    buffer::{BufferReadPage, BufferReadRequest},
6    permission::SpawnValidationInput,
7    pty::PtySpawnRequest,
8    session::{
9        SessionId, SessionKillResult, SessionStatus, SessionSummary, SessionTransport,
10        SessionWaitResult, SessionWriteResult, SignalKind,
11    },
12};
13
14use super::{LocalSessionService, SshService, types::SpawnSessionRequest};
15
16fn session_description(description: Option<String>, command: &str) -> String {
17    description
18        .map(|value| value.trim().to_string())
19        .filter(|value| !value.is_empty())
20        .unwrap_or_else(|| format!("PTY session: {command}"))
21}
22
23impl LocalSessionService {
24    pub fn list_sessions(&self) -> Vec<SessionSummary> {
25        self.context.registry.list()
26    }
27
28    pub fn get_session(&self, session_id: &SessionId) -> Option<SessionSummary> {
29        self.context.registry.get(session_id)
30    }
31
32    pub fn seed_session(&self, session: SessionSummary) {
33        self.context.registry.insert(session);
34    }
35
36    pub async fn spawn_session(&self, request: SpawnSessionRequest) -> Result<SessionSummary> {
37        let validated = self.context.guard.validate_spawn(SpawnValidationInput {
38            command: &request.command,
39            args: &request.args,
40            cwd: request.cwd.as_deref(),
41            env: request.env.as_ref(),
42        })?;
43
44        let session = SessionSummary {
45            session_id: SessionId::new(),
46            title: request.title,
47            description: session_description(request.description, &validated.command),
48            transport: SessionTransport::Local,
49            command: validated.command.clone(),
50            args: validated.args.clone(),
51            cwd: validated.cwd.as_ref().map(|cwd| cwd.display().to_string()),
52            connection_id: None,
53            target_summary: None,
54            remote_cwd: None,
55            remote_command: None,
56            remote_env_preview: Default::default(),
57            status: SessionStatus::Starting,
58            pid: None,
59            started_at: Utc::now(),
60            buffer_stats: Default::default(),
61            exit_info: None,
62        };
63        let session_id = self.context.registry.create_starting(session)?;
64
65        let mut runtime_request = PtySpawnRequest::new(validated.command).args(validated.args);
66        if let Some(cwd) = validated.cwd {
67            runtime_request = runtime_request.cwd(cwd);
68        }
69        for (key, value) in validated.env {
70            runtime_request = runtime_request.env(key, value);
71        }
72
73        match self.context.runtime.spawn(runtime_request).await {
74            Ok(spawned) => {
75                self.context.registry.attach_runtime(
76                    &session_id,
77                    spawned.pid,
78                    spawned.handle,
79                    spawned.output,
80                )?;
81            }
82            Err(error) => {
83                let _ = self.context.registry.mark_failed_to_spawn(&session_id);
84                return Err(error);
85            }
86        }
87
88        Ok(self
89            .context
90            .registry
91            .get(&session_id)
92            .expect("session disappeared after spawn"))
93    }
94
95    pub async fn write_session(
96        &self,
97        session_id: &SessionId,
98        data: &str,
99        escaped: bool,
100    ) -> Result<SessionWriteResult> {
101        if escaped {
102            self.context.registry.write_escaped(session_id, data).await
103        } else {
104            self.context.registry.write_plain(session_id, data).await
105        }
106    }
107
108    pub fn read_session(
109        &self,
110        session_id: &SessionId,
111        request: &BufferReadRequest,
112    ) -> Result<BufferReadPage> {
113        self.context.registry.read_output(session_id, request)
114    }
115
116    pub async fn kill_session(
117        &self,
118        session_id: &SessionId,
119        signal: SignalKind,
120        cleanup: bool,
121    ) -> Result<SessionKillResult> {
122        let outcome = self
123            .context
124            .registry
125            .kill(session_id, signal, cleanup)
126            .await?;
127        SshService::refresh_session_tracking_with_context(&self.context, session_id);
128        Ok(outcome)
129    }
130
131    pub async fn wait_session(
132        &self,
133        session_id: &SessionId,
134        timeout: Option<std::time::Duration>,
135    ) -> Result<SessionWaitResult> {
136        let outcome = self.context.registry.wait(session_id, timeout).await?;
137        SshService::refresh_session_tracking_with_context(&self.context, session_id);
138        Ok(outcome)
139    }
140
141    pub async fn shutdown(&self) -> Result<()> {
142        self.context.registry.shutdown().await
143    }
144}
145
146#[cfg(test)]
147mod tests {
148    use chrono::Utc;
149
150    use crate::{
151        Config,
152        session::{
153            ExitInfo, SessionId, SessionStatus, SessionSummary, SessionTransport, SignalKind,
154        },
155        ssh::SshConnectionStatus,
156    };
157
158    use super::*;
159
160    #[tokio::test]
161    async fn spawn_failure_marks_session_failed_to_spawn() {
162        let local = super::super::AppState::new(Config::default())
163            .local()
164            .clone();
165        let before = local.list_sessions().len();
166
167        let error = local
168            .spawn_session(SpawnSessionRequest {
169                command: "/definitely/not/a/real/command".into(),
170                args: Vec::new(),
171                cwd: None,
172                env: None,
173                title: None,
174                description: Some("spawn failure".into()),
175            })
176            .await
177            .unwrap_err();
178
179        assert!(!error.to_string().is_empty());
180        let sessions = local.list_sessions();
181        assert_eq!(sessions.len(), before + 1);
182        assert!(
183            sessions
184                .iter()
185                .any(|session| session.status == SessionStatus::FailedToSpawn)
186        );
187    }
188
189    #[tokio::test]
190    async fn kill_session_refreshes_ssh_tracking() {
191        let app = super::super::AppState::new(Config::default());
192        let mut connection = app
193            .ssh()
194            .create_placeholder_connection(crate::ssh::SshTarget {
195                host_alias: None,
196                host: "example.com".into(),
197                user: None,
198                port: None,
199            });
200        connection.status = SshConnectionStatus::Ready;
201        app.ssh().upsert_connection(connection.clone());
202
203        let session = SessionSummary {
204            session_id: SessionId::new(),
205            title: None,
206            description: "exited".into(),
207            command: "ssh".into(),
208            args: Vec::new(),
209            cwd: None,
210            transport: SessionTransport::Ssh,
211            connection_id: Some(connection.connection_id.clone()),
212            target_summary: Some(connection.target_summary.clone()),
213            remote_cwd: None,
214            remote_command: None,
215            remote_env_preview: Default::default(),
216            status: SessionStatus::Exited,
217            pid: None,
218            started_at: Utc::now(),
219            buffer_stats: Default::default(),
220            exit_info: Some(ExitInfo::default()),
221        };
222        app.local().seed_session(session.clone());
223        app.ssh()
224            .track_session(&connection.connection_id, session.session_id.clone())
225            .unwrap();
226
227        let outcome = app
228            .local()
229            .kill_session(&session.session_id, SignalKind::Sigterm, false)
230            .await
231            .unwrap();
232
233        assert_eq!(outcome.current_status, SessionStatus::Exited);
234        let relations = app
235            .ssh()
236            .connection_relations(&connection.connection_id)
237            .unwrap();
238        assert!(relations.session_ids.is_empty());
239        assert_eq!(
240            app.ssh()
241                .get_connection(&connection.connection_id)
242                .unwrap()
243                .status,
244            SshConnectionStatus::Ready
245        );
246    }
247}