1pub mod auth;
49mod client;
50mod error;
51mod stream;
52pub mod token;
53pub mod types;
54
55use async_trait::async_trait;
56use futures::stream::{BoxStream, StreamExt};
57use std::time::Duration;
58use tracing::debug;
59
60pub use client::{AccountType, VsCodeCopilotClient};
61pub use error::{Result, VsCodeError};
62pub use types::{Model, ModelsResponse};
63
64use crate::error::Result as LlmResult;
65use crate::traits::{
66 ChatMessage, ChatRole, CompletionOptions, EmbeddingProvider, FunctionCall, LLMProvider,
67 LLMResponse, StreamChunk, ToolCall, ToolChoice, ToolDefinition,
68};
69use types::{
70 ChatCompletionRequest, ContentPart, EmbeddingInput, EmbeddingRequest, ImageUrlContent,
71 RequestContent, RequestFunction, RequestMessage, RequestTool, ResponseFormat,
72};
73
74#[derive(Clone)]
78pub struct VsCodeCopilotProvider {
79 client: VsCodeCopilotClient,
81
82 model: String,
84
85 max_context_length: usize,
87
88 #[allow(dead_code)]
90 supports_vision: bool,
91
92 embedding_model: String,
94
95 embedding_dimension: usize,
97}
98
99impl VsCodeCopilotProvider {
100 #[allow(clippy::new_ret_no_self)]
102 pub fn new() -> VsCodeCopilotProviderBuilder {
103 VsCodeCopilotProviderBuilder::default()
104 }
105
106 pub fn with_proxy(proxy_url: impl Into<String>) -> VsCodeCopilotProviderBuilder {
108 VsCodeCopilotProviderBuilder::new().proxy_url(proxy_url)
109 }
110
111 pub fn get_client(&self) -> &VsCodeCopilotClient {
113 &self.client
114 }
115
116 pub async fn list_models(&self) -> Result<types::ModelsResponse> {
123 self.client.list_models().await
124 }
125
126 fn convert_messages(messages: &[ChatMessage]) -> Vec<RequestMessage> {
146 messages
147 .iter()
148 .map(|msg| {
149 let tool_calls = msg.tool_calls.as_ref().map(|calls| {
151 calls
152 .iter()
153 .map(|tc| types::ResponseToolCall {
154 id: tc.id.clone(),
155 call_type: "function".to_string(),
156 function: types::ResponseFunctionCall {
157 name: tc.name().to_string(),
158 arguments: tc.arguments().to_string(),
159 },
160 })
161 .collect()
162 });
163
164 let cache_control =
166 msg.cache_control
167 .as_ref()
168 .map(|cc| types::RequestCacheControl {
169 cache_type: cc.cache_type.clone(),
170 });
171
172 let content = if msg.content.is_empty() && tool_calls.is_some() {
175 None } else if msg.has_images() {
177 let mut parts: Vec<ContentPart> = Vec::new();
179
180 if !msg.content.is_empty() {
182 parts.push(ContentPart::Text {
183 text: msg.content.clone(),
184 });
185 }
186
187 if let Some(images) = &msg.images {
189 for img in images {
190 let data_uri = format!("data:{};base64,{}", img.mime_type, img.data);
192 parts.push(ContentPart::ImageUrl {
193 image_url: ImageUrlContent {
194 url: data_uri,
195 detail: img.detail.clone(),
196 },
197 });
198 }
199 }
200
201 Some(RequestContent::Parts(parts))
202 } else {
203 Some(RequestContent::Text(msg.content.clone()))
205 };
206
207 RequestMessage {
208 role: match msg.role {
209 ChatRole::System => "system".to_string(),
210 ChatRole::User => "user".to_string(),
211 ChatRole::Assistant => "assistant".to_string(),
212 ChatRole::Tool => "tool".to_string(),
213 ChatRole::Function => "tool".to_string(),
214 },
215 content,
216 name: msg.name.clone(),
217 tool_calls,
218 tool_call_id: msg.tool_call_id.clone(),
219 cache_control,
220 }
221 })
222 .collect()
223 }
224
225 fn convert_tools(tools: &[ToolDefinition]) -> Vec<RequestTool> {
227 tools
228 .iter()
229 .map(|tool| RequestTool {
230 tool_type: "function".to_string(),
231 function: RequestFunction {
232 name: tool.function.name.clone(),
233 description: tool.function.description.clone(),
234 parameters: tool.function.parameters.clone(),
235 strict: tool.function.strict,
236 },
237 })
238 .collect()
239 }
240
241 fn convert_tool_choice(choice: Option<ToolChoice>) -> Option<serde_json::Value> {
243 choice.map(|c| match c {
244 ToolChoice::Auto(s) | ToolChoice::Required(s) => serde_json::Value::String(s),
245 ToolChoice::Function { function, .. } => {
246 serde_json::json!({
247 "type": "function",
248 "function": {
249 "name": function.name
250 }
251 })
252 }
253 })
254 }
255
256 fn convert_response_tool_calls(calls: Option<Vec<types::ResponseToolCall>>) -> Vec<ToolCall> {
258 calls
259 .unwrap_or_default()
260 .into_iter()
261 .map(|tc| ToolCall {
262 id: tc.id,
263 call_type: tc.call_type,
264 function: FunctionCall {
265 name: tc.function.name,
266 arguments: tc.function.arguments,
267 },
268 })
269 .collect()
270 }
271}
272
273impl Default for VsCodeCopilotProvider {
274 fn default() -> Self {
275 Self::new()
276 .build()
277 .expect("Failed to build default VsCodeCopilotProvider")
278 }
279}
280
281#[derive(Clone)]
302pub struct VsCodeCopilotProviderBuilder {
303 base_url: Option<String>,
305 model: String,
307 max_context_length: usize,
309 supports_vision: bool,
311 timeout: Duration,
313 direct_mode: bool,
315 account_type: client::AccountType,
317 embedding_model: String,
319 embedding_dimension: usize,
321}
322
323impl Default for VsCodeCopilotProviderBuilder {
324 fn default() -> Self {
325 let direct_mode = std::env::var("VSCODE_COPILOT_DIRECT")
327 .map(|v| v.to_lowercase() != "false" && v != "0")
328 .unwrap_or(true); let account_type = std::env::var("VSCODE_COPILOT_ACCOUNT_TYPE")
332 .ok()
333 .and_then(|s| client::AccountType::from_str(&s))
334 .unwrap_or_default();
335
336 let embedding_model = std::env::var("VSCODE_COPILOT_EMBEDDING_MODEL")
338 .unwrap_or_else(|_| "text-embedding-3-small".to_string());
339
340 let embedding_dimension = Self::dimension_for_embedding_model(&embedding_model);
342
343 Self {
344 base_url: None,
345 model: "gpt-4o-mini".to_string(),
346 max_context_length: 128_000,
347 supports_vision: false,
348 timeout: Duration::from_secs(120),
349 direct_mode,
350 account_type,
351 embedding_model,
352 embedding_dimension,
353 }
354 }
355}
356
357impl VsCodeCopilotProviderBuilder {
358 pub fn new() -> Self {
360 Self::default()
361 }
362
363 pub fn proxy_url(mut self, url: impl Into<String>) -> Self {
367 self.base_url = Some(url.into());
368 self.direct_mode = false;
369 self
370 }
371
372 pub fn direct(mut self) -> Self {
376 self.direct_mode = true;
377 self.base_url = None;
378 self
379 }
380
381 pub fn account_type(mut self, account_type: client::AccountType) -> Self {
387 self.account_type = account_type;
388 self
389 }
390
391 pub fn model(mut self, model: impl Into<String>) -> Self {
393 let model_str = model.into();
394 self.max_context_length = Self::context_length_for_model(&model_str);
395
396 if model_str.contains("grok") {
398 self.timeout = Duration::from_secs(300); }
400
401 self.model = model_str;
402 self
403 }
404
405 pub fn embedding_model(mut self, model: impl Into<String>) -> Self {
412 let model_str = model.into();
413 self.embedding_dimension = Self::dimension_for_embedding_model(&model_str);
414 self.embedding_model = model_str;
415 self
416 }
417
418 pub fn with_vision(mut self, enabled: bool) -> Self {
420 self.supports_vision = enabled;
421 self
422 }
423
424 pub fn timeout(mut self, duration: Duration) -> Self {
426 self.timeout = duration;
427 self
428 }
429
430 pub fn build(self) -> Result<VsCodeCopilotProvider> {
432 let client = if let Some(url) = &self.base_url {
433 VsCodeCopilotClient::with_base_url(url, self.timeout)?
435 } else if self.direct_mode {
436 VsCodeCopilotClient::new_with_options(self.timeout, true, self.account_type)?
438 .with_vision(self.supports_vision)
439 } else {
440 let proxy_url = std::env::var("VSCODE_COPILOT_PROXY_URL")
442 .unwrap_or_else(|_| "http://localhost:4141".to_string());
443 VsCodeCopilotClient::with_base_url(&proxy_url, self.timeout)?
444 };
445
446 let mode_str = if self.direct_mode { "direct" } else { "proxy" };
447
448 debug!(
449 model = %self.model,
450 max_context = self.max_context_length,
451 mode = mode_str,
452 account_type = ?self.account_type,
453 embedding_model = %self.embedding_model,
454 "Built VsCodeCopilotProvider"
455 );
456
457 Ok(VsCodeCopilotProvider {
458 client,
459 model: self.model,
460 max_context_length: self.max_context_length,
461 supports_vision: self.supports_vision,
462 embedding_model: self.embedding_model,
463 embedding_dimension: self.embedding_dimension,
464 })
465 }
466
467 fn context_length_for_model(model: &str) -> usize {
469 match model {
470 m if m.contains("grok") => 131_072, m if m.contains("gpt-4o") => 128_000,
472 m if m.contains("gpt-4-turbo") => 128_000,
473 m if m.contains("gpt-4-32k") => 32_768,
474 m if m.contains("gpt-4") => 8_192,
475 m if m.contains("gpt-3.5-turbo-16k") => 16_384,
476 m if m.contains("gpt-3.5") => 4_096,
477 m if m.contains("o1") || m.contains("o3") => 200_000,
478 _ => 128_000, }
480 }
481
482 fn dimension_for_embedding_model(model: &str) -> usize {
484 match model {
485 m if m.contains("text-embedding-3-large") => 3072,
486 m if m.contains("text-embedding-3-small") => 1536,
487 m if m.contains("text-embedding-ada") => 1536,
488 m if m.contains("copilot-text-embedding") => 1536,
489 _ => 1536, }
491 }
492}
493
494#[async_trait]
495impl LLMProvider for VsCodeCopilotProvider {
496 fn name(&self) -> &str {
497 "vscode-copilot"
498 }
499
500 fn model(&self) -> &str {
501 &self.model
502 }
503
504 fn max_context_length(&self) -> usize {
505 self.max_context_length
506 }
507
508 async fn complete(&self, prompt: &str) -> LlmResult<LLMResponse> {
509 self.complete_with_options(prompt, &CompletionOptions::default())
510 .await
511 }
512
513 async fn complete_with_options(
514 &self,
515 prompt: &str,
516 options: &CompletionOptions,
517 ) -> LlmResult<LLMResponse> {
518 let mut messages = Vec::new();
519
520 if let Some(system) = &options.system_prompt {
521 messages.push(ChatMessage::system(system));
522 }
523 messages.push(ChatMessage::user(prompt));
524
525 self.chat(&messages, Some(options)).await
526 }
527
528 async fn chat(
529 &self,
530 messages: &[ChatMessage],
531 options: Option<&CompletionOptions>,
532 ) -> LlmResult<LLMResponse> {
533 let request_messages = Self::convert_messages(messages);
535
536 let opts = options.cloned().unwrap_or_default();
538 let request = ChatCompletionRequest {
539 messages: request_messages,
540 model: self.model.clone(),
541 temperature: opts.temperature,
542 top_p: opts.top_p,
543 max_tokens: opts.max_tokens,
544 stop: opts.stop,
545 stream: Some(false),
546 frequency_penalty: opts.frequency_penalty,
547 presence_penalty: opts.presence_penalty,
548 response_format: opts
549 .response_format
550 .map(|fmt| ResponseFormat { format_type: fmt }),
551 tools: None,
552 tool_choice: None,
553 parallel_tool_calls: None,
554 };
555
556 debug!(
557 model = %self.model,
558 message_count = messages.len(),
559 "Sending chat request"
560 );
561
562 let response = self.client.chat_completion(request).await?;
564
565 let choice = response
567 .choices
568 .first()
569 .ok_or_else(|| crate::error::LlmError::ApiError("No choices in response".into()))?;
570
571 let content = choice.message.content.clone().unwrap_or_default();
572
573 let usage = response.usage.unwrap_or(types::Usage {
574 prompt_tokens: 0,
575 completion_tokens: 0,
576 total_tokens: 0,
577 prompt_tokens_details: None,
578 extra: None,
579 });
580
581 debug!(
582 prompt_tokens = usage.prompt_tokens,
583 completion_tokens = usage.completion_tokens,
584 "Chat request completed"
585 );
586
587 let tool_calls = Self::convert_response_tool_calls(choice.message.tool_calls.clone());
589
590 let cache_hit_tokens = usage
592 .prompt_tokens_details
593 .as_ref()
594 .and_then(|d| d.cached_tokens);
595
596 let mut response_builder = LLMResponse::new(content, response.model.clone())
598 .with_usage(usage.prompt_tokens, usage.completion_tokens)
599 .with_finish_reason(choice.finish_reason.clone().unwrap_or_default())
600 .with_tool_calls(tool_calls)
601 .with_metadata("id", serde_json::json!(response.id));
602
603 if let Some(cached) = cache_hit_tokens {
605 response_builder = response_builder.with_cache_hit_tokens(cached);
606 }
607
608 Ok(response_builder)
609 }
610
611 async fn chat_with_tools(
612 &self,
613 messages: &[ChatMessage],
614 tools: &[ToolDefinition],
615 tool_choice: Option<ToolChoice>,
616 options: Option<&CompletionOptions>,
617 ) -> LlmResult<LLMResponse> {
618 let request_messages = Self::convert_messages(messages);
620
621 let request_tools = if tools.is_empty() {
623 None
624 } else {
625 Some(Self::convert_tools(tools))
626 };
627
628 let request_tool_choice = Self::convert_tool_choice(tool_choice);
630
631 let opts = options.cloned().unwrap_or_default();
633 let request = ChatCompletionRequest {
634 messages: request_messages,
635 model: self.model.clone(),
636 temperature: opts.temperature,
637 top_p: opts.top_p,
638 max_tokens: opts.max_tokens,
639 stop: opts.stop,
640 stream: Some(false),
641 frequency_penalty: opts.frequency_penalty,
642 presence_penalty: opts.presence_penalty,
643 response_format: opts
644 .response_format
645 .map(|fmt| ResponseFormat { format_type: fmt }),
646 tools: request_tools,
647 tool_choice: request_tool_choice,
648 parallel_tool_calls: Some(true),
649 };
650
651 debug!(
652 model = %self.model,
653 message_count = messages.len(),
654 tool_count = tools.len(),
655 "Sending chat request with tools"
656 );
657
658 let response = self.client.chat_completion(request).await?;
660
661 let choice = response
663 .choices
664 .first()
665 .ok_or_else(|| crate::error::LlmError::ApiError("No choices in response".into()))?;
666
667 let content = choice.message.content.clone().unwrap_or_default();
668 let tool_calls = Self::convert_response_tool_calls(choice.message.tool_calls.clone());
669
670 let usage = response.usage.unwrap_or(types::Usage {
671 prompt_tokens: 0,
672 completion_tokens: 0,
673 total_tokens: 0,
674 prompt_tokens_details: None,
675 extra: None,
676 });
677
678 debug!(
679 prompt_tokens = usage.prompt_tokens,
680 completion_tokens = usage.completion_tokens,
681 tool_call_count = tool_calls.len(),
682 "Chat with tools request completed"
683 );
684
685 let cache_hit_tokens = usage
687 .prompt_tokens_details
688 .as_ref()
689 .and_then(|d| d.cached_tokens);
690
691 let mut response_builder = LLMResponse::new(content, response.model.clone())
693 .with_usage(usage.prompt_tokens, usage.completion_tokens)
694 .with_finish_reason(choice.finish_reason.clone().unwrap_or_default())
695 .with_tool_calls(tool_calls)
696 .with_metadata("id", serde_json::json!(response.id));
697
698 if let Some(cached) = cache_hit_tokens {
700 response_builder = response_builder.with_cache_hit_tokens(cached);
701 }
702
703 Ok(response_builder)
704 }
705
706 async fn stream(&self, prompt: &str) -> LlmResult<BoxStream<'static, LlmResult<String>>> {
707 let request_messages = vec![RequestMessage {
708 role: "user".to_string(),
709 content: Some(RequestContent::Text(prompt.to_string())),
710 name: None,
711 tool_calls: None,
712 tool_call_id: None,
713 cache_control: None,
714 }];
715
716 let request = ChatCompletionRequest {
717 messages: request_messages,
718 model: self.model.clone(),
719 stream: Some(true),
720 ..Default::default()
721 };
722
723 debug!(model = %self.model, "Sending streaming request");
724
725 let response = self.client.chat_completion_stream(request).await?;
726 let stream = stream::parse_sse_stream(response);
727
728 let mapped = stream.map(|result| result.map_err(|e| e.into()));
730
731 Ok(Box::pin(mapped))
732 }
733
734 fn supports_streaming(&self) -> bool {
735 true
736 }
737
738 fn supports_json_mode(&self) -> bool {
739 true
740 }
741
742 fn supports_function_calling(&self) -> bool {
743 true
744 }
745
746 fn supports_tool_streaming(&self) -> bool {
748 true
749 }
750
751 async fn chat_with_tools_stream(
761 &self,
762 messages: &[ChatMessage],
763 tools: &[ToolDefinition],
764 tool_choice: Option<ToolChoice>,
765 options: Option<&CompletionOptions>,
766 ) -> LlmResult<BoxStream<'static, LlmResult<StreamChunk>>> {
767 let request_messages = Self::convert_messages(messages);
769
770 let request_tools = if tools.is_empty() {
772 None
773 } else {
774 Some(Self::convert_tools(tools))
775 };
776
777 let request_tool_choice = Self::convert_tool_choice(tool_choice);
779
780 let opts = options.cloned().unwrap_or_default();
782 let request = ChatCompletionRequest {
783 messages: request_messages,
784 model: self.model.clone(),
785 temperature: opts.temperature,
786 top_p: opts.top_p,
787 max_tokens: opts.max_tokens,
788 stop: opts.stop,
789 stream: Some(true), frequency_penalty: opts.frequency_penalty,
791 presence_penalty: opts.presence_penalty,
792 response_format: opts
793 .response_format
794 .map(|fmt| ResponseFormat { format_type: fmt }),
795 tools: request_tools,
796 tool_choice: request_tool_choice,
797 parallel_tool_calls: Some(true),
798 };
799
800 debug!(
801 model = %self.model,
802 message_count = messages.len(),
803 tool_count = tools.len(),
804 "Sending streaming chat request with tools (OODA-05)"
805 );
806
807 let response = self.client.chat_completion_stream(request).await?;
809
810 let stream = stream::parse_sse_stream_with_tools(response);
812
813 let mapped = stream.map(|result| result.map_err(|e| e.into()));
815
816 Ok(Box::pin(mapped))
817 }
818}
819
820#[async_trait]
821impl EmbeddingProvider for VsCodeCopilotProvider {
822 fn name(&self) -> &str {
823 "vscode-copilot"
824 }
825
826 #[allow(clippy::misnamed_getters)]
827 fn model(&self) -> &str {
828 &self.embedding_model
831 }
832
833 fn dimension(&self) -> usize {
834 self.embedding_dimension
835 }
836
837 fn max_tokens(&self) -> usize {
838 8192 }
840
841 async fn embed(&self, texts: &[String]) -> LlmResult<Vec<Vec<f32>>> {
842 let input = if texts.len() == 1 {
843 EmbeddingInput::Single(texts[0].clone())
844 } else {
845 EmbeddingInput::Multiple(texts.to_vec())
846 };
847
848 let request = EmbeddingRequest::new(input, &self.embedding_model);
849
850 debug!(
851 model = %self.embedding_model,
852 input_count = texts.len(),
853 "Sending embedding request"
854 );
855
856 let response = self.client.create_embeddings(request).await?;
857
858 debug!(
859 prompt_tokens = response.usage.prompt_tokens,
860 total_tokens = response.usage.total_tokens,
861 embedding_count = response.data.len(),
862 "Embedding request completed"
863 );
864
865 let embeddings: Vec<Vec<f32>> = response
867 .data
868 .into_iter()
869 .map(|e| (e.index, e.embedding))
870 .collect::<Vec<_>>()
871 .into_iter()
872 .map(|(_, e)| e)
873 .collect();
874
875 if embeddings.len() != texts.len() {
877 return Err(crate::error::LlmError::ApiError(format!(
878 "Expected {} embeddings, got {}",
879 texts.len(),
880 embeddings.len()
881 )));
882 }
883
884 Ok(embeddings)
885 }
886}
887
888#[cfg(test)]
889mod tests {
890 use super::*;
891 use types::{ResponseFunctionCall, ResponseToolCall};
892
893 #[test]
899 fn test_convert_single_tool() {
900 let tools = vec![ToolDefinition::function(
902 "read_file",
903 "Read contents of a file",
904 serde_json::json!({
905 "type": "object",
906 "properties": {
907 "path": {"type": "string"}
908 },
909 "required": ["path"]
910 }),
911 )];
912
913 let converted = VsCodeCopilotProvider::convert_tools(&tools);
914
915 assert_eq!(converted.len(), 1);
916 assert_eq!(converted[0].tool_type, "function");
917 assert_eq!(converted[0].function.name, "read_file");
918 assert_eq!(converted[0].function.description, "Read contents of a file");
919 assert!(converted[0].function.strict.is_some());
920 }
921
922 #[test]
923 fn test_convert_multiple_tools() {
924 let tools = vec![
926 ToolDefinition::function("tool_a", "First tool", serde_json::json!({})),
927 ToolDefinition::function("tool_b", "Second tool", serde_json::json!({})),
928 ToolDefinition::function("tool_c", "Third tool", serde_json::json!({})),
929 ];
930
931 let converted = VsCodeCopilotProvider::convert_tools(&tools);
932
933 assert_eq!(converted.len(), 3);
934 assert_eq!(converted[0].function.name, "tool_a");
935 assert_eq!(converted[1].function.name, "tool_b");
936 assert_eq!(converted[2].function.name, "tool_c");
937 }
938
939 #[test]
940 fn test_convert_tool_with_complex_parameters() {
941 let params = serde_json::json!({
943 "type": "object",
944 "properties": {
945 "query": {"type": "string", "description": "Search query"},
946 "options": {
947 "type": "object",
948 "properties": {
949 "regex": {"type": "boolean"},
950 "case_sensitive": {"type": "boolean"}
951 }
952 }
953 },
954 "required": ["query"]
955 });
956
957 let tools = vec![ToolDefinition::function(
958 "grep_search",
959 "Search codebase",
960 params.clone(),
961 )];
962
963 let converted = VsCodeCopilotProvider::convert_tools(&tools);
964
965 assert_eq!(converted[0].function.parameters, params);
966 }
967
968 #[test]
974 fn test_tool_choice_none() {
975 let result = VsCodeCopilotProvider::convert_tool_choice(None);
977 assert!(result.is_none());
978 }
979
980 #[test]
981 fn test_tool_choice_auto() {
982 let choice = ToolChoice::auto();
984 let result = VsCodeCopilotProvider::convert_tool_choice(Some(choice));
985
986 assert_eq!(result, Some(serde_json::Value::String("auto".to_string())));
987 }
988
989 #[test]
990 fn test_tool_choice_required() {
991 let choice = ToolChoice::required();
993 let result = VsCodeCopilotProvider::convert_tool_choice(Some(choice));
994
995 assert_eq!(
996 result,
997 Some(serde_json::Value::String("required".to_string()))
998 );
999 }
1000
1001 #[test]
1002 fn test_tool_choice_function() {
1003 let choice = ToolChoice::function("read_file");
1005 let result = VsCodeCopilotProvider::convert_tool_choice(Some(choice));
1006
1007 let expected = serde_json::json!({
1008 "type": "function",
1009 "function": {
1010 "name": "read_file"
1011 }
1012 });
1013
1014 assert_eq!(result, Some(expected));
1015 }
1016
1017 #[test]
1018 fn test_tool_choice_none_value() {
1019 let choice = ToolChoice::none();
1021 let result = VsCodeCopilotProvider::convert_tool_choice(Some(choice));
1022
1023 assert_eq!(result, Some(serde_json::Value::String("none".to_string())));
1024 }
1025
1026 #[test]
1032 fn test_response_tool_calls_none() {
1033 let result = VsCodeCopilotProvider::convert_response_tool_calls(None);
1035 assert!(result.is_empty());
1036 }
1037
1038 #[test]
1039 fn test_response_tool_calls_single() {
1040 let calls = vec![ResponseToolCall {
1042 id: "call_123".to_string(),
1043 call_type: "function".to_string(),
1044 function: ResponseFunctionCall {
1045 name: "read_file".to_string(),
1046 arguments: r#"{"path":"src/main.rs"}"#.to_string(),
1047 },
1048 }];
1049
1050 let result = VsCodeCopilotProvider::convert_response_tool_calls(Some(calls));
1051
1052 assert_eq!(result.len(), 1);
1053 assert_eq!(result[0].id, "call_123");
1054 assert_eq!(result[0].call_type, "function");
1055 assert_eq!(result[0].function.name, "read_file");
1056 assert_eq!(result[0].function.arguments, r#"{"path":"src/main.rs"}"#);
1057 }
1058
1059 #[test]
1060 fn test_response_tool_calls_multiple() {
1061 let calls = vec![
1063 ResponseToolCall {
1064 id: "call_1".to_string(),
1065 call_type: "function".to_string(),
1066 function: ResponseFunctionCall {
1067 name: "read_file".to_string(),
1068 arguments: "{}".to_string(),
1069 },
1070 },
1071 ResponseToolCall {
1072 id: "call_2".to_string(),
1073 call_type: "function".to_string(),
1074 function: ResponseFunctionCall {
1075 name: "search_code".to_string(),
1076 arguments: "{}".to_string(),
1077 },
1078 },
1079 ];
1080
1081 let result = VsCodeCopilotProvider::convert_response_tool_calls(Some(calls));
1082
1083 assert_eq!(result.len(), 2);
1084 assert_eq!(result[0].id, "call_1");
1085 assert_eq!(result[1].id, "call_2");
1086 }
1087
1088 #[test]
1094 fn test_message_with_tool_calls() {
1095 let mut msg = ChatMessage::assistant("I'll read that file for you.");
1097 msg.tool_calls = Some(vec![ToolCall {
1098 id: "call_abc".to_string(),
1099 call_type: "function".to_string(),
1100 function: FunctionCall {
1101 name: "read_file".to_string(),
1102 arguments: r#"{"path":"Cargo.toml"}"#.to_string(),
1103 },
1104 }]);
1105
1106 let converted = VsCodeCopilotProvider::convert_messages(&[msg]);
1107
1108 assert_eq!(converted.len(), 1);
1109 assert!(converted[0].tool_calls.is_some());
1110
1111 let tool_calls = converted[0].tool_calls.as_ref().unwrap();
1112 assert_eq!(tool_calls.len(), 1);
1113 assert_eq!(tool_calls[0].id, "call_abc");
1114 assert_eq!(tool_calls[0].function.name, "read_file");
1115 }
1116
1117 #[test]
1118 fn test_tool_message_conversion() {
1119 let msg = ChatMessage {
1121 role: ChatRole::Tool,
1122 content: "File contents: ...".to_string(),
1123 name: Some("read_file".to_string()),
1124 tool_calls: None,
1125 tool_call_id: Some("call_xyz".to_string()),
1126 cache_control: None,
1127 images: None,
1128 };
1129
1130 let converted = VsCodeCopilotProvider::convert_messages(&[msg]);
1131
1132 assert_eq!(converted.len(), 1);
1133 assert_eq!(converted[0].role, "tool");
1134 assert_eq!(
1136 converted[0].content,
1137 Some(RequestContent::Text("File contents: ...".to_string()))
1138 );
1139 assert_eq!(converted[0].tool_call_id, Some("call_xyz".to_string()));
1140 }
1141
1142 #[test]
1143 fn test_assistant_message_with_only_tool_calls() {
1144 let mut msg = ChatMessage::assistant("");
1146 msg.tool_calls = Some(vec![ToolCall {
1147 id: "call_1".to_string(),
1148 call_type: "function".to_string(),
1149 function: FunctionCall {
1150 name: "list_files".to_string(),
1151 arguments: "{}".to_string(),
1152 },
1153 }]);
1154
1155 let converted = VsCodeCopilotProvider::convert_messages(&[msg]);
1156
1157 assert!(converted[0].content.is_none());
1159 assert!(converted[0].tool_calls.is_some());
1160 }
1161
1162 #[test]
1169 fn test_convert_messages_text_only() {
1170 let messages = vec![ChatMessage::user("Hello, world!")];
1172 let converted = VsCodeCopilotProvider::convert_messages(&messages);
1173
1174 assert_eq!(converted.len(), 1);
1175 assert_eq!(converted[0].role, "user");
1176 match &converted[0].content {
1177 Some(RequestContent::Text(text)) => {
1178 assert_eq!(text, "Hello, world!");
1179 }
1180 _ => panic!("Expected RequestContent::Text"),
1181 }
1182 }
1183
1184 #[test]
1185 fn test_convert_messages_with_images() {
1186 use crate::traits::ImageData;
1188
1189 let msg = ChatMessage::user_with_images(
1190 "What's in this image?",
1191 vec![ImageData {
1192 data: "iVBORw0KGgo=".to_string(),
1193 mime_type: "image/png".to_string(),
1194 detail: None,
1195 }],
1196 );
1197
1198 let converted = VsCodeCopilotProvider::convert_messages(&[msg]);
1199
1200 assert_eq!(converted.len(), 1);
1201 match &converted[0].content {
1202 Some(RequestContent::Parts(parts)) => {
1203 assert_eq!(parts.len(), 2); match &parts[0] {
1207 ContentPart::Text { text } => {
1208 assert_eq!(text, "What's in this image?");
1209 }
1210 _ => panic!("First part should be text"),
1211 }
1212
1213 match &parts[1] {
1215 ContentPart::ImageUrl { image_url } => {
1216 assert!(image_url.url.starts_with("data:image/png;base64,"));
1217 assert!(image_url.url.contains("iVBORw0KGgo="));
1218 }
1219 _ => panic!("Second part should be image_url"),
1220 }
1221 }
1222 _ => panic!("Expected RequestContent::Parts for image message"),
1223 }
1224 }
1225
1226 #[test]
1227 fn test_convert_messages_with_image_detail() {
1228 use crate::traits::ImageData;
1230
1231 let msg = ChatMessage::user_with_images(
1232 "Describe in detail",
1233 vec![ImageData {
1234 data: "base64data".to_string(),
1235 mime_type: "image/jpeg".to_string(),
1236 detail: Some("high".to_string()),
1237 }],
1238 );
1239
1240 let converted = VsCodeCopilotProvider::convert_messages(&[msg]);
1241
1242 match &converted[0].content {
1243 Some(RequestContent::Parts(parts)) => {
1244 assert_eq!(parts.len(), 2);
1245
1246 match &parts[1] {
1247 ContentPart::ImageUrl { image_url } => {
1248 assert_eq!(image_url.detail, Some("high".to_string()));
1249 }
1250 _ => panic!("Expected ImageUrl part"),
1251 }
1252 }
1253 _ => panic!("Expected Parts content"),
1254 }
1255 }
1256
1257 #[test]
1262 fn test_context_length_detection() {
1263 assert_eq!(
1264 VsCodeCopilotProviderBuilder::context_length_for_model("gpt-4o"),
1265 128_000
1266 );
1267 assert_eq!(
1268 VsCodeCopilotProviderBuilder::context_length_for_model("gpt-4o-mini"),
1269 128_000
1270 );
1271 assert_eq!(
1272 VsCodeCopilotProviderBuilder::context_length_for_model("gpt-4"),
1273 8_192
1274 );
1275 assert_eq!(
1276 VsCodeCopilotProviderBuilder::context_length_for_model("gpt-3.5-turbo"),
1277 4_096
1278 );
1279 assert_eq!(
1280 VsCodeCopilotProviderBuilder::context_length_for_model("o1-preview"),
1281 200_000
1282 );
1283 }
1284
1285 #[test]
1286 fn test_message_conversion() {
1287 let messages = vec![
1288 ChatMessage::system("You are helpful."),
1289 ChatMessage::user("Hello!"),
1290 ChatMessage::assistant("Hi there!"),
1291 ];
1292
1293 let converted = VsCodeCopilotProvider::convert_messages(&messages);
1294
1295 assert_eq!(converted.len(), 3);
1296 assert_eq!(converted[0].role, "system");
1297 assert_eq!(
1299 converted[0].content,
1300 Some(RequestContent::Text("You are helpful.".to_string()))
1301 );
1302 assert_eq!(converted[1].role, "user");
1303 assert_eq!(converted[2].role, "assistant");
1304 }
1305
1306 #[test]
1307 fn test_builder_defaults() {
1308 std::env::set_var("VSCODE_COPILOT_DIRECT", "true");
1310 let builder = VsCodeCopilotProviderBuilder::default();
1311 assert_eq!(builder.model, "gpt-4o-mini");
1312 assert_eq!(builder.max_context_length, 128_000);
1313 assert!(!builder.supports_vision);
1314 assert!(builder.direct_mode); std::env::remove_var("VSCODE_COPILOT_DIRECT");
1316 }
1317
1318 #[test]
1319 fn test_builder_proxy_mode() {
1320 let provider = VsCodeCopilotProvider::new()
1321 .proxy_url("http://localhost:8080")
1322 .model("gpt-4")
1323 .with_vision(true)
1324 .build()
1325 .unwrap();
1326
1327 assert_eq!(provider.model, "gpt-4");
1328 assert_eq!(provider.max_context_length, 8_192);
1329 assert!(provider.supports_vision);
1330 }
1331
1332 #[test]
1333 fn test_builder_direct_mode() {
1334 let provider = VsCodeCopilotProvider::new()
1335 .direct()
1336 .model("gpt-4o")
1337 .build()
1338 .unwrap();
1339
1340 assert_eq!(provider.model, "gpt-4o");
1341 assert_eq!(provider.max_context_length, 128_000);
1342 }
1343
1344 #[test]
1345 fn test_account_type_base_url() {
1346 assert_eq!(
1347 client::AccountType::Individual.base_url(),
1348 "https://api.githubcopilot.com"
1349 );
1350 assert_eq!(
1351 client::AccountType::Business.base_url(),
1352 "https://api.business.githubcopilot.com"
1353 );
1354 assert_eq!(
1355 client::AccountType::Enterprise.base_url(),
1356 "https://api.enterprise.githubcopilot.com"
1357 );
1358 }
1359
1360 #[test]
1361 fn test_embedding_dimension_detection() {
1362 assert_eq!(
1363 VsCodeCopilotProviderBuilder::dimension_for_embedding_model("text-embedding-3-small"),
1364 1536
1365 );
1366 assert_eq!(
1367 VsCodeCopilotProviderBuilder::dimension_for_embedding_model("text-embedding-3-large"),
1368 3072
1369 );
1370 assert_eq!(
1371 VsCodeCopilotProviderBuilder::dimension_for_embedding_model("text-embedding-ada-002"),
1372 1536
1373 );
1374 assert_eq!(
1375 VsCodeCopilotProviderBuilder::dimension_for_embedding_model("unknown-model"),
1376 1536 );
1378 }
1379
1380 #[test]
1381 fn test_builder_embedding_model() {
1382 let provider = VsCodeCopilotProvider::new()
1383 .direct()
1384 .embedding_model("text-embedding-3-large")
1385 .build()
1386 .unwrap();
1387
1388 assert_eq!(provider.embedding_model, "text-embedding-3-large");
1389 assert_eq!(provider.embedding_dimension, 3072);
1390 }
1391
1392 #[test]
1404 fn test_builder_vision_disabled_by_default() {
1405 let builder = VsCodeCopilotProvider::new().direct();
1407
1408 let provider = builder.build();
1410 assert!(provider.is_ok());
1411
1412 let provider = provider.unwrap();
1414 assert!(!provider.supports_vision);
1415 }
1416
1417 #[test]
1418 fn test_builder_with_vision_true() {
1419 let provider = VsCodeCopilotProvider::new()
1421 .direct()
1422 .with_vision(true)
1423 .build()
1424 .unwrap();
1425
1426 assert!(provider.supports_vision);
1427 }
1428
1429 #[test]
1430 fn test_builder_with_vision_false() {
1431 let provider = VsCodeCopilotProvider::new()
1433 .direct()
1434 .with_vision(true)
1435 .with_vision(false) .build()
1437 .unwrap();
1438
1439 assert!(!provider.supports_vision);
1440 }
1441
1442 #[test]
1443 fn test_builder_vision_with_model() {
1444 let provider = VsCodeCopilotProvider::new()
1446 .direct()
1447 .model("gpt-4o") .with_vision(true)
1449 .build()
1450 .unwrap();
1451
1452 assert!(provider.supports_vision);
1453 assert_eq!(provider.model, "gpt-4o");
1454 }
1455
1456 #[test]
1457 fn test_builder_vision_with_proxy_mode() {
1458 let builder = VsCodeCopilotProvider::new()
1460 .proxy_url("http://localhost:4141")
1461 .with_vision(true);
1462
1463 assert!(builder.supports_vision);
1465 }
1466
1467 #[test]
1473 fn test_builder_chain_all_options() {
1474 use std::time::Duration;
1476
1477 let builder = VsCodeCopilotProvider::new()
1478 .model("claude-3.5-sonnet")
1479 .embedding_model("text-embedding-3-large")
1480 .with_vision(true)
1481 .timeout(Duration::from_secs(120));
1482
1483 assert_eq!(builder.model, "claude-3.5-sonnet");
1485 assert_eq!(builder.embedding_model, "text-embedding-3-large");
1486 assert!(builder.supports_vision);
1487 assert_eq!(builder.timeout.as_secs(), 120);
1488 }
1489
1490 #[test]
1491 fn test_builder_account_type_business() {
1492 use client::AccountType;
1494
1495 let builder = VsCodeCopilotProvider::new().account_type(AccountType::Business);
1496
1497 assert!(matches!(builder.account_type, AccountType::Business));
1498 }
1499
1500 #[test]
1501 fn test_builder_account_type_enterprise() {
1502 use client::AccountType;
1504
1505 let builder = VsCodeCopilotProvider::new().account_type(AccountType::Enterprise);
1506
1507 assert!(matches!(builder.account_type, AccountType::Enterprise));
1508 }
1509
1510 #[test]
1517 fn test_builder_default_embedding_model() {
1518 std::env::remove_var("VSCODE_COPILOT_EMBEDDING_MODEL"); let builder = VsCodeCopilotProviderBuilder::default();
1523 assert_eq!(builder.embedding_model, "text-embedding-3-small");
1524 assert_eq!(builder.embedding_dimension, 1536);
1525 }
1526
1527 #[test]
1528 fn test_builder_default_timeout() {
1529 let builder = VsCodeCopilotProviderBuilder::default();
1533 assert_eq!(builder.timeout.as_secs(), 120);
1534 }
1535
1536 #[test]
1537 fn test_builder_default_context_length() {
1538 std::env::set_var("VSCODE_COPILOT_DIRECT", "true");
1541 let builder = VsCodeCopilotProviderBuilder::default();
1542 assert_eq!(builder.max_context_length, 128_000);
1543 std::env::remove_var("VSCODE_COPILOT_DIRECT");
1544 }
1545}