1use reqwest::{Client as ReqwestClient, Response, StatusCode};
73use serde::de::DeserializeOwned;
74use std::time::Duration;
75use tracing::{debug, error, warn};
76use uuid::Uuid;
77
78use super::error::{Result, VsCodeError};
79use super::token::TokenManager;
80use super::types::*;
81
82const COPILOT_API_VERSION: &str = "2025-04-01";
84#[allow(dead_code)]
85const COPILOT_VERSION: &str = "0.26.7";
86const EDITOR_VERSION: &str = "vscode/1.95.0";
87const EDITOR_PLUGIN_VERSION: &str = "copilot-chat/0.26.7";
88const USER_AGENT: &str = "GitHubCopilotChat/0.26.7";
89const MAX_RETRIES: u32 = 3; const INITIAL_RETRY_DELAY_MS: u64 = 1000; #[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
94pub enum AccountType {
95 #[default]
97 Individual,
98 Business,
100 Enterprise,
102}
103
104impl AccountType {
105 pub fn base_url(&self) -> &'static str {
107 match self {
108 AccountType::Individual => "https://api.githubcopilot.com",
109 AccountType::Business => "https://api.business.githubcopilot.com",
110 AccountType::Enterprise => "https://api.enterprise.githubcopilot.com",
111 }
112 }
113
114 #[allow(clippy::should_implement_trait)]
116 pub fn from_str(s: &str) -> Option<Self> {
117 match s.to_lowercase().as_str() {
118 "individual" => Some(AccountType::Individual),
119 "business" => Some(AccountType::Business),
120 "enterprise" => Some(AccountType::Enterprise),
121 _ => None,
122 }
123 }
124}
125
126#[derive(Clone)]
130pub struct VsCodeCopilotClient {
131 client: ReqwestClient,
132 base_url: String,
133 token_manager: TokenManager,
134 direct_mode: bool,
136 #[allow(dead_code)]
138 account_type: AccountType,
139 vision_enabled: bool,
141}
142
143impl VsCodeCopilotClient {
144 pub fn new(timeout: Duration) -> Result<Self> {
148 Self::new_with_options(timeout, true, AccountType::Individual)
149 }
150
151 pub fn with_base_url(base_url: impl Into<String>, timeout: Duration) -> Result<Self> {
155 let base_url = base_url.into();
156 let is_direct = base_url.contains("githubcopilot.com");
157
158 let client = ReqwestClient::builder()
159 .timeout(timeout)
160 .pool_max_idle_per_host(10)
161 .pool_idle_timeout(Duration::from_secs(90))
162 .build()
163 .map_err(|e| VsCodeError::ClientInit(e.to_string()))?;
164
165 let token_manager =
166 TokenManager::new().map_err(|e| VsCodeError::ClientInit(e.to_string()))?;
167
168 debug!(
169 base_url = %base_url,
170 timeout_secs = timeout.as_secs(),
171 direct_mode = is_direct,
172 "VSCode Copilot client initialized"
173 );
174
175 Ok(Self {
176 client,
177 base_url,
178 token_manager,
179 direct_mode: is_direct,
180 account_type: AccountType::Individual,
181 vision_enabled: false,
182 })
183 }
184
185 pub fn new_with_options(
193 timeout: Duration,
194 direct_mode: bool,
195 account_type: AccountType,
196 ) -> Result<Self> {
197 let base_url = if direct_mode {
198 account_type.base_url().to_string()
199 } else {
200 std::env::var("VSCODE_COPILOT_PROXY_URL")
201 .unwrap_or_else(|_| "http://localhost:4141".to_string())
202 };
203
204 let client = ReqwestClient::builder()
205 .timeout(timeout)
206 .pool_max_idle_per_host(10)
207 .pool_idle_timeout(Duration::from_secs(90))
208 .build()
209 .map_err(|e| VsCodeError::ClientInit(e.to_string()))?;
210
211 let token_manager =
212 TokenManager::new().map_err(|e| VsCodeError::ClientInit(e.to_string()))?;
213
214 debug!(
215 base_url = %base_url,
216 timeout_secs = timeout.as_secs(),
217 direct_mode = direct_mode,
218 account_type = ?account_type,
219 "VSCode Copilot client initialized"
220 );
221
222 Ok(Self {
223 client,
224 base_url,
225 token_manager,
226 direct_mode,
227 account_type,
228 vision_enabled: false,
229 })
230 }
231
232 pub fn with_vision(mut self, enabled: bool) -> Self {
234 self.vision_enabled = enabled;
235 self
236 }
237
238 async fn get_token(&self) -> Result<String> {
240 self.token_manager
241 .get_valid_copilot_token()
242 .await
243 .map_err(|e| VsCodeError::Authentication(e.to_string()))
244 }
245
246 async fn build_headers(&self) -> Result<reqwest::header::HeaderMap> {
251 let token = self.get_token().await?;
252
253 let mut headers = reqwest::header::HeaderMap::new();
254
255 headers.insert(
257 reqwest::header::AUTHORIZATION,
258 format!("Bearer {}", token).parse().unwrap(),
259 );
260
261 headers.insert(
263 reqwest::header::CONTENT_TYPE,
264 "application/json".parse().unwrap(),
265 );
266
267 headers.insert(reqwest::header::USER_AGENT, USER_AGENT.parse().unwrap());
269
270 headers.insert("editor-version", EDITOR_VERSION.parse().unwrap());
272 headers.insert(
273 "editor-plugin-version",
274 EDITOR_PLUGIN_VERSION.parse().unwrap(),
275 );
276
277 if self.direct_mode {
279 headers.insert("copilot-integration-id", "vscode-chat".parse().unwrap());
281 headers.insert("openai-intent", "conversation-panel".parse().unwrap());
282
283 headers.insert("x-github-api-version", COPILOT_API_VERSION.parse().unwrap());
285
286 headers.insert("x-request-id", Uuid::new_v4().to_string().parse().unwrap());
288
289 headers.insert(
291 "x-vscode-user-agent-library-version",
292 "electron-fetch".parse().unwrap(),
293 );
294
295 if self.vision_enabled {
297 headers.insert("copilot-vision-request", "true".parse().unwrap());
298 }
299 }
300
301 Ok(headers)
302 }
303
304 async fn retry_with_backoff<F, Fut, T>(&self, operation: F, operation_name: &str) -> Result<T>
314 where
315 F: Fn() -> Fut,
316 Fut: std::future::Future<Output = Result<T>>,
317 {
318 let mut last_error = None;
319
320 for attempt in 0..=MAX_RETRIES {
321 match operation().await {
322 Ok(result) => return Ok(result),
323 Err(e) => {
324 if !e.is_retryable() || attempt == MAX_RETRIES {
326 return Err(e);
327 }
328
329 let delay = Duration::from_millis(INITIAL_RETRY_DELAY_MS * 2_u64.pow(attempt));
331
332 warn!(
333 operation = operation_name,
334 attempt = attempt + 1,
335 max_retries = MAX_RETRIES,
336 delay_ms = delay.as_millis(),
337 error = %e,
338 "Retrying after retryable error"
339 );
340
341 tokio::time::sleep(delay).await;
342 last_error = Some(e);
343 }
344 }
345 }
346
347 Err(last_error
348 .unwrap_or_else(|| VsCodeError::ApiError("Operation failed after retries".to_string())))
349 }
350
351 fn is_agent_call(messages: &[RequestMessage]) -> bool {
378 messages
379 .iter()
380 .any(|m| matches!(m.role.as_str(), "assistant" | "tool"))
381 }
382
383 pub async fn chat_completion(
385 &self,
386 request: ChatCompletionRequest,
387 ) -> Result<ChatCompletionResponse> {
388 let request_clone = request.clone();
389
390 self.retry_with_backoff(
392 || async {
393 let url = if self.direct_mode {
395 format!("{}/chat/completions", self.base_url)
396 } else {
397 format!("{}/v1/chat/completions", self.base_url)
398 };
399 let mut headers = self.build_headers().await?;
400
401 if self.direct_mode {
403 let initiator = if Self::is_agent_call(&request_clone.messages) {
404 "agent"
405 } else {
406 "user"
407 };
408 headers.insert("X-Initiator", initiator.parse().unwrap());
409 }
410
411 debug!(
412 url = %url,
413 model = %request_clone.model,
414 message_count = request_clone.messages.len(),
415 direct_mode = self.direct_mode,
416 "Sending chat completion request"
417 );
418
419 let response = self
420 .client
421 .post(&url)
422 .headers(headers)
423 .json(&request_clone)
424 .send()
425 .await
426 .map_err(|e| VsCodeError::Network(e.to_string()))?;
427
428 let mut response: ChatCompletionResponse = Self::handle_response(response).await?;
429
430 response = Self::normalize_choices(response);
432
433 Ok(response)
434 },
435 "chat_completion",
436 )
437 .await
438 }
439
440 pub async fn chat_completion_stream(&self, request: ChatCompletionRequest) -> Result<Response> {
442 let url = if self.direct_mode {
444 format!("{}/chat/completions", self.base_url)
445 } else {
446 format!("{}/v1/chat/completions", self.base_url)
447 };
448 let mut headers = self.build_headers().await?;
449
450 if self.direct_mode {
452 let initiator = if Self::is_agent_call(&request.messages) {
453 "agent"
454 } else {
455 "user"
456 };
457 headers.insert("X-Initiator", initiator.parse().unwrap());
458 }
459
460 debug!(
461 url = %url,
462 model = %request.model,
463 message_count = request.messages.len(),
464 "Sending streaming chat completion request"
465 );
466
467 let response = self
468 .client
469 .post(&url)
470 .headers(headers)
471 .json(&request)
472 .send()
473 .await
474 .map_err(|e| VsCodeError::Network(e.to_string()))?;
475
476 if response.status().is_success() {
477 Ok(response)
478 } else {
479 let status = response.status();
480 let error_body = response
481 .text()
482 .await
483 .unwrap_or_else(|_| "Unknown error".to_string());
484
485 warn!(
486 status = %status,
487 error = %error_body,
488 "Streaming request failed"
489 );
490
491 Err(Self::map_error_status(status, error_body))
492 }
493 }
494
495 pub async fn list_models(&self) -> Result<ModelsResponse> {
500 let url = if self.direct_mode {
502 format!("{}/models", self.base_url)
503 } else {
504 format!("{}/v1/models", self.base_url)
505 };
506 let headers = self.build_headers().await?;
507
508 debug!(
509 url = %url,
510 direct_mode = self.direct_mode,
511 "Fetching models list"
512 );
513
514 let response = self
515 .client
516 .get(&url)
517 .headers(headers)
518 .send()
519 .await
520 .map_err(|e| VsCodeError::Network(e.to_string()))?;
521
522 Self::handle_response(response).await
523 }
524
525 pub async fn create_embeddings(&self, request: EmbeddingRequest) -> Result<EmbeddingResponse> {
543 let url = if self.direct_mode {
545 format!("{}/embeddings", self.base_url)
546 } else {
547 format!("{}/v1/embeddings", self.base_url)
548 };
549 let headers = self.build_headers().await?;
550
551 debug!(
552 url = %url,
553 model = %request.model,
554 direct_mode = self.direct_mode,
555 "Sending embedding request"
556 );
557
558 let response = self
559 .client
560 .post(&url)
561 .headers(headers)
562 .json(&request)
563 .send()
564 .await
565 .map_err(|e| VsCodeError::Network(e.to_string()))?;
566
567 Self::handle_response(response).await
568 }
569
570 fn normalize_choices(mut response: ChatCompletionResponse) -> ChatCompletionResponse {
620 if response.choices.len() <= 1 {
622 return response;
623 }
624
625 let needs_merge = response
627 .choices
628 .iter()
629 .all(|c| c.index.is_none() || c.index == Some(0));
630
631 if !needs_merge {
632 return response;
633 }
634
635 debug!(
636 choice_count = response.choices.len(),
637 model = %response.model,
638 "OODA-07.2: Normalizing Anthropic-style split choices"
639 );
640
641 let mut choices_iter = response.choices.into_iter();
643
644 let mut merged = choices_iter.next().unwrap();
646
647 for choice in choices_iter {
649 if merged.message.content.is_none()
651 || merged
652 .message
653 .content
654 .as_ref()
655 .map(|s| s.is_empty())
656 .unwrap_or(true)
657 {
658 merged.message.content = choice.message.content;
659 }
660
661 if merged.message.tool_calls.is_none() {
663 merged.message.tool_calls = choice.message.tool_calls;
664 } else if let Some(mut existing) = merged.message.tool_calls.take() {
665 if let Some(new_calls) = choice.message.tool_calls {
666 existing.extend(new_calls);
667 }
668 merged.message.tool_calls = Some(existing);
669 }
670
671 if merged.finish_reason.is_none() {
673 merged.finish_reason = choice.finish_reason;
674 }
675 }
676
677 merged.index = Some(0);
679
680 response.choices = vec![merged];
681 response
682 }
683
684 async fn handle_response<T: DeserializeOwned>(response: Response) -> Result<T> {
686 let status = response.status();
687
688 if status.is_success() {
689 let body_text = response
691 .text()
692 .await
693 .map_err(|e| VsCodeError::Decode(format!("Failed to read response body: {}", e)))?;
694
695 debug!(
697 status = %status,
698 body_length = body_text.len(),
699 body_preview = &body_text[..body_text.len().min(500)],
700 "Raw API response"
701 );
702
703 serde_json::from_str(&body_text).map_err(|e| {
705 error!(
706 error = %e,
707 body = %body_text,
708 "Failed to deserialize response"
709 );
710 VsCodeError::Decode(format!(
711 "Deserialization failed: {} | Body: {}",
712 e, body_text
713 ))
714 })
715 } else {
716 let error_body = response
717 .text()
718 .await
719 .unwrap_or_else(|_| "Unknown error".to_string());
720
721 warn!(
722 status = %status,
723 error = %error_body,
724 "Request failed"
725 );
726
727 Err(Self::map_error_status(status, error_body))
728 }
729 }
730
731 fn map_error_status(status: StatusCode, body: String) -> VsCodeError {
733 match status {
734 StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => {
735 VsCodeError::Authentication(format!("Copilot authentication failed: {}", body))
736 }
737 StatusCode::TOO_MANY_REQUESTS => VsCodeError::RateLimited,
738 StatusCode::BAD_REQUEST => VsCodeError::InvalidRequest(body),
739 StatusCode::SERVICE_UNAVAILABLE => VsCodeError::ServiceUnavailable,
740 StatusCode::BAD_GATEWAY => {
741 VsCodeError::Network(format!("Upstream error (502): {}", body))
742 }
743 StatusCode::GATEWAY_TIMEOUT | StatusCode::REQUEST_TIMEOUT => {
744 VsCodeError::Network(format!("Timeout: {}", body))
745 }
746 _ => VsCodeError::ApiError(format!("HTTP {}: {}", status, body)),
747 }
748 }
749}
750
751#[cfg(test)]
752mod tests {
753 use super::*;
754
755 #[test]
760 fn test_account_type_base_url_individual() {
761 assert_eq!(
762 AccountType::Individual.base_url(),
763 "https://api.githubcopilot.com"
764 );
765 }
766
767 #[test]
768 fn test_account_type_base_url_business() {
769 assert_eq!(
770 AccountType::Business.base_url(),
771 "https://api.business.githubcopilot.com"
772 );
773 }
774
775 #[test]
776 fn test_account_type_base_url_enterprise() {
777 assert_eq!(
778 AccountType::Enterprise.base_url(),
779 "https://api.enterprise.githubcopilot.com"
780 );
781 }
782
783 #[test]
784 fn test_account_type_from_str_individual() {
785 assert_eq!(
786 AccountType::from_str("individual"),
787 Some(AccountType::Individual)
788 );
789 assert_eq!(
791 AccountType::from_str("INDIVIDUAL"),
792 Some(AccountType::Individual)
793 );
794 assert_eq!(
795 AccountType::from_str("Individual"),
796 Some(AccountType::Individual)
797 );
798 }
799
800 #[test]
801 fn test_account_type_from_str_business() {
802 assert_eq!(
803 AccountType::from_str("business"),
804 Some(AccountType::Business)
805 );
806 assert_eq!(
807 AccountType::from_str("BUSINESS"),
808 Some(AccountType::Business)
809 );
810 }
811
812 #[test]
813 fn test_account_type_from_str_enterprise() {
814 assert_eq!(
815 AccountType::from_str("enterprise"),
816 Some(AccountType::Enterprise)
817 );
818 assert_eq!(
819 AccountType::from_str("Enterprise"),
820 Some(AccountType::Enterprise)
821 );
822 }
823
824 #[test]
825 fn test_account_type_from_str_unknown_returns_none() {
826 assert_eq!(AccountType::from_str("unknown"), None);
827 assert_eq!(AccountType::from_str(""), None);
828 assert_eq!(AccountType::from_str("personal"), None);
829 assert_eq!(AccountType::from_str("team"), None);
830 }
831
832 #[test]
833 fn test_account_type_default_is_individual() {
834 let default: AccountType = Default::default();
835 assert_eq!(default, AccountType::Individual);
836 }
837
838 #[test]
843 fn test_map_error_status_unauthorized() {
844 let err = VsCodeCopilotClient::map_error_status(
845 StatusCode::UNAUTHORIZED,
846 "Invalid token".to_string(),
847 );
848 match err {
849 VsCodeError::Authentication(msg) => {
850 assert!(msg.contains("authentication failed"));
851 assert!(msg.contains("Invalid token"));
852 }
853 other => panic!("Expected Authentication error, got {:?}", other),
854 }
855 }
856
857 #[test]
858 fn test_map_error_status_forbidden() {
859 let err = VsCodeCopilotClient::map_error_status(
860 StatusCode::FORBIDDEN,
861 "Access denied".to_string(),
862 );
863 match err {
864 VsCodeError::Authentication(msg) => {
865 assert!(msg.contains("Access denied"));
866 }
867 other => panic!("Expected Authentication error, got {:?}", other),
868 }
869 }
870
871 #[test]
872 fn test_map_error_status_rate_limited() {
873 let err = VsCodeCopilotClient::map_error_status(
874 StatusCode::TOO_MANY_REQUESTS,
875 "Rate limit exceeded".to_string(),
876 );
877 assert!(matches!(err, VsCodeError::RateLimited));
878 }
879
880 #[test]
881 fn test_map_error_status_bad_request() {
882 let err = VsCodeCopilotClient::map_error_status(
883 StatusCode::BAD_REQUEST,
884 "Invalid JSON".to_string(),
885 );
886 match err {
887 VsCodeError::InvalidRequest(msg) => assert_eq!(msg, "Invalid JSON"),
888 other => panic!("Expected InvalidRequest error, got {:?}", other),
889 }
890 }
891
892 #[test]
893 fn test_map_error_status_service_unavailable() {
894 let err = VsCodeCopilotClient::map_error_status(
895 StatusCode::SERVICE_UNAVAILABLE,
896 "Maintenance".to_string(),
897 );
898 assert!(matches!(err, VsCodeError::ServiceUnavailable));
899 }
900
901 #[test]
902 fn test_map_error_status_timeout() {
903 let err = VsCodeCopilotClient::map_error_status(
904 StatusCode::GATEWAY_TIMEOUT,
905 "Upstream timeout".to_string(),
906 );
907 match err {
908 VsCodeError::Network(msg) => {
909 assert!(msg.contains("Timeout"));
910 assert!(msg.contains("Upstream timeout"));
911 }
912 other => panic!("Expected Network error, got {:?}", other),
913 }
914 }
915
916 #[test]
917 fn test_map_error_status_request_timeout() {
918 let err = VsCodeCopilotClient::map_error_status(
919 StatusCode::REQUEST_TIMEOUT,
920 "Request took too long".to_string(),
921 );
922 match err {
923 VsCodeError::Network(msg) => assert!(msg.contains("Timeout")),
924 other => panic!("Expected Network error, got {:?}", other),
925 }
926 }
927
928 #[test]
929 fn test_map_error_status_internal_server_error() {
930 let err = VsCodeCopilotClient::map_error_status(
931 StatusCode::INTERNAL_SERVER_ERROR,
932 "Something went wrong".to_string(),
933 );
934 match err {
935 VsCodeError::ApiError(msg) => {
936 assert!(msg.contains("500"));
937 assert!(msg.contains("Something went wrong"));
938 }
939 other => panic!("Expected ApiError, got {:?}", other),
940 }
941 }
942
943 #[test]
944 fn test_map_error_status_not_found() {
945 let err = VsCodeCopilotClient::map_error_status(
946 StatusCode::NOT_FOUND,
947 "Endpoint not found".to_string(),
948 );
949 match err {
950 VsCodeError::ApiError(msg) => {
951 assert!(msg.contains("404"));
952 assert!(msg.contains("not found"));
953 }
954 other => panic!("Expected ApiError, got {:?}", other),
955 }
956 }
957
958 #[test]
959 fn test_map_error_status_bad_gateway() {
960 let err = VsCodeCopilotClient::map_error_status(
962 StatusCode::BAD_GATEWAY,
963 "Upstream server error".to_string(),
964 );
965 match err {
966 VsCodeError::Network(msg) => {
967 assert!(msg.contains("Upstream") || msg.contains("502"));
968 }
969 other => panic!("Expected Network, got {:?}", other),
970 }
971 }
972
973 #[test]
978 fn test_header_constants_match_typescript() {
979 assert_eq!(COPILOT_API_VERSION, "2025-04-01");
981 assert_eq!(EDITOR_VERSION, "vscode/1.95.0");
982 assert!(EDITOR_PLUGIN_VERSION.contains("copilot"));
983 assert!(USER_AGENT.contains("Copilot"));
984 }
985
986 #[test]
987 fn test_api_version_format() {
988 assert!(COPILOT_API_VERSION.len() == 10);
990 assert!(COPILOT_API_VERSION.starts_with("2025"));
991 }
992
993 #[test]
994 fn test_editor_version_format() {
995 assert!(EDITOR_VERSION.starts_with("vscode/"));
996 }
997
998 #[test]
999 fn test_user_agent_contains_copilot() {
1000 assert!(USER_AGENT.contains("Copilot"));
1001 }
1002
1003 fn make_message(role: &str) -> RequestMessage {
1016 RequestMessage {
1017 role: role.to_string(),
1018 content: Some(RequestContent::Text("test".to_string())),
1019 name: None,
1020 tool_calls: None,
1021 tool_call_id: None,
1022 cache_control: None,
1023 }
1024 }
1025
1026 #[test]
1027 fn test_is_agent_call_empty_messages() {
1028 let messages: Vec<RequestMessage> = vec![];
1030 assert!(!VsCodeCopilotClient::is_agent_call(&messages));
1031 }
1032
1033 #[test]
1034 fn test_is_agent_call_user_only() {
1035 let messages = vec![make_message("user")];
1037 assert!(!VsCodeCopilotClient::is_agent_call(&messages));
1038 }
1039
1040 #[test]
1041 fn test_is_agent_call_system_and_user() {
1042 let messages = vec![make_message("system"), make_message("user")];
1044 assert!(!VsCodeCopilotClient::is_agent_call(&messages));
1045 }
1046
1047 #[test]
1048 fn test_is_agent_call_with_assistant() {
1049 let messages = vec![
1051 make_message("system"),
1052 make_message("user"),
1053 make_message("assistant"),
1054 make_message("user"),
1055 ];
1056 assert!(VsCodeCopilotClient::is_agent_call(&messages));
1057 }
1058
1059 #[test]
1060 fn test_is_agent_call_with_tool() {
1061 let messages = vec![
1063 make_message("system"),
1064 make_message("user"),
1065 make_message("assistant"),
1066 make_message("tool"),
1067 make_message("user"),
1068 ];
1069 assert!(VsCodeCopilotClient::is_agent_call(&messages));
1070 }
1071
1072 #[test]
1073 fn test_is_agent_call_assistant_only() {
1074 let messages = vec![make_message("assistant")];
1076 assert!(VsCodeCopilotClient::is_agent_call(&messages));
1077 }
1078
1079 #[test]
1080 fn test_is_agent_call_tool_only() {
1081 let messages = vec![make_message("tool")];
1083 assert!(VsCodeCopilotClient::is_agent_call(&messages));
1084 }
1085
1086 #[test]
1087 fn test_is_agent_call_developer_role() {
1088 let messages = vec![make_message("developer"), make_message("user")];
1090 assert!(!VsCodeCopilotClient::is_agent_call(&messages));
1091 }
1092
1093 #[test]
1106 fn test_client_vision_disabled_by_default() {
1107 use std::time::Duration;
1108
1109 let client = VsCodeCopilotClient::new(Duration::from_secs(30));
1111 assert!(client.is_ok(), "Client should be created successfully");
1112
1113 let client = client.unwrap().with_vision(false);
1116 let _ = client;
1118 }
1119
1120 #[test]
1121 fn test_client_with_vision_enables_mode() {
1122 use std::time::Duration;
1123
1124 let client = VsCodeCopilotClient::new(Duration::from_secs(30))
1125 .unwrap()
1126 .with_vision(true);
1127
1128 let _ = client;
1130 }
1131
1132 #[test]
1133 fn test_client_with_vision_chain() {
1134 use std::time::Duration;
1135
1136 let client = VsCodeCopilotClient::new(Duration::from_secs(30))
1138 .unwrap()
1139 .with_vision(true)
1140 .with_vision(false)
1141 .with_vision(true);
1142
1143 let _ = client;
1145 }
1146
1147 #[test]
1148 fn test_client_with_base_url_vision() {
1149 use std::time::Duration;
1150
1151 let client =
1152 VsCodeCopilotClient::with_base_url("http://localhost:4141", Duration::from_secs(30))
1153 .unwrap()
1154 .with_vision(true);
1155
1156 let _ = client;
1158 }
1159
1160 #[test]
1161 fn test_client_with_options_vision() {
1162 use std::time::Duration;
1163
1164 for account_type in [
1166 AccountType::Individual,
1167 AccountType::Business,
1168 AccountType::Enterprise,
1169 ] {
1170 let client = VsCodeCopilotClient::new_with_options(
1171 Duration::from_secs(30),
1172 true, account_type,
1174 )
1175 .unwrap()
1176 .with_vision(true);
1177
1178 let _ = client;
1180 }
1181 }
1182
1183 #[test]
1189 fn test_embedding_url_direct_mode() {
1190 use std::time::Duration;
1193
1194 let client =
1195 VsCodeCopilotClient::new(Duration::from_secs(30)).expect("Failed to create client");
1196
1197 assert!(client.direct_mode, "Default should be direct mode");
1199
1200 let base_url = &client.base_url;
1203 assert!(
1204 !base_url.ends_with("/v1"),
1205 "Direct mode base URL should not end with /v1"
1206 );
1207
1208 assert!(
1211 base_url.starts_with("https://api"),
1212 "Direct mode should use GitHub API: {}",
1213 base_url
1214 );
1215 }
1216
1217 #[test]
1218 fn test_embedding_url_proxy_mode() {
1219 use std::time::Duration;
1221
1222 let client =
1223 VsCodeCopilotClient::with_base_url("http://localhost:1337", Duration::from_secs(30))
1224 .expect("Failed to create proxy client");
1225
1226 assert!(!client.direct_mode, "Should be proxy mode");
1228
1229 assert_eq!(
1231 client.base_url, "http://localhost:1337",
1232 "Proxy base URL should be preserved"
1233 );
1234 }
1235
1236 #[test]
1237 fn test_embedding_single_input_format() {
1238 let input = EmbeddingInput::Single("Hello, world!".to_string());
1240 let request = EmbeddingRequest::new(input, "text-embedding-3-small");
1241
1242 let json = serde_json::to_value(&request).expect("Failed to serialize");
1243
1244 assert_eq!(
1245 json["input"],
1246 serde_json::json!("Hello, world!"),
1247 "Single input should serialize as string"
1248 );
1249 assert_eq!(json["model"], "text-embedding-3-small");
1250 }
1251
1252 #[test]
1253 fn test_embedding_multiple_inputs_format() {
1254 let inputs = vec![
1256 "First".to_string(),
1257 "Second".to_string(),
1258 "Third".to_string(),
1259 ];
1260 let input = EmbeddingInput::Multiple(inputs);
1261 let request = EmbeddingRequest::new(input, "text-embedding-3-small");
1262
1263 let json = serde_json::to_value(&request).expect("Failed to serialize");
1264
1265 assert_eq!(
1266 json["input"],
1267 serde_json::json!(["First", "Second", "Third"]),
1268 "Multiple inputs should serialize as array"
1269 );
1270 }
1271
1272 #[test]
1273 fn test_embedding_model_in_request() {
1274 let request = EmbeddingRequest::new("test", "text-embedding-ada-002");
1276
1277 let json = serde_json::to_value(&request).expect("Failed to serialize");
1278
1279 assert_eq!(
1280 json["model"], "text-embedding-ada-002",
1281 "Model name should be in request"
1282 );
1283
1284 let request2 = EmbeddingRequest::new("test", "text-embedding-3-large");
1286 let json2 = serde_json::to_value(&request2).expect("Failed to serialize");
1287
1288 assert_eq!(json2["model"], "text-embedding-3-large");
1289 }
1290
1291 #[test]
1297 fn test_list_models_url_direct_mode() {
1298 use std::time::Duration;
1301
1302 let client =
1303 VsCodeCopilotClient::new(Duration::from_secs(30)).expect("Failed to create client");
1304
1305 assert!(client.direct_mode, "Default should be direct mode");
1307
1308 let base_url = &client.base_url;
1311 assert!(
1312 !base_url.ends_with("/v1"),
1313 "Direct mode base URL should not end with /v1"
1314 );
1315
1316 assert!(
1318 base_url.starts_with("https://api"),
1319 "Direct mode should use GitHub API: {}",
1320 base_url
1321 );
1322 }
1323
1324 #[test]
1325 fn test_list_models_url_proxy_mode() {
1326 use std::time::Duration;
1328
1329 let client =
1330 VsCodeCopilotClient::with_base_url("http://localhost:1337", Duration::from_secs(30))
1331 .expect("Failed to create proxy client");
1332
1333 assert!(!client.direct_mode, "Should be proxy mode");
1335
1336 assert_eq!(
1338 client.base_url, "http://localhost:1337",
1339 "Proxy base URL should be preserved"
1340 );
1341 }
1342
1343 #[test]
1349 fn test_client_timeout_short() {
1350 use std::time::Duration;
1352
1353 let timeout = Duration::from_secs(5);
1354 let client = VsCodeCopilotClient::new(timeout);
1355
1356 assert!(client.is_ok(), "Client should accept short timeout");
1357 }
1358
1359 #[test]
1360 fn test_client_timeout_long() {
1361 use std::time::Duration;
1363
1364 let timeout = Duration::from_secs(300);
1365 let client = VsCodeCopilotClient::new(timeout);
1366
1367 assert!(client.is_ok(), "Client should accept long timeout");
1368 }
1369
1370 #[test]
1377 fn test_chat_url_direct_mode() {
1378 let base = AccountType::Individual.base_url();
1381 let url = format!("{}/chat/completions", base);
1382
1383 assert_eq!(
1384 url, "https://api.githubcopilot.com/chat/completions",
1385 "Individual account chat URL should use main Copilot API"
1386 );
1387 }
1388
1389 #[test]
1390 fn test_chat_url_business_mode() {
1391 let base = AccountType::Business.base_url();
1394 let url = format!("{}/chat/completions", base);
1395
1396 assert_eq!(
1397 url, "https://api.business.githubcopilot.com/chat/completions",
1398 "Business account chat URL should use business subdomain"
1399 );
1400 }
1401
1402 #[test]
1403 fn test_chat_url_enterprise_mode() {
1404 let base = AccountType::Enterprise.base_url();
1407 let url = format!("{}/chat/completions", base);
1408
1409 assert_eq!(
1410 url, "https://api.enterprise.githubcopilot.com/chat/completions",
1411 "Enterprise account chat URL should use enterprise subdomain"
1412 );
1413 }
1414
1415 #[test]
1416 fn test_chat_url_proxy_mode() {
1417 let proxy_url = "http://localhost:1337";
1420 let url = format!("{}/chat/completions", proxy_url);
1421
1422 assert_eq!(
1423 url, "http://localhost:1337/chat/completions",
1424 "Proxy mode chat URL should use configured proxy base"
1425 );
1426 }
1427
1428 #[test]
1434 fn test_normalize_choices_single_choice() {
1435 use crate::providers::vscode::types::*;
1437
1438 let response = ChatCompletionResponse {
1439 id: "test1".to_string(),
1440 object: None,
1441 created: None,
1442 model: "gpt-4.1".to_string(),
1443 choices: vec![Choice {
1444 index: Some(0),
1445 message: ResponseMessage {
1446 role: "assistant".to_string(),
1447 content: Some("Hello".to_string()),
1448 tool_calls: None,
1449 extra: None,
1450 },
1451 finish_reason: Some("stop".to_string()),
1452 extra: None,
1453 }],
1454 usage: None,
1455 extra: None,
1456 };
1457
1458 let normalized = VsCodeCopilotClient::normalize_choices(response.clone());
1459
1460 assert_eq!(normalized.choices.len(), 1);
1461 assert_eq!(normalized.choices[0].index, Some(0));
1462 assert_eq!(
1463 normalized.choices[0].message.content,
1464 Some("Hello".to_string())
1465 );
1466 }
1467
1468 #[test]
1469 fn test_normalize_choices_anthropic_split() {
1470 use crate::providers::vscode::types::*;
1473
1474 let response = ChatCompletionResponse {
1475 id: "msg_haiku".to_string(),
1476 object: None,
1477 created: Some(1768984171),
1478 model: "claude-haiku-4.5".to_string(),
1479 choices: vec![
1480 Choice {
1481 index: None,
1482 message: ResponseMessage {
1483 role: "assistant".to_string(),
1484 content: Some("I'll examine the file".to_string()),
1485 tool_calls: None,
1486 extra: None,
1487 },
1488 finish_reason: Some("tool_calls".to_string()),
1489 extra: None,
1490 },
1491 Choice {
1492 index: None,
1493 message: ResponseMessage {
1494 role: "assistant".to_string(),
1495 content: None,
1496 tool_calls: Some(vec![ResponseToolCall {
1497 id: "toolu_123".to_string(),
1498 function: ResponseFunctionCall {
1499 name: "read_file".to_string(),
1500 arguments: "{\"path\":\"test.js\"}".to_string(),
1501 },
1502 call_type: "function".to_string(),
1503 }]),
1504 extra: None,
1505 },
1506 finish_reason: Some("tool_calls".to_string()),
1507 extra: None,
1508 },
1509 ],
1510 usage: Some(Usage {
1511 prompt_tokens: 100,
1512 completion_tokens: 50,
1513 total_tokens: 150,
1514 prompt_tokens_details: None,
1515 extra: None,
1516 }),
1517 extra: None,
1518 };
1519
1520 let normalized = VsCodeCopilotClient::normalize_choices(response);
1521
1522 assert_eq!(normalized.choices.len(), 1);
1524
1525 let choice = &normalized.choices[0];
1526 assert_eq!(choice.index, Some(0));
1527
1528 assert_eq!(
1530 choice.message.content,
1531 Some("I'll examine the file".to_string())
1532 );
1533 assert!(choice.message.tool_calls.is_some());
1534 assert_eq!(choice.message.tool_calls.as_ref().unwrap().len(), 1);
1535 assert_eq!(
1536 choice.message.tool_calls.as_ref().unwrap()[0].function.name,
1537 "read_file"
1538 );
1539 }
1540
1541 #[test]
1542 fn test_normalize_choices_no_merge_with_indices() {
1543 use crate::providers::vscode::types::*;
1545
1546 let response = ChatCompletionResponse {
1547 id: "test_multiple".to_string(),
1548 object: None,
1549 created: None,
1550 model: "gpt-4.1".to_string(),
1551 choices: vec![
1552 Choice {
1553 index: Some(0),
1554 message: ResponseMessage {
1555 role: "assistant".to_string(),
1556 content: Some("First".to_string()),
1557 tool_calls: None,
1558 extra: None,
1559 },
1560 finish_reason: Some("stop".to_string()),
1561 extra: None,
1562 },
1563 Choice {
1564 index: Some(1),
1565 message: ResponseMessage {
1566 role: "assistant".to_string(),
1567 content: Some("Second".to_string()),
1568 tool_calls: None,
1569 extra: None,
1570 },
1571 finish_reason: Some("stop".to_string()),
1572 extra: None,
1573 },
1574 ],
1575 usage: None,
1576 extra: None,
1577 };
1578
1579 let normalized = VsCodeCopilotClient::normalize_choices(response.clone());
1580
1581 assert_eq!(normalized.choices.len(), 2);
1583 assert_eq!(normalized.choices[0].index, Some(0));
1584 assert_eq!(normalized.choices[1].index, Some(1));
1585 }
1586}