code_mesh_core/llm/
github_copilot.rs

1use async_trait::async_trait;
2use futures::Stream;
3use reqwest::Client;
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6use std::collections::HashMap;
7use std::pin::Pin;
8use std::task::{Context, Poll};
9use std::time::{SystemTime, UNIX_EPOCH};
10
11use super::{
12    FinishReason, GenerateOptions, GenerateResult, LanguageModel, Message, MessageContent,
13    MessagePart, MessageRole, StreamChunk, StreamOptions, ToolCall, ToolDefinition, Usage,
14};
15use crate::auth::{Auth, AuthCredentials};
16
17/// GitHub Copilot provider implementation
18pub struct GitHubCopilotProvider {
19    auth: Box<dyn Auth>,
20    client: Client,
21    models: HashMap<String, GitHubCopilotModel>,
22}
23
24#[derive(Debug, Clone)]
25pub struct GitHubCopilotModel {
26    pub id: String,
27    pub name: String,
28    pub max_tokens: u32,
29    pub supports_tools: bool,
30    pub supports_vision: bool,
31    pub supports_caching: bool,
32}
33
34#[derive(Debug, Serialize, Deserialize)]
35struct DeviceCodeRequest {
36    client_id: String,
37    scope: String,
38}
39
40#[derive(Debug, Deserialize)]
41struct DeviceCodeResponse {
42    device_code: String,
43    user_code: String,
44    verification_uri: String,
45    expires_in: u32,
46    interval: u32,
47}
48
49#[derive(Debug, Serialize)]
50struct AccessTokenRequest {
51    client_id: String,
52    device_code: String,
53    grant_type: String,
54}
55
56#[derive(Debug, Deserialize)]
57struct AccessTokenResponse {
58    access_token: Option<String>,
59    error: Option<String>,
60    error_description: Option<String>,
61}
62
63#[derive(Debug, Deserialize)]
64struct CopilotTokenResponse {
65    token: String,
66    expires_at: u64,
67    refresh_in: u64,
68    endpoints: CopilotEndpoints,
69}
70
71#[derive(Debug, Deserialize)]
72struct CopilotEndpoints {
73    api: String,
74}
75
76#[derive(Debug, Serialize)]
77struct CopilotRequest {
78    model: String,
79    messages: Vec<CopilotMessage>,
80    max_tokens: u32,
81    #[serde(skip_serializing_if = "Option::is_none")]
82    temperature: Option<f32>,
83    #[serde(skip_serializing_if = "Vec::is_empty")]
84    tools: Vec<CopilotTool>,
85    #[serde(skip_serializing_if = "Vec::is_empty")]
86    stop: Vec<String>,
87    #[serde(skip_serializing_if = "Option::is_none")]
88    stream: Option<bool>,
89}
90
91#[derive(Debug, Serialize, Deserialize)]
92struct CopilotMessage {
93    role: String,
94    content: CopilotContent,
95    #[serde(skip_serializing_if = "Option::is_none")]
96    name: Option<String>,
97    #[serde(skip_serializing_if = "Option::is_none")]
98    tool_calls: Option<Vec<CopilotToolCall>>,
99    #[serde(skip_serializing_if = "Option::is_none")]
100    tool_call_id: Option<String>,
101}
102
103#[derive(Debug, Serialize, Deserialize)]
104#[serde(untagged)]
105enum CopilotContent {
106    Text(String),
107    Parts(Vec<CopilotContentPart>),
108}
109
110#[derive(Debug, Serialize, Deserialize)]
111#[serde(tag = "type")]
112enum CopilotContentPart {
113    #[serde(rename = "text")]
114    Text { text: String },
115    #[serde(rename = "image_url")]
116    ImageUrl { image_url: CopilotImageUrl },
117}
118
119#[derive(Debug, Serialize, Deserialize)]
120struct CopilotImageUrl {
121    url: String,
122    detail: Option<String>,
123}
124
125#[derive(Debug, Serialize, Deserialize)]
126struct CopilotTool {
127    #[serde(rename = "type")]
128    tool_type: String,
129    function: CopilotFunction,
130}
131
132#[derive(Debug, Serialize, Deserialize)]
133struct CopilotFunction {
134    name: String,
135    description: String,
136    parameters: Value,
137}
138
139#[derive(Debug, Serialize, Deserialize)]
140struct CopilotToolCall {
141    id: String,
142    #[serde(rename = "type")]
143    tool_type: String,
144    function: CopilotFunctionCall,
145}
146
147#[derive(Debug, Serialize, Deserialize)]
148struct CopilotFunctionCall {
149    name: String,
150    arguments: String,
151}
152
153#[derive(Debug, Deserialize)]
154struct CopilotResponse {
155    choices: Vec<CopilotChoice>,
156    usage: CopilotUsage,
157}
158
159#[derive(Debug, Deserialize)]
160struct CopilotChoice {
161    message: CopilotMessage,
162    finish_reason: Option<String>,
163}
164
165#[derive(Debug, Deserialize)]
166struct CopilotUsage {
167    prompt_tokens: u32,
168    completion_tokens: u32,
169    total_tokens: u32,
170}
171
172impl GitHubCopilotProvider {
173    const CLIENT_ID: &'static str = "Iv1.b507a08c87ecfe98";
174    const DEVICE_CODE_URL: &'static str = "https://github.com/login/device/code";
175    const ACCESS_TOKEN_URL: &'static str = "https://github.com/login/oauth/access_token";
176    const COPILOT_TOKEN_URL: &'static str = "https://api.github.com/copilot_internal/v2/token";
177    const API_BASE: &'static str = "https://api.githubcopilot.com";
178    
179    pub fn new(auth: Box<dyn Auth>) -> Self {
180        let client = Client::new();
181        let models = Self::default_models();
182        
183        Self {
184            auth,
185            client,
186            models,
187        }
188    }
189    
190    fn default_models() -> HashMap<String, GitHubCopilotModel> {
191        let mut models = HashMap::new();
192        
193        models.insert(
194            "gpt-4o".to_string(),
195            GitHubCopilotModel {
196                id: "gpt-4o".to_string(),
197                name: "GPT-4o".to_string(),
198                max_tokens: 4096,
199                supports_tools: true,
200                supports_vision: true,
201                supports_caching: false,
202            },
203        );
204        
205        models.insert(
206            "gpt-4o-mini".to_string(),
207            GitHubCopilotModel {
208                id: "gpt-4o-mini".to_string(),
209                name: "GPT-4o Mini".to_string(),
210                max_tokens: 4096,
211                supports_tools: true,
212                supports_vision: true,
213                supports_caching: false,
214            },
215        );
216        
217        models.insert(
218            "o1-preview".to_string(),
219            GitHubCopilotModel {
220                id: "o1-preview".to_string(),
221                name: "OpenAI o1 Preview".to_string(),
222                max_tokens: 32768,
223                supports_tools: false,
224                supports_vision: false,
225                supports_caching: false,
226            },
227        );
228        
229        models
230    }
231    
232    /// Start device code flow for authentication
233    pub async fn start_device_flow() -> crate::Result<DeviceCodeResponse> {
234        let client = Client::new();
235        let request = DeviceCodeRequest {
236            client_id: Self::CLIENT_ID.to_string(),
237            scope: "read:user".to_string(),
238        };
239        
240        let response = client
241            .post(Self::DEVICE_CODE_URL)
242            .header("Accept", "application/json")
243            .header("Content-Type", "application/json")
244            .header("User-Agent", "GitHubCopilotChat/0.26.7")
245            .json(&request)
246            .send()
247            .await
248            .map_err(|e| crate::Error::Other(anyhow::anyhow!("Device code request failed: {}", e)))?;
249            
250        if !response.status().is_success() {
251            return Err(crate::Error::Other(anyhow::anyhow!(
252                "Device code request failed with status: {}",
253                response.status()
254            )));
255        }
256        
257        let device_response: DeviceCodeResponse = response
258            .json()
259            .await
260            .map_err(|e| crate::Error::Other(anyhow::anyhow!("Failed to parse device code response: {}", e)))?;
261            
262        Ok(device_response)
263    }
264    
265    /// Poll for access token
266    pub async fn poll_for_token(device_code: &str) -> crate::Result<Option<String>> {
267        let client = Client::new();
268        let request = AccessTokenRequest {
269            client_id: Self::CLIENT_ID.to_string(),
270            device_code: device_code.to_string(),
271            grant_type: "urn:ietf:params:oauth:grant-type:device_code".to_string(),
272        };
273        
274        let response = client
275            .post(Self::ACCESS_TOKEN_URL)
276            .header("Accept", "application/json")
277            .header("Content-Type", "application/json")
278            .header("User-Agent", "GitHubCopilotChat/0.26.7")
279            .json(&request)
280            .send()
281            .await
282            .map_err(|e| crate::Error::Other(anyhow::anyhow!("Token poll request failed: {}", e)))?;
283            
284        if !response.status().is_success() {
285            return Ok(None);
286        }
287        
288        let token_response: AccessTokenResponse = response
289            .json()
290            .await
291            .map_err(|e| crate::Error::Other(anyhow::anyhow!("Failed to parse token response: {}", e)))?;
292            
293        if let Some(access_token) = token_response.access_token {
294            Ok(Some(access_token))
295        } else if token_response.error.as_deref() == Some("authorization_pending") {
296            Ok(None)
297        } else {
298            Err(crate::Error::Other(anyhow::anyhow!(
299                "Token exchange failed: {:?}",
300                token_response.error
301            )))
302        }
303    }
304    
305    /// Get Copilot API token using GitHub OAuth token
306    pub async fn get_copilot_token(github_token: &str) -> crate::Result<AuthCredentials> {
307        let client = Client::new();
308        
309        let response = client
310            .get(Self::COPILOT_TOKEN_URL)
311            .header("Accept", "application/json")
312            .header("Authorization", format!("Bearer {}", github_token))
313            .header("User-Agent", "GitHubCopilotChat/0.26.7")
314            .header("Editor-Version", "vscode/1.99.3")
315            .header("Editor-Plugin-Version", "copilot-chat/0.26.7")
316            .send()
317            .await
318            .map_err(|e| crate::Error::Other(anyhow::anyhow!("Copilot token request failed: {}", e)))?;
319            
320        if !response.status().is_success() {
321            return Err(crate::Error::Other(anyhow::anyhow!(
322                "Copilot token request failed with status: {}",
323                response.status()
324            )));
325        }
326        
327        let token_response: CopilotTokenResponse = response
328            .json()
329            .await
330            .map_err(|e| crate::Error::Other(anyhow::anyhow!("Failed to parse copilot token response: {}", e)))?;
331            
332        Ok(AuthCredentials::OAuth {
333            access_token: token_response.token,
334            refresh_token: Some(github_token.to_string()), // Store GitHub token for refresh
335            expires_at: Some(token_response.expires_at),
336        })
337    }
338    
339    async fn get_auth_headers(&self) -> crate::Result<HashMap<String, String>> {
340        let credentials = self.auth.get_credentials().await?;
341        
342        let mut headers = HashMap::new();
343        headers.insert("User-Agent".to_string(), "GitHubCopilotChat/0.26.7".to_string());
344        headers.insert("Editor-Version".to_string(), "vscode/1.99.3".to_string());
345        headers.insert("Editor-Plugin-Version".to_string(), "copilot-chat/0.26.7".to_string());
346        headers.insert("Openai-Intent".to_string(), "conversation-edits".to_string());
347        
348        match credentials {
349            AuthCredentials::OAuth { access_token, refresh_token, expires_at } => {
350                // Check if token is expired and refresh if needed
351                if let Some(exp) = expires_at {
352                    let now = SystemTime::now()
353                        .duration_since(UNIX_EPOCH)
354                        .unwrap()
355                        .as_secs();
356                    
357                    if now >= exp {
358                        if let Some(github_token) = refresh_token {
359                            let new_creds = Self::get_copilot_token(&github_token).await?;
360                            self.auth.set_credentials(new_creds.clone()).await?;
361                            
362                            if let AuthCredentials::OAuth { access_token, .. } = new_creds {
363                                headers.insert("Authorization".to_string(), format!("Bearer {}", access_token));
364                            }
365                        } else {
366                            return Err(crate::Error::Other(anyhow::anyhow!("Token expired and no refresh token available")));
367                        }
368                    } else {
369                        headers.insert("Authorization".to_string(), format!("Bearer {}", access_token));
370                    }
371                } else {
372                    headers.insert("Authorization".to_string(), format!("Bearer {}", access_token));
373                }
374            }
375            _ => {
376                return Err(crate::Error::Other(anyhow::anyhow!(
377                    "Invalid credentials for GitHub Copilot"
378                )));
379            }
380        }
381        
382        Ok(headers)
383    }
384    
385    fn convert_messages(&self, messages: Vec<Message>) -> Vec<CopilotMessage> {
386        messages
387            .into_iter()
388            .map(|msg| self.convert_message(msg))
389            .collect()
390    }
391    
392    fn convert_message(&self, message: Message) -> CopilotMessage {
393        let role = match message.role {
394            MessageRole::System => "system",
395            MessageRole::User => "user",
396            MessageRole::Assistant => "assistant",
397            MessageRole::Tool => "tool",
398        }
399        .to_string();
400        
401        let content = match message.content {
402            MessageContent::Text(text) => CopilotContent::Text(text),
403            MessageContent::Parts(parts) => {
404                let copilot_parts: Vec<CopilotContentPart> = parts
405                    .into_iter()
406                    .filter_map(|part| match part {
407                        MessagePart::Text { text } => Some(CopilotContentPart::Text { text }),
408                        MessagePart::Image { image } => {
409                            if let Some(url) = image.url {
410                                Some(CopilotContentPart::ImageUrl {
411                                    image_url: CopilotImageUrl {
412                                        url,
413                                        detail: Some("auto".to_string()),
414                                    },
415                                })
416                            } else if let Some(base64) = image.base64 {
417                                Some(CopilotContentPart::ImageUrl {
418                                    image_url: CopilotImageUrl {
419                                        url: format!("data:{};base64,{}", image.mime_type, base64),
420                                        detail: Some("auto".to_string()),
421                                    },
422                                })
423                            } else {
424                                None
425                            }
426                        }
427                    })
428                    .collect();
429                CopilotContent::Parts(copilot_parts)
430            }
431        };
432        
433        let tool_calls = message.tool_calls.map(|calls| {
434            calls
435                .into_iter()
436                .map(|call| CopilotToolCall {
437                    id: call.id,
438                    tool_type: "function".to_string(),
439                    function: CopilotFunctionCall {
440                        name: call.name,
441                        arguments: call.arguments.to_string(),
442                    },
443                })
444                .collect()
445        });
446        
447        CopilotMessage {
448            role,
449            content,
450            name: message.name,
451            tool_calls,
452            tool_call_id: message.tool_call_id,
453        }
454    }
455    
456    fn convert_tools(&self, tools: Vec<ToolDefinition>) -> Vec<CopilotTool> {
457        tools
458            .into_iter()
459            .map(|tool| CopilotTool {
460                tool_type: "function".to_string(),
461                function: CopilotFunction {
462                    name: tool.name,
463                    description: tool.description,
464                    parameters: tool.parameters,
465                },
466            })
467            .collect()
468    }
469    
470    fn parse_finish_reason(&self, reason: Option<String>) -> FinishReason {
471        match reason.as_deref() {
472            Some("stop") => FinishReason::Stop,
473            Some("length") => FinishReason::Length,
474            Some("tool_calls") => FinishReason::ToolCalls,
475            Some("content_filter") => FinishReason::ContentFilter,
476            _ => FinishReason::Stop,
477        }
478    }
479}
480
481pub struct GitHubCopilotModelWithProvider {
482    model: GitHubCopilotModel,
483    provider: GitHubCopilotProvider,
484}
485
486impl GitHubCopilotModelWithProvider {
487    pub fn new(model: GitHubCopilotModel, provider: GitHubCopilotProvider) -> Self {
488        Self { model, provider }
489    }
490}
491
492#[async_trait]
493impl LanguageModel for GitHubCopilotModelWithProvider {
494    async fn generate(
495        &self,
496        messages: Vec<Message>,
497        options: GenerateOptions,
498    ) -> crate::Result<GenerateResult> {
499        let headers = self.provider.get_auth_headers().await?;
500        let copilot_messages = self.provider.convert_messages(messages);
501        let tools = self.provider.convert_tools(options.tools);
502        
503        let request = CopilotRequest {
504            model: self.model.id.clone(),
505            messages: copilot_messages,
506            max_tokens: options.max_tokens.unwrap_or(self.model.max_tokens),
507            temperature: options.temperature,
508            tools,
509            stop: options.stop_sequences,
510            stream: Some(false),
511        };
512        
513        let mut req_builder = self
514            .provider
515            .client
516            .post(&format!("{}/v1/chat/completions", GitHubCopilotProvider::API_BASE))
517            .header("Content-Type", "application/json");
518            
519        for (key, value) in headers {
520            req_builder = req_builder.header(&key, &value);
521        }
522        
523        let response = req_builder
524            .json(&request)
525            .send()
526            .await
527            .map_err(|e| crate::Error::Other(anyhow::anyhow!("Request failed: {}", e)))?;
528            
529        if !response.status().is_success() {
530            let status = response.status();
531            let body = response.text().await.unwrap_or_default();
532            return Err(crate::Error::Other(anyhow::anyhow!(
533                "API request failed with status {}: {}",
534                status,
535                body
536            )));
537        }
538        
539        let copilot_response: CopilotResponse = response
540            .json()
541            .await
542            .map_err(|e| crate::Error::Other(anyhow::anyhow!("Failed to parse response: {}", e)))?;
543            
544        let choice = copilot_response
545            .choices
546            .into_iter()
547            .next()
548            .ok_or_else(|| crate::Error::Other(anyhow::anyhow!("No choices in response")))?;
549            
550        let content = match choice.message.content {
551            CopilotContent::Text(text) => text,
552            CopilotContent::Parts(parts) => {
553                parts
554                    .into_iter()
555                    .filter_map(|part| match part {
556                        CopilotContentPart::Text { text } => Some(text),
557                        _ => None,
558                    })
559                    .collect::<Vec<_>>()
560                    .join("")
561            }
562        };
563        
564        let tool_calls = choice
565            .message
566            .tool_calls
567            .unwrap_or_default()
568            .into_iter()
569            .map(|call| ToolCall {
570                id: call.id,
571                name: call.function.name,
572                arguments: serde_json::from_str(&call.function.arguments)
573                    .unwrap_or(serde_json::Value::Object(serde_json::Map::new())),
574            })
575            .collect();
576            
577        Ok(GenerateResult {
578            content,
579            tool_calls,
580            usage: Usage {
581                prompt_tokens: copilot_response.usage.prompt_tokens,
582                completion_tokens: copilot_response.usage.completion_tokens,
583                total_tokens: copilot_response.usage.total_tokens,
584            },
585            finish_reason: self.provider.parse_finish_reason(choice.finish_reason),
586        })
587    }
588    
589    async fn stream(
590        &self,
591        messages: Vec<Message>,
592        options: StreamOptions,
593    ) -> crate::Result<Box<dyn Stream<Item = crate::Result<StreamChunk>> + Send + Unpin>> {
594        // Similar to generate but with stream: true
595        // Implementation would handle SSE stream parsing
596        Err(crate::Error::Other(anyhow::anyhow!(
597            "Streaming not yet implemented for GitHub Copilot"
598        )))
599    }
600    
601    fn supports_tools(&self) -> bool {
602        self.model.supports_tools
603    }
604    
605    fn supports_vision(&self) -> bool {
606        self.model.supports_vision
607    }
608    
609    fn supports_caching(&self) -> bool {
610        self.model.supports_caching
611    }
612}
613
614#[derive(Debug, thiserror::Error)]
615pub enum GitHubCopilotError {
616    #[error("Device code flow failed")]
617    DeviceCodeFailed,
618    
619    #[error("Token exchange failed")]
620    TokenExchangeFailed,
621    
622    #[error("Authentication expired")]
623    AuthenticationExpired,
624    
625    #[error("Copilot token request failed")]
626    CopilotTokenFailed,
627}
628
629#[cfg(test)]
630mod tests {
631    use super::*;
632    
633    #[test]
634    fn test_default_models() {
635        let models = GitHubCopilotProvider::default_models();
636        assert!(!models.is_empty());
637        assert!(models.contains_key("gpt-4o"));
638        assert!(models.contains_key("gpt-4o-mini"));
639        assert!(models.contains_key("o1-preview"));
640    }
641    
642    #[test]
643    fn test_model_capabilities() {
644        let models = GitHubCopilotProvider::default_models();
645        let gpt4o = models.get("gpt-4o").unwrap();
646        assert!(gpt4o.supports_tools);
647        assert!(gpt4o.supports_vision);
648        
649        let o1 = models.get("o1-preview").unwrap();
650        assert!(!o1.supports_tools);
651        assert!(!o1.supports_vision);
652    }
653}