1use eventsource_stream::Eventsource;
17use futures::{StreamExt, TryStreamExt};
18use llmg_core::{
19 provider::{ChatCompletionStream, LlmError, Provider},
20 streaming::{ChatCompletionChunk, ChoiceDelta, DeltaContent},
21 types::{
22 ChatCompletionRequest, ChatCompletionResponse, Choice, EmbeddingRequest, EmbeddingResponse,
23 Message, Usage,
24 },
25};
26use std::future::Future;
27use std::pin::Pin;
28use std::collections::HashMap;
30use std::path::PathBuf;
31
32const GITHUB_CLIENT_ID: &str = "Iv1.b507a08c87ecfe98";
34const GITHUB_DEVICE_CODE_URL: &str = "https://github.com/login/device/code";
35const GITHUB_ACCESS_TOKEN_URL: &str = "https://github.com/login/oauth/access_token";
36const GITHUB_COPILOT_TOKEN_URL: &str = "https://api.github.com/copilot_internal/v2/token";
37const GITHUB_COPILOT_API_BASE: &str = "https://api.githubcopilot.com";
38
39#[derive(Debug, Clone)]
41pub struct GitHubCopilotClient {
42 http_client: reqwest::Client,
43 api_key: String,
44 access_token: String,
45 editor_version: String,
46 integration_id: String,
47}
48
49#[derive(Debug, serde::Deserialize)]
50struct DeviceCodeResponse {
51 device_code: String,
52 user_code: String,
53 verification_uri: String,
54 expires_in: i32,
55 interval: i32,
56}
57
58#[derive(Debug, serde::Deserialize)]
59struct AccessTokenResponse {
60 access_token: Option<String>,
61 token_type: Option<String>,
62 error: Option<String>,
63 error_description: Option<String>,
64}
65
66#[derive(Debug, serde::Deserialize, serde::Serialize)]
67struct CopilotApiKeyResponse {
68 token: String,
69 expires_at: i64,
70 endpoints: Option<HashMap<String, String>>,
71}
72
73#[derive(Debug, serde::Serialize)]
74struct CopilotChatRequest {
75 messages: Vec<CopilotMessage>,
76 model: String,
77 #[serde(skip_serializing_if = "Option::is_none")]
78 temperature: Option<f32>,
79 #[serde(skip_serializing_if = "Option::is_none")]
80 top_p: Option<f32>,
81 #[serde(skip_serializing_if = "Option::is_none")]
82 stream: Option<bool>,
83 #[serde(skip_serializing_if = "Option::is_none")]
84 stop: Option<Vec<String>>,
85 #[serde(skip_serializing_if = "Option::is_none")]
86 max_tokens: Option<u32>,
87 #[serde(skip_serializing_if = "Option::is_none")]
88 tools: Option<Vec<llmg_core::types::Tool>>,
89 #[serde(skip_serializing_if = "Option::is_none")]
90 tool_choice: Option<llmg_core::types::ToolChoice>,
91}
92
93#[derive(Debug, serde::Serialize)]
94struct CopilotMessage {
95 role: String,
96 content: String,
97 #[serde(skip_serializing_if = "Option::is_none")]
98 tool_calls: Option<Vec<llmg_core::types::ToolCall>>,
99 #[serde(skip_serializing_if = "Option::is_none")]
100 tool_call_id: Option<String>,
101}
102
103#[derive(Debug, serde::Deserialize)]
104struct CopilotChatResponse {
105 id: String,
106 #[serde(default)]
107 object: String,
108 #[serde(default)]
109 created: i64,
110 model: String,
111 choices: Vec<CopilotChoice>,
112 usage: Option<CopilotUsage>,
113}
114
115#[derive(Debug, serde::Deserialize)]
116struct CopilotChoice {
117 index: i32,
118 message: CopilotMessageResponse,
119 #[serde(rename = "finish_reason")]
120 finish_reason: Option<String>,
121}
122
123#[derive(Debug, serde::Deserialize)]
124struct CopilotMessageResponse {
125 role: String,
126 content: String,
127}
128
129#[derive(Debug, serde::Deserialize)]
130struct CopilotUsage {
131 #[serde(rename = "prompt_tokens")]
132 prompt_tokens: u32,
133 #[serde(default, rename = "completion_tokens")]
134 completion_tokens: u32,
135 #[serde(rename = "total_tokens")]
136 total_tokens: u32,
137}
138
139#[derive(Debug, serde::Serialize)]
140struct CopilotEmbeddingRequest {
141 model: String,
142 input: Vec<String>,
143}
144
145#[derive(Debug, serde::Deserialize)]
146struct CopilotEmbeddingResponse {
147 #[serde(default)]
148 object: String,
149 data: Vec<CopilotEmbeddingData>,
150 #[serde(default)]
151 model: String,
152 usage: CopilotUsage,
153}
154
155#[derive(Debug, serde::Deserialize)]
156struct CopilotEmbeddingData {
157 object: String,
158 index: u32,
159 embedding: Vec<f32>,
160}
161
162impl GitHubCopilotClient {
163 pub async fn new() -> Result<Self, LlmError> {
165 let token_dir = Self::get_token_dir();
166 std::fs::create_dir_all(&token_dir).map_err(|e| {
167 LlmError::ProviderError(format!("Failed to create token directory: {}", e))
168 })?;
169
170 println!("Loading access token...");
171 let access_token = match Self::load_cached_access_token(&token_dir).await {
172 Ok(token) => token,
173 Err(e) => {
174 println!("load_cached_access_token failed with {:?}", e);
175 return Err(e);
176 }
177 };
178 println!("Access token loaded.");
179
180 let mut client = Self {
181 http_client: reqwest::Client::new(),
182 api_key: String::new(),
183 access_token,
184 editor_version: "vscode/1.85.1".to_string(),
185 integration_id: "vscode-chat".to_string(),
186 };
187
188 if let Ok(key) = Self::load_cached_api_key(&token_dir).await {
189 client.api_key = key;
190 println!("API key loaded from cache.");
191 } else {
192 println!("Refreshing API key...");
193 match client.refresh_api_key().await {
194 Ok(_) => println!("API key refreshed."),
195 Err(e) => {
196 println!("refresh_api_key failed with {:?}", e);
197 return Err(e);
198 }
199 }
200 }
201
202 Ok(client)
203 }
204
205 pub fn with_api_key(api_key: impl Into<String>, access_token: impl Into<String>) -> Self {
207 Self {
208 http_client: reqwest::Client::new(),
209 api_key: api_key.into(),
210 access_token: access_token.into(),
211 editor_version: "vscode/1.85.1".to_string(),
212 integration_id: "vscode-chat".to_string(),
213 }
214 }
215
216 fn get_token_dir() -> PathBuf {
217 std::env::var("GITHUB_COPILOT_TOKEN_DIR")
218 .map(PathBuf::from)
219 .unwrap_or_else(|_| {
220 dirs::config_dir()
221 .unwrap_or_else(|| PathBuf::from("."))
222 .join("llmg/github_copilot")
223 })
224 }
225
226 async fn load_cached_access_token(token_dir: &std::path::Path) -> Result<String, LlmError> {
227 let access_token_path = token_dir.join("access-token");
228
229 if let Ok(token) = std::fs::read_to_string(&access_token_path) {
230 let token = token.trim();
231 if !token.is_empty() {
232 return Ok(token.to_string());
233 }
234 }
235
236 Self::perform_oauth_flow(token_dir).await
237 }
238
239 async fn load_cached_api_key(token_dir: &std::path::Path) -> Result<String, LlmError> {
240 let api_key_path = token_dir.join("api-key.json");
241
242 if let Ok(content) = std::fs::read_to_string(&api_key_path) {
243 if let Ok(api_key_info) = serde_json::from_str::<CopilotApiKeyResponse>(&content) {
244 let now = std::time::SystemTime::now()
245 .duration_since(std::time::UNIX_EPOCH)
246 .unwrap()
247 .as_secs() as i64;
248
249 if api_key_info.expires_at > now {
250 return Ok(api_key_info.token);
251 }
252 }
253 }
254
255 Err(LlmError::AuthError)
256 }
257
258 async fn perform_oauth_flow(token_dir: &std::path::Path) -> Result<String, LlmError> {
259 let device_code_resp = Self::get_device_code().await?;
260
261 eprintln!("\nš GitHub Copilot Authentication Required");
262 eprintln!("Please visit: {}", device_code_resp.verification_uri);
263 eprintln!("And enter code: {}\n", device_code_resp.user_code);
264
265 let access_token = Self::poll_for_access_token(
266 &device_code_resp.device_code,
267 device_code_resp.interval as u64,
268 )
269 .await?;
270
271 let access_token_path = token_dir.join("access-token");
272 std::fs::write(&access_token_path, &access_token)
273 .map_err(|e| LlmError::ProviderError(format!("Failed to cache access token: {}", e)))?;
274
275 Ok(access_token)
276 }
277
278 async fn get_device_code() -> Result<DeviceCodeResponse, LlmError> {
279 let client = reqwest::Client::new();
280
281 let resp = client
282 .post(GITHUB_DEVICE_CODE_URL)
283 .header("Accept", "application/json")
284 .header("User-Agent", "GithubCopilot/1.155.0")
285 .json(&serde_json::json!({
286 "client_id": GITHUB_CLIENT_ID,
287 "scope": "read:user"
288 }))
289 .send()
290 .await
291 .map_err(|e| LlmError::HttpError(format!("Failed to get device code: {}", e)))?;
292
293 if !resp.status().is_success() {
294 return Err(LlmError::ApiError {
295 status: resp.status().as_u16(),
296 message: resp.text().await.unwrap_or_default(),
297 });
298 }
299
300 resp.json::<DeviceCodeResponse>()
301 .await
302 .map_err(|e| LlmError::HttpError(e.to_string()))
303 }
304
305 async fn poll_for_access_token(device_code: &str, interval: u64) -> Result<String, LlmError> {
306 let client = reqwest::Client::new();
307 let max_attempts = 60;
308
309 for attempt in 0..max_attempts {
310 tokio::time::sleep(tokio::time::Duration::from_secs(interval)).await;
311
312 let resp = client
313 .post(GITHUB_ACCESS_TOKEN_URL)
314 .header("Accept", "application/json")
315 .header("User-Agent", "GithubCopilot/1.155.0")
316 .json(&serde_json::json!({
317 "client_id": GITHUB_CLIENT_ID,
318 "device_code": device_code,
319 "grant_type": "urn:ietf:params:oauth:grant-type:device_code"
320 }))
321 .send()
322 .await
323 .map_err(|e| LlmError::HttpError(format!("Failed to poll for token: {}", e)))?;
324
325 if !resp.status().is_success() {
326 continue;
327 }
328
329 let token_resp = resp
330 .json::<AccessTokenResponse>()
331 .await
332 .map_err(|e| LlmError::HttpError(e.to_string()))?;
333
334 if let Some(token) = token_resp.access_token {
335 eprintln!("ā
Authentication successful!");
336 return Ok(token);
337 }
338
339 if let Some(error) = token_resp.error {
340 if error != "authorization_pending" {
341 return Err(LlmError::AuthError);
342 }
343 }
344
345 if attempt % 6 == 0 {
346 eprintln!(
347 "ā³ Waiting for authorization... (attempt {}/{})",
348 attempt + 1,
349 max_attempts
350 );
351 }
352 }
353
354 Err(LlmError::AuthError)
355 }
356
357 async fn refresh_api_key(&mut self) -> Result<(), LlmError> {
358 let client = reqwest::Client::new();
359
360 let resp = client
361 .get(GITHUB_COPILOT_TOKEN_URL)
362 .header("Authorization", format!("token {}", self.access_token))
363 .header("Accept", "application/json")
364 .header("User-Agent", "GithubCopilot/1.155.0")
365 .send()
366 .await
367 .map_err(|e| LlmError::HttpError(format!("Failed to refresh API key: {}", e)))?;
368
369 if !resp.status().is_success() {
370 if resp.status().as_u16() == 401 {
371 let token_dir = Self::get_token_dir();
372 self.access_token = Self::perform_oauth_flow(&token_dir).await?;
373 return Box::pin(self.refresh_api_key()).await;
374 }
375
376 return Err(LlmError::ApiError {
377 status: resp.status().as_u16(),
378 message: resp.text().await.unwrap_or_default(),
379 });
380 }
381
382 let api_key_info = resp
383 .json::<CopilotApiKeyResponse>()
384 .await
385 .map_err(|e| LlmError::HttpError(e.to_string()))?;
386
387 self.api_key = api_key_info.token.clone();
388
389 let token_dir = Self::get_token_dir();
390 let api_key_path = token_dir.join("api-key.json");
391 std::fs::write(&api_key_path, serde_json::to_string(&api_key_info).unwrap())
392 .map_err(|e| LlmError::ProviderError(format!("Failed to cache API key: {}", e)))?;
393
394 Ok(())
395 }
396
397 pub fn with_editor_version(mut self, version: impl Into<String>) -> Self {
398 self.editor_version = version.into();
399 self
400 }
401
402 pub fn with_integration_id(mut self, id: impl Into<String>) -> Self {
403 self.integration_id = id.into();
404 self
405 }
406
407 fn convert_request(&self, request: ChatCompletionRequest) -> CopilotChatRequest {
408 let messages: Vec<CopilotMessage> = request
409 .messages
410 .into_iter()
411 .map(|msg| match msg {
412 Message::System { content, .. } => CopilotMessage {
413 role: "system".to_string(),
414 content,
415 tool_calls: None,
416 tool_call_id: None,
417 },
418 Message::User { content, .. } => CopilotMessage {
419 role: "user".to_string(),
420 content,
421 tool_calls: None,
422 tool_call_id: None,
423 },
424 Message::Assistant {
425 content,
426 tool_calls,
427 ..
428 } => CopilotMessage {
429 role: "assistant".to_string(),
430 content: content.unwrap_or_default(),
431 tool_calls,
432 tool_call_id: None,
433 },
434 Message::Tool {
435 content,
436 tool_call_id,
437 } => CopilotMessage {
438 role: "tool".to_string(),
439 content,
440 tool_calls: None,
441 tool_call_id: Some(tool_call_id),
442 },
443 })
444 .collect();
445
446 CopilotChatRequest {
447 messages,
448 model: request.model,
449 temperature: request.temperature,
450 top_p: request.top_p,
451 stream: request.stream,
452 stop: request.stop,
453 max_tokens: request.max_tokens,
454 tools: request.tools,
455 tool_choice: request.tool_choice,
456 }
457 }
458
459 fn convert_response(&self, response: CopilotChatResponse) -> ChatCompletionResponse {
460 ChatCompletionResponse {
461 id: response.id,
462 object: response.object,
463 created: response.created,
464 model: response.model,
465 choices: response
466 .choices
467 .into_iter()
468 .map(|c| Choice {
469 index: c.index as u32,
470 message: Message::Assistant {
471 content: Some(c.message.content),
472 refusal: None,
473 tool_calls: None,
474 },
475 finish_reason: c.finish_reason,
476 })
477 .collect(),
478 usage: response.usage.map(|u| Usage {
479 prompt_tokens: u.prompt_tokens,
480 completion_tokens: u.completion_tokens,
481 total_tokens: u.total_tokens,
482 }),
483 }
484 }
485
486 async fn make_request(
487 &mut self,
488 request: ChatCompletionRequest,
489 ) -> Result<ChatCompletionResponse, LlmError> {
490 if self.api_key.is_empty() {
491 self.refresh_api_key().await?;
492 }
493
494 let url = format!("{}/chat/completions", GITHUB_COPILOT_API_BASE);
495 let copilot_req = self.convert_request(request.clone());
496
497 let initiator = if request
498 .messages
499 .iter()
500 .any(|m| matches!(m, Message::Assistant { .. } | Message::Tool { .. }))
501 {
502 "agent"
503 } else {
504 "user"
505 };
506
507 let request_id = uuid::Uuid::new_v4().to_string();
508 let resp = self
509 .http_client
510 .post(&url)
511 .header("Authorization", format!("Bearer {}", self.api_key))
512 .header("Content-Type", "application/json")
513 .header("Accept", "application/json")
514 .header("editor-version", "vscode/1.95.0")
515 .header("editor-plugin-version", "copilot-chat/0.26.7")
516 .header("Copilot-Integration-Id", "vscode-chat")
517 .header("User-Agent", "GitHubCopilotChat/0.26.7")
518 .header("openai-intent", "conversation-panel")
519 .header("x-github-api-version", "2025-04-01")
520 .header("x-request-id", &request_id)
521 .header("x-vscode-user-agent-library-version", "electron-fetch")
522 .header("X-Initiator", initiator)
523 .json(&copilot_req)
524 .send()
525 .await
526 .map_err(|e| LlmError::HttpError(e.to_string()))?;
527
528 if resp.status().as_u16() == 401 {
529 self.refresh_api_key().await?;
530 return Box::pin(async move { self.make_request(request).await }).await;
531 }
532
533 if !resp.status().is_success() {
534 let status = resp.status().as_u16();
535 let text = resp.text().await.unwrap_or_default();
536
537 if status == 429 {
538 return Err(LlmError::RateLimitError);
539 }
540
541 return Err(LlmError::ApiError {
542 status,
543 message: text,
544 });
545 }
546
547 let text = resp
548 .text()
549 .await
550 .map_err(|e| LlmError::HttpError(e.to_string()))?;
551 let copilot_resp: CopilotChatResponse = serde_json::from_str(&text)
552 .map_err(|e| LlmError::HttpError(format!("error decoding response body: {}", e)))?;
553
554 Ok(self.convert_response(copilot_resp))
555 }
556
557 async fn make_stream_request(
558 &mut self,
559 request: ChatCompletionRequest,
560 ) -> Result<ChatCompletionStream, LlmError> {
561 if self.api_key.is_empty() {
562 self.refresh_api_key().await?;
563 }
564
565 let url = format!("{}/chat/completions", GITHUB_COPILOT_API_BASE);
566 let mut copilot_req = self.convert_request(request.clone());
567 copilot_req.stream = Some(true);
568
569 let initiator = if request
570 .messages
571 .iter()
572 .any(|m| matches!(m, Message::Assistant { .. } | Message::Tool { .. }))
573 {
574 "agent"
575 } else {
576 "user"
577 };
578
579 let request_id = uuid::Uuid::new_v4().to_string();
580 let resp = self
581 .http_client
582 .post(&url)
583 .header("Authorization", format!("Bearer {}", self.api_key))
584 .header("Content-Type", "application/json")
585 .header("Accept", "application/json")
586 .header("editor-version", "vscode/1.95.0")
587 .header("editor-plugin-version", "copilot-chat/0.26.7")
588 .header("Copilot-Integration-Id", "vscode-chat")
589 .header("User-Agent", "GitHubCopilotChat/0.26.7")
590 .header("openai-intent", "conversation-panel")
591 .header("x-github-api-version", "2025-04-01")
592 .header("x-request-id", &request_id)
593 .header("x-vscode-user-agent-library-version", "electron-fetch")
594 .header("X-Initiator", initiator)
595 .json(&copilot_req)
596 .send()
597 .await
598 .map_err(|e| LlmError::HttpError(e.to_string()))?;
599
600 if resp.status().as_u16() == 401 {
601 self.refresh_api_key().await?;
602 return Box::pin(async move { self.make_stream_request(request).await }).await;
603 }
604
605 if !resp.status().is_success() {
606 let status = resp.status().as_u16();
607 let text = resp.text().await.unwrap_or_default();
608
609 if status == 429 {
610 return Err(LlmError::RateLimitError);
611 }
612
613 return Err(LlmError::ApiError {
614 status,
615 message: text,
616 });
617 }
618
619 let chunk_id = ChatCompletionChunk::generate_id();
620 let model = copilot_req.model.clone();
621
622 let stream = resp
623 .bytes_stream()
624 .eventsource()
625 .map_err(|e| LlmError::HttpError(e.to_string()))
626 .then(move |event_result| {
627 let chunk_id = chunk_id.clone();
628 let model = model.clone();
629 async move {
630 match event_result {
631 Ok(event) => parse_copilot_sse_data(&event.data, &chunk_id, &model),
632 Err(e) => Err(LlmError::HttpError(e.to_string())),
633 }
634 }
635 })
636 .try_filter_map(|chunk| async move { Ok(chunk) });
637
638 Ok(Box::pin(stream) as ChatCompletionStream)
639 }
640
641 async fn make_embedding_request(
642 &mut self,
643 request: EmbeddingRequest,
644 ) -> Result<EmbeddingResponse, LlmError> {
645 if self.api_key.is_empty() {
646 self.refresh_api_key().await?;
647 }
648
649 let url = format!("{}/embeddings", GITHUB_COPILOT_API_BASE);
650 let copilot_req = CopilotEmbeddingRequest {
651 model: request.model.clone(),
652 input: vec![request.input.clone()],
653 };
654
655 let request_id = uuid::Uuid::new_v4().to_string();
656 let resp = self
657 .http_client
658 .post(&url)
659 .header("Authorization", format!("Bearer {}", self.api_key))
660 .header("Content-Type", "application/json")
661 .header("Accept", "application/json")
662 .header("editor-version", "vscode/1.95.0")
663 .header("editor-plugin-version", "copilot-chat/0.26.7")
664 .header("Copilot-Integration-Id", "vscode-chat")
665 .header("User-Agent", "GitHubCopilotChat/0.26.7")
666 .header("openai-intent", "conversation-panel")
667 .header("x-github-api-version", "2025-04-01")
668 .header("x-request-id", &request_id)
669 .header("x-vscode-user-agent-library-version", "electron-fetch")
670 .header("X-Initiator", "user")
671 .json(&copilot_req)
672 .send()
673 .await
674 .map_err(|e| LlmError::HttpError(e.to_string()))?;
675
676 if resp.status().as_u16() == 401 {
677 self.refresh_api_key().await?;
678 return Box::pin(async move { self.make_embedding_request(request).await }).await;
679 }
680
681 if !resp.status().is_success() {
682 let status = resp.status().as_u16();
683 let text = resp.text().await.unwrap_or_default();
684
685 if status == 429 {
686 return Err(LlmError::RateLimitError);
687 }
688
689 return Err(LlmError::ApiError {
690 status,
691 message: text,
692 });
693 }
694
695 let text = resp
696 .text()
697 .await
698 .map_err(|e| LlmError::HttpError(e.to_string()))?;
699 let copilot_resp: CopilotEmbeddingResponse = serde_json::from_str(&text)
700 .map_err(|e| LlmError::HttpError(format!("error decoding response body: {}", e)))?;
701
702 Ok(EmbeddingResponse {
703 id: uuid::Uuid::new_v4().to_string(),
704 object: if copilot_resp.object.is_empty() {
705 "list".to_string()
706 } else {
707 copilot_resp.object
708 },
709 data: copilot_resp
710 .data
711 .into_iter()
712 .map(|d| llmg_core::types::Embedding {
713 index: d.index,
714 object: d.object,
715 embedding: d.embedding,
716 })
717 .collect(),
718 model: copilot_resp.model,
719 usage: Usage {
720 prompt_tokens: copilot_resp.usage.prompt_tokens,
721 completion_tokens: copilot_resp.usage.completion_tokens,
722 total_tokens: copilot_resp.usage.total_tokens,
723 },
724 })
725 }
726
727 pub fn get_models() -> Vec<String> {
728 vec![
729 "gpt-4".to_string(),
730 "gpt-4o".to_string(),
731 "gpt-4o-mini".to_string(),
732 "gpt-3.5-turbo".to_string(),
733 "o1-preview".to_string(),
734 "o1-mini".to_string(),
735 "claude-3-5-sonnet".to_string(),
736 "text-embedding-3-small".to_string(),
737 ]
738 }
739}
740
741#[async_trait::async_trait]
742impl Provider for GitHubCopilotClient {
743 async fn chat_completion(
744 &self,
745 request: ChatCompletionRequest,
746 ) -> Result<ChatCompletionResponse, LlmError> {
747 let mut client = self.clone();
748 client.make_request(request).await
749 }
750
751 fn chat_completion_stream(
752 &self,
753 request: ChatCompletionRequest,
754 ) -> Pin<Box<dyn Future<Output = Result<ChatCompletionStream, LlmError>> + Send + '_>> {
755 let mut client = self.clone();
756 Box::pin(async move { client.make_stream_request(request).await })
757 }
758
759 async fn embeddings(&self, request: EmbeddingRequest) -> Result<EmbeddingResponse, LlmError> {
760 let mut client = self.clone();
761 client.make_embedding_request(request).await
762 }
763 fn provider_name(&self) -> &'static str {
764 "github_copilot"
765 }
766}
767
768fn parse_copilot_sse_data(
769 data: &str,
770 chunk_id: &str,
771 model: &str,
772) -> Result<Option<ChatCompletionChunk>, LlmError> {
773 let data = data.trim();
774 if data.is_empty() || data == "[DONE]" {
775 return Ok(None);
776 }
777
778 let parsed: serde_json::Value =
779 serde_json::from_str(data).map_err(LlmError::SerializationError)?;
780
781 let choices = parsed
782 .get("choices")
783 .and_then(|c| c.as_array())
784 .map(|arr| {
785 arr.iter()
786 .filter_map(|choice| {
787 let index = choice.get("index").and_then(|i| i.as_u64()).unwrap_or(0) as u32;
788 let delta = choice.get("delta")?;
789 let finish_reason = choice
790 .get("finish_reason")
791 .and_then(|f| f.as_str())
792 .map(|s| s.to_string());
793
794 let role = delta
795 .get("role")
796 .and_then(|r| r.as_str())
797 .map(|s| s.to_string());
798 let content = delta
799 .get("content")
800 .and_then(|c| c.as_str())
801 .map(|s| s.to_string());
802 let tool_calls = delta
803 .get("tool_calls")
804 .and_then(|t| serde_json::from_value(t.clone()).ok());
805
806 Some(ChoiceDelta {
807 index,
808 delta: DeltaContent {
809 role,
810 content,
811 tool_calls,
812 },
813 finish_reason,
814 })
815 })
816 .collect::<Vec<_>>()
817 })
818 .unwrap_or_default();
819
820 if choices.is_empty() {
821 return Ok(None);
822 }
823
824 Ok(Some(ChatCompletionChunk {
825 id: chunk_id.to_string(),
826 object: "chat.completion.chunk".to_string(),
827 created: chrono::Utc::now().timestamp(),
828 model: model.to_string(),
829 choices,
830 usage: None,
831 }))
832}
833
834#[cfg(test)]
835mod tests {
836 use super::*;
837
838 #[test]
839 fn test_copilot_client_with_api_key() {
840 let client = GitHubCopilotClient::with_api_key("test-api-key", "test-access-token");
841 assert_eq!(client.provider_name(), "github_copilot");
842 }
843
844 #[test]
845 fn test_request_conversion() {
846 let client = GitHubCopilotClient::with_api_key("test-key", "test-token");
847
848 let request = ChatCompletionRequest {
849 model: "gpt-4".to_string(),
850 messages: vec![
851 Message::System {
852 content: "You are a helpful coding assistant".to_string(),
853 name: None,
854 },
855 Message::User {
856 content: "Write a Python function".to_string(),
857 name: None,
858 },
859 ],
860 temperature: Some(0.7),
861 max_tokens: Some(1000),
862 stream: None,
863 top_p: None,
864 frequency_penalty: None,
865 presence_penalty: None,
866 stop: None,
867 user: None,
868 tools: None,
869 tool_choice: None,
870 response_format: None,
871 };
872
873 let copilot_req = client.convert_request(request);
874
875 assert_eq!(copilot_req.model, "gpt-4");
876 assert_eq!(copilot_req.messages.len(), 2);
877 assert_eq!(copilot_req.messages[0].role, "system");
878 assert_eq!(copilot_req.messages[1].role, "user");
879 }
880
881 #[test]
882 fn test_tool_calling_conversion() {
883 let client = GitHubCopilotClient::with_api_key("test-key", "test-token");
884
885 let tool = llmg_core::types::Tool {
886 r#type: "function".to_string(),
887 function: llmg_core::types::FunctionDefinition {
888 name: "get_weather".to_string(),
889 description: Some("Get the weather".to_string()),
890 parameters: serde_json::json!({"type": "object", "properties": {"location": {"type": "string"}}}),
891 },
892 };
893
894 let request = ChatCompletionRequest {
895 model: "gpt-4".to_string(),
896 messages: vec![Message::User {
897 content: "Weather?".to_string(),
898 name: None,
899 }],
900 temperature: None,
901 max_tokens: None,
902 stream: None,
903 top_p: None,
904 frequency_penalty: None,
905 presence_penalty: None,
906 stop: None,
907 user: None,
908 tools: Some(vec![tool]),
909 tool_choice: Some(llmg_core::types::ToolChoice::String("auto".to_string())),
910 response_format: None,
911 };
912
913 let copilot_req = client.convert_request(request);
914
915 assert!(copilot_req.tools.is_some());
916 assert_eq!(copilot_req.tools.unwrap().len(), 1);
917 assert!(copilot_req.tool_choice.is_some());
918 }
919
920 #[test]
921 fn test_parse_copilot_sse_data_tool_calls() {
922 let raw_sse = r#"{"id":"chatcmpl-123","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"id":"call_abc","type":"function","function":{"name":"get_weather","arguments":"{\"location\":\"Boston\"}"}}]},"finish_reason":null}]}"#;
923 let chunk = parse_copilot_sse_data(raw_sse, "chatcmpl-123", "gpt-4")
924 .unwrap()
925 .unwrap();
926
927 assert_eq!(chunk.choices.len(), 1);
928 let choice = &chunk.choices[0];
929 assert!(choice.delta.tool_calls.is_some());
930
931 let tool_calls = choice.delta.tool_calls.as_ref().unwrap();
932 assert_eq!(tool_calls.len(), 1);
933 assert_eq!(tool_calls[0].id.as_deref(), Some("call_abc"));
934 assert_eq!(
935 tool_calls[0]
936 .function
937 .as_ref()
938 .and_then(|f| f.name.as_deref()),
939 Some("get_weather")
940 );
941 }
942}
943