Skip to main content

llmg_providers/
github_copilot.rs

1//! GitHub Copilot API client for LLMG
2//!
3//! Implements the Provider trait for GitHub Copilot Chat API using OAuth device flow.
4//!
5//! Authentication Flow:
6//! 1. Get device code from GitHub
7//! 2. User visits verification URL and enters code
8//! 3. Poll for access token
9//! 4. Exchange access token for Copilot API key
10//!
11//! Environment variables:
12//! - GITHUB_COPILOT_ACCESS_TOKEN: Cached OAuth access token
13//! - GITHUB_COPILOT_API_KEY: Cached Copilot API key
14//! - GITHUB_COPILOT_TOKEN_DIR: Directory to store tokens (default: ~/.config/llmg/github_copilot)
15
16use eventsource_stream::Eventsource;
17use futures::{StreamExt, TryStreamExt};
18use llmg_core::{
19    provider::{ChatCompletionStream, LlmError, Provider},
20    streaming::{ChatCompletionChunk, ChoiceDelta, DeltaContent},
21    types::{
22        ChatCompletionRequest, ChatCompletionResponse, Choice, EmbeddingRequest, EmbeddingResponse,
23        Message, Usage,
24    },
25};
26use std::future::Future;
27use std::pin::Pin;
28// use serde::Deserialize; // removed unused import
29use std::collections::HashMap;
30use std::path::PathBuf;
31
32/// GitHub Copilot OAuth constants
33const GITHUB_CLIENT_ID: &str = "Iv1.b507a08c87ecfe98";
34const GITHUB_DEVICE_CODE_URL: &str = "https://github.com/login/device/code";
35const GITHUB_ACCESS_TOKEN_URL: &str = "https://github.com/login/oauth/access_token";
36const GITHUB_COPILOT_TOKEN_URL: &str = "https://api.github.com/copilot_internal/v2/token";
37const GITHUB_COPILOT_API_BASE: &str = "https://api.githubcopilot.com";
38
39/// GitHub Copilot API client
40#[derive(Debug, Clone)]
41pub struct GitHubCopilotClient {
42    http_client: reqwest::Client,
43    api_key: String,
44    access_token: String,
45    editor_version: String,
46    integration_id: String,
47}
48
49#[derive(Debug, serde::Deserialize)]
50struct DeviceCodeResponse {
51    device_code: String,
52    user_code: String,
53    verification_uri: String,
54    expires_in: i32,
55    interval: i32,
56}
57
58#[derive(Debug, serde::Deserialize)]
59struct AccessTokenResponse {
60    access_token: Option<String>,
61    token_type: Option<String>,
62    error: Option<String>,
63    error_description: Option<String>,
64}
65
66#[derive(Debug, serde::Deserialize, serde::Serialize)]
67struct CopilotApiKeyResponse {
68    token: String,
69    expires_at: i64,
70    endpoints: Option<HashMap<String, String>>,
71}
72
73#[derive(Debug, serde::Serialize)]
74struct CopilotChatRequest {
75    messages: Vec<CopilotMessage>,
76    model: String,
77    #[serde(skip_serializing_if = "Option::is_none")]
78    temperature: Option<f32>,
79    #[serde(skip_serializing_if = "Option::is_none")]
80    top_p: Option<f32>,
81    #[serde(skip_serializing_if = "Option::is_none")]
82    stream: Option<bool>,
83    #[serde(skip_serializing_if = "Option::is_none")]
84    stop: Option<Vec<String>>,
85    #[serde(skip_serializing_if = "Option::is_none")]
86    max_tokens: Option<u32>,
87    #[serde(skip_serializing_if = "Option::is_none")]
88    tools: Option<Vec<llmg_core::types::Tool>>,
89    #[serde(skip_serializing_if = "Option::is_none")]
90    tool_choice: Option<llmg_core::types::ToolChoice>,
91}
92
93#[derive(Debug, serde::Serialize)]
94struct CopilotMessage {
95    role: String,
96    content: String,
97    #[serde(skip_serializing_if = "Option::is_none")]
98    tool_calls: Option<Vec<llmg_core::types::ToolCall>>,
99    #[serde(skip_serializing_if = "Option::is_none")]
100    tool_call_id: Option<String>,
101}
102
103#[derive(Debug, serde::Deserialize)]
104struct CopilotChatResponse {
105    id: String,
106    #[serde(default)]
107    object: String,
108    #[serde(default)]
109    created: i64,
110    model: String,
111    choices: Vec<CopilotChoice>,
112    usage: Option<CopilotUsage>,
113}
114
115#[derive(Debug, serde::Deserialize)]
116struct CopilotChoice {
117    index: i32,
118    message: CopilotMessageResponse,
119    #[serde(rename = "finish_reason")]
120    finish_reason: Option<String>,
121}
122
123#[derive(Debug, serde::Deserialize)]
124struct CopilotMessageResponse {
125    role: String,
126    content: String,
127}
128
129#[derive(Debug, serde::Deserialize)]
130struct CopilotUsage {
131    #[serde(rename = "prompt_tokens")]
132    prompt_tokens: u32,
133    #[serde(default, rename = "completion_tokens")]
134    completion_tokens: u32,
135    #[serde(rename = "total_tokens")]
136    total_tokens: u32,
137}
138
139#[derive(Debug, serde::Serialize)]
140struct CopilotEmbeddingRequest {
141    model: String,
142    input: Vec<String>,
143}
144
145#[derive(Debug, serde::Deserialize)]
146struct CopilotEmbeddingResponse {
147    #[serde(default)]
148    object: String,
149    data: Vec<CopilotEmbeddingData>,
150    #[serde(default)]
151    model: String,
152    usage: CopilotUsage,
153}
154
155#[derive(Debug, serde::Deserialize)]
156struct CopilotEmbeddingData {
157    object: String,
158    index: u32,
159    embedding: Vec<f32>,
160}
161
162impl GitHubCopilotClient {
163    /// Create a new GitHub Copilot client with OAuth flow
164    pub async fn new() -> Result<Self, LlmError> {
165        let token_dir = Self::get_token_dir();
166        std::fs::create_dir_all(&token_dir).map_err(|e| {
167            LlmError::ProviderError(format!("Failed to create token directory: {}", e))
168        })?;
169
170        println!("Loading access token...");
171        let access_token = match Self::load_cached_access_token(&token_dir).await {
172            Ok(token) => token,
173            Err(e) => {
174                println!("load_cached_access_token failed with {:?}", e);
175                return Err(e);
176            }
177        };
178        println!("Access token loaded.");
179
180        let mut client = Self {
181            http_client: reqwest::Client::new(),
182            api_key: String::new(),
183            access_token,
184            editor_version: "vscode/1.85.1".to_string(),
185            integration_id: "vscode-chat".to_string(),
186        };
187
188        if let Ok(key) = Self::load_cached_api_key(&token_dir).await {
189            client.api_key = key;
190            println!("API key loaded from cache.");
191        } else {
192            println!("Refreshing API key...");
193            match client.refresh_api_key().await {
194                Ok(_) => println!("API key refreshed."),
195                Err(e) => {
196                    println!("refresh_api_key failed with {:?}", e);
197                    return Err(e);
198                }
199            }
200        }
201
202        Ok(client)
203    }
204
205    /// Create with explicit API key (for testing or pre-authenticated scenarios)
206    pub fn with_api_key(api_key: impl Into<String>, access_token: impl Into<String>) -> Self {
207        Self {
208            http_client: reqwest::Client::new(),
209            api_key: api_key.into(),
210            access_token: access_token.into(),
211            editor_version: "vscode/1.85.1".to_string(),
212            integration_id: "vscode-chat".to_string(),
213        }
214    }
215
216    fn get_token_dir() -> PathBuf {
217        std::env::var("GITHUB_COPILOT_TOKEN_DIR")
218            .map(PathBuf::from)
219            .unwrap_or_else(|_| {
220                dirs::config_dir()
221                    .unwrap_or_else(|| PathBuf::from("."))
222                    .join("llmg/github_copilot")
223            })
224    }
225
226    async fn load_cached_access_token(token_dir: &std::path::Path) -> Result<String, LlmError> {
227        let access_token_path = token_dir.join("access-token");
228
229        if let Ok(token) = std::fs::read_to_string(&access_token_path) {
230            let token = token.trim();
231            if !token.is_empty() {
232                return Ok(token.to_string());
233            }
234        }
235
236        Self::perform_oauth_flow(token_dir).await
237    }
238
239    async fn load_cached_api_key(token_dir: &std::path::Path) -> Result<String, LlmError> {
240        let api_key_path = token_dir.join("api-key.json");
241
242        if let Ok(content) = std::fs::read_to_string(&api_key_path) {
243            if let Ok(api_key_info) = serde_json::from_str::<CopilotApiKeyResponse>(&content) {
244                let now = std::time::SystemTime::now()
245                    .duration_since(std::time::UNIX_EPOCH)
246                    .unwrap()
247                    .as_secs() as i64;
248
249                if api_key_info.expires_at > now {
250                    return Ok(api_key_info.token);
251                }
252            }
253        }
254
255        Err(LlmError::AuthError)
256    }
257
258    async fn perform_oauth_flow(token_dir: &std::path::Path) -> Result<String, LlmError> {
259        let device_code_resp = Self::get_device_code().await?;
260
261        eprintln!("\nšŸ” GitHub Copilot Authentication Required");
262        eprintln!("Please visit: {}", device_code_resp.verification_uri);
263        eprintln!("And enter code: {}\n", device_code_resp.user_code);
264
265        let access_token = Self::poll_for_access_token(
266            &device_code_resp.device_code,
267            device_code_resp.interval as u64,
268        )
269        .await?;
270
271        let access_token_path = token_dir.join("access-token");
272        std::fs::write(&access_token_path, &access_token)
273            .map_err(|e| LlmError::ProviderError(format!("Failed to cache access token: {}", e)))?;
274
275        Ok(access_token)
276    }
277
278    async fn get_device_code() -> Result<DeviceCodeResponse, LlmError> {
279        let client = reqwest::Client::new();
280
281        let resp = client
282            .post(GITHUB_DEVICE_CODE_URL)
283            .header("Accept", "application/json")
284            .header("User-Agent", "GithubCopilot/1.155.0")
285            .json(&serde_json::json!({
286                "client_id": GITHUB_CLIENT_ID,
287                "scope": "read:user"
288            }))
289            .send()
290            .await
291            .map_err(|e| LlmError::HttpError(format!("Failed to get device code: {}", e)))?;
292
293        if !resp.status().is_success() {
294            return Err(LlmError::ApiError {
295                status: resp.status().as_u16(),
296                message: resp.text().await.unwrap_or_default(),
297            });
298        }
299
300        resp.json::<DeviceCodeResponse>()
301            .await
302            .map_err(|e| LlmError::HttpError(e.to_string()))
303    }
304
305    async fn poll_for_access_token(device_code: &str, interval: u64) -> Result<String, LlmError> {
306        let client = reqwest::Client::new();
307        let max_attempts = 60;
308
309        for attempt in 0..max_attempts {
310            tokio::time::sleep(tokio::time::Duration::from_secs(interval)).await;
311
312            let resp = client
313                .post(GITHUB_ACCESS_TOKEN_URL)
314                .header("Accept", "application/json")
315                .header("User-Agent", "GithubCopilot/1.155.0")
316                .json(&serde_json::json!({
317                    "client_id": GITHUB_CLIENT_ID,
318                    "device_code": device_code,
319                    "grant_type": "urn:ietf:params:oauth:grant-type:device_code"
320                }))
321                .send()
322                .await
323                .map_err(|e| LlmError::HttpError(format!("Failed to poll for token: {}", e)))?;
324
325            if !resp.status().is_success() {
326                continue;
327            }
328
329            let token_resp = resp
330                .json::<AccessTokenResponse>()
331                .await
332                .map_err(|e| LlmError::HttpError(e.to_string()))?;
333
334            if let Some(token) = token_resp.access_token {
335                eprintln!("āœ… Authentication successful!");
336                return Ok(token);
337            }
338
339            if let Some(error) = token_resp.error {
340                if error != "authorization_pending" {
341                    return Err(LlmError::AuthError);
342                }
343            }
344
345            if attempt % 6 == 0 {
346                eprintln!(
347                    "ā³ Waiting for authorization... (attempt {}/{})",
348                    attempt + 1,
349                    max_attempts
350                );
351            }
352        }
353
354        Err(LlmError::AuthError)
355    }
356
357    async fn refresh_api_key(&mut self) -> Result<(), LlmError> {
358        let client = reqwest::Client::new();
359
360        let resp = client
361            .get(GITHUB_COPILOT_TOKEN_URL)
362            .header("Authorization", format!("token {}", self.access_token))
363            .header("Accept", "application/json")
364            .header("User-Agent", "GithubCopilot/1.155.0")
365            .send()
366            .await
367            .map_err(|e| LlmError::HttpError(format!("Failed to refresh API key: {}", e)))?;
368
369        if !resp.status().is_success() {
370            if resp.status().as_u16() == 401 {
371                let token_dir = Self::get_token_dir();
372                self.access_token = Self::perform_oauth_flow(&token_dir).await?;
373                return Box::pin(self.refresh_api_key()).await;
374            }
375
376            return Err(LlmError::ApiError {
377                status: resp.status().as_u16(),
378                message: resp.text().await.unwrap_or_default(),
379            });
380        }
381
382        let api_key_info = resp
383            .json::<CopilotApiKeyResponse>()
384            .await
385            .map_err(|e| LlmError::HttpError(e.to_string()))?;
386
387        self.api_key = api_key_info.token.clone();
388
389        let token_dir = Self::get_token_dir();
390        let api_key_path = token_dir.join("api-key.json");
391        std::fs::write(&api_key_path, serde_json::to_string(&api_key_info).unwrap())
392            .map_err(|e| LlmError::ProviderError(format!("Failed to cache API key: {}", e)))?;
393
394        Ok(())
395    }
396
397    pub fn with_editor_version(mut self, version: impl Into<String>) -> Self {
398        self.editor_version = version.into();
399        self
400    }
401
402    pub fn with_integration_id(mut self, id: impl Into<String>) -> Self {
403        self.integration_id = id.into();
404        self
405    }
406
407    fn convert_request(&self, request: ChatCompletionRequest) -> CopilotChatRequest {
408        let messages: Vec<CopilotMessage> = request
409            .messages
410            .into_iter()
411            .map(|msg| match msg {
412                Message::System { content, .. } => CopilotMessage {
413                    role: "system".to_string(),
414                    content,
415                    tool_calls: None,
416                    tool_call_id: None,
417                },
418                Message::User { content, .. } => CopilotMessage {
419                    role: "user".to_string(),
420                    content,
421                    tool_calls: None,
422                    tool_call_id: None,
423                },
424                Message::Assistant {
425                    content,
426                    tool_calls,
427                    ..
428                } => CopilotMessage {
429                    role: "assistant".to_string(),
430                    content: content.unwrap_or_default(),
431                    tool_calls,
432                    tool_call_id: None,
433                },
434                Message::Tool {
435                    content,
436                    tool_call_id,
437                } => CopilotMessage {
438                    role: "tool".to_string(),
439                    content,
440                    tool_calls: None,
441                    tool_call_id: Some(tool_call_id),
442                },
443            })
444            .collect();
445
446        CopilotChatRequest {
447            messages,
448            model: request.model,
449            temperature: request.temperature,
450            top_p: request.top_p,
451            stream: request.stream,
452            stop: request.stop,
453            max_tokens: request.max_tokens,
454            tools: request.tools,
455            tool_choice: request.tool_choice,
456        }
457    }
458
459    fn convert_response(&self, response: CopilotChatResponse) -> ChatCompletionResponse {
460        ChatCompletionResponse {
461            id: response.id,
462            object: response.object,
463            created: response.created,
464            model: response.model,
465            choices: response
466                .choices
467                .into_iter()
468                .map(|c| Choice {
469                    index: c.index as u32,
470                    message: Message::Assistant {
471                        content: Some(c.message.content),
472                        refusal: None,
473                        tool_calls: None,
474                    },
475                    finish_reason: c.finish_reason,
476                })
477                .collect(),
478            usage: response.usage.map(|u| Usage {
479                prompt_tokens: u.prompt_tokens,
480                completion_tokens: u.completion_tokens,
481                total_tokens: u.total_tokens,
482            }),
483        }
484    }
485
486    async fn make_request(
487        &mut self,
488        request: ChatCompletionRequest,
489    ) -> Result<ChatCompletionResponse, LlmError> {
490        if self.api_key.is_empty() {
491            self.refresh_api_key().await?;
492        }
493
494        let url = format!("{}/chat/completions", GITHUB_COPILOT_API_BASE);
495        let copilot_req = self.convert_request(request.clone());
496
497        let initiator = if request
498            .messages
499            .iter()
500            .any(|m| matches!(m, Message::Assistant { .. } | Message::Tool { .. }))
501        {
502            "agent"
503        } else {
504            "user"
505        };
506
507        let request_id = uuid::Uuid::new_v4().to_string();
508        let resp = self
509            .http_client
510            .post(&url)
511            .header("Authorization", format!("Bearer {}", self.api_key))
512            .header("Content-Type", "application/json")
513            .header("Accept", "application/json")
514            .header("editor-version", "vscode/1.95.0")
515            .header("editor-plugin-version", "copilot-chat/0.26.7")
516            .header("Copilot-Integration-Id", "vscode-chat")
517            .header("User-Agent", "GitHubCopilotChat/0.26.7")
518            .header("openai-intent", "conversation-panel")
519            .header("x-github-api-version", "2025-04-01")
520            .header("x-request-id", &request_id)
521            .header("x-vscode-user-agent-library-version", "electron-fetch")
522            .header("X-Initiator", initiator)
523            .json(&copilot_req)
524            .send()
525            .await
526            .map_err(|e| LlmError::HttpError(e.to_string()))?;
527
528        if resp.status().as_u16() == 401 {
529            self.refresh_api_key().await?;
530            return Box::pin(async move { self.make_request(request).await }).await;
531        }
532
533        if !resp.status().is_success() {
534            let status = resp.status().as_u16();
535            let text = resp.text().await.unwrap_or_default();
536
537            if status == 429 {
538                return Err(LlmError::RateLimitError);
539            }
540
541            return Err(LlmError::ApiError {
542                status,
543                message: text,
544            });
545        }
546
547        let text = resp
548            .text()
549            .await
550            .map_err(|e| LlmError::HttpError(e.to_string()))?;
551        let copilot_resp: CopilotChatResponse = serde_json::from_str(&text)
552            .map_err(|e| LlmError::HttpError(format!("error decoding response body: {}", e)))?;
553
554        Ok(self.convert_response(copilot_resp))
555    }
556
557    async fn make_stream_request(
558        &mut self,
559        request: ChatCompletionRequest,
560    ) -> Result<ChatCompletionStream, LlmError> {
561        if self.api_key.is_empty() {
562            self.refresh_api_key().await?;
563        }
564
565        let url = format!("{}/chat/completions", GITHUB_COPILOT_API_BASE);
566        let mut copilot_req = self.convert_request(request.clone());
567        copilot_req.stream = Some(true);
568
569        let initiator = if request
570            .messages
571            .iter()
572            .any(|m| matches!(m, Message::Assistant { .. } | Message::Tool { .. }))
573        {
574            "agent"
575        } else {
576            "user"
577        };
578
579        let request_id = uuid::Uuid::new_v4().to_string();
580        let resp = self
581            .http_client
582            .post(&url)
583            .header("Authorization", format!("Bearer {}", self.api_key))
584            .header("Content-Type", "application/json")
585            .header("Accept", "application/json")
586            .header("editor-version", "vscode/1.95.0")
587            .header("editor-plugin-version", "copilot-chat/0.26.7")
588            .header("Copilot-Integration-Id", "vscode-chat")
589            .header("User-Agent", "GitHubCopilotChat/0.26.7")
590            .header("openai-intent", "conversation-panel")
591            .header("x-github-api-version", "2025-04-01")
592            .header("x-request-id", &request_id)
593            .header("x-vscode-user-agent-library-version", "electron-fetch")
594            .header("X-Initiator", initiator)
595            .json(&copilot_req)
596            .send()
597            .await
598            .map_err(|e| LlmError::HttpError(e.to_string()))?;
599
600        if resp.status().as_u16() == 401 {
601            self.refresh_api_key().await?;
602            return Box::pin(async move { self.make_stream_request(request).await }).await;
603        }
604
605        if !resp.status().is_success() {
606            let status = resp.status().as_u16();
607            let text = resp.text().await.unwrap_or_default();
608
609            if status == 429 {
610                return Err(LlmError::RateLimitError);
611            }
612
613            return Err(LlmError::ApiError {
614                status,
615                message: text,
616            });
617        }
618
619        let chunk_id = ChatCompletionChunk::generate_id();
620        let model = copilot_req.model.clone();
621
622        let stream = resp
623            .bytes_stream()
624            .eventsource()
625            .map_err(|e| LlmError::HttpError(e.to_string()))
626            .then(move |event_result| {
627                let chunk_id = chunk_id.clone();
628                let model = model.clone();
629                async move {
630                    match event_result {
631                        Ok(event) => parse_copilot_sse_data(&event.data, &chunk_id, &model),
632                        Err(e) => Err(LlmError::HttpError(e.to_string())),
633                    }
634                }
635            })
636            .try_filter_map(|chunk| async move { Ok(chunk) });
637
638        Ok(Box::pin(stream) as ChatCompletionStream)
639    }
640
641    async fn make_embedding_request(
642        &mut self,
643        request: EmbeddingRequest,
644    ) -> Result<EmbeddingResponse, LlmError> {
645        if self.api_key.is_empty() {
646            self.refresh_api_key().await?;
647        }
648
649        let url = format!("{}/embeddings", GITHUB_COPILOT_API_BASE);
650        let copilot_req = CopilotEmbeddingRequest {
651            model: request.model.clone(),
652            input: vec![request.input.clone()],
653        };
654
655        let request_id = uuid::Uuid::new_v4().to_string();
656        let resp = self
657            .http_client
658            .post(&url)
659            .header("Authorization", format!("Bearer {}", self.api_key))
660            .header("Content-Type", "application/json")
661            .header("Accept", "application/json")
662            .header("editor-version", "vscode/1.95.0")
663            .header("editor-plugin-version", "copilot-chat/0.26.7")
664            .header("Copilot-Integration-Id", "vscode-chat")
665            .header("User-Agent", "GitHubCopilotChat/0.26.7")
666            .header("openai-intent", "conversation-panel")
667            .header("x-github-api-version", "2025-04-01")
668            .header("x-request-id", &request_id)
669            .header("x-vscode-user-agent-library-version", "electron-fetch")
670            .header("X-Initiator", "user")
671            .json(&copilot_req)
672            .send()
673            .await
674            .map_err(|e| LlmError::HttpError(e.to_string()))?;
675
676        if resp.status().as_u16() == 401 {
677            self.refresh_api_key().await?;
678            return Box::pin(async move { self.make_embedding_request(request).await }).await;
679        }
680
681        if !resp.status().is_success() {
682            let status = resp.status().as_u16();
683            let text = resp.text().await.unwrap_or_default();
684
685            if status == 429 {
686                return Err(LlmError::RateLimitError);
687            }
688
689            return Err(LlmError::ApiError {
690                status,
691                message: text,
692            });
693        }
694
695        let text = resp
696            .text()
697            .await
698            .map_err(|e| LlmError::HttpError(e.to_string()))?;
699        let copilot_resp: CopilotEmbeddingResponse = serde_json::from_str(&text)
700            .map_err(|e| LlmError::HttpError(format!("error decoding response body: {}", e)))?;
701
702        Ok(EmbeddingResponse {
703            id: uuid::Uuid::new_v4().to_string(),
704            object: if copilot_resp.object.is_empty() {
705                "list".to_string()
706            } else {
707                copilot_resp.object
708            },
709            data: copilot_resp
710                .data
711                .into_iter()
712                .map(|d| llmg_core::types::Embedding {
713                    index: d.index,
714                    object: d.object,
715                    embedding: d.embedding,
716                })
717                .collect(),
718            model: copilot_resp.model,
719            usage: Usage {
720                prompt_tokens: copilot_resp.usage.prompt_tokens,
721                completion_tokens: copilot_resp.usage.completion_tokens,
722                total_tokens: copilot_resp.usage.total_tokens,
723            },
724        })
725    }
726
727    pub fn get_models() -> Vec<String> {
728        vec![
729            "gpt-4".to_string(),
730            "gpt-4o".to_string(),
731            "gpt-4o-mini".to_string(),
732            "gpt-3.5-turbo".to_string(),
733            "o1-preview".to_string(),
734            "o1-mini".to_string(),
735            "claude-3-5-sonnet".to_string(),
736            "text-embedding-3-small".to_string(),
737        ]
738    }
739}
740
741#[async_trait::async_trait]
742impl Provider for GitHubCopilotClient {
743    async fn chat_completion(
744        &self,
745        request: ChatCompletionRequest,
746    ) -> Result<ChatCompletionResponse, LlmError> {
747        let mut client = self.clone();
748        client.make_request(request).await
749    }
750
751    fn chat_completion_stream(
752        &self,
753        request: ChatCompletionRequest,
754    ) -> Pin<Box<dyn Future<Output = Result<ChatCompletionStream, LlmError>> + Send + '_>> {
755        let mut client = self.clone();
756        Box::pin(async move { client.make_stream_request(request).await })
757    }
758
759    async fn embeddings(&self, request: EmbeddingRequest) -> Result<EmbeddingResponse, LlmError> {
760        let mut client = self.clone();
761        client.make_embedding_request(request).await
762    }
763    fn provider_name(&self) -> &'static str {
764        "github_copilot"
765    }
766}
767
768fn parse_copilot_sse_data(
769    data: &str,
770    chunk_id: &str,
771    model: &str,
772) -> Result<Option<ChatCompletionChunk>, LlmError> {
773    let data = data.trim();
774    if data.is_empty() || data == "[DONE]" {
775        return Ok(None);
776    }
777
778    let parsed: serde_json::Value =
779        serde_json::from_str(data).map_err(LlmError::SerializationError)?;
780
781    let choices = parsed
782        .get("choices")
783        .and_then(|c| c.as_array())
784        .map(|arr| {
785            arr.iter()
786                .filter_map(|choice| {
787                    let index = choice.get("index").and_then(|i| i.as_u64()).unwrap_or(0) as u32;
788                    let delta = choice.get("delta")?;
789                    let finish_reason = choice
790                        .get("finish_reason")
791                        .and_then(|f| f.as_str())
792                        .map(|s| s.to_string());
793
794                    let role = delta
795                        .get("role")
796                        .and_then(|r| r.as_str())
797                        .map(|s| s.to_string());
798                    let content = delta
799                        .get("content")
800                        .and_then(|c| c.as_str())
801                        .map(|s| s.to_string());
802                    let tool_calls = delta
803                        .get("tool_calls")
804                        .and_then(|t| serde_json::from_value(t.clone()).ok());
805
806                    Some(ChoiceDelta {
807                        index,
808                        delta: DeltaContent {
809                            role,
810                            content,
811                            tool_calls,
812                        },
813                        finish_reason,
814                    })
815                })
816                .collect::<Vec<_>>()
817        })
818        .unwrap_or_default();
819
820    if choices.is_empty() {
821        return Ok(None);
822    }
823
824    Ok(Some(ChatCompletionChunk {
825        id: chunk_id.to_string(),
826        object: "chat.completion.chunk".to_string(),
827        created: chrono::Utc::now().timestamp(),
828        model: model.to_string(),
829        choices,
830        usage: None,
831    }))
832}
833
834#[cfg(test)]
835mod tests {
836    use super::*;
837
838    #[test]
839    fn test_copilot_client_with_api_key() {
840        let client = GitHubCopilotClient::with_api_key("test-api-key", "test-access-token");
841        assert_eq!(client.provider_name(), "github_copilot");
842    }
843
844    #[test]
845    fn test_request_conversion() {
846        let client = GitHubCopilotClient::with_api_key("test-key", "test-token");
847
848        let request = ChatCompletionRequest {
849            model: "gpt-4".to_string(),
850            messages: vec![
851                Message::System {
852                    content: "You are a helpful coding assistant".to_string(),
853                    name: None,
854                },
855                Message::User {
856                    content: "Write a Python function".to_string(),
857                    name: None,
858                },
859            ],
860            temperature: Some(0.7),
861            max_tokens: Some(1000),
862            stream: None,
863            top_p: None,
864            frequency_penalty: None,
865            presence_penalty: None,
866            stop: None,
867            user: None,
868            tools: None,
869            tool_choice: None,
870            response_format: None,
871        };
872
873        let copilot_req = client.convert_request(request);
874
875        assert_eq!(copilot_req.model, "gpt-4");
876        assert_eq!(copilot_req.messages.len(), 2);
877        assert_eq!(copilot_req.messages[0].role, "system");
878        assert_eq!(copilot_req.messages[1].role, "user");
879    }
880
881    #[test]
882    fn test_tool_calling_conversion() {
883        let client = GitHubCopilotClient::with_api_key("test-key", "test-token");
884
885        let tool = llmg_core::types::Tool {
886            r#type: "function".to_string(),
887            function: llmg_core::types::FunctionDefinition {
888                name: "get_weather".to_string(),
889                description: Some("Get the weather".to_string()),
890                parameters: serde_json::json!({"type": "object", "properties": {"location": {"type": "string"}}}),
891            },
892        };
893
894        let request = ChatCompletionRequest {
895            model: "gpt-4".to_string(),
896            messages: vec![Message::User {
897                content: "Weather?".to_string(),
898                name: None,
899            }],
900            temperature: None,
901            max_tokens: None,
902            stream: None,
903            top_p: None,
904            frequency_penalty: None,
905            presence_penalty: None,
906            stop: None,
907            user: None,
908            tools: Some(vec![tool]),
909            tool_choice: Some(llmg_core::types::ToolChoice::String("auto".to_string())),
910            response_format: None,
911        };
912
913        let copilot_req = client.convert_request(request);
914
915        assert!(copilot_req.tools.is_some());
916        assert_eq!(copilot_req.tools.unwrap().len(), 1);
917        assert!(copilot_req.tool_choice.is_some());
918    }
919
920    #[test]
921    fn test_parse_copilot_sse_data_tool_calls() {
922        let raw_sse = r#"{"id":"chatcmpl-123","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"id":"call_abc","type":"function","function":{"name":"get_weather","arguments":"{\"location\":\"Boston\"}"}}]},"finish_reason":null}]}"#;
923        let chunk = parse_copilot_sse_data(raw_sse, "chatcmpl-123", "gpt-4")
924            .unwrap()
925            .unwrap();
926
927        assert_eq!(chunk.choices.len(), 1);
928        let choice = &chunk.choices[0];
929        assert!(choice.delta.tool_calls.is_some());
930
931        let tool_calls = choice.delta.tool_calls.as_ref().unwrap();
932        assert_eq!(tool_calls.len(), 1);
933        assert_eq!(tool_calls[0].id.as_deref(), Some("call_abc"));
934        assert_eq!(
935            tool_calls[0]
936                .function
937                .as_ref()
938                .and_then(|f| f.name.as_deref()),
939            Some("get_weather")
940        );
941    }
942}
943