use async_trait::async_trait;
use futures::stream::BoxStream;
use reqwest::header::{self, HeaderMap, HeaderName, HeaderValue};
use reqwest::Client;
use serde::{Deserialize, Serialize};
use std::time::Duration;
use tracing::{debug, warn};
use crate::error::{LlmError, Result};
use crate::model_config::{ModelCard, ModelType, ProviderConfig};
use crate::traits::{
ChatMessage, ChatRole, CompletionOptions, EmbeddingProvider, FunctionCall, LLMProvider,
LLMResponse, StreamChunk, ToolCall, ToolChoice, ToolDefinition,
};
#[derive(Debug, Serialize)]
struct ChatRequest<'a> {
model: &'a str,
messages: Vec<MessageRequest>,
#[serde(skip_serializing_if = "Option::is_none")]
temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
top_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
max_tokens: Option<usize>,
#[serde(skip_serializing_if = "Option::is_none")]
stop: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
frequency_penalty: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
presence_penalty: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
seed: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
user: Option<&'a str>,
#[serde(skip_serializing_if = "Option::is_none")]
stream: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
stream_options: Option<StreamOptions>,
#[serde(skip_serializing_if = "Option::is_none")]
tools: Option<Vec<ToolDefinition>>,
#[serde(skip_serializing_if = "Option::is_none")]
tool_choice: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
thinking: Option<ThinkingConfig>,
#[serde(skip_serializing_if = "Option::is_none")]
response_format: Option<ResponseFormat>,
#[serde(skip_serializing_if = "Option::is_none")]
reasoning_effort: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
safe_prompt: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
parallel_tool_calls: Option<bool>,
}
#[derive(Debug, Serialize)]
struct StreamOptions {
include_usage: bool,
}
#[derive(Debug, Serialize)]
#[serde(untagged)]
enum RequestContent {
Text(String),
Parts(Vec<ContentPart>),
}
#[derive(Debug, Serialize)]
#[serde(tag = "type")]
enum ContentPart {
#[serde(rename = "text")]
Text { text: String },
#[serde(rename = "image_url")]
ImageUrl { image_url: ImageUrlContent },
}
#[derive(Debug, Serialize)]
struct ImageUrlContent {
url: String,
#[serde(skip_serializing_if = "Option::is_none")]
detail: Option<String>,
}
#[derive(Debug, Serialize)]
struct MessageRequest {
role: String,
content: RequestContent,
#[serde(skip_serializing_if = "Option::is_none")]
name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
tool_call_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
tool_calls: Option<Vec<ToolCallRequest>>,
}
#[derive(Debug, Serialize)]
struct ToolCallRequest {
id: String,
#[serde(rename = "type")]
call_type: String,
function: FunctionCallRequest,
}
#[derive(Debug, Serialize)]
struct FunctionCallRequest {
name: String,
arguments: String,
}
#[derive(Debug, Serialize)]
struct ThinkingConfig {
#[serde(rename = "type")]
thinking_type: String, }
#[derive(Debug, Serialize)]
struct ResponseFormat {
#[serde(rename = "type")]
format_type: String, }
#[derive(Debug, Deserialize)]
struct ChatResponse {
#[allow(dead_code)]
id: Option<String>,
#[allow(dead_code)]
model: Option<String>,
choices: Vec<Choice>,
usage: Option<Usage>,
}
#[derive(Debug, Deserialize)]
struct Choice {
#[allow(dead_code)]
index: Option<usize>,
message: Option<MessageContent>,
#[allow(dead_code)]
delta: Option<serde_json::Value>,
finish_reason: Option<String>,
}
#[derive(Debug, Deserialize)]
struct MessageContent {
#[allow(dead_code)]
role: Option<String>,
content: Option<String>,
reasoning_content: Option<String>,
tool_calls: Option<Vec<ToolCallResponse>>,
}
#[derive(Debug, Deserialize)]
struct ToolCallResponse {
id: String,
#[serde(rename = "type")]
#[allow(dead_code)]
call_type: Option<String>,
function: FunctionCallResponse,
}
#[derive(Debug, Deserialize)]
struct FunctionCallResponse {
name: String,
arguments: String,
}
#[derive(Debug, Deserialize, Default)]
struct CompletionTokensDetails {
#[serde(default)]
reasoning_tokens: Option<usize>,
}
#[derive(Debug, Deserialize, Default)]
struct PromptTokensDetails {
#[serde(default)]
cached_tokens: Option<usize>,
}
#[derive(Debug, Deserialize)]
struct Usage {
prompt_tokens: usize,
completion_tokens: usize,
#[allow(dead_code)]
total_tokens: Option<usize>,
#[serde(default)]
prompt_tokens_details: Option<PromptTokensDetails>,
#[serde(default)]
completion_tokens_details: Option<CompletionTokensDetails>,
}
#[derive(Debug, Deserialize)]
struct ErrorResponse {
error: ErrorDetail,
}
#[derive(Debug, Deserialize)]
struct ErrorDetail {
message: String,
#[allow(dead_code)]
code: Option<String>,
}
#[derive(Debug, Deserialize)]
struct ChatStreamChunk {
#[allow(dead_code)]
id: Option<String>,
#[allow(dead_code)]
model: Option<String>,
choices: Vec<StreamChoice>,
#[allow(dead_code)]
#[serde(default)]
usage: Option<Usage>,
}
#[derive(Debug, Deserialize)]
struct StreamChoice {
#[allow(dead_code)]
index: Option<usize>,
delta: Option<StreamDelta>,
#[allow(dead_code)]
finish_reason: Option<String>,
}
#[derive(Debug, Deserialize)]
struct StreamDelta {
#[serde(default)]
content: Option<String>,
#[serde(default)]
reasoning_content: Option<String>,
tool_calls: Option<Vec<ToolCallDelta>>,
}
#[derive(Debug, Deserialize)]
struct ToolCallDelta {
index: Option<usize>,
id: Option<String>,
function: Option<FunctionDelta>,
}
#[derive(Debug, Deserialize)]
struct FunctionDelta {
name: Option<String>,
arguments: Option<String>,
}
#[derive(Debug, Serialize)]
struct EmbeddingRequest<'a> {
model: &'a str,
input: &'a [String],
#[serde(skip_serializing_if = "Option::is_none")]
encoding_format: Option<&'a str>,
}
#[derive(Debug, Deserialize)]
struct EmbeddingResponse {
data: Vec<EmbeddingObject>,
#[allow(dead_code)]
model: Option<String>,
#[allow(dead_code)]
usage: Option<EmbeddingUsage>,
}
#[derive(Debug, Deserialize)]
struct EmbeddingObject {
embedding: Vec<f32>,
#[allow(dead_code)]
index: Option<usize>,
}
#[derive(Debug, Deserialize)]
struct EmbeddingUsage {
#[allow(dead_code)]
prompt_tokens: Option<usize>,
#[allow(dead_code)]
total_tokens: Option<usize>,
}
#[derive(Debug)]
pub struct OpenAICompatibleProvider {
client: Client,
config: ProviderConfig,
api_key: String,
model: String,
model_card: Option<ModelCard>,
base_url: String,
}
impl OpenAICompatibleProvider {
pub fn from_config(config: ProviderConfig) -> Result<Self> {
let api_key = Self::resolve_api_key(&config)?;
let base_url = Self::resolve_base_url(&config)?;
let client = Self::build_client(&config)?;
let model = config
.default_llm_model
.clone()
.unwrap_or_else(|| "default".to_string());
let model_card = config.models.iter().find(|m| m.name == model).cloned();
debug!(
provider = config.name,
model = model,
base_url = base_url,
"Created OpenAI-compatible provider"
);
Ok(Self {
client,
config,
api_key,
model,
model_card,
base_url,
})
}
fn resolve_api_key(config: &ProviderConfig) -> Result<String> {
if let Some(ref key) = config.api_key {
return Ok(key.clone());
}
if let Some(env_var) = &config.api_key_env {
std::env::var(env_var).map_err(|_| {
LlmError::ConfigError(format!(
"API key environment variable '{}' not set for provider '{}'. \
Please set it with: export {}=your-api-key",
env_var, config.name, env_var
))
})
} else {
Ok(String::new())
}
}
fn resolve_base_url(config: &ProviderConfig) -> Result<String> {
if let Some(env_var) = &config.base_url_env {
if let Ok(url) = std::env::var(env_var) {
return Ok(url);
}
}
config.base_url.clone().ok_or_else(|| {
LlmError::ConfigError(format!(
"Provider '{}' requires 'base_url' or 'base_url_env' to be set",
config.name
))
})
}
fn build_client(config: &ProviderConfig) -> Result<Client> {
let mut headers = HeaderMap::new();
headers.insert(
header::CONTENT_TYPE,
HeaderValue::from_static("application/json"),
);
for (key, value) in &config.headers {
let header_name = HeaderName::from_bytes(key.as_bytes()).map_err(|e| {
LlmError::ConfigError(format!("Invalid header name '{}': {}", key, e))
})?;
let header_value = HeaderValue::from_str(value).map_err(|e| {
LlmError::ConfigError(format!("Invalid header value for '{}': {}", key, e))
})?;
headers.insert(header_name, header_value);
}
Client::builder()
.default_headers(headers)
.timeout(Duration::from_secs(config.timeout_seconds))
.build()
.map_err(|e| LlmError::ConfigError(format!("Failed to build HTTP client: {}", e)))
}
fn chat_completions_url(&self) -> String {
let base = self.base_url.trim_end_matches('/');
format!("{}/chat/completions", base)
}
fn embeddings_url(&self) -> String {
let base = self.base_url.trim_end_matches('/');
format!("{}/embeddings", base)
}
fn extract_answer_content(message: &MessageContent) -> String {
message.content.clone().unwrap_or_default()
}
fn extract_thinking_content(message: &MessageContent) -> Option<String> {
message
.reasoning_content
.as_deref()
.filter(|s| !s.is_empty())
.map(ToOwned::to_owned)
}
fn convert_messages(messages: &[ChatMessage]) -> Vec<MessageRequest> {
messages
.iter()
.map(|msg| {
let role = match msg.role {
ChatRole::System => "system",
ChatRole::User => "user",
ChatRole::Assistant => "assistant",
ChatRole::Tool | ChatRole::Function => "tool",
};
let content = if msg.has_images() {
let mut parts = Vec::new();
if !msg.content.is_empty() {
parts.push(ContentPart::Text {
text: msg.content.clone(),
});
}
if let Some(ref images) = msg.images {
for img in images {
parts.push(ContentPart::ImageUrl {
image_url: ImageUrlContent {
url: img.to_data_uri(),
detail: img.detail.clone(),
},
});
}
}
RequestContent::Parts(parts)
} else {
RequestContent::Text(msg.content.clone())
};
MessageRequest {
role: role.to_string(),
content,
name: msg.name.clone(),
tool_call_id: msg.tool_call_id.clone(),
tool_calls: None,
}
})
.collect()
}
fn convert_tool_choice(choice: &ToolChoice) -> serde_json::Value {
match choice {
ToolChoice::Auto(s) => serde_json::json!(s),
ToolChoice::Required(s) => serde_json::json!(s),
ToolChoice::Function { function, .. } => {
serde_json::json!({
"type": "function",
"function": {
"name": function.name
}
})
}
}
}
async fn chat_request(&self, request: &ChatRequest<'_>) -> Result<ChatResponse> {
let url = self.chat_completions_url();
debug!(
"OpenAI-compatible API Request: url={} model={} provider={}",
url, request.model, self.config.name
);
let mut req_builder = self.client.post(&url);
if !self.api_key.is_empty() {
req_builder = req_builder.header("Authorization", format!("Bearer {}", self.api_key));
debug!("API key: {}...", &self.api_key[..4.min(self.api_key.len())]);
} else {
warn!("No API key set for provider: {}", self.config.name);
}
let response = req_builder.json(request).send().await.map_err(|e| {
warn!("Network error calling {} API: {}", self.config.name, e);
LlmError::NetworkError(format!("Failed to connect to {}: {}", url, e))
})?;
let status = response.status();
debug!("{} API Response: status={}", self.config.name, status);
let body = response.text().await.map_err(|e| {
warn!("Failed to read response body: {}", e);
LlmError::NetworkError(e.to_string())
})?;
if !status.is_success() {
warn!(
"{} API error: status={} body={}",
self.config.name, status, body
);
if let Ok(error_resp) = serde_json::from_str::<ErrorResponse>(&body) {
return Err(LlmError::ApiError(format!(
"{} API {}: {}",
self.config.name,
status.as_u16(),
error_resp.error.message
)));
}
return Err(LlmError::ApiError(format!(
"{} API error {}: {}",
self.config.name,
status.as_u16(),
body
)));
}
debug!(
"{} API success, parsing response body (length: {})",
self.config.name,
body.len()
);
if self.model.to_lowercase().contains("glm") {
debug!("GLM response body: {}", &body[..1000.min(body.len())]);
}
serde_json::from_str(&body).map_err(|e| {
warn!(
"Failed to parse {} response: {} | body: {}",
self.config.name,
e,
&body[..500.min(body.len())]
);
LlmError::ApiError(format!(
"Failed to parse {} response: {} | body preview: {}",
self.config.name,
e,
&body[..500.min(body.len())]
))
})
}
pub fn with_model(mut self, model: impl Into<String>) -> Self {
let model_name = model.into();
self.model_card = self
.config
.models
.iter()
.find(|m| m.name == model_name)
.cloned();
self.model = model_name;
self
}
pub fn model_card(&self) -> Option<&ModelCard> {
self.model_card.as_ref()
}
}
#[async_trait]
impl LLMProvider for OpenAICompatibleProvider {
fn name(&self) -> &str {
&self.config.name
}
fn model(&self) -> &str {
&self.model
}
fn max_context_length(&self) -> usize {
self.model_card
.as_ref()
.map(|m| m.capabilities.context_length)
.unwrap_or(128000)
}
async fn complete(&self, prompt: &str) -> Result<LLMResponse> {
self.complete_with_options(prompt, &CompletionOptions::default())
.await
}
async fn complete_with_options(
&self,
prompt: &str,
options: &CompletionOptions,
) -> Result<LLMResponse> {
let mut messages = Vec::new();
if let Some(system) = &options.system_prompt {
messages.push(ChatMessage::system(system));
}
messages.push(ChatMessage::user(prompt));
self.chat(&messages, Some(options)).await
}
async fn chat(
&self,
messages: &[ChatMessage],
options: Option<&CompletionOptions>,
) -> Result<LLMResponse> {
let options = options.cloned().unwrap_or_default();
let messages_req = Self::convert_messages(messages);
let use_json_mode = options
.response_format
.as_ref()
.map(|f| f == "json_object" || f == "json")
.unwrap_or(false);
let request = ChatRequest {
model: &self.model,
messages: messages_req,
temperature: options.temperature,
top_p: options.top_p,
max_tokens: options.max_tokens,
stop: options.stop.clone(),
frequency_penalty: options.frequency_penalty,
presence_penalty: options.presence_penalty,
seed: None,
user: None,
stream: Some(false),
stream_options: None,
tools: None,
tool_choice: None,
thinking: if self.config.supports_thinking {
Some(ThinkingConfig {
thinking_type: "enabled".to_string(),
})
} else {
None
},
response_format: if use_json_mode {
Some(ResponseFormat {
format_type: "json_object".to_string(),
})
} else {
None
},
reasoning_effort: options.reasoning_effort.clone(),
safe_prompt: options.safe_prompt,
parallel_tool_calls: None,
};
let response = self.chat_request(&request).await?;
let choice = response
.choices
.first()
.ok_or_else(|| LlmError::ApiError("No choices in response".to_string()))?;
let message = choice
.message
.as_ref()
.ok_or_else(|| LlmError::ApiError("No message in choice".to_string()))?;
let content = Self::extract_answer_content(message);
let thinking_content = message.reasoning_content.clone().filter(|s| !s.is_empty());
let (prompt_tokens, completion_tokens, reasoning_tokens, cache_hit_tokens) = response
.usage
.as_ref()
.map(|u| {
let reasoning = u
.completion_tokens_details
.as_ref()
.and_then(|d| d.reasoning_tokens);
let cached = u
.prompt_tokens_details
.as_ref()
.and_then(|d| d.cached_tokens);
(u.prompt_tokens, u.completion_tokens, reasoning, cached)
})
.unwrap_or((0, 0, None, None));
let mut llm_response = LLMResponse::new(content, &self.model)
.with_usage(prompt_tokens, completion_tokens)
.with_finish_reason(
choice
.finish_reason
.clone()
.unwrap_or_else(|| "stop".to_string()),
);
if let Some(ref id) = response.id {
llm_response = llm_response.with_metadata("id", serde_json::Value::String(id.clone()));
}
if let Some(tokens) = reasoning_tokens {
llm_response = llm_response.with_thinking_tokens(tokens);
}
if let Some(thinking) = thinking_content {
llm_response = llm_response.with_thinking_content(thinking);
}
if let Some(cached) = cache_hit_tokens {
llm_response = llm_response.with_cache_hit_tokens(cached);
}
Ok(llm_response)
}
async fn chat_with_tools(
&self,
messages: &[ChatMessage],
tools: &[ToolDefinition],
tool_choice: Option<ToolChoice>,
options: Option<&CompletionOptions>,
) -> Result<LLMResponse> {
debug!(
"chat_with_tools called: model={} provider={} messages={} tools={}",
self.model,
self.config.name,
messages.len(),
tools.len()
);
let options = options.cloned().unwrap_or_default();
let messages_req = Self::convert_messages(messages);
let have_tools = !tools.is_empty();
let api_tools = have_tools.then(|| tools.to_vec());
let api_tool_choice = if have_tools {
tool_choice.as_ref().map(Self::convert_tool_choice)
} else {
None
};
let request = ChatRequest {
model: &self.model,
messages: messages_req,
temperature: options.temperature,
top_p: options.top_p,
max_tokens: options.max_tokens,
stop: options.stop.clone(),
frequency_penalty: options.frequency_penalty,
presence_penalty: options.presence_penalty,
seed: None,
user: None,
stream: Some(false),
stream_options: None,
tools: api_tools,
tool_choice: api_tool_choice,
thinking: if self.config.supports_thinking {
Some(ThinkingConfig {
thinking_type: "enabled".to_string(),
})
} else {
None
},
response_format: None,
reasoning_effort: None,
safe_prompt: None,
parallel_tool_calls: options.parallel_tool_calls,
};
let response = self.chat_request(&request).await?;
let choice = response
.choices
.first()
.ok_or_else(|| LlmError::ApiError("No choices in response".to_string()))?;
let message = choice
.message
.as_ref()
.ok_or_else(|| LlmError::ApiError("No message in choice".to_string()))?;
let content = Self::extract_answer_content(message);
let thinking_content = Self::extract_thinking_content(message);
if self.model.to_lowercase().contains("glm") {
debug!(
"GLM message structure - content_len={} tool_calls_present={} tool_calls_count={}",
content.len(),
message.tool_calls.is_some(),
message.tool_calls.as_ref().map(|t| t.len()).unwrap_or(0)
);
}
let tool_calls: Vec<ToolCall> = message
.tool_calls
.as_ref()
.map(|calls| {
calls
.iter()
.map(|tc| {
if tc.function.arguments.is_empty() || tc.function.arguments == "{}" {
warn!(
"Empty tool arguments detected - tool={} id={} args='{}' (GLM model may not be providing arguments correctly)",
tc.function.name, tc.id, tc.function.arguments
);
} else {
debug!(
"Tool call extracted - tool={} args_len={} id={}",
tc.function.name, tc.function.arguments.len(), tc.id
);
}
ToolCall {
id: tc.id.clone(),
call_type: "function".to_string(),
function: FunctionCall {
name: tc.function.name.clone(),
arguments: tc.function.arguments.clone(),
},
thought_signature: None,
}
})
.collect()
})
.unwrap_or_default();
let (prompt_tokens, completion_tokens, reasoning_tokens) = response
.usage
.as_ref()
.map(|u| {
let reasoning = u
.completion_tokens_details
.as_ref()
.and_then(|d| d.reasoning_tokens);
(u.prompt_tokens, u.completion_tokens, reasoning)
})
.unwrap_or((0, 0, None));
let mut llm_response = LLMResponse::new(content, &self.model)
.with_usage(prompt_tokens, completion_tokens)
.with_tool_calls(tool_calls)
.with_finish_reason(
choice
.finish_reason
.clone()
.unwrap_or_else(|| "stop".to_string()),
);
if let Some(tokens) = reasoning_tokens {
llm_response = llm_response.with_thinking_tokens(tokens);
}
if let Some(thinking) = thinking_content {
llm_response = llm_response.with_thinking_content(thinking);
}
Ok(llm_response)
}
fn supports_streaming(&self) -> bool {
self.model_card
.as_ref()
.map(|m| m.capabilities.supports_streaming)
.unwrap_or(true)
}
fn supports_function_calling(&self) -> bool {
self.model_card
.as_ref()
.map(|m| m.capabilities.supports_function_calling)
.unwrap_or(true)
}
async fn stream(&self, prompt: &str) -> Result<BoxStream<'static, Result<String>>> {
use futures::StreamExt;
let messages = vec![MessageRequest {
role: "user".to_string(),
content: RequestContent::Text(prompt.to_string()),
name: None,
tool_call_id: None,
tool_calls: None,
}];
let request = ChatRequest {
model: &self.model,
messages,
temperature: None,
top_p: None,
max_tokens: None,
stop: None,
frequency_penalty: None,
presence_penalty: None,
seed: None,
user: None,
stream: Some(true),
stream_options: Some(StreamOptions {
include_usage: true,
}),
tools: None,
tool_choice: None,
thinking: None,
response_format: None,
reasoning_effort: None,
safe_prompt: None,
parallel_tool_calls: None,
};
let url = self.chat_completions_url();
debug!(
"{} Stream Request: url={} model={}",
self.config.name, url, &self.model
);
debug!(
"{} Stream Request body: {}",
self.config.name,
serde_json::to_string_pretty(&request).unwrap_or_default()
);
let mut req_builder = self.client.post(&url);
if !self.api_key.is_empty() {
req_builder = req_builder.header("Authorization", format!("Bearer {}", self.api_key));
}
let req_builder = req_builder.json(&request);
use reqwest_eventsource::EventSource;
let event_source = EventSource::new(req_builder)
.map_err(|e| {
let error_msg = e.to_string();
warn!("Failed to create event source: {}", error_msg);
if error_msg.contains("400") && error_msg.contains("Bad Request") {
let error_lower = error_msg.to_lowercase();
if error_lower.contains("tool")
|| error_lower.contains("function")
|| error_msg.contains("not supported")
|| error_msg.contains("No endpoints found") {
LlmError::ApiError(format!(
"stream failed: {}\n\n\
💡 Model doesn't support function calling required by EdgeCode React agent.\n\
\n\
Try one of these compatible models:\n\
- anthropic/claude-3.5-sonnet (recommended)\n\
- openai/gpt-4o\n\
- google/gemini-2.0-flash-exp\n\
- meta-llama/llama-3.3-70b-instruct\n\
\n\
Use /model to select a different model.",
error_msg
))
} else {
LlmError::ApiError(format!(
"stream failed: {}\n\n\
💡 Troubleshooting 400 Bad Request:\n\
\n\
If using LMStudio:\n\
• The prompt likely exceeds your model's configured context window\n\
• Solution 1: Increase context length in LMStudio model settings (32K+ recommended)\n\
• Solution 2: Set LMSTUDIO_CONTEXT_LENGTH environment variable (e.g., 32768 or 65536)\n\
• Solution 3: Use a model with larger context window\n\
• Solution 4: Reduce task complexity or working directory size\n\
\n\
If using other providers:\n\
• Check the model's context limits in the provider's documentation\n\
• Reduce the amount of context being sent (files, history, etc.)\n\
• Use a model with larger context window",
error_msg
))
}
} else {
LlmError::ApiError(format!("stream failed: {}", error_msg))
}
})?;
use futures::stream;
use reqwest_eventsource::Event;
let stream = stream::unfold(event_source, |mut es| async move {
match es.next().await {
Some(Ok(Event::Open)) => {
Some((Ok("".to_string()), es))
}
Some(Ok(Event::Message(msg))) => {
if msg.data == "[DONE]" {
es.close();
return None;
}
match serde_json::from_str::<ChatStreamChunk>(&msg.data) {
Ok(chunk) => {
if let Some(choice) = chunk.choices.first() {
if let Some(ref delta) = choice.delta {
if let Some(ref reasoning) = delta.reasoning_content {
if !reasoning.is_empty() {
return Some((Ok(reasoning.clone()), es));
}
}
if let Some(ref content) = delta.content {
if !content.is_empty() {
return Some((Ok(content.clone()), es));
}
}
}
}
Some((Ok("".to_string()), es))
}
Err(e) => {
warn!("Failed to parse stream chunk: {} | data: {}", e, msg.data);
Some((Err(LlmError::ApiError(format!("Parse error: {}", e))), es))
}
}
}
Some(Err(e)) => {
es.close();
Some((Err(LlmError::ApiError(format!("Stream error: {}", e))), es))
}
None => None,
}
});
Ok(stream.boxed())
}
fn supports_tool_streaming(&self) -> bool {
self.supports_streaming() && self.supports_function_calling()
}
async fn chat_with_tools_stream(
&self,
messages: &[ChatMessage],
tools: &[ToolDefinition],
tool_choice: Option<ToolChoice>,
options: Option<&CompletionOptions>,
) -> Result<BoxStream<'static, Result<StreamChunk>>> {
use futures::stream::{self, StreamExt};
use reqwest_eventsource::{Event, EventSource};
let options = options.cloned().unwrap_or_default();
let messages_req = Self::convert_messages(messages);
let have_tools = !tools.is_empty();
let api_tools = have_tools.then(|| tools.to_vec());
let api_tool_choice = if have_tools {
tool_choice.as_ref().map(Self::convert_tool_choice)
} else {
None
};
let request = ChatRequest {
model: &self.model,
messages: messages_req,
temperature: options.temperature,
top_p: options.top_p,
max_tokens: options.max_tokens,
stop: options.stop.clone(),
frequency_penalty: options.frequency_penalty,
presence_penalty: options.presence_penalty,
seed: None,
user: None,
stream: Some(true),
stream_options: Some(StreamOptions {
include_usage: true,
}),
tools: api_tools,
tool_choice: api_tool_choice,
thinking: if self.config.supports_thinking {
Some(ThinkingConfig {
thinking_type: "enabled".to_string(),
})
} else {
None
},
response_format: None,
reasoning_effort: options.reasoning_effort.clone(),
safe_prompt: options.safe_prompt,
parallel_tool_calls: options.parallel_tool_calls,
};
let url = self.chat_completions_url();
debug!(
"Starting streaming request: model={} url={} tools={}",
self.model,
url,
tools.len()
);
let mut req_builder = self.client.post(&url);
if !self.api_key.is_empty() {
req_builder = req_builder.header("Authorization", format!("Bearer {}", self.api_key));
}
let req_builder = req_builder.json(&request);
let event_source = EventSource::new(req_builder).map_err(|e| {
warn!("Failed to create event source: {}", e);
LlmError::ApiError(format!("Failed to create event source: {}", e))
})?;
let stream = stream::unfold(event_source, |mut es| async move {
match es.next().await {
Some(Ok(Event::Open)) => {
Some((Ok(StreamChunk::Content("".to_string())), es))
}
Some(Ok(Event::Message(msg))) => {
if msg.data == "[DONE]" {
es.close();
return None;
}
if msg.data.contains("tool_calls") || msg.data.contains("write_file") {
debug!("RAW SSE message (len={}): {}", msg.data.len(), &msg.data);
}
match serde_json::from_str::<ChatStreamChunk>(&msg.data) {
Ok(chunk) => {
if let Some(choice) = chunk.choices.first() {
if let Some(ref delta) = choice.delta {
if let Some(ref reasoning) = delta.reasoning_content {
if !reasoning.is_empty() {
return Some((
Ok(StreamChunk::ThinkingContent {
text: reasoning.clone(),
tokens_used: None, budget_total: None,
}),
es,
));
}
}
if let Some(ref content) = delta.content {
if !content.is_empty() {
return Some((
Ok(StreamChunk::Content(content.clone())),
es,
));
}
}
if let Some(ref tool_calls) = delta.tool_calls {
for (i, tool_call) in tool_calls.iter().enumerate() {
if let Some(ref function) = tool_call.function {
debug!(
"GLM tool_call[{}] - id={:?} index={:?} name={:?} args_len={} args={:?}",
i,
tool_call.id,
tool_call.index,
function.name,
function.arguments.as_ref().map(|s| s.len()).unwrap_or(0),
function.arguments
);
}
}
for tool_call in tool_calls {
if let Some(ref function) = tool_call.function {
let has_name = function.name.is_some();
let has_args = function.arguments.is_some();
if has_name || has_args {
return Some((
Ok(StreamChunk::ToolCallDelta {
index: tool_call.index.unwrap_or(0),
id: tool_call.id.clone(),
function_name: function.name.clone(),
function_arguments: function
.arguments
.clone(),
thought_signature: None,
}),
es,
));
}
}
}
}
}
}
Some((Ok(StreamChunk::Content("".to_string())), es))
}
Err(e) => {
warn!("Failed to parse stream chunk: {} | data: {}", e, msg.data);
es.close();
Some((
Err(LlmError::ApiError(format!(
"Failed to parse stream chunk: {} | data: {}",
e, msg.data
))),
es,
))
}
}
}
Some(Err(e)) => {
warn!("Stream error: {}", e);
Some((
Err(LlmError::NetworkError(format!("Stream error: {}", e))),
es,
))
}
None => {
None
}
}
});
Ok(Box::pin(stream))
}
}
#[async_trait]
impl EmbeddingProvider for OpenAICompatibleProvider {
fn name(&self) -> &str {
&self.config.name
}
fn model(&self) -> &str {
self.config
.default_embedding_model
.as_deref()
.unwrap_or("unknown")
}
fn dimension(&self) -> usize {
self.config
.models
.iter()
.find(|m| matches!(m.model_type, ModelType::Embedding))
.map(|m| m.capabilities.embedding_dimension)
.unwrap_or(1536)
}
fn max_tokens(&self) -> usize {
self.config
.models
.iter()
.find(|m| matches!(m.model_type, ModelType::Embedding))
.map(|m| {
if m.capabilities.max_embedding_tokens > 0 {
m.capabilities.max_embedding_tokens
} else if m.capabilities.context_length > 0 {
m.capabilities.context_length
} else {
8192
}
})
.unwrap_or(8192)
}
async fn embed(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
let embedding_model = self
.config
.default_embedding_model
.as_ref()
.ok_or_else(|| {
LlmError::ConfigError(format!(
"Provider '{}' does not have an embedding model configured",
self.config.name
))
})?;
let url = self.embeddings_url();
debug!(
provider = self.config.name,
model = embedding_model,
text_count = texts.len(),
url = url,
"Sending embedding request"
);
let request = EmbeddingRequest {
model: embedding_model,
input: texts,
encoding_format: Some("float"),
};
let mut req_builder = self.client.post(&url);
if !self.api_key.is_empty() {
req_builder = req_builder.header("Authorization", format!("Bearer {}", self.api_key));
}
let response = req_builder
.json(&request)
.send()
.await
.map_err(|e| LlmError::NetworkError(format!("Embedding request failed: {}", e)))?;
let status = response.status();
let body = response.text().await.map_err(|e| {
LlmError::NetworkError(format!("Failed to read embedding response: {}", e))
})?;
if !status.is_success() {
if let Ok(err) = serde_json::from_str::<ErrorResponse>(&body) {
return Err(LlmError::ApiError(format!(
"[{}] {} – {}",
status.as_u16(),
self.config.name,
err.error.message
)));
}
return Err(LlmError::ApiError(format!(
"[{}] {} – {}",
status.as_u16(),
self.config.name,
body
)));
}
let embed_resp: EmbeddingResponse = serde_json::from_str(&body).map_err(|e| {
LlmError::InvalidRequest(format!(
"Failed to parse embedding response: {} – body: {}",
e,
&body[..body.len().min(500)]
))
})?;
Ok(embed_resp.data.into_iter().map(|o| o.embedding).collect())
}
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_config() -> ProviderConfig {
ProviderConfig {
name: "test-provider".to_string(),
display_name: "Test Provider".to_string(),
provider_type: crate::model_config::ProviderType::OpenAICompatible,
api_key_env: Some("TEST_API_KEY".to_string()),
base_url: Some("https://api.example.com/v1".to_string()),
default_llm_model: Some("test-model".to_string()),
models: vec![ModelCard {
name: "test-model".to_string(),
display_name: "Test Model".to_string(),
model_type: ModelType::Llm,
capabilities: crate::model_config::ModelCapabilities {
context_length: 128000,
max_output_tokens: 8192,
supports_function_calling: true,
supports_streaming: true,
..Default::default()
},
..Default::default()
}],
..Default::default()
}
}
#[test]
fn test_provider_creation_requires_api_key() {
let config = create_test_config();
std::env::remove_var("TEST_API_KEY");
let result = OpenAICompatibleProvider::from_config(config);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("TEST_API_KEY"));
}
#[test]
fn test_provider_creation_success() {
let config = create_test_config();
std::env::set_var("TEST_API_KEY", "test-key-12345");
let provider = OpenAICompatibleProvider::from_config(config).unwrap();
assert_eq!(LLMProvider::name(&provider), "test-provider");
assert_eq!(LLMProvider::model(&provider), "test-model");
assert_eq!(provider.max_context_length(), 128000);
assert!(provider.supports_function_calling());
std::env::remove_var("TEST_API_KEY");
}
#[test]
fn test_chat_completions_url() {
std::env::set_var("TEST_API_KEY2", "key");
let mut config = create_test_config();
config.api_key_env = Some("TEST_API_KEY2".to_string());
config.base_url = Some("https://api.z.ai/api/paas/v4".to_string());
let provider = OpenAICompatibleProvider::from_config(config).unwrap();
assert_eq!(
provider.chat_completions_url(),
"https://api.z.ai/api/paas/v4/chat/completions"
);
std::env::remove_var("TEST_API_KEY2");
}
#[test]
fn test_custom_headers() {
std::env::set_var("TEST_API_KEY3", "key");
let mut config = create_test_config();
config.api_key_env = Some("TEST_API_KEY3".to_string());
config
.headers
.insert("Accept-Language".to_string(), "en-US,en".to_string());
config
.headers
.insert("X-Custom-Header".to_string(), "custom-value".to_string());
let result = OpenAICompatibleProvider::from_config(config);
assert!(result.is_ok());
std::env::remove_var("TEST_API_KEY3");
}
#[test]
fn test_convert_messages() {
let messages = vec![
ChatMessage::system("You are a helpful assistant."),
ChatMessage::user("Hello!"),
ChatMessage::assistant("Hi there!"),
];
let converted = OpenAICompatibleProvider::convert_messages(&messages);
assert_eq!(converted.len(), 3);
assert_eq!(converted[0].role, "system");
assert_eq!(converted[1].role, "user");
assert_eq!(converted[2].role, "assistant");
}
#[test]
fn test_base_url_env_override() {
std::env::set_var("TEST_API_KEY4", "key");
std::env::set_var("CUSTOM_BASE_URL", "https://override.example.com/v1");
let mut config = create_test_config();
config.api_key_env = Some("TEST_API_KEY4".to_string());
config.base_url = Some("https://default.example.com/v1".to_string());
config.base_url_env = Some("CUSTOM_BASE_URL".to_string());
let provider = OpenAICompatibleProvider::from_config(config).unwrap();
assert_eq!(provider.base_url, "https://override.example.com/v1");
std::env::remove_var("TEST_API_KEY4");
std::env::remove_var("CUSTOM_BASE_URL");
}
#[test]
fn test_with_model() {
std::env::set_var("TEST_API_KEY5", "key");
let mut config = create_test_config();
config.api_key_env = Some("TEST_API_KEY5".to_string());
config.models.push(ModelCard {
name: "another-model".to_string(),
display_name: "Another Model".to_string(),
model_type: ModelType::Llm,
capabilities: crate::model_config::ModelCapabilities {
context_length: 32000,
..Default::default()
},
..Default::default()
});
let provider = OpenAICompatibleProvider::from_config(config)
.unwrap()
.with_model("another-model");
assert_eq!(LLMProvider::model(&provider), "another-model");
assert_eq!(provider.max_context_length(), 32000);
std::env::remove_var("TEST_API_KEY5");
}
#[test]
fn test_convert_messages_text_only() {
let messages = vec![ChatMessage::user("Hello, world!")];
let converted = OpenAICompatibleProvider::convert_messages(&messages);
assert_eq!(converted.len(), 1);
assert_eq!(converted[0].role, "user");
let json = serde_json::to_value(&converted[0]).unwrap();
assert_eq!(json["content"], "Hello, world!");
}
#[test]
fn test_convert_messages_with_images() {
use crate::traits::ImageData;
let images = vec![ImageData::new("base64data", "image/png")];
let messages = vec![ChatMessage::user_with_images("What's this?", images)];
let converted = OpenAICompatibleProvider::convert_messages(&messages);
assert_eq!(converted.len(), 1);
let json = serde_json::to_value(&converted[0]).unwrap();
let content = &json["content"];
assert!(content.is_array());
assert_eq!(content.as_array().unwrap().len(), 2);
assert_eq!(content[0]["type"], "text");
assert_eq!(content[0]["text"], "What's this?");
assert_eq!(content[1]["type"], "image_url");
assert!(content[1]["image_url"]["url"]
.as_str()
.unwrap()
.starts_with("data:image/png;base64,"));
}
#[test]
fn test_convert_messages_with_image_detail() {
use crate::traits::ImageData;
let images = vec![ImageData::new("data", "image/jpeg").with_detail("high")];
let messages = vec![ChatMessage::user_with_images("Analyze", images)];
let converted = OpenAICompatibleProvider::convert_messages(&messages);
let json = serde_json::to_value(&converted[0]).unwrap();
let content = &json["content"];
assert_eq!(content[1]["image_url"]["detail"], "high");
}
#[test]
fn test_usage_with_reasoning_tokens() {
let json = r#"{
"prompt_tokens": 32,
"completion_tokens": 9,
"total_tokens": 135,
"completion_tokens_details": {
"reasoning_tokens": 94
}
}"#;
let usage: Usage = serde_json::from_str(json).unwrap();
assert_eq!(usage.prompt_tokens, 32);
assert_eq!(usage.completion_tokens, 9);
assert_eq!(
usage
.completion_tokens_details
.as_ref()
.unwrap()
.reasoning_tokens,
Some(94)
);
}
#[test]
fn test_usage_without_reasoning_tokens() {
let json = r#"{
"prompt_tokens": 100,
"completion_tokens": 50,
"total_tokens": 150
}"#;
let usage: Usage = serde_json::from_str(json).unwrap();
assert_eq!(usage.prompt_tokens, 100);
assert_eq!(usage.completion_tokens, 50);
assert!(usage.completion_tokens_details.is_none());
}
#[test]
fn test_stream_delta_with_reasoning_content() {
let json = r#"{
"content": null,
"reasoning_content": "Let me think about this..."
}"#;
let delta: StreamDelta = serde_json::from_str(json).unwrap();
assert!(delta.content.is_none());
assert_eq!(
delta.reasoning_content,
Some("Let me think about this...".to_string())
);
}
#[test]
fn test_supports_streaming() {
std::env::set_var("TEST_API_KEY_STREAM", "key");
let config = create_test_config_with_key("TEST_API_KEY_STREAM");
let provider = OpenAICompatibleProvider::from_config(config).unwrap();
assert!(provider.supports_streaming());
std::env::remove_var("TEST_API_KEY_STREAM");
}
#[test]
fn test_thinking_config_serialization() {
let config = ThinkingConfig {
thinking_type: "enabled".to_string(),
};
let json = serde_json::to_value(&config).unwrap();
assert_eq!(json["type"], "enabled");
}
#[test]
fn test_response_format_serialization() {
let format = ResponseFormat {
format_type: "json_object".to_string(),
};
let json = serde_json::to_value(&format).unwrap();
assert_eq!(json["type"], "json_object");
}
#[test]
fn test_embedding_provider_name() {
std::env::set_var("TEST_API_KEY_EMB1", "key");
let config = create_test_config_with_key("TEST_API_KEY_EMB1");
let provider = OpenAICompatibleProvider::from_config(config).unwrap();
assert_eq!(EmbeddingProvider::name(&provider), "test-provider");
std::env::remove_var("TEST_API_KEY_EMB1");
}
#[test]
fn test_embedding_provider_model() {
std::env::set_var("TEST_API_KEY_EMB2", "key");
let mut config = create_test_config_with_key("TEST_API_KEY_EMB2");
config.default_embedding_model = Some("text-embedding-ada-002".to_string());
let provider = OpenAICompatibleProvider::from_config(config).unwrap();
assert_eq!(
EmbeddingProvider::model(&provider),
"text-embedding-ada-002"
);
std::env::remove_var("TEST_API_KEY_EMB2");
}
#[test]
fn test_tool_call_request_serialization() {
let tool_call = ToolCallRequest {
id: "call_123".to_string(),
call_type: "function".to_string(),
function: FunctionCallRequest {
name: "get_weather".to_string(),
arguments: r#"{"location":"NYC"}"#.to_string(),
},
};
let json = serde_json::to_value(&tool_call).unwrap();
assert_eq!(json["id"], "call_123");
assert_eq!(json["type"], "function");
assert_eq!(json["function"]["name"], "get_weather");
assert_eq!(json["function"]["arguments"], r#"{"location":"NYC"}"#);
}
#[test]
fn test_function_call_request_serialization() {
let func_call = FunctionCallRequest {
name: "search".to_string(),
arguments: r#"{"query":"rust programming"}"#.to_string(),
};
let json = serde_json::to_value(&func_call).unwrap();
assert_eq!(json["name"], "search");
assert_eq!(json["arguments"], r#"{"query":"rust programming"}"#);
}
fn create_test_config_with_key(env_var: &str) -> ProviderConfig {
ProviderConfig {
name: "test-provider".to_string(),
display_name: "Test Provider".to_string(),
provider_type: crate::model_config::ProviderType::OpenAICompatible,
api_key_env: Some(env_var.to_string()),
base_url: Some("https://api.test.com/v1".to_string()),
default_llm_model: Some("test-model".to_string()),
models: vec![ModelCard {
name: "test-model".to_string(),
display_name: "Test Model".to_string(),
model_type: ModelType::Llm,
capabilities: crate::model_config::ModelCapabilities {
context_length: 128000,
supports_streaming: true,
supports_function_calling: true,
..Default::default()
},
..Default::default()
}],
..Default::default()
}
}
#[test]
fn test_extract_answer_content_returns_content_not_reasoning() {
let msg = MessageContent {
role: None,
content: Some("actual answer".to_string()),
reasoning_content: Some("I should reason about this...".to_string()),
tool_calls: None,
};
assert_eq!(
OpenAICompatibleProvider::extract_answer_content(&msg),
"actual answer"
);
}
#[test]
fn test_extract_answer_content_none_returns_empty() {
let msg = MessageContent {
role: None,
content: None,
reasoning_content: Some("some thinking".to_string()),
tool_calls: None,
};
assert_eq!(OpenAICompatibleProvider::extract_answer_content(&msg), "");
}
#[test]
fn test_extract_thinking_content_returns_reasoning() {
let msg = MessageContent {
role: None,
content: Some("the real answer".to_string()),
reasoning_content: Some("thinking monologue".to_string()),
tool_calls: None,
};
assert_eq!(
OpenAICompatibleProvider::extract_thinking_content(&msg),
Some("thinking monologue".to_string())
);
}
#[test]
fn test_extract_thinking_content_none_when_absent() {
let msg = MessageContent {
role: None,
content: Some("answer".to_string()),
reasoning_content: None,
tool_calls: None,
};
assert!(OpenAICompatibleProvider::extract_thinking_content(&msg).is_none());
}
#[test]
fn test_extract_thinking_content_none_when_empty_string() {
let msg = MessageContent {
role: None,
content: Some("answer".to_string()),
reasoning_content: Some(String::new()),
tool_calls: None,
};
assert!(OpenAICompatibleProvider::extract_thinking_content(&msg).is_none());
}
#[test]
fn test_chat_stream_chunk_usage_deserialization() {
let json = r#"{
"id": "chatcmpl-xyz",
"model": "gpt-4o",
"choices": [],
"usage": {
"prompt_tokens": 42,
"completion_tokens": 100,
"total_tokens": 142
}
}"#;
let chunk: ChatStreamChunk = serde_json::from_str(json).unwrap();
let usage = chunk.usage.expect("usage must be present");
assert_eq!(usage.prompt_tokens, 42);
assert_eq!(usage.completion_tokens, 100);
}
#[test]
fn test_chat_stream_chunk_no_usage_ok() {
let json = r#"{"id":"x","model":"m","choices":[]}"#;
let chunk: ChatStreamChunk = serde_json::from_str(json).unwrap();
assert!(chunk.usage.is_none());
}
#[test]
fn test_chat_request_reasoning_effort_serialized() {
let msgs: Vec<MessageRequest> = vec![];
let req = ChatRequest {
model: "test-model",
messages: msgs,
temperature: None,
top_p: None,
max_tokens: None,
stop: None,
stream: Some(false),
stream_options: None,
tools: None,
tool_choice: None,
response_format: None,
thinking: None,
reasoning_effort: Some("high".to_string()),
seed: None,
frequency_penalty: None,
presence_penalty: None,
user: None,
safe_prompt: None,
parallel_tool_calls: None,
};
let json = serde_json::to_value(&req).unwrap();
assert_eq!(json["reasoning_effort"], "high");
}
#[test]
fn test_chat_request_reasoning_effort_omitted_when_none() {
let msgs: Vec<MessageRequest> = vec![];
let req = ChatRequest {
model: "test-model",
messages: msgs,
temperature: None,
top_p: None,
max_tokens: None,
stop: None,
stream: Some(false),
stream_options: None,
tools: None,
tool_choice: None,
response_format: None,
thinking: None,
reasoning_effort: None,
seed: None,
frequency_penalty: None,
presence_penalty: None,
user: None,
safe_prompt: None,
parallel_tool_calls: None,
};
let json = serde_json::to_value(&req).unwrap();
assert!(json.get("reasoning_effort").is_none());
}
}