1use std::io;
8use std::process::Stdio;
9use std::str;
10
11use crate::cli_common::{CliRunnerBase, MAX_OUTPUT_BYTES};
12use crate::types::{
13 ChatRequest, ChatResponse, ChatStream, LlmCapabilities, LlmProvider, RunnerError, StreamChunk,
14};
15use async_trait::async_trait;
16use tempfile::NamedTempFile;
17use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
18use tokio::process::Command;
19use tokio_stream::wrappers::LinesStream;
20use tokio_stream::StreamExt;
21use tracing::instrument;
22
23use crate::config::RunnerConfig;
24use crate::process::{read_stderr_capped, run_cli_command};
25use crate::prompt::build_user_prompt;
26use crate::sandbox::{apply_sandbox, build_policy};
27use crate::stream::{GuardedStream, MAX_STREAMING_STDERR_BYTES};
28
29const DEFAULT_MODEL: &str = "auto";
31
32const FALLBACK_MODELS: &[&str] = &["auto"];
34
35pub struct GooseCliRunner {
42 base: CliRunnerBase,
43}
44
45impl GooseCliRunner {
46 #[must_use]
48 pub fn new(config: RunnerConfig) -> Self {
49 Self {
50 base: CliRunnerBase::new(config, DEFAULT_MODEL, FALLBACK_MODELS),
51 }
52 }
53
54 pub async fn set_session(&self, key: &str, session_id: &str) {
56 self.base.set_session(key, session_id).await;
57 }
58
59 fn build_command_base(&self, output_format: &str) -> Command {
61 let mut cmd = Command::new(&self.base.config.binary_path);
62 cmd.args([
63 "run",
64 "--quiet",
65 "--no-session",
66 "--output-format",
67 output_format,
68 ]);
69
70 for arg in &self.base.config.extra_args {
71 cmd.arg(arg);
72 }
73
74 if let Ok(policy) = build_policy(
75 self.base.config.working_directory.as_deref(),
76 &self.base.config.allowed_env_keys,
77 ) {
78 apply_sandbox(&mut cmd, &policy);
79 }
80
81 cmd
82 }
83
84 fn parse_json_response(raw: &[u8]) -> Result<ChatResponse, RunnerError> {
89 let text = str::from_utf8(raw).map_err(|e| {
90 RunnerError::internal(format!("Goose CLI output is not valid UTF-8: {e}"))
91 })?;
92
93 let value: serde_json::Value = serde_json::from_str(text).map_err(|e| {
94 RunnerError::internal(format!("Failed to parse Goose JSON response: {e}"))
95 })?;
96
97 let messages = value
98 .get("messages")
99 .and_then(|v| v.as_array())
100 .ok_or_else(|| RunnerError::internal("Goose response missing 'messages' array"))?;
101
102 let mut content = String::new();
104 for msg in messages.iter().rev() {
105 let role = msg.get("role").and_then(|v| v.as_str()).unwrap_or("");
106 if role == "assistant" {
107 if let Some(parts) = msg.get("content").and_then(|v| v.as_array()) {
108 for part in parts {
109 let part_type = part.get("type").and_then(|v| v.as_str()).unwrap_or("");
110 if part_type == "text" {
111 if let Some(t) = part.get("text").and_then(|v| v.as_str()) {
112 content.push_str(t);
113 }
114 }
115 }
116 }
117 break;
118 }
119 }
120
121 Ok(ChatResponse {
122 content,
123 model: "goose".to_owned(),
124 usage: None,
125 finish_reason: Some("stop".to_owned()),
126 warnings: None,
127 })
128 }
129}
130
131#[async_trait]
132impl LlmProvider for GooseCliRunner {
133 crate::delegate_provider_base!("goose", "Goose CLI", LlmCapabilities::STREAMING);
134
135 #[instrument(skip_all, fields(runner = "goose"))]
136 async fn complete(&self, request: &ChatRequest) -> Result<ChatResponse, RunnerError> {
137 let prompt = build_user_prompt(&request.messages);
138
139 let mut prompt_file = NamedTempFile::new().map_err(|e| {
141 RunnerError::internal(format!("Failed to create temp file for Goose prompt: {e}"))
142 })?;
143 std::io::Write::write_all(&mut prompt_file, prompt.as_bytes()).map_err(|e| {
144 RunnerError::internal(format!("Failed to write Goose prompt to temp file: {e}"))
145 })?;
146
147 let mut cmd = self.build_command_base("json");
148 cmd.args(["-i", &prompt_file.path().display().to_string()]);
149
150 if let Some(model) = &request.model {
151 if let Some(sid) = self.base.get_session(model).await {
152 cmd.args(["--session-id", &sid, "--resume"]);
153 }
154 }
155
156 let output = run_cli_command(&mut cmd, self.base.config.timeout, MAX_OUTPUT_BYTES).await?;
157 self.base.check_exit_code(&output, "goose")?;
158
159 Self::parse_json_response(&output.stdout)
160 }
161
162 #[instrument(skip_all, fields(runner = "goose"))]
163 async fn complete_stream(&self, request: &ChatRequest) -> Result<ChatStream, RunnerError> {
164 let prompt = build_user_prompt(&request.messages);
165
166 let mut cmd = self.build_command_base("stream-json");
167 cmd.args(["-i", "-"]);
168
169 if let Some(model) = &request.model {
170 if let Some(sid) = self.base.get_session(model).await {
171 cmd.args(["--session-id", &sid, "--resume"]);
172 }
173 }
174
175 cmd.stdin(Stdio::piped());
176 cmd.stdout(Stdio::piped());
177 cmd.stderr(Stdio::piped());
178
179 let mut child = cmd.spawn().map_err(|e| {
180 RunnerError::internal(format!("Failed to spawn goose for streaming: {e}"))
181 })?;
182
183 let mut stdin = child
185 .stdin
186 .take()
187 .ok_or_else(|| RunnerError::internal("Failed to capture goose stdin for streaming"))?;
188 tokio::spawn(async move {
189 let _ = stdin.write_all(prompt.as_bytes()).await;
190 let _ = stdin.shutdown().await;
191 });
192
193 let stdout = child
194 .stdout
195 .take()
196 .ok_or_else(|| RunnerError::internal("Failed to capture goose stdout for streaming"))?;
197
198 let stderr_task = tokio::spawn(read_stderr_capped(
199 child.stderr.take(),
200 MAX_STREAMING_STDERR_BYTES,
201 ));
202
203 let reader = BufReader::new(stdout);
204 let lines = LinesStream::new(reader.lines());
205
206 let stream = lines.map(move |line_result: Result<String, io::Error>| {
207 let line = line_result
208 .map_err(|e| RunnerError::internal(format!("Error reading goose stream: {e}")))?;
209
210 if line.trim().is_empty() {
211 return Ok(StreamChunk {
212 delta: String::new(),
213 is_final: false,
214 finish_reason: None,
215 });
216 }
217
218 let value: serde_json::Value = serde_json::from_str(&line)
219 .map_err(|e| RunnerError::internal(format!("Invalid JSON in goose stream: {e}")))?;
220
221 let chunk_type = value.get("type").and_then(|v| v.as_str()).unwrap_or("");
222 match chunk_type {
223 "message" => {
224 let mut delta = String::new();
226 if let Some(msg) = value.get("message") {
227 if let Some(parts) = msg.get("content").and_then(|v| v.as_array()) {
228 for part in parts {
229 let pt = part.get("type").and_then(|v| v.as_str()).unwrap_or("");
230 if pt == "text" {
231 if let Some(t) = part.get("text").and_then(|v| v.as_str()) {
232 delta.push_str(t);
233 }
234 }
235 }
236 }
237 }
238 Ok(StreamChunk {
239 delta,
240 is_final: false,
241 finish_reason: None,
242 })
243 }
244 "complete" => Ok(StreamChunk {
245 delta: String::new(),
246 is_final: true,
247 finish_reason: Some("stop".to_owned()),
248 }),
249 _ => Ok(StreamChunk {
250 delta: String::new(),
251 is_final: false,
252 finish_reason: None,
253 }),
254 }
255 });
256
257 Ok(Box::pin(GuardedStream::new(stream, child, stderr_task)))
258 }
259}
260
261#[cfg(test)]
262mod tests {
263 use super::*;
264 use std::path::PathBuf;
265
266 #[test]
267 fn test_parse_json_response_basic() {
268 let json = br#"{"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]},{"role":"assistant","content":[{"type":"text","text":"hello"}]}],"metadata":{"total_tokens":null,"status":"completed"}}"#;
269 let resp = GooseCliRunner::parse_json_response(json).unwrap();
270 assert_eq!(resp.content, "hello");
271 assert!(resp.usage.is_none());
272 }
273
274 #[test]
275 fn test_parse_json_response_multi_content() {
276 let json = br#"{"messages":[{"role":"assistant","content":[{"type":"text","text":"part1"},{"type":"text","text":"part2"}]}],"metadata":{"status":"completed"}}"#;
277 let resp = GooseCliRunner::parse_json_response(json).unwrap();
278 assert_eq!(resp.content, "part1part2");
279 }
280
281 #[test]
282 fn test_parse_json_response_skips_user_messages() {
283 let json = br#"{"messages":[{"role":"user","content":[{"type":"text","text":"ignored"}]},{"role":"assistant","content":[{"type":"text","text":"kept"}]}],"metadata":{}}"#;
284 let resp = GooseCliRunner::parse_json_response(json).unwrap();
285 assert_eq!(resp.content, "kept");
286 }
287
288 #[test]
289 fn test_default_model() {
290 let config = RunnerConfig::new(PathBuf::from("goose"));
291 let runner = GooseCliRunner::new(config);
292 assert_eq!(runner.default_model(), "auto");
293 }
294
295 #[test]
296 fn test_capabilities() {
297 let config = RunnerConfig::new(PathBuf::from("goose"));
298 let runner = GooseCliRunner::new(config);
299 assert!(runner.capabilities().supports_streaming());
300 }
301
302 #[test]
303 fn test_name_and_display() {
304 let config = RunnerConfig::new(PathBuf::from("goose"));
305 let runner = GooseCliRunner::new(config);
306 assert_eq!(runner.name(), "goose");
307 assert_eq!(runner.display_name(), "Goose CLI");
308 }
309}