Skip to main content

embacle/
goose_cli.rs

1// ABOUTME: Goose CLI runner implementing the `LlmProvider` trait
2// ABOUTME: Wraps the `goose` CLI with JSON/stream-JSON output parsing and session resume
3//
4// SPDX-License-Identifier: Apache-2.0
5// Copyright (c) 2026 dravr.ai
6
7use std::io;
8use std::process::Stdio;
9use std::str;
10
11use crate::cli_common::{CliRunnerBase, MAX_OUTPUT_BYTES};
12use crate::types::{
13    ChatRequest, ChatResponse, ChatStream, LlmCapabilities, LlmProvider, RunnerError, StreamChunk,
14};
15use async_trait::async_trait;
16use tempfile::NamedTempFile;
17use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
18use tokio::process::Command;
19use tokio_stream::wrappers::LinesStream;
20use tokio_stream::StreamExt;
21use tracing::instrument;
22
23use crate::config::RunnerConfig;
24use crate::process::{read_stderr_capped, run_cli_command};
25use crate::prompt::build_user_prompt;
26use crate::sandbox::{apply_sandbox, build_policy};
27use crate::stream::{GuardedStream, MAX_STREAMING_STDERR_BYTES};
28
29/// Default model for Goose CLI (provider-agnostic)
30const DEFAULT_MODEL: &str = "auto";
31
32/// Fallback model list (Goose delegates to whatever backend the user configured)
33const FALLBACK_MODELS: &[&str] = &["auto"];
34
35/// Goose CLI runner
36///
37/// Implements `LlmProvider` by delegating to the `goose` binary with
38/// `--output-format json` for complete responses and `--output-format stream-json`
39/// for streaming. Uses `--quiet` to suppress progress output and `--no-session`
40/// for stateless invocations.
41pub struct GooseCliRunner {
42    base: CliRunnerBase,
43}
44
45impl GooseCliRunner {
46    /// Create a new Goose CLI runner with the given configuration
47    #[must_use]
48    pub fn new(config: RunnerConfig) -> Self {
49        Self {
50            base: CliRunnerBase::new(config, DEFAULT_MODEL, FALLBACK_MODELS),
51        }
52    }
53
54    /// Store a session ID for later resumption
55    pub async fn set_session(&self, key: &str, session_id: &str) {
56        self.base.set_session(key, session_id).await;
57    }
58
59    /// Build the base command with common arguments (without prompt delivery)
60    fn build_command_base(&self, output_format: &str) -> Command {
61        let mut cmd = Command::new(&self.base.config.binary_path);
62        cmd.args([
63            "run",
64            "--quiet",
65            "--no-session",
66            "--output-format",
67            output_format,
68        ]);
69
70        for arg in &self.base.config.extra_args {
71            cmd.arg(arg);
72        }
73
74        if let Ok(policy) = build_policy(
75            self.base.config.working_directory.as_deref(),
76            &self.base.config.allowed_env_keys,
77        ) {
78            apply_sandbox(&mut cmd, &policy);
79        }
80
81        cmd
82    }
83
84    /// Parse a JSON response from `goose run --output-format json`
85    ///
86    /// The response contains a `messages` array; we extract the last assistant
87    /// message and join its `content` text parts.
88    fn parse_json_response(raw: &[u8]) -> Result<ChatResponse, RunnerError> {
89        let text = str::from_utf8(raw).map_err(|e| {
90            RunnerError::internal(format!("Goose CLI output is not valid UTF-8: {e}"))
91        })?;
92
93        let value: serde_json::Value = serde_json::from_str(text).map_err(|e| {
94            RunnerError::internal(format!("Failed to parse Goose JSON response: {e}"))
95        })?;
96
97        let messages = value
98            .get("messages")
99            .and_then(|v| v.as_array())
100            .ok_or_else(|| RunnerError::internal("Goose response missing 'messages' array"))?;
101
102        // Find the last assistant message and join its content text parts
103        let mut content = String::new();
104        for msg in messages.iter().rev() {
105            let role = msg.get("role").and_then(|v| v.as_str()).unwrap_or("");
106            if role == "assistant" {
107                if let Some(parts) = msg.get("content").and_then(|v| v.as_array()) {
108                    for part in parts {
109                        let part_type = part.get("type").and_then(|v| v.as_str()).unwrap_or("");
110                        if part_type == "text" {
111                            if let Some(t) = part.get("text").and_then(|v| v.as_str()) {
112                                content.push_str(t);
113                            }
114                        }
115                    }
116                }
117                break;
118            }
119        }
120
121        Ok(ChatResponse {
122            content,
123            model: "goose".to_owned(),
124            usage: None,
125            finish_reason: Some("stop".to_owned()),
126            warnings: None,
127        })
128    }
129}
130
131#[async_trait]
132impl LlmProvider for GooseCliRunner {
133    crate::delegate_provider_base!("goose", "Goose CLI", LlmCapabilities::STREAMING);
134
135    #[instrument(skip_all, fields(runner = "goose"))]
136    async fn complete(&self, request: &ChatRequest) -> Result<ChatResponse, RunnerError> {
137        let prompt = build_user_prompt(&request.messages);
138
139        // Write prompt to a temp file since Goose reads from `-i <path>`
140        let mut prompt_file = NamedTempFile::new().map_err(|e| {
141            RunnerError::internal(format!("Failed to create temp file for Goose prompt: {e}"))
142        })?;
143        std::io::Write::write_all(&mut prompt_file, prompt.as_bytes()).map_err(|e| {
144            RunnerError::internal(format!("Failed to write Goose prompt to temp file: {e}"))
145        })?;
146
147        let mut cmd = self.build_command_base("json");
148        cmd.args(["-i", &prompt_file.path().display().to_string()]);
149
150        if let Some(model) = &request.model {
151            if let Some(sid) = self.base.get_session(model).await {
152                cmd.args(["--session-id", &sid, "--resume"]);
153            }
154        }
155
156        let output = run_cli_command(&mut cmd, self.base.config.timeout, MAX_OUTPUT_BYTES).await?;
157        self.base.check_exit_code(&output, "goose")?;
158
159        Self::parse_json_response(&output.stdout)
160    }
161
162    #[instrument(skip_all, fields(runner = "goose"))]
163    async fn complete_stream(&self, request: &ChatRequest) -> Result<ChatStream, RunnerError> {
164        let prompt = build_user_prompt(&request.messages);
165
166        let mut cmd = self.build_command_base("stream-json");
167        cmd.args(["-i", "-"]);
168
169        if let Some(model) = &request.model {
170            if let Some(sid) = self.base.get_session(model).await {
171                cmd.args(["--session-id", &sid, "--resume"]);
172            }
173        }
174
175        cmd.stdin(Stdio::piped());
176        cmd.stdout(Stdio::piped());
177        cmd.stderr(Stdio::piped());
178
179        let mut child = cmd.spawn().map_err(|e| {
180            RunnerError::internal(format!("Failed to spawn goose for streaming: {e}"))
181        })?;
182
183        // Write prompt to stdin then close it
184        let mut stdin = child
185            .stdin
186            .take()
187            .ok_or_else(|| RunnerError::internal("Failed to capture goose stdin for streaming"))?;
188        tokio::spawn(async move {
189            let _ = stdin.write_all(prompt.as_bytes()).await;
190            let _ = stdin.shutdown().await;
191        });
192
193        let stdout = child
194            .stdout
195            .take()
196            .ok_or_else(|| RunnerError::internal("Failed to capture goose stdout for streaming"))?;
197
198        let stderr_task = tokio::spawn(read_stderr_capped(
199            child.stderr.take(),
200            MAX_STREAMING_STDERR_BYTES,
201        ));
202
203        let reader = BufReader::new(stdout);
204        let lines = LinesStream::new(reader.lines());
205
206        let stream = lines.map(move |line_result: Result<String, io::Error>| {
207            let line = line_result
208                .map_err(|e| RunnerError::internal(format!("Error reading goose stream: {e}")))?;
209
210            if line.trim().is_empty() {
211                return Ok(StreamChunk {
212                    delta: String::new(),
213                    is_final: false,
214                    finish_reason: None,
215                });
216            }
217
218            let value: serde_json::Value = serde_json::from_str(&line)
219                .map_err(|e| RunnerError::internal(format!("Invalid JSON in goose stream: {e}")))?;
220
221            let chunk_type = value.get("type").and_then(|v| v.as_str()).unwrap_or("");
222            match chunk_type {
223                "message" => {
224                    // Extract text from message.content[] text parts
225                    let mut delta = String::new();
226                    if let Some(msg) = value.get("message") {
227                        if let Some(parts) = msg.get("content").and_then(|v| v.as_array()) {
228                            for part in parts {
229                                let pt = part.get("type").and_then(|v| v.as_str()).unwrap_or("");
230                                if pt == "text" {
231                                    if let Some(t) = part.get("text").and_then(|v| v.as_str()) {
232                                        delta.push_str(t);
233                                    }
234                                }
235                            }
236                        }
237                    }
238                    Ok(StreamChunk {
239                        delta,
240                        is_final: false,
241                        finish_reason: None,
242                    })
243                }
244                "complete" => Ok(StreamChunk {
245                    delta: String::new(),
246                    is_final: true,
247                    finish_reason: Some("stop".to_owned()),
248                }),
249                _ => Ok(StreamChunk {
250                    delta: String::new(),
251                    is_final: false,
252                    finish_reason: None,
253                }),
254            }
255        });
256
257        Ok(Box::pin(GuardedStream::new(stream, child, stderr_task)))
258    }
259}
260
261#[cfg(test)]
262mod tests {
263    use super::*;
264    use std::path::PathBuf;
265
266    #[test]
267    fn test_parse_json_response_basic() {
268        let json = br#"{"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]},{"role":"assistant","content":[{"type":"text","text":"hello"}]}],"metadata":{"total_tokens":null,"status":"completed"}}"#;
269        let resp = GooseCliRunner::parse_json_response(json).unwrap();
270        assert_eq!(resp.content, "hello");
271        assert!(resp.usage.is_none());
272    }
273
274    #[test]
275    fn test_parse_json_response_multi_content() {
276        let json = br#"{"messages":[{"role":"assistant","content":[{"type":"text","text":"part1"},{"type":"text","text":"part2"}]}],"metadata":{"status":"completed"}}"#;
277        let resp = GooseCliRunner::parse_json_response(json).unwrap();
278        assert_eq!(resp.content, "part1part2");
279    }
280
281    #[test]
282    fn test_parse_json_response_skips_user_messages() {
283        let json = br#"{"messages":[{"role":"user","content":[{"type":"text","text":"ignored"}]},{"role":"assistant","content":[{"type":"text","text":"kept"}]}],"metadata":{}}"#;
284        let resp = GooseCliRunner::parse_json_response(json).unwrap();
285        assert_eq!(resp.content, "kept");
286    }
287
288    #[test]
289    fn test_default_model() {
290        let config = RunnerConfig::new(PathBuf::from("goose"));
291        let runner = GooseCliRunner::new(config);
292        assert_eq!(runner.default_model(), "auto");
293    }
294
295    #[test]
296    fn test_capabilities() {
297        let config = RunnerConfig::new(PathBuf::from("goose"));
298        let runner = GooseCliRunner::new(config);
299        assert!(runner.capabilities().supports_streaming());
300    }
301
302    #[test]
303    fn test_name_and_display() {
304        let config = RunnerConfig::new(PathBuf::from("goose"));
305        let runner = GooseCliRunner::new(config);
306        assert_eq!(runner.name(), "goose");
307        assert_eq!(runner.display_name(), "Goose CLI");
308    }
309}