Skip to main content

dot/provider/
copilot.rs

1use std::{collections::HashMap, future::Future, pin::Pin};
2
3use anyhow::Context;
4use async_openai::{
5    Client,
6    config::OpenAIConfig,
7    types::{
8        ChatCompletionMessageToolCall, ChatCompletionRequestAssistantMessage,
9        ChatCompletionRequestAssistantMessageContent, ChatCompletionRequestMessage,
10        ChatCompletionRequestMessageContentPartImage, ChatCompletionRequestMessageContentPartText,
11        ChatCompletionRequestSystemMessage, ChatCompletionRequestSystemMessageContent,
12        ChatCompletionRequestToolMessage, ChatCompletionRequestToolMessageContent,
13        ChatCompletionRequestUserMessage, ChatCompletionRequestUserMessageContent,
14        ChatCompletionRequestUserMessageContentPart, ChatCompletionTool, ChatCompletionToolType,
15        CreateChatCompletionRequest, FinishReason, FunctionCall, FunctionObject, ImageUrl,
16    },
17};
18use futures::StreamExt;
19use tokio::sync::mpsc;
20use tracing::warn;
21
22use crate::provider::{
23    ContentBlock, Message, Provider, Role, StopReason, StreamEvent, StreamEventType,
24    ToolDefinition, Usage,
25};
26
27const COPILOT_API_BASE: &str = "https://api.githubcopilot.com";
28const DEFAULT_MODEL: &str = "gpt-4o";
29
30pub struct CopilotProvider {
31    client: Client<OpenAIConfig>,
32    model: String,
33    cached_models: std::sync::Mutex<Option<Vec<String>>>,
34    context_windows: std::sync::Mutex<HashMap<String, u32>>,
35}
36
37impl CopilotProvider {
38    pub fn new(token: impl Into<String>, model: impl Into<String>) -> Self {
39        let config = OpenAIConfig::new()
40            .with_api_key(token)
41            .with_api_base(COPILOT_API_BASE);
42        Self {
43            client: Client::with_config(config),
44            model: model.into(),
45            cached_models: std::sync::Mutex::new(None),
46            context_windows: std::sync::Mutex::new(HashMap::new()),
47        }
48    }
49}
50
51fn known_context_window(model: &str) -> u32 {
52    if model.starts_with("o1") || model.starts_with("o3") || model.starts_with("o4") {
53        return 200_000;
54    }
55    if model.starts_with("gpt-4o") || model.starts_with("gpt-4-turbo") {
56        return 128_000;
57    }
58    if model.contains("claude") {
59        return 200_000;
60    }
61    if model.starts_with("gpt-4") {
62        return 8_192;
63    }
64    0
65}
66
67#[derive(Default)]
68struct ToolCallAccum {
69    id: String,
70    name: String,
71    arguments: String,
72    started: bool,
73}
74
75fn convert_messages(
76    messages: &[Message],
77    system: Option<&str>,
78) -> anyhow::Result<Vec<ChatCompletionRequestMessage>> {
79    let mut result: Vec<ChatCompletionRequestMessage> = Vec::new();
80
81    if let Some(sys) = system {
82        result.push(ChatCompletionRequestMessage::System(
83            ChatCompletionRequestSystemMessage {
84                content: ChatCompletionRequestSystemMessageContent::Text(sys.to_string()),
85                name: None,
86            },
87        ));
88    }
89
90    for msg in messages {
91        match msg.role {
92            Role::System => {
93                let text = extract_text(&msg.content);
94                result.push(ChatCompletionRequestMessage::System(
95                    ChatCompletionRequestSystemMessage {
96                        content: ChatCompletionRequestSystemMessageContent::Text(text),
97                        name: None,
98                    },
99                ));
100            }
101            Role::User => {
102                let mut tool_results: Vec<(String, String)> = Vec::new();
103                let mut texts: Vec<String> = Vec::new();
104                let mut images: Vec<(String, String)> = Vec::new();
105
106                for block in &msg.content {
107                    match block {
108                        ContentBlock::Text(t) => texts.push(t.clone()),
109                        ContentBlock::Image { media_type, data } => {
110                            images.push((media_type.clone(), data.clone()));
111                        }
112                        ContentBlock::ToolResult {
113                            tool_use_id,
114                            content,
115                            ..
116                        } => {
117                            tool_results.push((tool_use_id.clone(), content.clone()));
118                        }
119                        _ => {}
120                    }
121                }
122
123                for (id, content) in tool_results {
124                    result.push(ChatCompletionRequestMessage::Tool(
125                        ChatCompletionRequestToolMessage {
126                            content: ChatCompletionRequestToolMessageContent::Text(content),
127                            tool_call_id: id,
128                        },
129                    ));
130                }
131
132                if !images.is_empty() {
133                    let mut parts: Vec<ChatCompletionRequestUserMessageContentPart> = Vec::new();
134                    if !texts.is_empty() {
135                        parts.push(ChatCompletionRequestUserMessageContentPart::Text(
136                            ChatCompletionRequestMessageContentPartText {
137                                text: texts.join("\n"),
138                            },
139                        ));
140                    }
141                    for (media_type, data) in images {
142                        parts.push(ChatCompletionRequestUserMessageContentPart::ImageUrl(
143                            ChatCompletionRequestMessageContentPartImage {
144                                image_url: ImageUrl {
145                                    url: format!("data:{};base64,{}", media_type, data),
146                                    detail: None,
147                                },
148                            },
149                        ));
150                    }
151                    result.push(ChatCompletionRequestMessage::User(
152                        ChatCompletionRequestUserMessage {
153                            content: ChatCompletionRequestUserMessageContent::Array(parts),
154                            name: None,
155                        },
156                    ));
157                } else if !texts.is_empty() {
158                    result.push(ChatCompletionRequestMessage::User(
159                        ChatCompletionRequestUserMessage {
160                            content: ChatCompletionRequestUserMessageContent::Text(
161                                texts.join("\n"),
162                            ),
163                            name: None,
164                        },
165                    ));
166                }
167            }
168            Role::Assistant => {
169                let mut text_parts: Vec<String> = Vec::new();
170                let mut tool_calls: Vec<ChatCompletionMessageToolCall> = Vec::new();
171
172                for block in &msg.content {
173                    match block {
174                        ContentBlock::Text(t) => text_parts.push(t.clone()),
175                        ContentBlock::ToolUse { id, name, input } => {
176                            tool_calls.push(ChatCompletionMessageToolCall {
177                                id: id.clone(),
178                                r#type: ChatCompletionToolType::Function,
179                                function: FunctionCall {
180                                    name: name.clone(),
181                                    arguments: serde_json::to_string(input).unwrap_or_default(),
182                                },
183                            });
184                        }
185                        _ => {}
186                    }
187                }
188
189                let content = if text_parts.is_empty() {
190                    None
191                } else {
192                    Some(ChatCompletionRequestAssistantMessageContent::Text(
193                        text_parts.join("\n"),
194                    ))
195                };
196
197                result.push(ChatCompletionRequestMessage::Assistant(
198                    ChatCompletionRequestAssistantMessage {
199                        content,
200                        name: None,
201                        tool_calls: if tool_calls.is_empty() {
202                            None
203                        } else {
204                            Some(tool_calls)
205                        },
206                        refusal: None,
207                        ..Default::default()
208                    },
209                ));
210            }
211        }
212    }
213
214    Ok(result)
215}
216
217fn extract_text(blocks: &[ContentBlock]) -> String {
218    blocks
219        .iter()
220        .filter_map(|b| {
221            if let ContentBlock::Text(t) = b {
222                Some(t.as_str())
223            } else {
224                None
225            }
226        })
227        .collect::<Vec<_>>()
228        .join("\n")
229}
230
231fn convert_tools(tools: &[ToolDefinition]) -> Vec<ChatCompletionTool> {
232    tools
233        .iter()
234        .map(|t| ChatCompletionTool {
235            r#type: ChatCompletionToolType::Function,
236            function: FunctionObject {
237                name: t.name.clone(),
238                description: Some(t.description.clone()),
239                parameters: Some(t.input_schema.clone()),
240                strict: None,
241            },
242        })
243        .collect()
244}
245
246fn map_finish(reason: &FinishReason) -> StopReason {
247    match reason {
248        FinishReason::Stop => StopReason::EndTurn,
249        FinishReason::Length => StopReason::MaxTokens,
250        FinishReason::ToolCalls | FinishReason::FunctionCall => StopReason::ToolUse,
251        FinishReason::ContentFilter => StopReason::StopSequence,
252    }
253}
254
255impl Provider for CopilotProvider {
256    fn name(&self) -> &str {
257        "copilot"
258    }
259
260    fn model(&self) -> &str {
261        &self.model
262    }
263
264    fn set_model(&mut self, model: String) {
265        self.model = model;
266    }
267
268    fn available_models(&self) -> Vec<String> {
269        let cache = self.cached_models.lock().unwrap();
270        cache.clone().unwrap_or_default()
271    }
272
273    fn context_window(&self) -> u32 {
274        let cw = self.context_windows.lock().unwrap();
275        cw.get(&self.model).copied().unwrap_or(0)
276    }
277
278    fn fetch_context_window(
279        &self,
280    ) -> Pin<Box<dyn Future<Output = anyhow::Result<u32>> + Send + '_>> {
281        Box::pin(async move {
282            {
283                let cw = self.context_windows.lock().unwrap();
284                if let Some(&v) = cw.get(&self.model) {
285                    return Ok(v);
286                }
287            }
288            let v = known_context_window(&self.model);
289            if v > 0 {
290                let mut cw = self.context_windows.lock().unwrap();
291                cw.insert(self.model.clone(), v);
292            }
293            Ok(v)
294        })
295    }
296
297    fn fetch_models(
298        &self,
299    ) -> Pin<Box<dyn Future<Output = anyhow::Result<Vec<String>>> + Send + '_>> {
300        let client = self.client.clone();
301        Box::pin(async move {
302            {
303                let cache = self.cached_models.lock().unwrap();
304                if let Some(ref models) = *cache {
305                    return Ok(models.clone());
306                }
307            }
308
309            let resp = client.models().list().await;
310
311            match resp {
312                Ok(list) => {
313                    let mut models: Vec<String> = list.data.into_iter().map(|m| m.id).collect();
314                    models.sort();
315                    models.dedup();
316
317                    if models.is_empty() {
318                        models.push(DEFAULT_MODEL.to_string());
319                    }
320
321                    let mut cache = self.cached_models.lock().unwrap();
322                    *cache = Some(models.clone());
323                    Ok(models)
324                }
325                Err(e) => {
326                    warn!("Failed to fetch Copilot models, using defaults: {e}");
327                    let defaults = vec![DEFAULT_MODEL.to_string()];
328                    let mut cache = self.cached_models.lock().unwrap();
329                    *cache = Some(defaults.clone());
330                    Ok(defaults)
331                }
332            }
333        })
334    }
335
336    fn stream(
337        &self,
338        messages: &[Message],
339        system: Option<&str>,
340        tools: &[ToolDefinition],
341        max_tokens: u32,
342        thinking_budget: u32,
343    ) -> Pin<
344        Box<dyn Future<Output = anyhow::Result<mpsc::UnboundedReceiver<StreamEvent>>> + Send + '_>,
345    > {
346        self.stream_with_model(
347            &self.model,
348            messages,
349            system,
350            tools,
351            max_tokens,
352            thinking_budget,
353        )
354    }
355
356    fn stream_with_model(
357        &self,
358        model: &str,
359        messages: &[Message],
360        system: Option<&str>,
361        tools: &[ToolDefinition],
362        max_tokens: u32,
363        _thinking_budget: u32,
364    ) -> Pin<
365        Box<dyn Future<Output = anyhow::Result<mpsc::UnboundedReceiver<StreamEvent>>> + Send + '_>,
366    > {
367        let messages = messages.to_vec();
368        let system = system.map(String::from);
369        let tools = tools.to_vec();
370        let model = model.to_string();
371        let client = self.client.clone();
372
373        Box::pin(async move {
374            let converted = convert_messages(&messages, system.as_deref())
375                .context("converting messages for Copilot")?;
376            let converted_tools = convert_tools(&tools);
377
378            let request = CreateChatCompletionRequest {
379                model: model.clone(),
380                messages: converted,
381                max_completion_tokens: Some(max_tokens),
382                stream: Some(true),
383                tools: if converted_tools.is_empty() {
384                    None
385                } else {
386                    Some(converted_tools)
387                },
388                temperature: Some(1.0),
389                ..Default::default()
390            };
391
392            let mut stream = client
393                .chat()
394                .create_stream(request)
395                .await
396                .context("creating Copilot stream")?;
397
398            let (tx, rx) = mpsc::unbounded_channel::<StreamEvent>();
399
400            tokio::spawn(async move {
401                let mut tool_accum: HashMap<u32, ToolCallAccum> = HashMap::new();
402                let mut total_output: u32 = 0;
403                let mut stop: Option<StopReason> = None;
404
405                let _ = tx.send(StreamEvent {
406                    event_type: StreamEventType::MessageStart,
407                });
408
409                while let Some(result) = stream.next().await {
410                    match result {
411                        Err(e) => {
412                            warn!("Copilot stream error: {e}");
413                            let _ = tx.send(StreamEvent {
414                                event_type: StreamEventType::Error(e.to_string()),
415                            });
416                            return;
417                        }
418                        Ok(response) => {
419                            if let Some(usage) = response.usage {
420                                total_output = usage.completion_tokens;
421                            }
422
423                            for choice in response.choices {
424                                if let Some(reason) = &choice.finish_reason {
425                                    stop = Some(map_finish(reason));
426                                    if matches!(
427                                        reason,
428                                        FinishReason::ToolCalls | FinishReason::FunctionCall
429                                    ) {
430                                        for accum in tool_accum.values() {
431                                            if accum.started {
432                                                let _ = tx.send(StreamEvent {
433                                                    event_type: StreamEventType::ToolUseEnd,
434                                                });
435                                            }
436                                        }
437                                        tool_accum.clear();
438                                    }
439                                }
440
441                                let delta = choice.delta;
442
443                                if let Some(content) = delta.content
444                                    && !content.is_empty()
445                                {
446                                    let _ = tx.send(StreamEvent {
447                                        event_type: StreamEventType::TextDelta(content),
448                                    });
449                                }
450
451                                if let Some(chunks) = delta.tool_calls {
452                                    for chunk in chunks {
453                                        let idx = chunk.index;
454                                        let entry = tool_accum.entry(idx).or_default();
455
456                                        if let Some(id) = chunk.id
457                                            && !id.is_empty()
458                                        {
459                                            entry.id = id;
460                                        }
461
462                                        if let Some(func) = chunk.function {
463                                            if let Some(name) = func.name
464                                                && !name.is_empty()
465                                            {
466                                                entry.name = name;
467                                            }
468
469                                            if !entry.started
470                                                && !entry.id.is_empty()
471                                                && !entry.name.is_empty()
472                                            {
473                                                let _ = tx.send(StreamEvent {
474                                                    event_type: StreamEventType::ToolUseStart {
475                                                        id: entry.id.clone(),
476                                                        name: entry.name.clone(),
477                                                    },
478                                                });
479                                                entry.started = true;
480                                            }
481
482                                            if let Some(args) = func.arguments
483                                                && !args.is_empty()
484                                            {
485                                                entry.arguments.push_str(&args);
486                                                let _ = tx.send(StreamEvent {
487                                                    event_type: StreamEventType::ToolUseInputDelta(
488                                                        args,
489                                                    ),
490                                                });
491                                            }
492                                        }
493                                    }
494                                }
495                            }
496                        }
497                    }
498                }
499
500                for accum in tool_accum.values() {
501                    if accum.started {
502                        let _ = tx.send(StreamEvent {
503                            event_type: StreamEventType::ToolUseEnd,
504                        });
505                    }
506                }
507
508                let _ = tx.send(StreamEvent {
509                    event_type: StreamEventType::MessageEnd {
510                        stop_reason: stop.unwrap_or(StopReason::EndTurn),
511                        usage: Usage {
512                            input_tokens: 0,
513                            output_tokens: total_output,
514                            cache_read_tokens: 0,
515                            cache_write_tokens: 0,
516                        },
517                    },
518                });
519            });
520
521            Ok(rx)
522        })
523    }
524}