pub mod auth;
mod client;
mod error;
mod stream;
pub mod token;
pub mod types;
use async_trait::async_trait;
use futures::stream::{BoxStream, StreamExt};
use std::time::Duration;
use tracing::debug;
pub use client::{AccountType, VsCodeCopilotClient};
pub use error::{Result, VsCodeError};
pub use types::{Model, ModelsResponse};
use crate::error::Result as LlmResult;
use crate::traits::{
ChatMessage, ChatRole, CompletionOptions, EmbeddingProvider, FunctionCall, LLMProvider,
LLMResponse, StreamChunk, ToolCall, ToolChoice, ToolDefinition,
};
use types::{
ChatCompletionRequest, ContentPart, EmbeddingInput, EmbeddingRequest, ImageUrlContent,
RequestContent, RequestFunction, RequestMessage, RequestTool, ResponseFormat,
};
#[derive(Clone)]
pub struct VsCodeCopilotProvider {
client: VsCodeCopilotClient,
model: String,
max_context_length: usize,
#[allow(dead_code)]
supports_vision: bool,
embedding_model: String,
embedding_dimension: usize,
}
impl VsCodeCopilotProvider {
#[allow(clippy::new_ret_no_self)]
pub fn new() -> VsCodeCopilotProviderBuilder {
VsCodeCopilotProviderBuilder::default()
}
pub fn with_proxy(proxy_url: impl Into<String>) -> VsCodeCopilotProviderBuilder {
VsCodeCopilotProviderBuilder::new().proxy_url(proxy_url)
}
pub fn get_client(&self) -> &VsCodeCopilotClient {
&self.client
}
pub async fn list_models(&self) -> Result<types::ModelsResponse> {
self.client.list_models().await
}
fn convert_messages(messages: &[ChatMessage]) -> Vec<RequestMessage> {
messages
.iter()
.map(|msg| {
let tool_calls = msg.tool_calls.as_ref().map(|calls| {
calls
.iter()
.map(|tc| types::ResponseToolCall {
id: tc.id.clone(),
call_type: "function".to_string(),
function: types::ResponseFunctionCall {
name: tc.name().to_string(),
arguments: tc.arguments().to_string(),
},
})
.collect()
});
let cache_control =
msg.cache_control
.as_ref()
.map(|cc| types::RequestCacheControl {
cache_type: cc.cache_type.clone(),
});
let content = if msg.content.is_empty() && tool_calls.is_some() {
None } else if msg.has_images() {
let mut parts: Vec<ContentPart> = Vec::new();
if !msg.content.is_empty() {
parts.push(ContentPart::Text {
text: msg.content.clone(),
});
}
if let Some(images) = &msg.images {
for img in images {
let data_uri = format!("data:{};base64,{}", img.mime_type, img.data);
parts.push(ContentPart::ImageUrl {
image_url: ImageUrlContent {
url: data_uri,
detail: img.detail.clone(),
},
});
}
}
Some(RequestContent::Parts(parts))
} else {
Some(RequestContent::Text(msg.content.clone()))
};
RequestMessage {
role: match msg.role {
ChatRole::System => "system".to_string(),
ChatRole::User => "user".to_string(),
ChatRole::Assistant => "assistant".to_string(),
ChatRole::Tool => "tool".to_string(),
ChatRole::Function => "tool".to_string(),
},
content,
name: msg.name.clone(),
tool_calls,
tool_call_id: msg.tool_call_id.clone(),
cache_control,
}
})
.collect()
}
fn convert_tools(tools: &[ToolDefinition]) -> Vec<RequestTool> {
tools
.iter()
.map(|tool| RequestTool {
tool_type: "function".to_string(),
function: RequestFunction {
name: tool.function.name.clone(),
description: tool.function.description.clone(),
parameters: tool.function.parameters.clone(),
strict: tool.function.strict,
},
})
.collect()
}
fn convert_tool_choice(choice: Option<ToolChoice>) -> Option<serde_json::Value> {
choice.map(|c| match c {
ToolChoice::Auto(s) | ToolChoice::Required(s) => serde_json::Value::String(s),
ToolChoice::Function { function, .. } => {
serde_json::json!({
"type": "function",
"function": {
"name": function.name
}
})
}
})
}
fn convert_response_tool_calls(calls: Option<Vec<types::ResponseToolCall>>) -> Vec<ToolCall> {
calls
.unwrap_or_default()
.into_iter()
.map(|tc| ToolCall {
id: tc.id,
call_type: tc.call_type,
function: FunctionCall {
name: tc.function.name,
arguments: tc.function.arguments,
},
thought_signature: None,
})
.collect()
}
}
impl Default for VsCodeCopilotProvider {
fn default() -> Self {
Self::new()
.build()
.expect("Failed to build default VsCodeCopilotProvider")
}
}
#[derive(Clone)]
pub struct VsCodeCopilotProviderBuilder {
base_url: Option<String>,
model: String,
max_context_length: usize,
supports_vision: bool,
timeout: Duration,
direct_mode: bool,
account_type: client::AccountType,
embedding_model: String,
embedding_dimension: usize,
}
impl Default for VsCodeCopilotProviderBuilder {
fn default() -> Self {
let direct_mode = std::env::var("VSCODE_COPILOT_DIRECT")
.map(|v| v.to_lowercase() != "false" && v != "0")
.unwrap_or(true);
let account_type = std::env::var("VSCODE_COPILOT_ACCOUNT_TYPE")
.ok()
.and_then(|s| client::AccountType::from_str(&s))
.unwrap_or_default();
let embedding_model = std::env::var("VSCODE_COPILOT_EMBEDDING_MODEL")
.unwrap_or_else(|_| "text-embedding-3-small".to_string());
let embedding_dimension = Self::dimension_for_embedding_model(&embedding_model);
Self {
base_url: None,
model: "gpt-5-mini".to_string(),
max_context_length: 128_000,
supports_vision: false,
timeout: Duration::from_secs(120),
direct_mode,
account_type,
embedding_model,
embedding_dimension,
}
}
}
impl VsCodeCopilotProviderBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn proxy_url(mut self, url: impl Into<String>) -> Self {
self.base_url = Some(url.into());
self.direct_mode = false;
self
}
pub fn direct(mut self) -> Self {
self.direct_mode = true;
self.base_url = None;
self
}
pub fn account_type(mut self, account_type: client::AccountType) -> Self {
self.account_type = account_type;
self
}
pub fn model(mut self, model: impl Into<String>) -> Self {
let model_str = model.into();
self.max_context_length = Self::context_length_for_model(&model_str);
if model_str.contains("grok") {
self.timeout = Duration::from_secs(300); }
self.model = model_str;
self
}
pub fn embedding_model(mut self, model: impl Into<String>) -> Self {
let model_str = model.into();
self.embedding_dimension = Self::dimension_for_embedding_model(&model_str);
self.embedding_model = model_str;
self
}
pub fn with_vision(mut self, enabled: bool) -> Self {
self.supports_vision = enabled;
self
}
pub fn timeout(mut self, duration: Duration) -> Self {
self.timeout = duration;
self
}
pub fn build(self) -> Result<VsCodeCopilotProvider> {
let client = if let Some(url) = &self.base_url {
VsCodeCopilotClient::with_base_url(url, self.timeout)?
} else if self.direct_mode {
VsCodeCopilotClient::new_with_options(self.timeout, true, self.account_type)?
.with_vision(self.supports_vision)
} else {
let proxy_url = std::env::var("VSCODE_COPILOT_PROXY_URL")
.unwrap_or_else(|_| "http://localhost:4141".to_string());
VsCodeCopilotClient::with_base_url(&proxy_url, self.timeout)?
};
let mode_str = if self.direct_mode { "direct" } else { "proxy" };
debug!(
model = %self.model,
max_context = self.max_context_length,
mode = mode_str,
account_type = ?self.account_type,
embedding_model = %self.embedding_model,
"Built VsCodeCopilotProvider"
);
Ok(VsCodeCopilotProvider {
client,
model: self.model,
max_context_length: self.max_context_length,
supports_vision: self.supports_vision,
embedding_model: self.embedding_model,
embedding_dimension: self.embedding_dimension,
})
}
fn context_length_for_model(model: &str) -> usize {
match model {
m if m.contains("grok") => 131_072, m if m.contains("gpt-4o") => 128_000,
m if m.contains("gpt-4-turbo") => 128_000,
m if m.contains("gpt-4-32k") => 32_768,
m if m.contains("gpt-4") => 8_192,
m if m.contains("gpt-3.5-turbo-16k") => 16_384,
m if m.contains("gpt-3.5") => 4_096,
m if m.contains("o1") || m.contains("o3") => 200_000,
_ => 128_000, }
}
fn dimension_for_embedding_model(model: &str) -> usize {
match model {
m if m.contains("text-embedding-3-large") => 3072,
m if m.contains("text-embedding-3-small") => 1536,
m if m.contains("text-embedding-ada") => 1536,
m if m.contains("copilot-text-embedding") => 1536,
_ => 1536, }
}
}
#[async_trait]
impl LLMProvider for VsCodeCopilotProvider {
fn name(&self) -> &str {
"vscode-copilot"
}
fn model(&self) -> &str {
&self.model
}
fn max_context_length(&self) -> usize {
self.max_context_length
}
async fn complete(&self, prompt: &str) -> LlmResult<LLMResponse> {
self.complete_with_options(prompt, &CompletionOptions::default())
.await
}
async fn complete_with_options(
&self,
prompt: &str,
options: &CompletionOptions,
) -> LlmResult<LLMResponse> {
let mut messages = Vec::new();
if let Some(system) = &options.system_prompt {
messages.push(ChatMessage::system(system));
}
messages.push(ChatMessage::user(prompt));
self.chat(&messages, Some(options)).await
}
async fn chat(
&self,
messages: &[ChatMessage],
options: Option<&CompletionOptions>,
) -> LlmResult<LLMResponse> {
let request_messages = Self::convert_messages(messages);
let opts = options.cloned().unwrap_or_default();
let request = ChatCompletionRequest {
messages: request_messages,
model: self.model.clone(),
temperature: opts.temperature,
top_p: opts.top_p,
max_tokens: opts.max_tokens,
stop: opts.stop,
stream: Some(false),
frequency_penalty: opts.frequency_penalty,
presence_penalty: opts.presence_penalty,
response_format: opts
.response_format
.map(|fmt| ResponseFormat { format_type: fmt }),
tools: None,
tool_choice: None,
parallel_tool_calls: None,
};
debug!(
model = %self.model,
message_count = messages.len(),
"Sending chat request"
);
let response = self.client.chat_completion(request).await?;
let choice = response
.choices
.first()
.ok_or_else(|| crate::error::LlmError::ApiError("No choices in response".into()))?;
let content = choice.message.content.clone().unwrap_or_default();
let usage = response.usage.unwrap_or(types::Usage {
prompt_tokens: 0,
completion_tokens: 0,
total_tokens: 0,
prompt_tokens_details: None,
extra: None,
});
debug!(
prompt_tokens = usage.prompt_tokens,
completion_tokens = usage.completion_tokens,
"Chat request completed"
);
let tool_calls = Self::convert_response_tool_calls(choice.message.tool_calls.clone());
let cache_hit_tokens = usage
.prompt_tokens_details
.as_ref()
.and_then(|d| d.cached_tokens);
let mut response_builder = LLMResponse::new(content, response.model.clone())
.with_usage(usage.prompt_tokens, usage.completion_tokens)
.with_finish_reason(choice.finish_reason.clone().unwrap_or_default())
.with_tool_calls(tool_calls)
.with_metadata("id", serde_json::json!(response.id));
if let Some(cached) = cache_hit_tokens {
response_builder = response_builder.with_cache_hit_tokens(cached);
}
Ok(response_builder)
}
async fn chat_with_tools(
&self,
messages: &[ChatMessage],
tools: &[ToolDefinition],
tool_choice: Option<ToolChoice>,
options: Option<&CompletionOptions>,
) -> LlmResult<LLMResponse> {
let request_messages = Self::convert_messages(messages);
let request_tools = if tools.is_empty() {
None
} else {
Some(Self::convert_tools(tools))
};
let request_tool_choice = Self::convert_tool_choice(tool_choice);
let opts = options.cloned().unwrap_or_default();
let request = ChatCompletionRequest {
messages: request_messages,
model: self.model.clone(),
temperature: opts.temperature,
top_p: opts.top_p,
max_tokens: opts.max_tokens,
stop: opts.stop,
stream: Some(false),
frequency_penalty: opts.frequency_penalty,
presence_penalty: opts.presence_penalty,
response_format: opts
.response_format
.map(|fmt| ResponseFormat { format_type: fmt }),
tools: request_tools,
tool_choice: request_tool_choice,
parallel_tool_calls: Some(true),
};
debug!(
model = %self.model,
message_count = messages.len(),
tool_count = tools.len(),
"Sending chat request with tools"
);
let response = self.client.chat_completion(request).await?;
let choice = response
.choices
.first()
.ok_or_else(|| crate::error::LlmError::ApiError("No choices in response".into()))?;
let content = choice.message.content.clone().unwrap_or_default();
let tool_calls = Self::convert_response_tool_calls(choice.message.tool_calls.clone());
let usage = response.usage.unwrap_or(types::Usage {
prompt_tokens: 0,
completion_tokens: 0,
total_tokens: 0,
prompt_tokens_details: None,
extra: None,
});
debug!(
prompt_tokens = usage.prompt_tokens,
completion_tokens = usage.completion_tokens,
tool_call_count = tool_calls.len(),
"Chat with tools request completed"
);
let cache_hit_tokens = usage
.prompt_tokens_details
.as_ref()
.and_then(|d| d.cached_tokens);
let mut response_builder = LLMResponse::new(content, response.model.clone())
.with_usage(usage.prompt_tokens, usage.completion_tokens)
.with_finish_reason(choice.finish_reason.clone().unwrap_or_default())
.with_tool_calls(tool_calls)
.with_metadata("id", serde_json::json!(response.id));
if let Some(cached) = cache_hit_tokens {
response_builder = response_builder.with_cache_hit_tokens(cached);
}
Ok(response_builder)
}
async fn stream(&self, prompt: &str) -> LlmResult<BoxStream<'static, LlmResult<String>>> {
let request_messages = vec![RequestMessage {
role: "user".to_string(),
content: Some(RequestContent::Text(prompt.to_string())),
name: None,
tool_calls: None,
tool_call_id: None,
cache_control: None,
}];
let request = ChatCompletionRequest {
messages: request_messages,
model: self.model.clone(),
stream: Some(true),
..Default::default()
};
debug!(model = %self.model, "Sending streaming request");
let response = self.client.chat_completion_stream(request).await?;
let stream = stream::parse_sse_stream(response);
let mapped = stream.map(|result| result.map_err(|e| e.into()));
Ok(Box::pin(mapped))
}
fn supports_streaming(&self) -> bool {
true
}
fn supports_json_mode(&self) -> bool {
true
}
fn supports_function_calling(&self) -> bool {
true
}
fn supports_tool_streaming(&self) -> bool {
true
}
async fn chat_with_tools_stream(
&self,
messages: &[ChatMessage],
tools: &[ToolDefinition],
tool_choice: Option<ToolChoice>,
options: Option<&CompletionOptions>,
) -> LlmResult<BoxStream<'static, LlmResult<StreamChunk>>> {
let request_messages = Self::convert_messages(messages);
let request_tools = if tools.is_empty() {
None
} else {
Some(Self::convert_tools(tools))
};
let request_tool_choice = Self::convert_tool_choice(tool_choice);
let opts = options.cloned().unwrap_or_default();
let request = ChatCompletionRequest {
messages: request_messages,
model: self.model.clone(),
temperature: opts.temperature,
top_p: opts.top_p,
max_tokens: opts.max_tokens,
stop: opts.stop,
stream: Some(true), frequency_penalty: opts.frequency_penalty,
presence_penalty: opts.presence_penalty,
response_format: opts
.response_format
.map(|fmt| ResponseFormat { format_type: fmt }),
tools: request_tools,
tool_choice: request_tool_choice,
parallel_tool_calls: Some(true),
};
debug!(
model = %self.model,
message_count = messages.len(),
tool_count = tools.len(),
"Sending streaming chat request with tools (OODA-05)"
);
let response = self.client.chat_completion_stream(request).await?;
let stream = stream::parse_sse_stream_with_tools(response);
let mapped = stream.map(|result| result.map_err(|e| e.into()));
Ok(Box::pin(mapped))
}
}
#[async_trait]
impl EmbeddingProvider for VsCodeCopilotProvider {
fn name(&self) -> &str {
"vscode-copilot"
}
#[allow(clippy::misnamed_getters)]
fn model(&self) -> &str {
&self.embedding_model
}
fn dimension(&self) -> usize {
self.embedding_dimension
}
fn max_tokens(&self) -> usize {
8192 }
async fn embed(&self, texts: &[String]) -> LlmResult<Vec<Vec<f32>>> {
let input = if texts.len() == 1 {
EmbeddingInput::Single(texts[0].clone())
} else {
EmbeddingInput::Multiple(texts.to_vec())
};
let request = EmbeddingRequest::new(input, &self.embedding_model);
debug!(
model = %self.embedding_model,
input_count = texts.len(),
"Sending embedding request"
);
let response = self.client.create_embeddings(request).await?;
debug!(
prompt_tokens = response.usage.prompt_tokens,
total_tokens = response.usage.total_tokens,
embedding_count = response.data.len(),
"Embedding request completed"
);
let embeddings: Vec<Vec<f32>> = response
.data
.into_iter()
.map(|e| (e.index, e.embedding))
.collect::<Vec<_>>()
.into_iter()
.map(|(_, e)| e)
.collect();
if embeddings.len() != texts.len() {
return Err(crate::error::LlmError::ApiError(format!(
"Expected {} embeddings, got {}",
texts.len(),
embeddings.len()
)));
}
Ok(embeddings)
}
}
#[cfg(test)]
mod tests {
use super::*;
use types::{ResponseFunctionCall, ResponseToolCall};
#[test]
fn test_convert_single_tool() {
let tools = vec![ToolDefinition::function(
"read_file",
"Read contents of a file",
serde_json::json!({
"type": "object",
"properties": {
"path": {"type": "string"}
},
"required": ["path"]
}),
)];
let converted = VsCodeCopilotProvider::convert_tools(&tools);
assert_eq!(converted.len(), 1);
assert_eq!(converted[0].tool_type, "function");
assert_eq!(converted[0].function.name, "read_file");
assert_eq!(converted[0].function.description, "Read contents of a file");
assert!(converted[0].function.strict.is_some());
}
#[test]
fn test_convert_multiple_tools() {
let tools = vec![
ToolDefinition::function("tool_a", "First tool", serde_json::json!({})),
ToolDefinition::function("tool_b", "Second tool", serde_json::json!({})),
ToolDefinition::function("tool_c", "Third tool", serde_json::json!({})),
];
let converted = VsCodeCopilotProvider::convert_tools(&tools);
assert_eq!(converted.len(), 3);
assert_eq!(converted[0].function.name, "tool_a");
assert_eq!(converted[1].function.name, "tool_b");
assert_eq!(converted[2].function.name, "tool_c");
}
#[test]
fn test_convert_tool_with_complex_parameters() {
let params = serde_json::json!({
"type": "object",
"properties": {
"query": {"type": "string", "description": "Search query"},
"options": {
"type": "object",
"properties": {
"regex": {"type": "boolean"},
"case_sensitive": {"type": "boolean"}
}
}
},
"required": ["query"]
});
let tools = vec![ToolDefinition::function(
"grep_search",
"Search codebase",
params.clone(),
)];
let converted = VsCodeCopilotProvider::convert_tools(&tools);
assert_eq!(converted[0].function.parameters, params);
}
#[test]
fn test_tool_choice_none() {
let result = VsCodeCopilotProvider::convert_tool_choice(None);
assert!(result.is_none());
}
#[test]
fn test_tool_choice_auto() {
let choice = ToolChoice::auto();
let result = VsCodeCopilotProvider::convert_tool_choice(Some(choice));
assert_eq!(result, Some(serde_json::Value::String("auto".to_string())));
}
#[test]
fn test_tool_choice_required() {
let choice = ToolChoice::required();
let result = VsCodeCopilotProvider::convert_tool_choice(Some(choice));
assert_eq!(
result,
Some(serde_json::Value::String("required".to_string()))
);
}
#[test]
fn test_tool_choice_function() {
let choice = ToolChoice::function("read_file");
let result = VsCodeCopilotProvider::convert_tool_choice(Some(choice));
let expected = serde_json::json!({
"type": "function",
"function": {
"name": "read_file"
}
});
assert_eq!(result, Some(expected));
}
#[test]
fn test_tool_choice_none_value() {
let choice = ToolChoice::none();
let result = VsCodeCopilotProvider::convert_tool_choice(Some(choice));
assert_eq!(result, Some(serde_json::Value::String("none".to_string())));
}
#[test]
fn test_response_tool_calls_none() {
let result = VsCodeCopilotProvider::convert_response_tool_calls(None);
assert!(result.is_empty());
}
#[test]
fn test_response_tool_calls_single() {
let calls = vec![ResponseToolCall {
id: "call_123".to_string(),
call_type: "function".to_string(),
function: ResponseFunctionCall {
name: "read_file".to_string(),
arguments: r#"{"path":"src/main.rs"}"#.to_string(),
},
}];
let result = VsCodeCopilotProvider::convert_response_tool_calls(Some(calls));
assert_eq!(result.len(), 1);
assert_eq!(result[0].id, "call_123");
assert_eq!(result[0].call_type, "function");
assert_eq!(result[0].function.name, "read_file");
assert_eq!(result[0].function.arguments, r#"{"path":"src/main.rs"}"#);
}
#[test]
fn test_response_tool_calls_multiple() {
let calls = vec![
ResponseToolCall {
id: "call_1".to_string(),
call_type: "function".to_string(),
function: ResponseFunctionCall {
name: "read_file".to_string(),
arguments: "{}".to_string(),
},
},
ResponseToolCall {
id: "call_2".to_string(),
call_type: "function".to_string(),
function: ResponseFunctionCall {
name: "search_code".to_string(),
arguments: "{}".to_string(),
},
},
];
let result = VsCodeCopilotProvider::convert_response_tool_calls(Some(calls));
assert_eq!(result.len(), 2);
assert_eq!(result[0].id, "call_1");
assert_eq!(result[1].id, "call_2");
}
#[test]
fn test_message_with_tool_calls() {
let mut msg = ChatMessage::assistant("I'll read that file for you.");
msg.tool_calls = Some(vec![ToolCall {
id: "call_abc".to_string(),
call_type: "function".to_string(),
function: FunctionCall {
name: "read_file".to_string(),
arguments: r#"{"path":"Cargo.toml"}"#.to_string(),
},
thought_signature: None,
}]);
let converted = VsCodeCopilotProvider::convert_messages(&[msg]);
assert_eq!(converted.len(), 1);
assert!(converted[0].tool_calls.is_some());
let tool_calls = converted[0].tool_calls.as_ref().unwrap();
assert_eq!(tool_calls.len(), 1);
assert_eq!(tool_calls[0].id, "call_abc");
assert_eq!(tool_calls[0].function.name, "read_file");
}
#[test]
fn test_tool_message_conversion() {
let msg = ChatMessage {
role: ChatRole::Tool,
content: "File contents: ...".to_string(),
name: Some("read_file".to_string()),
tool_calls: None,
tool_call_id: Some("call_xyz".to_string()),
cache_control: None,
images: None,
};
let converted = VsCodeCopilotProvider::convert_messages(&[msg]);
assert_eq!(converted.len(), 1);
assert_eq!(converted[0].role, "tool");
assert_eq!(
converted[0].content,
Some(RequestContent::Text("File contents: ...".to_string()))
);
assert_eq!(converted[0].tool_call_id, Some("call_xyz".to_string()));
}
#[test]
fn test_assistant_message_with_only_tool_calls() {
let mut msg = ChatMessage::assistant("");
msg.tool_calls = Some(vec![ToolCall {
id: "call_1".to_string(),
call_type: "function".to_string(),
function: FunctionCall {
name: "list_files".to_string(),
arguments: "{}".to_string(),
},
thought_signature: None,
}]);
let converted = VsCodeCopilotProvider::convert_messages(&[msg]);
assert!(converted[0].content.is_none());
assert!(converted[0].tool_calls.is_some());
}
#[test]
fn test_convert_messages_text_only() {
let messages = vec![ChatMessage::user("Hello, world!")];
let converted = VsCodeCopilotProvider::convert_messages(&messages);
assert_eq!(converted.len(), 1);
assert_eq!(converted[0].role, "user");
match &converted[0].content {
Some(RequestContent::Text(text)) => {
assert_eq!(text, "Hello, world!");
}
_ => panic!("Expected RequestContent::Text"),
}
}
#[test]
fn test_convert_messages_with_images() {
use crate::traits::ImageData;
let msg = ChatMessage::user_with_images(
"What's in this image?",
vec![ImageData {
data: "iVBORw0KGgo=".to_string(),
mime_type: "image/png".to_string(),
detail: None,
}],
);
let converted = VsCodeCopilotProvider::convert_messages(&[msg]);
assert_eq!(converted.len(), 1);
match &converted[0].content {
Some(RequestContent::Parts(parts)) => {
assert_eq!(parts.len(), 2);
match &parts[0] {
ContentPart::Text { text } => {
assert_eq!(text, "What's in this image?");
}
_ => panic!("First part should be text"),
}
match &parts[1] {
ContentPart::ImageUrl { image_url } => {
assert!(image_url.url.starts_with("data:image/png;base64,"));
assert!(image_url.url.contains("iVBORw0KGgo="));
}
_ => panic!("Second part should be image_url"),
}
}
_ => panic!("Expected RequestContent::Parts for image message"),
}
}
#[test]
fn test_convert_messages_with_image_detail() {
use crate::traits::ImageData;
let msg = ChatMessage::user_with_images(
"Describe in detail",
vec![ImageData {
data: "base64data".to_string(),
mime_type: "image/jpeg".to_string(),
detail: Some("high".to_string()),
}],
);
let converted = VsCodeCopilotProvider::convert_messages(&[msg]);
match &converted[0].content {
Some(RequestContent::Parts(parts)) => {
assert_eq!(parts.len(), 2);
match &parts[1] {
ContentPart::ImageUrl { image_url } => {
assert_eq!(image_url.detail, Some("high".to_string()));
}
_ => panic!("Expected ImageUrl part"),
}
}
_ => panic!("Expected Parts content"),
}
}
#[test]
fn test_context_length_detection() {
assert_eq!(
VsCodeCopilotProviderBuilder::context_length_for_model("gpt-4o"),
128_000
);
assert_eq!(
VsCodeCopilotProviderBuilder::context_length_for_model("gpt-4o-mini"),
128_000
);
assert_eq!(
VsCodeCopilotProviderBuilder::context_length_for_model("gpt-4"),
8_192
);
assert_eq!(
VsCodeCopilotProviderBuilder::context_length_for_model("gpt-3.5-turbo"),
4_096
);
assert_eq!(
VsCodeCopilotProviderBuilder::context_length_for_model("o1-preview"),
200_000
);
}
#[test]
fn test_message_conversion() {
let messages = vec![
ChatMessage::system("You are helpful."),
ChatMessage::user("Hello!"),
ChatMessage::assistant("Hi there!"),
];
let converted = VsCodeCopilotProvider::convert_messages(&messages);
assert_eq!(converted.len(), 3);
assert_eq!(converted[0].role, "system");
assert_eq!(
converted[0].content,
Some(RequestContent::Text("You are helpful.".to_string()))
);
assert_eq!(converted[1].role, "user");
assert_eq!(converted[2].role, "assistant");
}
#[test]
fn test_builder_defaults() {
std::env::set_var("VSCODE_COPILOT_DIRECT", "true");
let builder = VsCodeCopilotProviderBuilder::default();
assert_eq!(builder.model, "gpt-5-mini");
assert_eq!(builder.max_context_length, 128_000);
assert!(!builder.supports_vision);
assert!(builder.direct_mode); std::env::remove_var("VSCODE_COPILOT_DIRECT");
}
#[test]
fn test_builder_proxy_mode() {
let provider = VsCodeCopilotProvider::new()
.proxy_url("http://localhost:8080")
.model("gpt-4")
.with_vision(true)
.build()
.unwrap();
assert_eq!(provider.model, "gpt-4");
assert_eq!(provider.max_context_length, 8_192);
assert!(provider.supports_vision);
}
#[test]
fn test_builder_direct_mode() {
let provider = VsCodeCopilotProvider::new()
.direct()
.model("gpt-4o")
.build()
.unwrap();
assert_eq!(provider.model, "gpt-4o");
assert_eq!(provider.max_context_length, 128_000);
}
#[test]
fn test_account_type_base_url() {
assert_eq!(
client::AccountType::Individual.base_url(),
"https://api.githubcopilot.com"
);
assert_eq!(
client::AccountType::Business.base_url(),
"https://api.business.githubcopilot.com"
);
assert_eq!(
client::AccountType::Enterprise.base_url(),
"https://api.enterprise.githubcopilot.com"
);
}
#[test]
fn test_embedding_dimension_detection() {
assert_eq!(
VsCodeCopilotProviderBuilder::dimension_for_embedding_model("text-embedding-3-small"),
1536
);
assert_eq!(
VsCodeCopilotProviderBuilder::dimension_for_embedding_model("text-embedding-3-large"),
3072
);
assert_eq!(
VsCodeCopilotProviderBuilder::dimension_for_embedding_model("text-embedding-ada-002"),
1536
);
assert_eq!(
VsCodeCopilotProviderBuilder::dimension_for_embedding_model("unknown-model"),
1536 );
}
#[test]
fn test_builder_embedding_model() {
let provider = VsCodeCopilotProvider::new()
.direct()
.embedding_model("text-embedding-3-large")
.build()
.unwrap();
assert_eq!(provider.embedding_model, "text-embedding-3-large");
assert_eq!(provider.embedding_dimension, 3072);
}
#[test]
fn test_builder_vision_disabled_by_default() {
let builder = VsCodeCopilotProvider::new().direct();
let provider = builder.build();
assert!(provider.is_ok());
let provider = provider.unwrap();
assert!(!provider.supports_vision);
}
#[test]
fn test_builder_with_vision_true() {
let provider = VsCodeCopilotProvider::new()
.direct()
.with_vision(true)
.build()
.unwrap();
assert!(provider.supports_vision);
}
#[test]
fn test_builder_with_vision_false() {
let provider = VsCodeCopilotProvider::new()
.direct()
.with_vision(true)
.with_vision(false) .build()
.unwrap();
assert!(!provider.supports_vision);
}
#[test]
fn test_builder_vision_with_model() {
let provider = VsCodeCopilotProvider::new()
.direct()
.model("gpt-4o") .with_vision(true)
.build()
.unwrap();
assert!(provider.supports_vision);
assert_eq!(provider.model, "gpt-4o");
}
#[test]
fn test_builder_vision_with_proxy_mode() {
let builder = VsCodeCopilotProvider::new()
.proxy_url("http://localhost:4141")
.with_vision(true);
assert!(builder.supports_vision);
}
#[test]
fn test_builder_chain_all_options() {
use std::time::Duration;
let builder = VsCodeCopilotProvider::new()
.model("claude-3.5-sonnet")
.embedding_model("text-embedding-3-large")
.with_vision(true)
.timeout(Duration::from_secs(120));
assert_eq!(builder.model, "claude-3.5-sonnet");
assert_eq!(builder.embedding_model, "text-embedding-3-large");
assert!(builder.supports_vision);
assert_eq!(builder.timeout.as_secs(), 120);
}
#[test]
fn test_builder_account_type_business() {
use client::AccountType;
let builder = VsCodeCopilotProvider::new().account_type(AccountType::Business);
assert!(matches!(builder.account_type, AccountType::Business));
}
#[test]
fn test_builder_account_type_enterprise() {
use client::AccountType;
let builder = VsCodeCopilotProvider::new().account_type(AccountType::Enterprise);
assert!(matches!(builder.account_type, AccountType::Enterprise));
}
#[test]
fn test_builder_default_embedding_model() {
std::env::remove_var("VSCODE_COPILOT_EMBEDDING_MODEL"); let builder = VsCodeCopilotProviderBuilder::default();
assert_eq!(builder.embedding_model, "text-embedding-3-small");
assert_eq!(builder.embedding_dimension, 1536);
}
#[test]
fn test_builder_default_timeout() {
let builder = VsCodeCopilotProviderBuilder::default();
assert_eq!(builder.timeout.as_secs(), 120);
}
#[test]
fn test_builder_default_context_length() {
std::env::set_var("VSCODE_COPILOT_DIRECT", "true");
let builder = VsCodeCopilotProviderBuilder::default();
assert_eq!(builder.max_context_length, 128_000);
std::env::remove_var("VSCODE_COPILOT_DIRECT");
}
}