use anyhow::{Context, Result};
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use super::types::{safe_truncate, ChatMessage, EventSink, ToolCall, ToolDefinition, ToolFormat};
mod claude;
mod openai;
#[cfg(test)]
mod tests;
pub fn detect_tool_format(model: &str, api_base: &str) -> ToolFormat {
let model_lower = model.to_lowercase();
let base_lower = api_base.to_lowercase();
if model_lower.starts_with("claude")
|| base_lower.contains("anthropic")
|| base_lower.contains("claude")
{
ToolFormat::Claude
} else {
ToolFormat::OpenAI
}
}
pub struct LlmClient {
http: reqwest::Client,
api_base: String,
api_key: String,
}
impl LlmClient {
pub fn new(api_base: &str, api_key: &str) -> Result<Self> {
let http = reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(300))
.build()
.context("build HTTP client for LLM")?;
Ok(Self {
http,
api_base: api_base.trim_end_matches('/').to_string(),
api_key: api_key.to_string(),
})
}
pub async fn chat_completion(
&self,
model: &str,
messages: &[ChatMessage],
tools: Option<&[ToolDefinition]>,
temperature: Option<f64>,
) -> Result<ChatCompletionResponse> {
let format = detect_tool_format(model, &self.api_base);
match format {
ToolFormat::Claude => {
self.claude_chat_completion(model, messages, tools, temperature)
.await
}
ToolFormat::OpenAI => {
self.openai_chat_completion(model, messages, tools, temperature)
.await
}
}
}
pub async fn chat_completion_stream(
&self,
model: &str,
messages: &[ChatMessage],
tools: Option<&[ToolDefinition]>,
temperature: Option<f64>,
event_sink: &mut dyn EventSink,
) -> Result<ChatCompletionResponse> {
let format = detect_tool_format(model, &self.api_base);
match format {
ToolFormat::Claude => {
self.claude_chat_completion_stream(model, messages, tools, temperature, event_sink)
.await
}
ToolFormat::OpenAI => {
self.openai_chat_completion_stream(model, messages, tools, temperature, event_sink)
.await
}
}
}
#[allow(dead_code)]
pub async fn embed(
&self,
model: &str,
texts: &[&str],
custom_url: Option<&str>,
custom_key: Option<&str>,
) -> Result<Vec<Vec<f32>>> {
if texts.is_empty() {
return Ok(Vec::new());
}
let api_base = custom_url.unwrap_or(&self.api_base);
let api_key = custom_key.unwrap_or(&self.api_key);
let base = api_base.trim_end_matches('/');
let url = if api_base.to_lowercase().contains("minimax") {
format!("{}/text/embeddings", base)
} else {
format!("{}/embeddings", base)
};
let input: Value = if texts.len() == 1 {
json!(texts[0])
} else {
json!(texts.iter().map(|s| s.to_string()).collect::<Vec<_>>())
};
let body = json!({ "model": model, "input": input });
let resp = self
.http
.post(&url)
.header("Authorization", format!("Bearer {}", api_key))
.header("Content-Type", "application/json")
.json(&body)
.send()
.await
.context("Embedding API request failed")?;
let status = resp.status();
if !status.is_success() {
let body_text = resp.text().await.unwrap_or_default();
anyhow::bail!("Embedding API error ({}): {}", status, body_text);
}
let json: Value = resp
.json()
.await
.context("Failed to parse embedding response")?;
if let Some(data) = json.get("data").and_then(|d| d.as_array()) {
return Self::extract_embeddings_from_items(data);
}
if let Some(items) = json
.get("output")
.and_then(|o| o.get("embeddings"))
.and_then(|e| e.as_array())
{
tracing::debug!("Embedding response uses Dashscope native format (output.embeddings)");
return Self::extract_embeddings_from_items(items);
}
let preview = serde_json::to_string(&json).unwrap_or_default();
let preview = &preview[..preview.len().min(500)];
anyhow::bail!(
"Unexpected embedding response format (no 'data' or 'output.embeddings'): {}",
preview
)
}
fn extract_embeddings_from_items(items: &[Value]) -> Result<Vec<Vec<f32>>> {
let mut embeddings = Vec::with_capacity(items.len());
for item in items {
let emb = item
.get("embedding")
.and_then(|e| e.as_array())
.context("Missing 'embedding' in embedding item")?;
let vec: Vec<f32> = emb
.iter()
.filter_map(|v| v.as_f64().map(|f| f as f32))
.collect();
embeddings.push(vec);
}
Ok(embeddings)
}
}
#[derive(Debug, Deserialize)]
#[allow(dead_code)]
pub struct ChatCompletionResponse {
pub id: String,
pub model: String,
pub choices: Vec<Choice>,
pub usage: Option<Usage>,
}
#[derive(Debug, Deserialize)]
#[allow(dead_code)]
pub struct Choice {
pub index: u32,
pub message: ChoiceMessage,
pub finish_reason: Option<String>,
}
#[derive(Debug, Deserialize)]
#[allow(dead_code)]
pub struct ChoiceMessage {
pub role: String,
pub content: Option<String>,
#[serde(default)]
pub reasoning_content: Option<String>,
pub tool_calls: Option<Vec<ToolCall>>,
}
#[derive(Debug, Deserialize, Serialize)]
pub struct Usage {
pub prompt_tokens: u64,
pub completion_tokens: u64,
pub total_tokens: u64,
}
pub fn is_context_overflow_error(err_msg: &str) -> bool {
let lower = err_msg.to_lowercase();
lower.contains("context_length_exceeded")
|| lower.contains("maximum context length")
|| lower.contains("token limit")
|| lower.contains("too many tokens")
|| lower.contains("context window")
|| lower.contains("max_tokens")
}
pub fn truncate_tool_messages(messages: &mut [ChatMessage], max_chars: usize) {
for msg in messages.iter_mut() {
if msg.role == "tool" {
if let Some(ref mut content) = msg.content {
if content.len() > max_chars {
let truncated = format!(
"{}...\n[truncated: {} chars → {}]",
safe_truncate(content, max_chars),
content.len(),
max_chars
);
*content = truncated;
}
}
}
}
}