1use std::{collections::HashMap, future::Future, pin::Pin};
2
3use anyhow::Context;
4use async_openai::{
5 Client,
6 config::OpenAIConfig,
7 types::{
8 ChatCompletionMessageToolCall, ChatCompletionRequestAssistantMessage,
9 ChatCompletionRequestAssistantMessageContent, ChatCompletionRequestMessage,
10 ChatCompletionRequestMessageContentPartImage, ChatCompletionRequestMessageContentPartText,
11 ChatCompletionRequestSystemMessage, ChatCompletionRequestSystemMessageContent,
12 ChatCompletionRequestToolMessage, ChatCompletionRequestToolMessageContent,
13 ChatCompletionRequestUserMessage, ChatCompletionRequestUserMessageContent,
14 ChatCompletionRequestUserMessageContentPart, ChatCompletionTool, ChatCompletionToolType,
15 CreateChatCompletionRequest, FinishReason, FunctionCall, FunctionObject, ImageUrl,
16 },
17};
18use futures::StreamExt;
19use tokio::sync::mpsc;
20use tracing::warn;
21
22use crate::provider::{
23 ContentBlock, Message, Provider, Role, StopReason, StreamEvent, StreamEventType,
24 ToolDefinition, Usage,
25};
26
27const COPILOT_API_BASE: &str = "https://api.githubcopilot.com";
28const DEFAULT_MODEL: &str = "gpt-4o";
29
30pub struct CopilotProvider {
31 client: Client<OpenAIConfig>,
32 model: String,
33 cached_models: std::sync::Mutex<Option<Vec<String>>>,
34 context_windows: std::sync::Mutex<HashMap<String, u32>>,
35}
36
37impl CopilotProvider {
38 pub fn new(token: impl Into<String>, model: impl Into<String>) -> Self {
39 let config = OpenAIConfig::new()
40 .with_api_key(token)
41 .with_api_base(COPILOT_API_BASE);
42 Self {
43 client: Client::with_config(config),
44 model: model.into(),
45 cached_models: std::sync::Mutex::new(None),
46 context_windows: std::sync::Mutex::new(HashMap::new()),
47 }
48 }
49}
50
51fn known_context_window(model: &str) -> u32 {
52 if model.starts_with("o1") || model.starts_with("o3") || model.starts_with("o4") {
53 return 200_000;
54 }
55 if model.starts_with("gpt-4o") || model.starts_with("gpt-4-turbo") {
56 return 128_000;
57 }
58 if model.contains("claude") {
59 return 200_000;
60 }
61 if model.starts_with("gpt-4") {
62 return 8_192;
63 }
64 0
65}
66
67#[derive(Default)]
68struct ToolCallAccum {
69 id: String,
70 name: String,
71 arguments: String,
72 started: bool,
73}
74
75fn convert_messages(
76 messages: &[Message],
77 system: Option<&str>,
78) -> anyhow::Result<Vec<ChatCompletionRequestMessage>> {
79 let mut result: Vec<ChatCompletionRequestMessage> = Vec::new();
80
81 if let Some(sys) = system {
82 result.push(ChatCompletionRequestMessage::System(
83 ChatCompletionRequestSystemMessage {
84 content: ChatCompletionRequestSystemMessageContent::Text(sys.to_string()),
85 name: None,
86 },
87 ));
88 }
89
90 for msg in messages {
91 match msg.role {
92 Role::System => {
93 let text = extract_text(&msg.content);
94 result.push(ChatCompletionRequestMessage::System(
95 ChatCompletionRequestSystemMessage {
96 content: ChatCompletionRequestSystemMessageContent::Text(text),
97 name: None,
98 },
99 ));
100 }
101 Role::User => {
102 let mut tool_results: Vec<(String, String)> = Vec::new();
103 let mut texts: Vec<String> = Vec::new();
104 let mut images: Vec<(String, String)> = Vec::new();
105
106 for block in &msg.content {
107 match block {
108 ContentBlock::Text(t) => texts.push(t.clone()),
109 ContentBlock::Image { media_type, data } => {
110 images.push((media_type.clone(), data.clone()));
111 }
112 ContentBlock::ToolResult {
113 tool_use_id,
114 content,
115 ..
116 } => {
117 tool_results.push((tool_use_id.clone(), content.clone()));
118 }
119 _ => {}
120 }
121 }
122
123 for (id, content) in tool_results {
124 result.push(ChatCompletionRequestMessage::Tool(
125 ChatCompletionRequestToolMessage {
126 content: ChatCompletionRequestToolMessageContent::Text(content),
127 tool_call_id: id,
128 },
129 ));
130 }
131
132 if !images.is_empty() {
133 let mut parts: Vec<ChatCompletionRequestUserMessageContentPart> = Vec::new();
134 if !texts.is_empty() {
135 parts.push(ChatCompletionRequestUserMessageContentPart::Text(
136 ChatCompletionRequestMessageContentPartText {
137 text: texts.join("\n"),
138 },
139 ));
140 }
141 for (media_type, data) in images {
142 parts.push(ChatCompletionRequestUserMessageContentPart::ImageUrl(
143 ChatCompletionRequestMessageContentPartImage {
144 image_url: ImageUrl {
145 url: format!("data:{};base64,{}", media_type, data),
146 detail: None,
147 },
148 },
149 ));
150 }
151 result.push(ChatCompletionRequestMessage::User(
152 ChatCompletionRequestUserMessage {
153 content: ChatCompletionRequestUserMessageContent::Array(parts),
154 name: None,
155 },
156 ));
157 } else if !texts.is_empty() {
158 result.push(ChatCompletionRequestMessage::User(
159 ChatCompletionRequestUserMessage {
160 content: ChatCompletionRequestUserMessageContent::Text(
161 texts.join("\n"),
162 ),
163 name: None,
164 },
165 ));
166 }
167 }
168 Role::Assistant => {
169 let mut text_parts: Vec<String> = Vec::new();
170 let mut tool_calls: Vec<ChatCompletionMessageToolCall> = Vec::new();
171
172 for block in &msg.content {
173 match block {
174 ContentBlock::Text(t) => text_parts.push(t.clone()),
175 ContentBlock::ToolUse { id, name, input } => {
176 tool_calls.push(ChatCompletionMessageToolCall {
177 id: id.clone(),
178 r#type: ChatCompletionToolType::Function,
179 function: FunctionCall {
180 name: name.clone(),
181 arguments: serde_json::to_string(input).unwrap_or_default(),
182 },
183 });
184 }
185 _ => {}
186 }
187 }
188
189 let content = if text_parts.is_empty() {
190 None
191 } else {
192 Some(ChatCompletionRequestAssistantMessageContent::Text(
193 text_parts.join("\n"),
194 ))
195 };
196
197 result.push(ChatCompletionRequestMessage::Assistant(
198 ChatCompletionRequestAssistantMessage {
199 content,
200 name: None,
201 tool_calls: if tool_calls.is_empty() {
202 None
203 } else {
204 Some(tool_calls)
205 },
206 refusal: None,
207 ..Default::default()
208 },
209 ));
210 }
211 }
212 }
213
214 Ok(result)
215}
216
217fn extract_text(blocks: &[ContentBlock]) -> String {
218 blocks
219 .iter()
220 .filter_map(|b| {
221 if let ContentBlock::Text(t) = b {
222 Some(t.as_str())
223 } else {
224 None
225 }
226 })
227 .collect::<Vec<_>>()
228 .join("\n")
229}
230
231fn convert_tools(tools: &[ToolDefinition]) -> Vec<ChatCompletionTool> {
232 tools
233 .iter()
234 .map(|t| ChatCompletionTool {
235 r#type: ChatCompletionToolType::Function,
236 function: FunctionObject {
237 name: t.name.clone(),
238 description: Some(t.description.clone()),
239 parameters: Some(t.input_schema.clone()),
240 strict: None,
241 },
242 })
243 .collect()
244}
245
246fn map_finish(reason: &FinishReason) -> StopReason {
247 match reason {
248 FinishReason::Stop => StopReason::EndTurn,
249 FinishReason::Length => StopReason::MaxTokens,
250 FinishReason::ToolCalls | FinishReason::FunctionCall => StopReason::ToolUse,
251 FinishReason::ContentFilter => StopReason::StopSequence,
252 }
253}
254
255impl Provider for CopilotProvider {
256 fn name(&self) -> &str {
257 "copilot"
258 }
259
260 fn model(&self) -> &str {
261 &self.model
262 }
263
264 fn set_model(&mut self, model: String) {
265 self.model = model;
266 }
267
268 fn available_models(&self) -> Vec<String> {
269 let cache = self.cached_models.lock().unwrap();
270 cache.clone().unwrap_or_default()
271 }
272
273 fn context_window(&self) -> u32 {
274 let cw = self.context_windows.lock().unwrap();
275 cw.get(&self.model).copied().unwrap_or(0)
276 }
277
278 fn fetch_context_window(
279 &self,
280 ) -> Pin<Box<dyn Future<Output = anyhow::Result<u32>> + Send + '_>> {
281 Box::pin(async move {
282 {
283 let cw = self.context_windows.lock().unwrap();
284 if let Some(&v) = cw.get(&self.model) {
285 return Ok(v);
286 }
287 }
288 let v = known_context_window(&self.model);
289 if v > 0 {
290 let mut cw = self.context_windows.lock().unwrap();
291 cw.insert(self.model.clone(), v);
292 }
293 Ok(v)
294 })
295 }
296
297 fn fetch_models(
298 &self,
299 ) -> Pin<Box<dyn Future<Output = anyhow::Result<Vec<String>>> + Send + '_>> {
300 let client = self.client.clone();
301 Box::pin(async move {
302 {
303 let cache = self.cached_models.lock().unwrap();
304 if let Some(ref models) = *cache {
305 return Ok(models.clone());
306 }
307 }
308
309 let resp = client.models().list().await;
310
311 match resp {
312 Ok(list) => {
313 let mut models: Vec<String> = list.data.into_iter().map(|m| m.id).collect();
314 models.sort();
315 models.dedup();
316
317 if models.is_empty() {
318 models.push(DEFAULT_MODEL.to_string());
319 }
320
321 let mut cache = self.cached_models.lock().unwrap();
322 *cache = Some(models.clone());
323 Ok(models)
324 }
325 Err(e) => {
326 warn!("Failed to fetch Copilot models, using defaults: {e}");
327 let defaults = vec![DEFAULT_MODEL.to_string()];
328 let mut cache = self.cached_models.lock().unwrap();
329 *cache = Some(defaults.clone());
330 Ok(defaults)
331 }
332 }
333 })
334 }
335
336 fn stream(
337 &self,
338 messages: &[Message],
339 system: Option<&str>,
340 tools: &[ToolDefinition],
341 max_tokens: u32,
342 thinking_budget: u32,
343 ) -> Pin<
344 Box<dyn Future<Output = anyhow::Result<mpsc::UnboundedReceiver<StreamEvent>>> + Send + '_>,
345 > {
346 self.stream_with_model(
347 &self.model,
348 messages,
349 system,
350 tools,
351 max_tokens,
352 thinking_budget,
353 )
354 }
355
356 fn stream_with_model(
357 &self,
358 model: &str,
359 messages: &[Message],
360 system: Option<&str>,
361 tools: &[ToolDefinition],
362 max_tokens: u32,
363 _thinking_budget: u32,
364 ) -> Pin<
365 Box<dyn Future<Output = anyhow::Result<mpsc::UnboundedReceiver<StreamEvent>>> + Send + '_>,
366 > {
367 let messages = messages.to_vec();
368 let system = system.map(String::from);
369 let tools = tools.to_vec();
370 let model = model.to_string();
371 let client = self.client.clone();
372
373 Box::pin(async move {
374 let converted = convert_messages(&messages, system.as_deref())
375 .context("converting messages for Copilot")?;
376 let converted_tools = convert_tools(&tools);
377
378 let request = CreateChatCompletionRequest {
379 model: model.clone(),
380 messages: converted,
381 max_completion_tokens: Some(max_tokens),
382 stream: Some(true),
383 tools: if converted_tools.is_empty() {
384 None
385 } else {
386 Some(converted_tools)
387 },
388 temperature: Some(1.0),
389 ..Default::default()
390 };
391
392 let mut stream = client
393 .chat()
394 .create_stream(request)
395 .await
396 .context("creating Copilot stream")?;
397
398 let (tx, rx) = mpsc::unbounded_channel::<StreamEvent>();
399
400 tokio::spawn(async move {
401 let mut tool_accum: HashMap<u32, ToolCallAccum> = HashMap::new();
402 let mut total_output: u32 = 0;
403 let mut stop: Option<StopReason> = None;
404
405 let _ = tx.send(StreamEvent {
406 event_type: StreamEventType::MessageStart,
407 });
408
409 while let Some(result) = stream.next().await {
410 match result {
411 Err(e) => {
412 warn!("Copilot stream error: {e}");
413 let _ = tx.send(StreamEvent {
414 event_type: StreamEventType::Error(e.to_string()),
415 });
416 return;
417 }
418 Ok(response) => {
419 if let Some(usage) = response.usage {
420 total_output = usage.completion_tokens;
421 }
422
423 for choice in response.choices {
424 if let Some(reason) = &choice.finish_reason {
425 stop = Some(map_finish(reason));
426 if matches!(
427 reason,
428 FinishReason::ToolCalls | FinishReason::FunctionCall
429 ) {
430 for accum in tool_accum.values() {
431 if accum.started {
432 let _ = tx.send(StreamEvent {
433 event_type: StreamEventType::ToolUseEnd,
434 });
435 }
436 }
437 tool_accum.clear();
438 }
439 }
440
441 let delta = choice.delta;
442
443 if let Some(content) = delta.content
444 && !content.is_empty()
445 {
446 let _ = tx.send(StreamEvent {
447 event_type: StreamEventType::TextDelta(content),
448 });
449 }
450
451 if let Some(chunks) = delta.tool_calls {
452 for chunk in chunks {
453 let idx = chunk.index;
454 let entry = tool_accum.entry(idx).or_default();
455
456 if let Some(id) = chunk.id
457 && !id.is_empty()
458 {
459 entry.id = id;
460 }
461
462 if let Some(func) = chunk.function {
463 if let Some(name) = func.name
464 && !name.is_empty()
465 {
466 entry.name = name;
467 }
468
469 if !entry.started
470 && !entry.id.is_empty()
471 && !entry.name.is_empty()
472 {
473 let _ = tx.send(StreamEvent {
474 event_type: StreamEventType::ToolUseStart {
475 id: entry.id.clone(),
476 name: entry.name.clone(),
477 },
478 });
479 entry.started = true;
480 }
481
482 if let Some(args) = func.arguments
483 && !args.is_empty()
484 {
485 entry.arguments.push_str(&args);
486 let _ = tx.send(StreamEvent {
487 event_type: StreamEventType::ToolUseInputDelta(
488 args,
489 ),
490 });
491 }
492 }
493 }
494 }
495 }
496 }
497 }
498 }
499
500 for accum in tool_accum.values() {
501 if accum.started {
502 let _ = tx.send(StreamEvent {
503 event_type: StreamEventType::ToolUseEnd,
504 });
505 }
506 }
507
508 let _ = tx.send(StreamEvent {
509 event_type: StreamEventType::MessageEnd {
510 stop_reason: stop.unwrap_or(StopReason::EndTurn),
511 usage: Usage {
512 input_tokens: 0,
513 output_tokens: total_output,
514 cache_read_tokens: 0,
515 cache_write_tokens: 0,
516 },
517 },
518 });
519 });
520
521 Ok(rx)
522 })
523 }
524}