use super::agent::{ChatMessage, LlmProvider, Role};
use crate::traits::{CerebroError, Result};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LlmResponse {
pub content: String,
pub input_tokens: usize,
pub output_tokens: usize,
pub tool_call: Option<ToolCallRequest>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCallRequest {
pub tool_name: String,
pub arguments: serde_json::Value,
}
pub struct LlmClient {
http: reqwest::Client,
}
impl LlmClient {
pub fn new() -> Self {
Self {
http: reqwest::Client::new(),
}
}
pub async fn chat(
&self,
provider: &LlmProvider,
messages: &[ChatMessage],
) -> Result<LlmResponse> {
let mut resp = match provider {
LlmProvider::Ollama { model, base_url } => {
self.chat_ollama(base_url, model, messages).await?
}
LlmProvider::OpenAI { model, api_key } => {
self.chat_openai_api("https://api.openai.com/v1", api_key, model, messages)
.await?
}
LlmProvider::Gemini { model, api_key } => {
self.chat_gemini(api_key, model, messages).await?
}
LlmProvider::Anthropic {
model,
api_key,
max_tokens,
} => {
self.chat_anthropic(api_key, model, *max_tokens, messages)
.await?
}
LlmProvider::OpenAICompatible {
model,
api_key,
base_url,
..
} => {
self.chat_openai_api(base_url, api_key, model, messages)
.await?
}
};
if let Some(start) = resp.content.find("<tool_call>") {
if let Some(end) = resp.content.find("</tool_call>") {
let json_str = &resp.content[start + 11..end].trim();
if let Ok(tc) = serde_json::from_str::<ToolCallRequest>(json_str) {
resp.tool_call = Some(tc);
resp.content = resp.content[..start].trim().to_string();
}
}
}
Ok(resp)
}
async fn chat_ollama(
&self,
base_url: &str,
model: &str,
messages: &[ChatMessage],
) -> Result<LlmResponse> {
let url = format!("{}/api/chat", base_url);
let msgs: Vec<serde_json::Value> = messages
.iter()
.map(|m| {
serde_json::json!({
"role": format!("{}", m.role),
"content": m.content,
})
})
.collect();
let payload = serde_json::json!({
"model": model,
"messages": msgs,
"stream": false,
});
let resp = self
.http
.post(&url)
.json(&payload)
.send()
.await
.map_err(|e| CerebroError::EmbeddingError(format!("Ollama request failed: {}", e)))?;
if !resp.status().is_success() {
let status = resp.status();
let err_text = resp.text().await.unwrap_or_default();
return Err(CerebroError::EmbeddingError(format!(
"Ollama API error ({}): {}",
status, err_text
)));
}
let data: serde_json::Value = resp
.json()
.await
.map_err(|e| CerebroError::EmbeddingError(format!("Ollama JSON parse error: {}", e)))?;
let content = data["message"]["content"]
.as_str()
.unwrap_or("")
.to_string();
let input_tokens = data["prompt_eval_count"].as_u64().unwrap_or(0) as usize;
let output_tokens = data["eval_count"].as_u64().unwrap_or(0) as usize;
Ok(LlmResponse {
content,
input_tokens,
output_tokens,
tool_call: None,
})
}
async fn chat_openai_api(
&self,
base_url: &str,
api_key: &str,
model: &str,
messages: &[ChatMessage],
) -> Result<LlmResponse> {
let url = format!("{}/chat/completions", base_url.trim_end_matches('/'));
let msgs: Vec<serde_json::Value> = messages
.iter()
.map(|m| {
serde_json::json!({
"role": format!("{}", m.role),
"content": m.content,
})
})
.collect();
let payload = serde_json::json!({
"model": model,
"messages": msgs,
});
let resp = self
.http
.post(&url)
.bearer_auth(api_key)
.json(&payload)
.send()
.await
.map_err(|e| {
CerebroError::EmbeddingError(format!("OpenAI-compatible request failed: {}", e))
})?;
if !resp.status().is_success() {
let err_text = resp.text().await.unwrap_or_default();
return Err(CerebroError::EmbeddingError(format!(
"OpenAI-compatible API error: {}",
err_text
)));
}
let data: serde_json::Value = resp
.json()
.await
.map_err(|e| CerebroError::EmbeddingError(format!("JSON parse error: {}", e)))?;
let content = data["choices"][0]["message"]["content"]
.as_str()
.unwrap_or("")
.to_string();
let input_tokens = data["usage"]["prompt_tokens"].as_u64().unwrap_or(0) as usize;
let output_tokens = data["usage"]["completion_tokens"].as_u64().unwrap_or(0) as usize;
Ok(LlmResponse {
content,
input_tokens,
output_tokens,
tool_call: None,
})
}
async fn chat_anthropic(
&self,
api_key: &str,
model: &str,
max_tokens: usize,
messages: &[ChatMessage],
) -> Result<LlmResponse> {
let url = "https://api.anthropic.com/v1/messages";
let mut system_prompt: Option<String> = None;
let mut api_messages: Vec<serde_json::Value> = Vec::new();
for msg in messages {
match msg.role {
Role::System => {
system_prompt = Some(msg.content.clone());
}
Role::User | Role::Tool => {
api_messages.push(serde_json::json!({
"role": "user",
"content": msg.content,
}));
}
Role::Assistant => {
api_messages.push(serde_json::json!({
"role": "assistant",
"content": msg.content,
}));
}
}
}
let mut payload = serde_json::json!({
"model": model,
"max_tokens": max_tokens,
"messages": api_messages,
});
if let Some(sys) = system_prompt {
payload["system"] = serde_json::Value::String(sys);
}
let resp = self
.http
.post(url)
.header("x-api-key", api_key)
.header("anthropic-version", "2023-06-01")
.header("content-type", "application/json")
.json(&payload)
.send()
.await
.map_err(|e| {
CerebroError::EmbeddingError(format!("Anthropic request failed: {}", e))
})?;
if !resp.status().is_success() {
let err_text = resp.text().await.unwrap_or_default();
return Err(CerebroError::EmbeddingError(format!(
"Anthropic API error: {}",
err_text
)));
}
let data: serde_json::Value = resp.json().await.map_err(|e| {
CerebroError::EmbeddingError(format!("Anthropic JSON parse error: {}", e))
})?;
let content = data["content"]
.as_array()
.and_then(|blocks| {
blocks
.iter()
.filter_map(|b| {
if b["type"].as_str() == Some("text") {
b["text"].as_str().map(|s| s.to_string())
} else {
None
}
})
.reduce(|a, b| format!("{}{}", a, b))
})
.unwrap_or_default();
let input_tokens = data["usage"]["input_tokens"].as_u64().unwrap_or(0) as usize;
let output_tokens = data["usage"]["output_tokens"].as_u64().unwrap_or(0) as usize;
Ok(LlmResponse {
content,
input_tokens,
output_tokens,
tool_call: None,
})
}
async fn chat_gemini(
&self,
api_key: &str,
model: &str,
messages: &[ChatMessage],
) -> Result<LlmResponse> {
let url = format!(
"https://generativelanguage.googleapis.com/v1beta/models/{}:generateContent?key={}",
model, api_key
);
let mut contents: Vec<serde_json::Value> = Vec::new();
let mut system_instruction: Option<String> = None;
for msg in messages {
match msg.role {
Role::System => {
system_instruction = Some(msg.content.clone());
}
Role::User => {
contents.push(serde_json::json!({
"role": "user",
"parts": [{"text": msg.content}]
}));
}
Role::Assistant => {
contents.push(serde_json::json!({
"role": "model",
"parts": [{"text": msg.content}]
}));
}
Role::Tool => {
contents.push(serde_json::json!({
"role": "user",
"parts": [{"text": format!("[Tool Result]: {}", msg.content)}]
}));
}
}
}
let mut payload = serde_json::json!({
"contents": contents,
});
if let Some(sys) = system_instruction {
payload["systemInstruction"] = serde_json::json!({
"parts": [{"text": sys}]
});
}
let resp = self
.http
.post(&url)
.json(&payload)
.send()
.await
.map_err(|e| CerebroError::EmbeddingError(format!("Gemini request failed: {}", e)))?;
if !resp.status().is_success() {
let err_text = resp.text().await.unwrap_or_default();
return Err(CerebroError::EmbeddingError(format!(
"Gemini API error: {}",
err_text
)));
}
let data: serde_json::Value = resp
.json()
.await
.map_err(|e| CerebroError::EmbeddingError(format!("Gemini JSON parse error: {}", e)))?;
let content = data["candidates"][0]["content"]["parts"][0]["text"]
.as_str()
.unwrap_or("")
.to_string();
let input_tokens = data["usageMetadata"]["promptTokenCount"]
.as_u64()
.unwrap_or(0) as usize;
let output_tokens = data["usageMetadata"]["candidatesTokenCount"]
.as_u64()
.unwrap_or(0) as usize;
Ok(LlmResponse {
content,
input_tokens,
output_tokens,
tool_call: None,
})
}
}
impl Default for LlmClient {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_llm_response_serialization() {
let resp = LlmResponse {
content: "Hello world".into(),
input_tokens: 10,
output_tokens: 5,
tool_call: None,
};
let json = serde_json::to_string(&resp).unwrap();
let deser: LlmResponse = serde_json::from_str(&json).unwrap();
assert_eq!(deser.content, "Hello world");
assert_eq!(deser.input_tokens, 10);
}
#[test]
fn test_tool_call_request_serialization() {
let tc = ToolCallRequest {
tool_name: "web_search".into(),
arguments: serde_json::json!({"query": "Rust safety"}),
};
let json = serde_json::to_string(&tc).unwrap();
assert!(json.contains("web_search"));
}
}