use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use async_trait::async_trait;
use futures::stream::BoxStream;
use futures::StreamExt;
use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION, CONTENT_TYPE};
use reqwest::Client;
use serde::{Deserialize, Serialize};
use tokio::sync::RwLock;
use tracing::{debug, instrument, warn};
use crate::error::{LlmError, Result};
use crate::traits::{
ChatMessage, ChatRole, CompletionOptions, EmbeddingProvider, FunctionCall, ImageData,
LLMProvider, LLMResponse, StreamChunk, StreamUsage, ToolCall, ToolChoice, ToolDefinition,
};
fn openrouter_build_content(msg: &ChatMessage) -> serde_json::Value {
match &msg.images {
Some(imgs) if !imgs.is_empty() => {
let mut parts: Vec<serde_json::Value> = vec![serde_json::json!({
"type": "text",
"text": msg.content
})];
for img in imgs {
parts.push(openrouter_build_image_part(img));
}
serde_json::Value::Array(parts)
}
_ => serde_json::Value::String(msg.content.clone()),
}
}
fn openrouter_build_image_part(img: &ImageData) -> serde_json::Value {
let url = img.to_data_uri();
let mut image_url = serde_json::json!({ "url": url });
if let Some(detail) = &img.detail {
image_url["detail"] = serde_json::Value::String(detail.clone());
}
serde_json::json!({ "type": "image_url", "image_url": image_url })
}
const OPENROUTER_BASE_URL: &str = "https://openrouter.ai/api/v1";
const DEFAULT_MODEL: &str = "anthropic/claude-3.5-sonnet";
const DEFAULT_MAX_TOKENS: u32 = 4096;
const DEFAULT_MAX_CONTEXT_LENGTH: usize = 128_000;
#[derive(Debug, Serialize)]
struct ChatRequest<'a> {
model: &'a str,
messages: Vec<RequestMessage>,
#[serde(skip_serializing_if = "Option::is_none")]
stream: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
max_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
top_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
stop: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
frequency_penalty: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
presence_penalty: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
tools: Option<Vec<RequestTool>>,
#[serde(skip_serializing_if = "Option::is_none")]
tool_choice: Option<serde_json::Value>,
}
#[derive(Debug, Serialize)]
struct RequestMessage {
role: String,
content: serde_json::Value,
#[serde(skip_serializing_if = "Option::is_none")]
name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
tool_call_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
tool_calls: Option<Vec<RequestToolCall>>,
}
#[derive(Debug, Serialize)]
struct RequestTool {
#[serde(rename = "type")]
tool_type: String,
function: RequestFunction,
}
#[derive(Debug, Serialize)]
struct RequestFunction {
name: String,
description: String,
parameters: serde_json::Value,
}
#[derive(Debug, Serialize)]
struct RequestToolCall {
id: String,
#[serde(rename = "type")]
call_type: String,
function: RequestFunctionCall,
}
#[derive(Debug, Serialize)]
struct RequestFunctionCall {
name: String,
arguments: String,
}
#[derive(Debug, Deserialize)]
#[allow(dead_code)]
struct ChatResponse {
id: String,
model: String,
choices: Vec<Choice>,
usage: Option<Usage>,
}
#[derive(Debug, Deserialize)]
#[allow(dead_code)]
struct Choice {
index: usize,
message: Option<ResponseMessage>,
finish_reason: Option<String>,
}
#[derive(Debug, Deserialize)]
#[allow(dead_code)]
struct ResponseMessage {
role: String,
content: Option<String>,
tool_calls: Option<Vec<ResponseToolCall>>,
}
#[derive(Debug, Deserialize)]
#[allow(dead_code)]
struct ResponseToolCall {
id: String,
#[serde(rename = "type")]
call_type: Option<String>,
function: ResponseFunctionCall,
}
#[derive(Debug, Deserialize)]
struct ResponseFunctionCall {
name: String,
arguments: String,
}
#[derive(Debug, Deserialize)]
#[allow(dead_code)]
struct Usage {
prompt_tokens: u32,
completion_tokens: u32,
total_tokens: Option<u32>,
}
#[derive(Debug, Deserialize)]
#[allow(dead_code)]
struct ErrorResponse {
error: ErrorDetail,
}
#[derive(Debug, Deserialize)]
#[allow(dead_code)]
struct ErrorDetail {
message: String,
code: Option<i32>,
}
#[derive(Debug, Deserialize)]
#[allow(dead_code)]
struct StreamErrorDetail {
code: Option<serde_json::Value>,
message: String,
}
#[derive(Debug, Deserialize)]
#[allow(dead_code)]
struct StreamChunkResponse {
id: Option<String>,
model: Option<String>,
choices: Vec<StreamChoice>,
usage: Option<Usage>,
error: Option<StreamErrorDetail>,
}
#[derive(Debug, Deserialize)]
#[allow(dead_code)]
struct StreamChoice {
index: Option<usize>,
delta: Option<StreamDelta>,
finish_reason: Option<String>,
}
#[derive(Debug, Deserialize)]
#[allow(dead_code)]
struct StreamDelta {
role: Option<String>,
content: Option<String>,
#[serde(alias = "thinking", alias = "reasoning_content")]
reasoning: Option<String>,
tool_calls: Option<Vec<StreamToolCall>>,
}
#[derive(Debug, Deserialize)]
#[allow(dead_code)]
struct StreamToolCall {
index: Option<usize>,
id: Option<String>,
#[serde(rename = "type")]
call_type: Option<String>,
function: Option<StreamFunction>,
}
#[derive(Debug, Deserialize)]
#[allow(dead_code)]
struct StreamFunction {
name: Option<String>,
arguments: Option<String>,
}
#[derive(Debug, Deserialize, Clone)]
pub struct ModelsResponse {
pub data: Vec<ModelInfo>,
}
#[derive(Debug, Deserialize, Clone)]
pub struct ModelInfo {
pub id: String,
pub name: String,
#[serde(default)]
pub context_length: usize,
#[serde(default)]
pub pricing: ModelPricing,
#[serde(default)]
pub architecture: ModelArchitecture,
#[serde(default)]
pub supported_parameters: Vec<String>,
#[serde(default)]
pub description: Option<String>,
#[serde(default)]
pub created: Option<u64>,
}
#[derive(Debug, Deserialize, Clone, Default)]
pub struct ModelPricing {
#[serde(default)]
pub prompt: String,
#[serde(default)]
pub completion: String,
#[serde(default)]
pub request: Option<String>,
#[serde(default)]
pub image: Option<String>,
}
#[derive(Debug, Deserialize, Clone, Default)]
pub struct ModelArchitecture {
#[serde(default)]
pub input_modalities: Vec<String>,
#[serde(default)]
pub output_modalities: Vec<String>,
#[serde(default)]
pub tokenizer: Option<String>,
#[serde(default)]
pub instruct_type: Option<String>,
}
#[derive(Debug)]
struct ModelCache {
models: Vec<ModelInfo>,
fetched_at: Instant,
}
#[derive(Debug)]
pub struct OpenRouterProvider {
client: Client,
api_key: String,
base_url: String,
model: String,
max_tokens: u32,
max_context_length: usize,
site_url: Option<String>,
site_name: Option<String>,
model_cache: Arc<RwLock<Option<ModelCache>>>,
}
impl Clone for OpenRouterProvider {
fn clone(&self) -> Self {
Self {
client: self.client.clone(),
api_key: self.api_key.clone(),
base_url: self.base_url.clone(),
model: self.model.clone(),
max_tokens: self.max_tokens,
max_context_length: self.max_context_length,
site_url: self.site_url.clone(),
site_name: self.site_name.clone(),
model_cache: Arc::clone(&self.model_cache),
}
}
}
impl OpenRouterProvider {
pub fn new(api_key: impl Into<String>) -> Self {
Self {
client: Client::builder()
.timeout(std::time::Duration::from_secs(300))
.build()
.expect("Failed to create HTTP client"),
api_key: api_key.into(),
base_url: OPENROUTER_BASE_URL.to_string(),
model: DEFAULT_MODEL.to_string(),
max_tokens: DEFAULT_MAX_TOKENS,
max_context_length: DEFAULT_MAX_CONTEXT_LENGTH,
site_url: None,
site_name: None,
model_cache: Arc::new(RwLock::new(None)),
}
}
pub fn from_env() -> Result<Self> {
let api_key = std::env::var("OPENROUTER_API_KEY").map_err(|_| {
LlmError::ConfigError(
"OPENROUTER_API_KEY environment variable not set. \
Get your API key at https://openrouter.ai/keys"
.to_string(),
)
})?;
let mut provider = Self::new(api_key);
if let Ok(model) = std::env::var("OPENROUTER_MODEL") {
provider.model = model;
}
if let Ok(url) = std::env::var("OPENROUTER_SITE_URL") {
provider.site_url = Some(url);
}
if let Ok(name) = std::env::var("OPENROUTER_SITE_NAME") {
provider.site_name = Some(name);
}
Ok(provider)
}
pub fn with_model(mut self, model: impl Into<String>) -> Self {
self.model = model.into();
self
}
pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
self.base_url = base_url.into();
self
}
pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
self.max_tokens = max_tokens;
self
}
pub fn with_site_url(mut self, url: impl Into<String>) -> Self {
self.site_url = Some(url.into());
self
}
pub fn with_site_name(mut self, name: impl Into<String>) -> Self {
self.site_name = Some(name.into());
self
}
fn endpoint(&self) -> String {
format!("{}/chat/completions", self.base_url)
}
fn headers(&self) -> HeaderMap {
let mut headers = HeaderMap::new();
headers.insert(
AUTHORIZATION,
HeaderValue::from_str(&format!("Bearer {}", self.api_key))
.expect("Invalid API key format"),
);
headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
if let Some(ref url) = self.site_url {
if let Ok(value) = HeaderValue::from_str(url) {
headers.insert("HTTP-Referer", value);
}
}
if let Some(ref name) = self.site_name {
if let Ok(value) = HeaderValue::from_str(name) {
headers.insert("X-Title", value);
}
}
headers
}
fn convert_messages(messages: &[ChatMessage]) -> Result<Vec<RequestMessage>> {
messages
.iter()
.map(|msg| {
let role = match msg.role {
ChatRole::System => "system",
ChatRole::User => "user",
ChatRole::Assistant => "assistant",
ChatRole::Tool => {
if msg.tool_call_id.is_none() {
return Err(LlmError::InvalidRequest(
"Tool role message is missing required `tool_call_id`. \
Ensure every ChatMessage with role=Tool has a tool_call_id \
matching the assistant's tool_call."
.to_string(),
));
}
"tool"
}
ChatRole::Function => {
return Err(LlmError::InvalidRequest(
"ChatRole::Function is not supported by OpenRouter. \
Use ChatRole::Tool with a tool_call_id instead."
.to_string(),
));
}
};
let content = openrouter_build_content(msg);
Ok(RequestMessage {
role: role.to_string(),
content,
name: msg.name.clone(),
tool_call_id: msg.tool_call_id.clone(),
tool_calls: msg.tool_calls.as_ref().map(|calls| {
calls
.iter()
.map(|tc| RequestToolCall {
id: tc.id.clone(),
call_type: "function".to_string(),
function: RequestFunctionCall {
name: tc.function.name.clone(),
arguments: tc.function.arguments.clone(),
},
})
.collect()
}),
})
})
.collect()
}
fn convert_tools(tools: &[ToolDefinition]) -> Vec<RequestTool> {
tools
.iter()
.map(|tool| RequestTool {
tool_type: "function".to_string(),
function: RequestFunction {
name: tool.function.name.clone(),
description: tool.function.description.clone(),
parameters: tool.function.parameters.clone(),
},
})
.collect()
}
fn convert_tool_choice(choice: &ToolChoice) -> serde_json::Value {
match choice {
ToolChoice::Auto(_) => serde_json::json!("auto"),
ToolChoice::Required(_) => serde_json::json!("required"),
ToolChoice::Function {
choice_type: _,
function,
} => serde_json::json!({
"type": "function",
"function": { "name": function.name }
}),
}
}
fn parse_response(response: ChatResponse) -> Result<LLMResponse> {
let choice = response.choices.first().ok_or_else(|| {
LlmError::ApiError(
"OpenRouter returned a response with no choices. \
This may indicate a content-filter block or a provider error."
.to_string(),
)
})?;
let message = choice.message.as_ref();
let content = message
.and_then(|m| m.content.as_ref())
.cloned()
.unwrap_or_default();
let tool_calls = message
.and_then(|m| m.tool_calls.as_ref())
.map(|tcs| {
tcs.iter()
.map(|tc| ToolCall {
id: tc.id.clone(),
call_type: "function".to_string(),
function: FunctionCall {
name: tc.function.name.clone(),
arguments: tc.function.arguments.clone(),
},
thought_signature: None,
})
.collect()
})
.unwrap_or_default();
let usage = response.usage.as_ref();
let prompt_tokens = usage.map(|u| u.prompt_tokens as usize).unwrap_or(0);
let completion_tokens = usage.map(|u| u.completion_tokens as usize).unwrap_or(0);
let total_tokens = usage
.and_then(|u| u.total_tokens)
.map(|t| t as usize)
.unwrap_or(prompt_tokens + completion_tokens);
Ok(LLMResponse {
content,
tool_calls,
prompt_tokens,
completion_tokens,
total_tokens,
model: response.model,
finish_reason: choice.finish_reason.clone(),
metadata: HashMap::new(),
cache_hit_tokens: None,
thinking_tokens: None,
thinking_content: None,
})
}
#[allow(dead_code)]
fn is_retryable_error(error: &LlmError) -> bool {
match error {
LlmError::RateLimited(_) | LlmError::NetworkError(_) => true,
LlmError::ApiError(msg) => {
msg.contains("502") || msg.contains("503") || msg.contains("504")
}
_ => false,
}
}
fn is_retryable_status(status: reqwest::StatusCode) -> bool {
matches!(status.as_u16(), 429 | 500 | 502 | 503 | 504)
}
fn handle_error(status: reqwest::StatusCode, body: &str) -> LlmError {
if let Ok(error_response) = serde_json::from_str::<ErrorResponse>(body) {
let message = error_response.error.message;
match status.as_u16() {
400 => {
if message.to_lowercase().contains("tool")
|| message.to_lowercase().contains("function")
|| message.contains("not supported")
|| message.contains("No endpoints found")
{
LlmError::InvalidRequest(format!(
"Model doesn't support function calling: {}.\n\
\n\
💡 EdgeCode React agent requires function calling support.\n\
\n\
Try one of these compatible models:\n\
- anthropic/claude-3.5-sonnet (recommended)\n\
- openai/gpt-4o\n\
- google/gemini-2.0-flash-exp\n\
- meta-llama/llama-3.3-70b-instruct\n\
\n\
Use /model to select a different model.",
message
))
} else {
LlmError::InvalidRequest(format!(
"{}. Check that the model name is correct and the request format is valid.",
message
))
}
}
401 => LlmError::AuthError(message),
402 => {
LlmError::ApiError(format!(
"Insufficient credits: {}. Add credits at https://openrouter.ai/credits",
message
))
}
403 => {
if message.contains("not available in your region")
|| message.contains("region")
{
LlmError::ApiError(format!(
"Regional restriction: {}. This model is not available in your geographic region. Try selecting a different model with /model or check OpenRouter's model availability at https://openrouter.ai/docs/models",
message
))
} else if message.contains("moderation") {
LlmError::ApiError(format!(
"Content policy violation: {}. Your request was blocked by content moderation. Review OpenRouter's policies or try a different model.",
message
))
} else {
LlmError::ApiError(format!(
"Access forbidden: {}. This may be due to model availability, account restrictions, or content policy. Try a different model with /model",
message
))
}
}
404 => {
LlmError::ApiError(format!(
"Model not found: {}. The model name may be incorrect or the model may have been removed. Use /model to see available models.",
message
))
}
429 => LlmError::RateLimited(message),
_ => LlmError::ApiError(format!("{}: {}", status, message)),
}
} else {
match status.as_u16() {
403 => LlmError::ApiError(format!(
"403 Forbidden: {}. This model may not be available in your region or due to account restrictions. Try selecting a different model with /model",
body
)),
404 => LlmError::ApiError(format!(
"404 Not Found: {}. The model may not exist or has been removed. Use /model to see available models.",
body
)),
_ => LlmError::ApiError(format!("{}: {}", status, body)),
}
}
}
#[instrument(skip(self, request))]
async fn send_request(&self, request: &ChatRequest<'_>) -> Result<ChatResponse> {
const MAX_RETRIES: u32 = 3;
const BASE_DELAY_MS: u64 = 1000;
let mut last_error = None;
for attempt in 0..=MAX_RETRIES {
if attempt > 0 {
let delay_ms = BASE_DELAY_MS * (1 << (attempt - 1)); debug!(
"Retry attempt {}/{} after {}ms delay",
attempt, MAX_RETRIES, delay_ms
);
tokio::time::sleep(Duration::from_millis(delay_ms)).await;
}
debug!(
"OpenRouter request: model={}, messages={}, tools={}, max_tokens={:?}, stream={:?}",
request.model,
request.messages.len(),
request.tools.as_ref().map(|t| t.len()).unwrap_or(0),
request.max_tokens,
request.stream
);
let start = std::time::Instant::now();
let response = match self
.client
.post(self.endpoint())
.headers(self.headers())
.json(request)
.send()
.await
{
Ok(resp) => resp,
Err(e) => {
let error = LlmError::NetworkError(e.to_string());
if attempt < MAX_RETRIES {
warn!(
"Network error (attempt {}/{}): {}",
attempt + 1,
MAX_RETRIES + 1,
e
);
last_error = Some(error);
continue;
}
return Err(error);
}
};
let status = response.status();
let body = match response.text().await {
Ok(b) => b,
Err(e) => {
let error = LlmError::NetworkError(e.to_string());
if attempt < MAX_RETRIES {
last_error = Some(error);
continue;
}
return Err(error);
}
};
if !status.is_success() {
let error = Self::handle_error(status, &body);
if Self::is_retryable_status(status) && attempt < MAX_RETRIES {
warn!(
"Retryable error (attempt {}/{}): {:?}",
attempt + 1,
MAX_RETRIES + 1,
error
);
last_error = Some(error);
continue;
}
return Err(error);
}
let elapsed = start.elapsed();
let response: ChatResponse = serde_json::from_str(&body)
.map_err(|e| LlmError::ApiError(format!("Failed to parse response: {}", e)))?;
debug!(
"OpenRouter response: elapsed={}ms, prompt_tokens={}, completion_tokens={}, total_tokens={}, finish_reason={:?}",
elapsed.as_millis(),
response.usage.as_ref().map(|u| u.prompt_tokens).unwrap_or(0),
response.usage.as_ref().map(|u| u.completion_tokens).unwrap_or(0),
response.usage.as_ref().and_then(|u| u.total_tokens).unwrap_or(0),
response.choices.first().and_then(|c| c.finish_reason.as_deref()).unwrap_or("none")
);
return Ok(response);
}
Err(last_error.unwrap_or_else(|| LlmError::ApiError("Max retries exceeded".to_string())))
}
#[instrument(skip(self))]
pub async fn list_models(&self) -> Result<Vec<ModelInfo>> {
debug!("Fetching models from OpenRouter API");
let url = format!("{}/models", self.base_url);
let response = self
.client
.get(&url)
.headers(self.headers())
.send()
.await
.map_err(|e| LlmError::NetworkError(e.to_string()))?;
let status = response.status();
let body = response
.text()
.await
.map_err(|e| LlmError::NetworkError(e.to_string()))?;
if !status.is_success() {
return Err(Self::handle_error(status, &body));
}
let models_response: ModelsResponse = serde_json::from_str(&body)
.map_err(|e| LlmError::ApiError(format!("Failed to parse models response: {}", e)))?;
debug!(
"Fetched {} models from OpenRouter",
models_response.data.len()
);
{
let mut cache = self.model_cache.write().await;
*cache = Some(ModelCache {
models: models_response.data.clone(),
fetched_at: Instant::now(),
});
}
Ok(models_response.data)
}
pub async fn list_models_cached(&self, max_age: Duration) -> Result<Vec<ModelInfo>> {
{
let cache = self.model_cache.read().await;
if let Some(ref cached) = *cache {
if cached.fetched_at.elapsed() < max_age {
debug!(
"Using cached models ({} models, age: {:?})",
cached.models.len(),
cached.fetched_at.elapsed()
);
return Ok(cached.models.clone());
}
}
}
self.list_models().await
}
pub async fn invalidate_model_cache(&self) {
let mut cache = self.model_cache.write().await;
*cache = None;
debug!("Model cache invalidated");
}
pub async fn get_model(&self, model_id: &str) -> Result<Option<ModelInfo>> {
let models = self.list_models_cached(Duration::from_secs(3600)).await?;
Ok(models.into_iter().find(|m| m.id == model_id))
}
pub async fn get_models_by_modality(&self, modality: &str) -> Result<Vec<ModelInfo>> {
let models = self.list_models_cached(Duration::from_secs(3600)).await?;
Ok(models
.into_iter()
.filter(|m| {
m.architecture
.input_modalities
.contains(&modality.to_string())
})
.collect())
}
pub async fn cached_model_count(&self) -> usize {
let cache = self.model_cache.read().await;
cache.as_ref().map(|c| c.models.len()).unwrap_or(0)
}
pub async fn is_cache_valid(&self, max_age: Duration) -> bool {
let cache = self.model_cache.read().await;
cache
.as_ref()
.map(|c| c.fetched_at.elapsed() < max_age)
.unwrap_or(false)
}
}
#[async_trait]
impl LLMProvider for OpenRouterProvider {
fn name(&self) -> &str {
"openrouter"
}
fn model(&self) -> &str {
&self.model
}
fn max_context_length(&self) -> usize {
self.max_context_length
}
#[instrument(skip(self, prompt))]
async fn complete(&self, prompt: &str) -> Result<LLMResponse> {
self.complete_with_options(prompt, &CompletionOptions::default())
.await
}
#[instrument(skip(self, prompt, options))]
async fn complete_with_options(
&self,
prompt: &str,
options: &CompletionOptions,
) -> Result<LLMResponse> {
let mut messages = Vec::new();
if let Some(system) = &options.system_prompt {
messages.push(ChatMessage::system(system));
}
messages.push(ChatMessage::user(prompt));
self.chat(&messages, Some(options)).await
}
#[instrument(skip(self, messages, options))]
async fn chat(
&self,
messages: &[ChatMessage],
options: Option<&CompletionOptions>,
) -> Result<LLMResponse> {
let options = options.cloned().unwrap_or_default();
let request = ChatRequest {
model: &self.model,
messages: Self::convert_messages(messages)?,
stream: Some(false),
max_tokens: Some(options.max_tokens.unwrap_or(self.max_tokens as usize) as u32),
temperature: options.temperature,
top_p: options.top_p,
stop: options.stop.clone(),
frequency_penalty: options.frequency_penalty,
presence_penalty: options.presence_penalty,
tools: None,
tool_choice: None,
};
let response = self.send_request(&request).await?;
Self::parse_response(response)
}
#[instrument(skip(self, messages, tools, options))]
async fn chat_with_tools(
&self,
messages: &[ChatMessage],
tools: &[ToolDefinition],
tool_choice: Option<ToolChoice>,
options: Option<&CompletionOptions>,
) -> Result<LLMResponse> {
let options = options.cloned().unwrap_or_default();
let request = ChatRequest {
model: &self.model,
messages: Self::convert_messages(messages)?,
stream: Some(false),
max_tokens: Some(options.max_tokens.unwrap_or(self.max_tokens as usize) as u32),
temperature: options.temperature,
top_p: options.top_p,
stop: options.stop.clone(),
frequency_penalty: options.frequency_penalty,
presence_penalty: options.presence_penalty,
tools: Some(Self::convert_tools(tools)),
tool_choice: tool_choice.map(|tc| Self::convert_tool_choice(&tc)),
};
let response = self.send_request(&request).await?;
Self::parse_response(response)
}
#[instrument(skip(self, prompt))]
async fn stream(&self, prompt: &str) -> Result<BoxStream<'static, Result<String>>> {
let messages = vec![ChatMessage::user(prompt)];
let request = ChatRequest {
model: &self.model,
messages: Self::convert_messages(&messages)?,
stream: Some(true),
max_tokens: Some(self.max_tokens),
temperature: None,
top_p: None,
stop: None,
frequency_penalty: None,
presence_penalty: None,
tools: None,
tool_choice: None,
};
let request_body = serde_json::to_string(&request)
.map_err(|e| LlmError::InvalidRequest(format!("Failed to serialize request: {}", e)))?;
let client = self.client.clone();
let endpoint = self.endpoint();
let headers = self.headers();
let response = client
.post(&endpoint)
.headers(headers)
.body(request_body)
.send()
.await
.map_err(|e| LlmError::NetworkError(e.to_string()))?;
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
return Err(Self::handle_error(status, &body));
}
let mut line_buffer = String::new();
let stream = response.bytes_stream().map(move |chunk| {
let chunk = chunk.map_err(|e| LlmError::NetworkError(e.to_string()))?;
let text = String::from_utf8_lossy(&chunk);
line_buffer.push_str(&text);
let mut content = String::new();
while let Some(newline_idx) = line_buffer.find('\n') {
let line = line_buffer[..newline_idx].trim().to_string();
line_buffer.drain(..=newline_idx);
if line.is_empty() || line.starts_with(':') {
continue;
}
if let Some(data) = line.strip_prefix("data: ") {
if data == "[DONE]" {
continue;
}
if let Ok(chunk) = serde_json::from_str::<StreamChunkResponse>(data) {
if let Some(err) = chunk.error {
return Err(LlmError::ApiError(format!(
"OpenRouter stream error: {}",
err.message
)));
}
for choice in chunk.choices {
if let Some(delta) = choice.delta {
if let Some(c) = delta.content {
content.push_str(&c);
}
}
}
}
}
}
Ok(content)
});
Ok(stream.boxed())
}
async fn chat_with_tools_stream(
&self,
messages: &[ChatMessage],
tools: &[ToolDefinition],
tool_choice: Option<ToolChoice>,
options: Option<&CompletionOptions>,
) -> Result<BoxStream<'static, Result<StreamChunk>>> {
let options = options.cloned().unwrap_or_default();
let request = ChatRequest {
model: &self.model,
messages: Self::convert_messages(messages)?,
stream: Some(true),
max_tokens: Some(options.max_tokens.unwrap_or(self.max_tokens as usize) as u32),
temperature: options.temperature,
top_p: options.top_p,
stop: options.stop.clone(),
frequency_penalty: options.frequency_penalty,
presence_penalty: options.presence_penalty,
tools: Some(Self::convert_tools(tools)),
tool_choice: tool_choice.map(|tc| Self::convert_tool_choice(&tc)),
};
let request_body = serde_json::to_string(&request)
.map_err(|e| LlmError::InvalidRequest(format!("Failed to serialize request: {}", e)))?;
let client = self.client.clone();
let endpoint = self.endpoint();
let headers = self.headers();
let response = client
.post(&endpoint)
.headers(headers)
.body(request_body)
.send()
.await
.map_err(|e| LlmError::NetworkError(e.to_string()))?;
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
return Err(Self::handle_error(status, &body));
}
let mut line_buffer = String::new();
let mut latest_usage: Option<StreamUsage> = None;
let stream = response
.bytes_stream()
.map(move |chunk| -> Result<Vec<StreamChunk>> {
let chunk = chunk.map_err(|e| LlmError::NetworkError(e.to_string()))?;
let text = String::from_utf8_lossy(&chunk);
line_buffer.push_str(&text);
let mut chunks = Vec::new();
while let Some(newline_idx) = line_buffer.find('\n') {
let line = line_buffer[..newline_idx].trim().to_string();
line_buffer.drain(..=newline_idx);
if line.is_empty() || line.starts_with(':') {
continue;
}
if let Some(data) = line.strip_prefix("data: ") {
if data == "[DONE]" {
chunks.push(StreamChunk::Finished {
reason: "stop".to_string(),
ttft_ms: None,
usage: latest_usage.clone(),
});
continue;
}
if let Ok(chunk_response) =
serde_json::from_str::<StreamChunkResponse>(data)
{
if let Some(usage) = chunk_response.usage.as_ref() {
latest_usage = Some(StreamUsage::new(
usage.prompt_tokens as usize,
usage.completion_tokens as usize,
));
}
if let Some(err) = chunk_response.error {
return Err(LlmError::ApiError(format!(
"OpenRouter stream error: {}",
err.message
)));
}
for choice in chunk_response.choices {
if let Some(delta) = choice.delta {
if let Some(reasoning) = delta.reasoning {
if !reasoning.is_empty() {
chunks.push(StreamChunk::ThinkingContent {
text: reasoning,
tokens_used: None, budget_total: None, });
}
}
if let Some(content) = delta.content {
if !content.is_empty() {
chunks.push(StreamChunk::Content(content));
}
}
if let Some(tool_calls) = delta.tool_calls {
for tc in tool_calls {
if let Some(func) = tc.function {
let args = func.arguments.unwrap_or_default();
if !args.is_empty() || tc.id.is_some() {
chunks.push(StreamChunk::ToolCallDelta {
index: tc.index.unwrap_or(0),
id: tc.id,
function_name: func.name,
function_arguments: if args.is_empty() {
None
} else {
Some(args)
},
thought_signature: None,
});
}
}
}
}
}
if let Some(reason) = choice.finish_reason {
chunks.push(StreamChunk::Finished {
reason,
ttft_ms: None,
usage: latest_usage.clone(),
});
}
}
}
}
}
Ok(chunks)
})
.flat_map(|result: Result<Vec<StreamChunk>>| {
futures::stream::iter(match result {
Ok(chunks) => chunks.into_iter().map(Ok).collect::<Vec<_>>(),
Err(e) => vec![Err(e)],
})
});
Ok(stream.boxed())
}
fn supports_streaming(&self) -> bool {
true
}
fn supports_function_calling(&self) -> bool {
true
}
fn supports_tool_streaming(&self) -> bool {
true
}
}
#[async_trait]
impl EmbeddingProvider for OpenRouterProvider {
fn name(&self) -> &str {
"openrouter"
}
fn model(&self) -> &str {
"none"
}
fn dimension(&self) -> usize {
0
}
fn max_tokens(&self) -> usize {
0
}
async fn embed(&self, _texts: &[String]) -> Result<Vec<Vec<f32>>> {
Err(LlmError::InvalidRequest(
"OpenRouter does not support embeddings. Use a dedicated embedding provider."
.to_string(),
))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_new_provider() {
let provider = OpenRouterProvider::new("test-key");
assert_eq!(provider.api_key, "test-key");
assert_eq!(provider.model, DEFAULT_MODEL);
assert_eq!(provider.base_url, OPENROUTER_BASE_URL);
}
#[test]
fn test_with_model() {
let provider = OpenRouterProvider::new("test-key").with_model("openai/gpt-4o");
assert_eq!(provider.model, "openai/gpt-4o");
}
#[test]
fn test_with_base_url() {
let provider =
OpenRouterProvider::new("test-key").with_base_url("https://custom.openrouter.ai");
assert_eq!(provider.base_url, "https://custom.openrouter.ai");
}
#[test]
fn test_with_site_url() {
let provider = OpenRouterProvider::new("test-key").with_site_url("https://myapp.com");
assert_eq!(provider.site_url, Some("https://myapp.com".to_string()));
}
#[test]
fn test_with_site_name() {
let provider = OpenRouterProvider::new("test-key").with_site_name("My App");
assert_eq!(provider.site_name, Some("My App".to_string()));
}
#[test]
fn test_from_env_missing_key() {
std::env::remove_var("OPENROUTER_API_KEY");
let result = OpenRouterProvider::from_env();
assert!(result.is_err());
}
#[test]
fn test_endpoint() {
let provider = OpenRouterProvider::new("test-key");
assert_eq!(
provider.endpoint(),
format!("{}/chat/completions", OPENROUTER_BASE_URL)
);
}
#[test]
fn test_headers() {
let provider = OpenRouterProvider::new("test-key")
.with_site_url("https://example.com")
.with_site_name("Example");
let headers = provider.headers();
assert!(headers.contains_key(AUTHORIZATION));
assert!(headers.contains_key(CONTENT_TYPE));
assert!(headers.contains_key("HTTP-Referer"));
assert!(headers.contains_key("X-Title"));
}
#[test]
fn test_convert_messages() {
let messages = vec![
ChatMessage::system("You are helpful."),
ChatMessage::user("Hello!"),
];
let converted = OpenRouterProvider::convert_messages(&messages).unwrap();
assert_eq!(converted.len(), 2);
assert_eq!(converted[0].role, "system");
assert_eq!(converted[0].content, serde_json::json!("You are helpful."));
assert_eq!(converted[1].role, "user");
assert_eq!(converted[1].content, serde_json::json!("Hello!"));
}
#[test]
fn test_convert_messages_with_vision_images() {
use crate::traits::ImageData;
let img = ImageData::new("iVBORw0KGgo=", "image/png");
let messages = vec![ChatMessage::user_with_images(
"What is in this image?",
vec![img],
)];
let converted = OpenRouterProvider::convert_messages(&messages).unwrap();
assert_eq!(converted.len(), 1);
let content = &converted[0].content;
assert!(
content.is_array(),
"Vision message must serialize as content-parts array"
);
let parts = content.as_array().unwrap();
assert_eq!(parts[0]["type"], "text");
assert_eq!(parts[0]["text"], "What is in this image?");
assert_eq!(parts[1]["type"], "image_url");
let url = &parts[1]["image_url"]["url"];
assert!(
url.as_str().unwrap().starts_with("data:image/png;base64,"),
"Image URL must be a data URI, got: {}",
url
);
}
#[test]
fn test_convert_messages_text_only_still_string() {
let messages = vec![ChatMessage::user("plain text only")];
let converted = OpenRouterProvider::convert_messages(&messages).unwrap();
let content = &converted[0].content;
assert!(
content.is_string(),
"Text-only message must serialize as a plain JSON string"
);
assert_eq!(content.as_str().unwrap(), "plain text only");
}
#[test]
fn test_convert_tool_choice() {
assert_eq!(
OpenRouterProvider::convert_tool_choice(&ToolChoice::auto()),
serde_json::json!("auto")
);
assert_eq!(
OpenRouterProvider::convert_tool_choice(&ToolChoice::required()),
serde_json::json!("required")
);
assert_eq!(
OpenRouterProvider::convert_tool_choice(&ToolChoice::function("my_func")),
serde_json::json!({"type": "function", "function": {"name": "my_func"}})
);
}
#[test]
fn test_parse_response() {
let response = ChatResponse {
id: "gen-123".to_string(),
model: "anthropic/claude-3.5-sonnet".to_string(),
choices: vec![Choice {
index: 0,
message: Some(ResponseMessage {
role: "assistant".to_string(),
content: Some("Hello!".to_string()),
tool_calls: None,
}),
finish_reason: Some("stop".to_string()),
}],
usage: Some(Usage {
prompt_tokens: 10,
completion_tokens: 5,
total_tokens: Some(15),
}),
};
let llm_response = OpenRouterProvider::parse_response(response).unwrap();
assert_eq!(llm_response.content, "Hello!");
assert_eq!(llm_response.prompt_tokens, 10);
assert_eq!(llm_response.completion_tokens, 5);
assert_eq!(llm_response.total_tokens, 15);
assert_eq!(llm_response.model, "anthropic/claude-3.5-sonnet");
assert_eq!(llm_response.finish_reason, Some("stop".to_string()));
}
#[test]
fn test_parse_response_with_tool_calls() {
let response = ChatResponse {
id: "gen-456".to_string(),
model: "openai/gpt-4o".to_string(),
choices: vec![Choice {
index: 0,
message: Some(ResponseMessage {
role: "assistant".to_string(),
content: None,
tool_calls: Some(vec![ResponseToolCall {
id: "call_1".to_string(),
call_type: Some("function".to_string()),
function: ResponseFunctionCall {
name: "get_weather".to_string(),
arguments: r#"{"location":"Paris"}"#.to_string(),
},
}]),
}),
finish_reason: Some("tool_calls".to_string()),
}],
usage: Some(Usage {
prompt_tokens: 20,
completion_tokens: 10,
total_tokens: Some(30),
}),
};
let llm_response = OpenRouterProvider::parse_response(response).unwrap();
assert_eq!(llm_response.tool_calls.len(), 1);
assert_eq!(llm_response.tool_calls[0].id, "call_1");
assert_eq!(llm_response.tool_calls[0].name(), "get_weather");
assert!(llm_response.tool_calls[0].arguments().contains("Paris"));
}
#[test]
fn test_provider_trait() {
let provider = OpenRouterProvider::new("test-key");
assert_eq!(LLMProvider::name(&provider), "openrouter");
assert_eq!(LLMProvider::model(&provider), DEFAULT_MODEL);
assert!(provider.supports_streaming());
assert!(provider.supports_function_calling());
}
#[test]
fn test_tool_message_with_id_succeeds() {
let msg = ChatMessage::tool_result("call_abc123", "42 degrees");
let converted = OpenRouterProvider::convert_messages(&[msg]).unwrap();
assert_eq!(converted.len(), 1);
assert_eq!(converted[0].role, "tool");
assert_eq!(
converted[0].tool_call_id.as_deref(),
Some("call_abc123"),
"tool_call_id must be forwarded"
);
assert_eq!(
converted[0].content,
serde_json::json!("42 degrees"),
"Tool content must be a plain string"
);
}
#[test]
fn test_tool_message_missing_id_returns_err() {
let mut msg = ChatMessage::user("orphan tool result");
msg.role = ChatRole::Tool;
msg.tool_call_id = None;
let result = OpenRouterProvider::convert_messages(&[msg]);
assert!(
result.is_err(),
"Tool message without tool_call_id must return Err"
);
let err_msg = result.unwrap_err().to_string();
assert!(
err_msg.contains("tool_call_id"),
"Error message must mention tool_call_id, got: {}",
err_msg
);
}
#[test]
fn test_function_role_returns_err() {
let mut msg = ChatMessage::user("result");
msg.role = ChatRole::Function;
let result = OpenRouterProvider::convert_messages(&[msg]);
assert!(
result.is_err(),
"ChatRole::Function must return Err (not supported by OpenRouter)"
);
let err_msg = result.unwrap_err().to_string();
assert!(
err_msg.contains("Function"),
"Error message must mention Function role, got: {}",
err_msg
);
}
#[test]
fn test_parse_response_empty_choices_returns_err() {
let response = ChatResponse {
id: "gen-empty".to_string(),
model: "openai/gpt-4o".to_string(),
choices: vec![],
usage: None,
};
let result = OpenRouterProvider::parse_response(response);
assert!(
result.is_err(),
"Empty choices array must return Err, not a silent empty LLMResponse"
);
}
#[test]
fn test_parse_response_uses_api_total_tokens() {
let response = ChatResponse {
id: "gen-total".to_string(),
model: "openai/gpt-4o".to_string(),
choices: vec![Choice {
index: 0,
message: Some(ResponseMessage {
role: "assistant".to_string(),
content: Some("hi".to_string()),
tool_calls: None,
}),
finish_reason: Some("stop".to_string()),
}],
usage: Some(Usage {
prompt_tokens: 8,
completion_tokens: 7,
total_tokens: Some(20),
}),
};
let lr = OpenRouterProvider::parse_response(response).unwrap();
assert_eq!(
lr.total_tokens, 20,
"total_tokens must use API-provided value, not prompt+completion sum"
);
}
#[test]
fn test_parse_response_falls_back_to_sum_when_no_total() {
let response = ChatResponse {
id: "gen-sum".to_string(),
model: "openai/gpt-4o".to_string(),
choices: vec![Choice {
index: 0,
message: Some(ResponseMessage {
role: "assistant".to_string(),
content: Some("ok".to_string()),
tool_calls: None,
}),
finish_reason: Some("stop".to_string()),
}],
usage: Some(Usage {
prompt_tokens: 5,
completion_tokens: 3,
total_tokens: None,
}),
};
let lr = OpenRouterProvider::parse_response(response).unwrap();
assert_eq!(lr.total_tokens, 8);
}
#[test]
fn test_convert_messages_forwards_name() {
let mut msg = ChatMessage::user("Hello");
msg.name = Some("Alice".to_string());
let converted = OpenRouterProvider::convert_messages(&[msg]).unwrap();
assert_eq!(converted[0].name.as_deref(), Some("Alice"));
}
#[test]
fn test_assistant_with_tool_calls_serialized() {
use crate::traits::{FunctionCall, ToolCall};
let calls = vec![ToolCall {
id: "call_xyz".to_string(),
call_type: "function".to_string(),
function: FunctionCall {
name: "get_weather".to_string(),
arguments: r#"{"city":"Paris"}"#.to_string(),
},
thought_signature: None,
}];
let msg = ChatMessage::assistant_with_tools("", calls);
let converted = OpenRouterProvider::convert_messages(&[msg]).unwrap();
assert_eq!(converted.len(), 1);
assert_eq!(converted[0].role, "assistant");
let tcs = converted[0]
.tool_calls
.as_ref()
.expect("tool_calls must be present");
assert_eq!(tcs.len(), 1);
assert_eq!(tcs[0].id, "call_xyz");
assert_eq!(tcs[0].function.name, "get_weather");
}
#[test]
fn test_chat_request_serializes_penalty_fields() {
let req = ChatRequest {
model: "openai/gpt-4o",
messages: vec![],
stream: None,
max_tokens: None,
temperature: None,
top_p: None,
stop: None,
frequency_penalty: Some(0.5),
presence_penalty: Some(-0.3),
tools: None,
tool_choice: None,
};
let json = serde_json::to_string(&req).unwrap();
assert!(
json.contains("\"frequency_penalty\":0.5"),
"frequency_penalty must appear in request JSON, got: {}",
json
);
assert!(
json.contains("\"presence_penalty\":-0.3"),
"presence_penalty must appear in request JSON, got: {}",
json
);
}
#[test]
fn test_chat_request_omits_penalty_fields_when_none() {
let req = ChatRequest {
model: "openai/gpt-4o",
messages: vec![],
stream: None,
max_tokens: None,
temperature: None,
top_p: None,
stop: None,
frequency_penalty: None,
presence_penalty: None,
tools: None,
tool_choice: None,
};
let json = serde_json::to_string(&req).unwrap();
assert!(
!json.contains("frequency_penalty"),
"frequency_penalty must be omitted when None"
);
assert!(
!json.contains("presence_penalty"),
"presence_penalty must be omitted when None"
);
}
#[test]
fn test_stream_chunk_with_error_deserializes() {
let json = r#"{
"id": "cmpl-abc",
"model": "openai/gpt-4o",
"choices": [{"index": 0, "delta": {"content": ""}, "finish_reason": "error"}],
"error": {"code": "server_error", "message": "Provider disconnected unexpectedly"}
}"#;
let chunk: StreamChunkResponse = serde_json::from_str(json).unwrap();
assert!(chunk.error.is_some(), "error field must be deserialized");
assert_eq!(
chunk.error.unwrap().message,
"Provider disconnected unexpectedly"
);
}
#[test]
fn test_stream_chunk_without_error_deserializes() {
let json = r#"{
"id": "cmpl-xyz",
"model": "openai/gpt-4o",
"choices": [{"index": 0, "delta": {"content": "Hello"}}]
}"#;
let chunk: StreamChunkResponse = serde_json::from_str(json).unwrap();
assert!(chunk.error.is_none());
assert_eq!(
chunk.choices[0].delta.as_ref().unwrap().content.as_deref(),
Some("Hello")
);
}
#[tokio::test]
#[ignore]
async fn test_chat_completion_live() {
let provider = OpenRouterProvider::from_env().expect("OPENROUTER_API_KEY not set");
let messages = vec![ChatMessage::user("Say 'hello' and nothing else.")];
let response = provider.chat(&messages, None).await;
assert!(response.is_ok());
let response = response.unwrap();
assert!(!response.content.is_empty());
assert!(response.prompt_tokens > 0);
assert!(response.completion_tokens > 0);
}
#[test]
fn test_model_info_deserialization() {
let json = r#"{
"id": "openai/gpt-4o",
"name": "GPT-4o",
"context_length": 128000,
"pricing": {
"prompt": "0.000005",
"completion": "0.000015"
},
"architecture": {
"input_modalities": ["text", "image"],
"output_modalities": ["text"]
},
"supported_parameters": ["temperature", "max_tokens"]
}"#;
let model: ModelInfo = serde_json::from_str(json).unwrap();
assert_eq!(model.id, "openai/gpt-4o");
assert_eq!(model.name, "GPT-4o");
assert_eq!(model.context_length, 128000);
assert_eq!(model.pricing.prompt, "0.000005");
assert_eq!(model.architecture.input_modalities.len(), 2);
assert!(model
.architecture
.input_modalities
.contains(&"image".to_string()));
}
#[test]
fn test_models_response_deserialization() {
let json = r#"{
"data": [
{
"id": "openai/gpt-4o",
"name": "GPT-4o",
"context_length": 128000,
"pricing": {"prompt": "0.000005", "completion": "0.000015"}
},
{
"id": "anthropic/claude-3.5-sonnet",
"name": "Claude 3.5 Sonnet",
"context_length": 200000,
"pricing": {"prompt": "0.000003", "completion": "0.000015"}
}
]
}"#;
let response: ModelsResponse = serde_json::from_str(json).unwrap();
assert_eq!(response.data.len(), 2);
assert_eq!(response.data[0].id, "openai/gpt-4o");
assert_eq!(response.data[1].id, "anthropic/claude-3.5-sonnet");
}
#[tokio::test]
async fn test_cache_initially_empty() {
let provider = OpenRouterProvider::new("test-key");
assert_eq!(provider.cached_model_count().await, 0);
assert!(
!provider
.is_cache_valid(std::time::Duration::from_secs(3600))
.await
);
}
#[tokio::test]
async fn test_provider_clone_shares_cache() {
let provider1 = OpenRouterProvider::new("test-key");
let provider2 = provider1.clone();
assert_eq!(provider1.cached_model_count().await, 0);
assert_eq!(provider2.cached_model_count().await, 0);
}
#[tokio::test]
#[ignore]
async fn test_list_models_live() {
let provider = OpenRouterProvider::from_env().expect("OPENROUTER_API_KEY not set");
let models = provider.list_models().await;
assert!(models.is_ok(), "Failed to list models: {:?}", models);
let models = models.unwrap();
assert!(!models.is_empty(), "Should have at least one model");
assert!(provider.cached_model_count().await > 0);
assert!(
provider
.is_cache_valid(std::time::Duration::from_secs(3600))
.await
);
let gpt4 = models.iter().find(|m| m.id.contains("gpt-4"));
assert!(gpt4.is_some(), "Should have a GPT-4 model");
}
#[tokio::test]
#[ignore]
async fn test_list_models_cached_live() {
use std::time::Duration;
let provider = OpenRouterProvider::from_env().expect("OPENROUTER_API_KEY not set");
let models1 = provider
.list_models_cached(Duration::from_secs(3600))
.await
.unwrap();
let count1 = models1.len();
let models2 = provider
.list_models_cached(Duration::from_secs(3600))
.await
.unwrap();
assert_eq!(models2.len(), count1, "Cache should return same models");
}
#[tokio::test]
#[ignore]
async fn test_get_model_live() {
let provider = OpenRouterProvider::from_env().expect("OPENROUTER_API_KEY not set");
let model = provider.get_model("openai/gpt-4o").await;
assert!(model.is_ok());
let model = model.unwrap();
assert!(model.is_some(), "Should find gpt-4o model");
let model = model.unwrap();
assert_eq!(model.id, "openai/gpt-4o");
assert!(model.context_length > 0);
}
#[tokio::test]
#[ignore]
async fn test_get_models_by_modality_live() {
let provider = OpenRouterProvider::from_env().expect("OPENROUTER_API_KEY not set");
let vision_models = provider.get_models_by_modality("image").await;
assert!(vision_models.is_ok());
let vision_models = vision_models.unwrap();
assert!(!vision_models.is_empty(), "Should have vision models");
for model in &vision_models {
assert!(
model
.architecture
.input_modalities
.contains(&"image".to_string()),
"Model {} should support image input",
model.id
);
}
}
#[test]
fn test_sse_line_buffering_algorithm() {
fn process_chunks(chunks: &[&str]) -> Vec<String> {
let mut line_buffer = String::new();
let mut complete_lines = Vec::new();
for chunk in chunks {
line_buffer.push_str(chunk);
while let Some(newline_idx) = line_buffer.find('\n') {
let line = line_buffer[..newline_idx].trim().to_string();
line_buffer.drain(..=newline_idx);
if !line.is_empty() {
complete_lines.push(line);
}
}
}
complete_lines
}
let chunks = vec![
"data: {\"function\":{\"name\":\"run_command\",\"arguments\":\"{\\\"command\\\":\\\"mkdir",
" -p ./demo/snake_gemini406\\\"}\"}}",
"\n",
];
let lines = process_chunks(&chunks);
assert_eq!(lines.len(), 1);
assert!(lines[0].contains("mkdir -p ./demo/snake_gemini406"));
let chunks = vec![
"data: {\"content\":\"Hello\"}\n",
"data: {\"content\":\"World\"}\n",
];
let lines = process_chunks(&chunks);
assert_eq!(lines.len(), 2);
assert!(lines[0].contains("Hello"));
assert!(lines[1].contains("World"));
let chunks = vec![
"data: {\"content\":\"Complete\"}\n",
"data: {\"content\":\"Incomplete",
];
let lines = process_chunks(&chunks);
assert_eq!(lines.len(), 1);
assert!(lines[0].contains("Complete"));
let chunks = vec!["\n", ": comment\n", "data: {\"content\":\"Real\"}\n", "\n"];
let lines = process_chunks(&chunks);
assert_eq!(lines.len(), 2); assert!(lines[1].contains("Real"));
let chunks = vec![
"data: {\"content\":\"First\"}",
"\ndata: {\"content\":\"Second\"}\n",
];
let lines = process_chunks(&chunks);
assert_eq!(lines.len(), 2);
assert!(lines[0].contains("First"));
assert!(lines[1].contains("Second"));
}
}