use std::sync::Arc;
use tokio::sync::RwLock;
use super::config::BehaviorModelConfig;
use super::types::LlmGenerationRequest;
use crate::Result;
pub struct LlmClient {
rag_engine: Arc<RwLock<Option<Box<dyn LlmProvider>>>>,
config: BehaviorModelConfig,
}
impl LlmClient {
pub fn new(config: BehaviorModelConfig) -> Self {
Self {
rag_engine: Arc::new(RwLock::new(None)),
config,
}
}
async fn ensure_initialized(&self) -> Result<()> {
let mut engine = self.rag_engine.write().await;
if engine.is_none() {
let provider = self.create_provider()?;
*engine = Some(provider);
}
Ok(())
}
fn create_provider(&self) -> Result<Box<dyn LlmProvider>> {
match self.config.llm_provider.to_lowercase().as_str() {
"openai" => Ok(Box::new(OpenAIProvider::new(&self.config)?)),
"anthropic" => Ok(Box::new(AnthropicProvider::new(&self.config)?)),
"ollama" => Ok(Box::new(OllamaProvider::new(&self.config)?)),
"openai-compatible" => Ok(Box::new(OpenAICompatibleProvider::new(&self.config)?)),
_ => Err(crate::Error::internal(format!(
"Unsupported LLM provider: {}",
self.config.llm_provider
))),
}
}
pub async fn generate(&self, request: &LlmGenerationRequest) -> Result<serde_json::Value> {
self.ensure_initialized().await?;
let engine = self.rag_engine.read().await;
let provider = engine
.as_ref()
.ok_or_else(|| crate::Error::internal("LLM provider not initialized"))?;
let messages = vec![
ChatMessage {
role: "system".to_string(),
content: request.system_prompt.clone(),
},
ChatMessage {
role: "user".to_string(),
content: request.user_prompt.clone(),
},
];
let response_text = provider
.generate_chat(messages, request.temperature, request.max_tokens)
.await?;
match serde_json::from_str::<serde_json::Value>(&response_text) {
Ok(json) => Ok(json),
Err(_) => {
if let Some(start) = response_text.find('{') {
if let Some(end) = response_text.rfind('}') {
let json_str = &response_text[start..=end];
if let Ok(json) = serde_json::from_str::<serde_json::Value>(json_str) {
return Ok(json);
}
}
}
Ok(serde_json::json!({
"response": response_text,
"note": "Response was not valid JSON, wrapped in object"
}))
}
}
}
pub async fn generate_with_usage(
&self,
request: &LlmGenerationRequest,
) -> Result<(serde_json::Value, LlmUsage)> {
self.ensure_initialized().await?;
let engine = self.rag_engine.read().await;
let provider = engine
.as_ref()
.ok_or_else(|| crate::Error::internal("LLM provider not initialized"))?;
let messages = vec![
ChatMessage {
role: "system".to_string(),
content: request.system_prompt.clone(),
},
ChatMessage {
role: "user".to_string(),
content: request.user_prompt.clone(),
},
];
let (response_text, usage) = provider
.generate_chat_with_usage(messages, request.temperature, request.max_tokens)
.await?;
let json_value = match serde_json::from_str::<serde_json::Value>(&response_text) {
Ok(json) => json,
Err(_) => {
if let Some(start) = response_text.find('{') {
if let Some(end) = response_text.rfind('}') {
let json_str = &response_text[start..=end];
if let Ok(json) = serde_json::from_str::<serde_json::Value>(json_str) {
json
} else {
serde_json::json!({
"response": response_text,
"note": "Response was not valid JSON, wrapped in object"
})
}
} else {
serde_json::json!({
"response": response_text,
"note": "Response was not valid JSON, wrapped in object"
})
}
} else {
serde_json::json!({
"response": response_text,
"note": "Response was not valid JSON, wrapped in object"
})
}
}
};
Ok((json_value, usage))
}
pub fn config(&self) -> &BehaviorModelConfig {
&self.config
}
}
#[derive(Debug, Clone)]
struct ChatMessage {
role: String,
content: String,
}
#[derive(Debug, Clone, Default)]
pub struct LlmUsage {
pub prompt_tokens: u64,
pub completion_tokens: u64,
pub total_tokens: u64,
}
impl LlmUsage {
pub fn new(prompt_tokens: u64, completion_tokens: u64) -> Self {
Self {
prompt_tokens,
completion_tokens,
total_tokens: prompt_tokens + completion_tokens,
}
}
}
#[async_trait::async_trait]
trait LlmProvider: Send + Sync {
async fn generate_chat(
&self,
messages: Vec<ChatMessage>,
temperature: f64,
max_tokens: usize,
) -> Result<String>;
async fn generate_chat_with_usage(
&self,
messages: Vec<ChatMessage>,
temperature: f64,
max_tokens: usize,
) -> Result<(String, LlmUsage)> {
let response = self.generate_chat(messages, temperature, max_tokens).await?;
let estimated_tokens = (response.len() as f64 / 4.0) as u64;
Ok((response, LlmUsage::new(estimated_tokens, estimated_tokens)))
}
}
struct OpenAIProvider {
client: reqwest::Client,
api_key: String,
model: String,
endpoint: String,
}
impl OpenAIProvider {
fn new(config: &BehaviorModelConfig) -> Result<Self> {
let api_key = config
.api_key
.clone()
.or_else(|| std::env::var("OPENAI_API_KEY").ok())
.ok_or_else(|| crate::Error::internal("OpenAI API key not found"))?;
let endpoint = config
.api_endpoint
.clone()
.unwrap_or_else(|| "https://api.openai.com/v1/chat/completions".to_string());
Ok(Self {
client: reqwest::Client::new(),
api_key,
model: config.model.clone(),
endpoint,
})
}
}
#[async_trait::async_trait]
impl LlmProvider for OpenAIProvider {
async fn generate_chat(
&self,
messages: Vec<ChatMessage>,
temperature: f64,
max_tokens: usize,
) -> Result<String> {
let request_body = serde_json::json!({
"model": self.model,
"messages": messages.iter().map(|m| {
serde_json::json!({
"role": m.role,
"content": m.content
})
}).collect::<Vec<_>>(),
"temperature": temperature,
"max_tokens": max_tokens,
});
let response = self
.client
.post(&self.endpoint)
.header("Authorization", format!("Bearer {}", self.api_key))
.header("Content-Type", "application/json")
.json(&request_body)
.send()
.await
.map_err(|e| crate::Error::internal(format!("OpenAI API request failed: {}", e)))?;
if !response.status().is_success() {
let error_text = response.text().await.unwrap_or_default();
return Err(crate::Error::internal(format!("OpenAI API error: {}", error_text)));
}
let response_json: serde_json::Value = response.json().await.map_err(|e| {
crate::Error::internal(format!("Failed to parse OpenAI response: {}", e))
})?;
let content = response_json["choices"][0]["message"]["content"]
.as_str()
.ok_or_else(|| crate::Error::internal("Invalid OpenAI response format"))?
.to_string();
Ok(content)
}
async fn generate_chat_with_usage(
&self,
messages: Vec<ChatMessage>,
temperature: f64,
max_tokens: usize,
) -> Result<(String, LlmUsage)> {
let request_body = serde_json::json!({
"model": self.model,
"messages": messages.iter().map(|m| {
serde_json::json!({
"role": m.role,
"content": m.content
})
}).collect::<Vec<_>>(),
"temperature": temperature,
"max_tokens": max_tokens,
});
let response = self
.client
.post(&self.endpoint)
.header("Authorization", format!("Bearer {}", self.api_key))
.header("Content-Type", "application/json")
.json(&request_body)
.send()
.await
.map_err(|e| crate::Error::internal(format!("OpenAI API request failed: {}", e)))?;
if !response.status().is_success() {
let error_text = response.text().await.unwrap_or_default();
return Err(crate::Error::internal(format!("OpenAI API error: {}", error_text)));
}
let response_json: serde_json::Value = response.json().await.map_err(|e| {
crate::Error::internal(format!("Failed to parse OpenAI response: {}", e))
})?;
let content = response_json["choices"][0]["message"]["content"]
.as_str()
.ok_or_else(|| crate::Error::internal("Invalid OpenAI response format"))?
.to_string();
let usage = if let Some(usage_obj) = response_json.get("usage") {
LlmUsage::new(
usage_obj["prompt_tokens"].as_u64().unwrap_or(0),
usage_obj["completion_tokens"].as_u64().unwrap_or(0),
)
} else {
let estimated = (content.len() as f64 / 4.0) as u64;
LlmUsage::new(estimated, estimated)
};
Ok((content, usage))
}
}
struct OllamaProvider {
client: reqwest::Client,
model: String,
endpoint: String,
}
impl OllamaProvider {
fn new(config: &BehaviorModelConfig) -> Result<Self> {
let endpoint = config
.api_endpoint
.clone()
.unwrap_or_else(|| "http://localhost:11434/api/chat".to_string());
Ok(Self {
client: reqwest::Client::new(),
model: config.model.clone(),
endpoint,
})
}
}
#[async_trait::async_trait]
impl LlmProvider for OllamaProvider {
async fn generate_chat(
&self,
messages: Vec<ChatMessage>,
temperature: f64,
max_tokens: usize,
) -> Result<String> {
let request_body = serde_json::json!({
"model": self.model,
"messages": messages.iter().map(|m| {
serde_json::json!({
"role": m.role,
"content": m.content
})
}).collect::<Vec<_>>(),
"options": {
"temperature": temperature,
"num_predict": max_tokens,
},
"stream": false,
});
let response = self
.client
.post(&self.endpoint)
.header("Content-Type", "application/json")
.json(&request_body)
.send()
.await
.map_err(|e| crate::Error::internal(format!("Ollama API request failed: {}", e)))?;
if !response.status().is_success() {
let error_text = response.text().await.unwrap_or_default();
return Err(crate::Error::internal(format!("Ollama API error: {}", error_text)));
}
let response_json: serde_json::Value = response.json().await.map_err(|e| {
crate::Error::internal(format!("Failed to parse Ollama response: {}", e))
})?;
let content = response_json["message"]["content"]
.as_str()
.ok_or_else(|| crate::Error::internal("Invalid Ollama response format"))?
.to_string();
Ok(content)
}
}
struct AnthropicProvider {
client: reqwest::Client,
api_key: String,
model: String,
endpoint: String,
}
impl AnthropicProvider {
fn new(config: &BehaviorModelConfig) -> Result<Self> {
let api_key = config
.api_key
.clone()
.or_else(|| std::env::var("ANTHROPIC_API_KEY").ok())
.ok_or_else(|| crate::Error::internal("Anthropic API key not found"))?;
let endpoint = config
.api_endpoint
.clone()
.unwrap_or_else(|| "https://api.anthropic.com/v1/messages".to_string());
Ok(Self {
client: reqwest::Client::new(),
api_key,
model: config.model.clone(),
endpoint,
})
}
}
#[async_trait::async_trait]
impl LlmProvider for AnthropicProvider {
async fn generate_chat(
&self,
messages: Vec<ChatMessage>,
temperature: f64,
max_tokens: usize,
) -> Result<String> {
let system_message =
messages.iter().find(|m| m.role == "system").map(|m| m.content.clone());
let chat_messages: Vec<_> = messages
.iter()
.filter(|m| m.role != "system")
.map(|m| {
serde_json::json!({
"role": m.role,
"content": m.content
})
})
.collect();
let mut request_body = serde_json::json!({
"model": self.model,
"messages": chat_messages,
"temperature": temperature,
"max_tokens": max_tokens,
});
if let Some(system) = system_message {
request_body["system"] = serde_json::Value::String(system);
}
let response = self
.client
.post(&self.endpoint)
.header("x-api-key", &self.api_key)
.header("anthropic-version", "2023-06-01")
.header("Content-Type", "application/json")
.json(&request_body)
.send()
.await
.map_err(|e| crate::Error::internal(format!("Anthropic API request failed: {}", e)))?;
if !response.status().is_success() {
let error_text = response.text().await.unwrap_or_default();
return Err(crate::Error::internal(format!("Anthropic API error: {}", error_text)));
}
let response_json: serde_json::Value = response.json().await.map_err(|e| {
crate::Error::internal(format!("Failed to parse Anthropic response: {}", e))
})?;
let content = response_json["content"][0]["text"]
.as_str()
.ok_or_else(|| crate::Error::internal("Invalid Anthropic response format"))?
.to_string();
Ok(content)
}
}
struct OpenAICompatibleProvider {
client: reqwest::Client,
api_key: Option<String>,
model: String,
endpoint: String,
}
impl OpenAICompatibleProvider {
fn new(config: &BehaviorModelConfig) -> Result<Self> {
let endpoint = config.api_endpoint.clone().ok_or_else(|| {
crate::Error::internal("API endpoint required for OpenAI-compatible provider")
})?;
Ok(Self {
client: reqwest::Client::new(),
api_key: config.api_key.clone(),
model: config.model.clone(),
endpoint,
})
}
}
#[async_trait::async_trait]
impl LlmProvider for OpenAICompatibleProvider {
async fn generate_chat(
&self,
messages: Vec<ChatMessage>,
temperature: f64,
max_tokens: usize,
) -> Result<String> {
let request_body = serde_json::json!({
"model": self.model,
"messages": messages.iter().map(|m| {
serde_json::json!({
"role": m.role,
"content": m.content
})
}).collect::<Vec<_>>(),
"temperature": temperature,
"max_tokens": max_tokens,
});
let mut request =
self.client.post(&self.endpoint).header("Content-Type", "application/json");
if let Some(api_key) = &self.api_key {
request = request.header("Authorization", format!("Bearer {}", api_key));
}
let response = request
.json(&request_body)
.send()
.await
.map_err(|e| crate::Error::internal(format!("API request failed: {}", e)))?;
if !response.status().is_success() {
let error_text = response.text().await.unwrap_or_default();
return Err(crate::Error::internal(format!("API error: {}", error_text)));
}
let response_json: serde_json::Value = response
.json()
.await
.map_err(|e| crate::Error::internal(format!("Failed to parse API response: {}", e)))?;
let content = response_json["choices"][0]["message"]["content"]
.as_str()
.or_else(|| response_json["message"]["content"].as_str())
.ok_or_else(|| crate::Error::internal("Invalid API response format"))?
.to_string();
Ok(content)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_llm_client_creation() {
let config = BehaviorModelConfig::default();
let client = LlmClient::new(config);
assert_eq!(client.config().llm_provider, "openai");
}
}