1use super::{AiBackend, AiRequest, AiResponse};
2use anyhow::{Context, Result};
3use async_trait::async_trait;
4use tokio::process::Command;
5
6const DEFAULT_MODEL: &str = "gpt-4.1";
7
8pub struct CopilotBackend {
9 model: Option<String>,
10 debug: bool,
11}
12
13impl CopilotBackend {
14 pub fn new(model: Option<String>, debug: bool) -> Self {
15 Self { model, debug }
16 }
17}
18
19fn build_system_prompt(base: &str, json_schema: Option<&str>) -> String {
21 match json_schema {
22 Some(schema) => format!(
23 "{base}\n\n\
24 You MUST respond with valid JSON matching this schema:\n\
25 ```json\n{schema}\n```\n\n\
26 Respond ONLY with the JSON object, no markdown fences, no explanation."
27 ),
28 None => base.to_string(),
29 }
30}
31
32#[async_trait]
33impl AiBackend for CopilotBackend {
34 fn name(&self) -> &str {
35 "copilot"
36 }
37
38 async fn is_available(&self) -> bool {
39 Command::new("gh")
40 .args(["copilot", "--version"])
41 .output()
42 .await
43 .is_ok_and(|o| o.status.success())
44 }
45
46 async fn request(&self, req: &AiRequest) -> Result<AiResponse> {
47 let model = self.model.as_deref().unwrap_or(DEFAULT_MODEL);
48 let system = build_system_prompt(&req.system_prompt, req.json_schema.as_deref());
49
50 let mut cmd = Command::new("gh");
51 cmd.current_dir(&req.working_dir)
52 .arg("copilot")
53 .arg("-p")
54 .arg(&req.user_prompt)
55 .arg("-s")
56 .arg("--model")
57 .arg(model)
58 .arg("--allow-tool")
59 .arg("shell(git:*)")
60 .arg("--no-custom-instructions")
61 .arg("--system-prompt")
62 .arg(&system);
63
64 if self.debug {
65 eprintln!("[DEBUG] Calling gh copilot (model={model})");
66 }
67
68 let output = cmd.output().await.context("failed to run gh copilot")?;
69
70 let raw = String::from_utf8_lossy(&output.stdout).to_string();
71 let stderr = String::from_utf8_lossy(&output.stderr);
72
73 if self.debug {
74 eprintln!("[DEBUG] gh copilot exit code: {}", output.status);
75 eprintln!(
76 "[DEBUG] Raw response (first 500 chars): {}",
77 &raw[..raw.len().min(500)]
78 );
79 if !stderr.is_empty() {
80 eprintln!("[DEBUG] Stderr: {stderr}");
81 }
82 }
83
84 if !output.status.success() {
85 anyhow::bail!(crate::error::SrAiError::AiBackend(format!(
86 "gh copilot failed (exit {}): {}",
87 output.status,
88 stderr.trim()
89 )));
90 }
91
92 let text = extract_json(&raw).unwrap_or(raw);
94
95 Ok(AiResponse { text })
96 }
97}
98
99pub(crate) fn extract_json(raw: &str) -> Option<String> {
101 let trimmed = raw.trim();
102
103 if serde_json::from_str::<serde_json::Value>(trimmed).is_ok() {
105 return Some(trimmed.to_string());
106 }
107
108 if let Some(start) = trimmed.find("```json") {
110 let after = &trimmed[start + 7..];
111 if let Some(end) = after.find("```") {
112 let json_str = after[..end].trim();
113 if serde_json::from_str::<serde_json::Value>(json_str).is_ok() {
114 return Some(json_str.to_string());
115 }
116 }
117 }
118
119 if let Some(start) = trimmed.find("```") {
121 let after = &trimmed[start + 3..];
122 let after = if let Some(nl) = after.find('\n') {
123 &after[nl + 1..]
124 } else {
125 after
126 };
127 if let Some(end) = after.find("```") {
128 let json_str = after[..end].trim();
129 if serde_json::from_str::<serde_json::Value>(json_str).is_ok() {
130 return Some(json_str.to_string());
131 }
132 }
133 }
134
135 for (open, close) in [("{", "}"), ("[", "]")] {
137 if let Some(start) = trimmed.find(open)
138 && let Some(end) = trimmed.rfind(close)
139 && end > start
140 {
141 let candidate = &trimmed[start..=end];
142 if serde_json::from_str::<serde_json::Value>(candidate).is_ok() {
143 return Some(candidate.to_string());
144 }
145 }
146 }
147
148 None
149}
150
151#[cfg(test)]
152mod tests {
153 use super::*;
154
155 #[test]
158 fn extract_direct_json() {
159 let input = r#"{"commits": []}"#;
160 assert_eq!(extract_json(input), Some(input.to_string()));
161 }
162
163 #[test]
164 fn extract_from_json_fences() {
165 let input = "Here is the plan:\n```json\n{\"commits\": []}\n```\nDone.";
166 assert_eq!(extract_json(input), Some(r#"{"commits": []}"#.to_string()));
167 }
168
169 #[test]
170 fn extract_from_plain_fences() {
171 let input = "Result:\n```\n{\"commits\": [{\"order\": 1}]}\n```";
172 assert_eq!(
173 extract_json(input),
174 Some(r#"{"commits": [{"order": 1}]}"#.to_string())
175 );
176 }
177
178 #[test]
179 fn extract_from_surrounding_text() {
180 let input = "The result is {\"commits\": []} and that's it.";
181 assert_eq!(extract_json(input), Some(r#"{"commits": []}"#.to_string()));
182 }
183
184 #[test]
185 fn extract_array_json() {
186 let input = "Here: [1, 2, 3] done";
187 assert_eq!(extract_json(input), Some("[1, 2, 3]".to_string()));
188 }
189
190 #[test]
191 fn extract_returns_none_for_invalid() {
192 assert_eq!(extract_json("no json here"), None);
193 assert_eq!(extract_json(""), None);
194 assert_eq!(extract_json("{not valid json}"), None);
195 }
196
197 #[test]
198 fn extract_with_whitespace() {
199 let input = " \n {\"key\": \"value\"} \n ";
200 assert_eq!(extract_json(input), Some(r#"{"key": "value"}"#.to_string()));
201 }
202
203 #[test]
206 fn system_prompt_without_schema() {
207 let result = build_system_prompt("You are a commit assistant.", None);
208 assert_eq!(result, "You are a commit assistant.");
209 }
210
211 #[test]
212 fn system_prompt_with_schema() {
213 let schema = r#"{"type": "object"}"#;
214 let result = build_system_prompt("Base prompt.", Some(schema));
215 assert!(result.starts_with("Base prompt."));
216 assert!(result.contains("You MUST respond with valid JSON"));
217 assert!(result.contains(schema));
218 assert!(result.contains("no markdown fences"));
219 }
220
221 #[test]
224 fn backend_name() {
225 let backend = CopilotBackend::new(None, false);
226 assert_eq!(backend.name(), "copilot");
227 }
228
229 #[test]
230 fn default_model_constant() {
231 assert_eq!(DEFAULT_MODEL, "gpt-4.1");
232 }
233
234 #[test]
237 fn system_prompt_preserves_multiline_base() {
238 let base = "Line one.\nLine two.\nLine three.";
239 let result = build_system_prompt(base, None);
240 assert_eq!(result, base);
241
242 let with_schema = build_system_prompt(base, Some("{}"));
243 assert!(with_schema.starts_with(base));
244 }
245}