Skip to main content

aft/commands/
bash_kill.rs

1use crate::context::AppContext;
2use crate::protocol::{RawRequest, Response};
3use serde::Deserialize;
4use serde_json::json;
5
6#[derive(Debug, Deserialize)]
7struct BashKillParams {
8    #[serde(default)]
9    task_id: Option<String>,
10}
11
12pub fn handle(req: &RawRequest, ctx: &AppContext) -> Response {
13    let raw_params = req
14        .params
15        .get("params")
16        .cloned()
17        .unwrap_or_else(|| req.params.clone());
18    let params = match serde_json::from_value::<BashKillParams>(raw_params) {
19        Ok(params) => params,
20        Err(e) => {
21            return Response::error(
22                &req.id,
23                "invalid_request",
24                format!("bash_kill: invalid params: {e}"),
25            );
26        }
27    };
28
29    let Some(task_id) = params.task_id else {
30        return Response::error(&req.id, "invalid_request", "bash_kill: missing task_id");
31    };
32
33    let storage_dir = crate::bash_background::storage_dir(ctx.config().storage_dir.as_deref());
34    let result = ctx
35        .bash_background()
36        .kill(&task_id, req.session())
37        .or_else(|message| {
38            if !message.contains("not found") {
39                return Err(message);
40            }
41            {
42                let config = ctx.config();
43                let _ = if let Some(project_root) = config.project_root.as_deref() {
44                    ctx.bash_background().replay_session_for_project(
45                        &storage_dir,
46                        req.session(),
47                        project_root,
48                    )
49                } else {
50                    ctx.bash_background()
51                        .replay_session(&storage_dir, req.session())
52                };
53            }
54            ctx.bash_background().kill(&task_id, req.session())
55        })
56        .or_else(|message| {
57            if !message.contains("not found") {
58                return Err(message);
59            }
60            let config = ctx.config();
61            let Some(project_root) = config.project_root.as_deref() else {
62                return Err(message);
63            };
64            ctx.bash_background()
65                .kill_relaxed(&task_id, project_root, &storage_dir)
66        });
67
68    match result {
69        Ok(snapshot) => Response::success(&req.id, json!(snapshot)),
70        Err(message) if message.contains("not found") => {
71            Response::error(&req.id, "task_not_found", message)
72        }
73        Err(message) => Response::error(&req.id, "kill_failed", message),
74    }
75}
76
77#[cfg(test)]
78mod tests {
79    use std::fs;
80    use std::path::Path;
81    use std::sync::Arc;
82
83    use serde_json::json;
84
85    use super::*;
86    use crate::bash_background::persistence::{task_paths, write_task, PersistedTask};
87    use crate::bash_background::BgTaskStatus;
88    use crate::config::Config;
89    use crate::context::{App, AppContext};
90
91    fn actor(app: &Arc<App>, project: &Path, storage: &Path) -> AppContext {
92        let config = Config {
93            project_root: Some(project.to_path_buf()),
94            storage_dir: Some(storage.to_path_buf()),
95            ..Config::default()
96        };
97        AppContext::from_app(Arc::clone(app), config)
98    }
99
100    fn write_running_project_task(storage: &Path, project: &Path, session: &str, task_id: &str) {
101        let paths = task_paths(storage, session, task_id);
102        let mut metadata = PersistedTask::starting(
103            task_id.to_string(),
104            session.to_string(),
105            "sleep 60".to_string(),
106            project.to_path_buf(),
107            Some(project.to_path_buf()),
108            Some(30_000),
109            true,
110            true,
111        );
112        metadata.status = BgTaskStatus::Running;
113        write_task(&paths.json, &metadata).unwrap();
114        fs::write(&paths.stdout, "still running\n").unwrap();
115        fs::write(&paths.stderr, "").unwrap();
116    }
117
118    fn kill_request(task_id: &str, session: &str) -> RawRequest {
119        RawRequest {
120            id: "kill-project-filter".to_string(),
121            command: "bash_kill".to_string(),
122            lsp_hints: None,
123            session_id: Some(session.to_string()),
124            params: json!({ "params": { "task_id": task_id } }),
125        }
126    }
127
128    #[test]
129    fn bash_kill_replay_filters_same_session_by_project_root() {
130        let project_a = tempfile::tempdir().unwrap();
131        let project_b = tempfile::tempdir().unwrap();
132        let storage = tempfile::tempdir().unwrap();
133        let app = App::default_shared();
134        let ctx_a = actor(&app, project_a.path(), storage.path());
135        let ctx_b = actor(&app, project_b.path(), storage.path());
136        let session = "shared-session";
137        let task_id = "bash-project-a";
138        write_running_project_task(storage.path(), project_a.path(), session, task_id);
139
140        let miss = serde_json::to_value(handle(&kill_request(task_id, session), &ctx_b)).unwrap();
141        assert_eq!(
142            miss["success"], false,
143            "wrong project killed task: {miss:?}"
144        );
145        assert_eq!(miss["code"], "task_not_found");
146
147        let killed = serde_json::to_value(handle(&kill_request(task_id, session), &ctx_a)).unwrap();
148        assert_eq!(
149            killed["success"], true,
150            "owning project kill failed: {killed:?}"
151        );
152        assert_eq!(killed["status"], "killed");
153    }
154}