//! LLM client integration using Rig framework (rig-core 0.14.0 compatible)
//!
//! This module provides a professional LLM client that leverages the Rig framework
//! for token counting, cost tracking, and provider abstraction.
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use tokio::sync::RwLock;
use uuid::Uuid;
use rig::{
providers::{anthropic, ollama, openai},
completion::Prompt,
agent::Agent,
client::CompletionClient,
};
use crate::{
AgentId, CostRecord, CostTracker, MultiAgentResult, TokenUsageRecord, TokenUsageTracker,
};
/// LLM client configuration
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LlmClientConfig {
/// Provider type (openai, anthropic, ollama)
pub provider: String,
/// Model name
pub model: String,
/// API key (optional for Ollama, required for others)
pub api_key: Option<String>,
/// Base URL (for custom endpoints)
pub base_url: Option<String>,
/// Default temperature
pub temperature: f32,
/// Max tokens per request
pub max_tokens: u32,
/// Request timeout in seconds
pub timeout_seconds: u64,
/// Enable cost tracking
pub track_costs: bool,
}
impl Default for LlmClientConfig {
fn default() -> Self {
Self {
provider: "openai".to_string(),
model: "gpt-3.5-turbo".to_string(),
api_key: None,
base_url: None,
temperature: 0.7,
max_tokens: 4096,
timeout_seconds: 30,
track_costs: true,
}
}
}
/// Message role in conversation
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum MessageRole {
System,
User,
Assistant,
Tool,
}
impl std::fmt::Display for MessageRole {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
MessageRole::System => write!(f, "System"),
MessageRole::User => write!(f, "User"),
MessageRole::Assistant => write!(f, "Assistant"),
MessageRole::Tool => write!(f, "Tool"),
}
}
}
/// LLM message
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LlmMessage {
pub role: MessageRole,
pub content: String,
pub timestamp: DateTime<Utc>,
}
impl LlmMessage {
pub fn system(content: String) -> Self {
Self {
role: MessageRole::System,
content,
timestamp: Utc::now(),
}
}
pub fn user(content: String) -> Self {
Self {
role: MessageRole::User,
content,
timestamp: Utc::now(),
}
}
pub fn assistant(content: String) -> Self {
Self {
role: MessageRole::Assistant,
content,
timestamp: Utc::now(),
}
}
}
/// LLM request parameters
#[derive(Debug, Clone)]
pub struct LlmRequest {
pub messages: Vec<LlmMessage>,
pub temperature: Option<f32>,
pub max_tokens: Option<u32>,
pub metadata: std::collections::HashMap<String, String>,
}
impl LlmRequest {
pub fn new(messages: Vec<LlmMessage>) -> Self {
Self {
messages,
temperature: None,
max_tokens: None,
metadata: std::collections::HashMap::new(),
}
}
pub fn with_temperature(mut self, temperature: f32) -> Self {
self.temperature = Some(temperature);
self
}
pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
self.max_tokens = Some(max_tokens);
self
}
pub fn with_metadata(mut self, key: String, value: String) -> Self {
self.metadata.insert(key, value);
self
}
}
/// LLM response with detailed usage information
#[derive(Debug, Clone)]
pub struct LlmResponse {
pub content: String,
pub model: String,
pub usage: TokenUsage,
pub request_id: Uuid,
pub timestamp: DateTime<Utc>,
pub duration_ms: u64,
pub finish_reason: String,
}
/// Token usage information
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TokenUsage {
pub input_tokens: u64,
pub output_tokens: u64,
pub total_tokens: u64,
}
impl TokenUsage {
pub fn new(input_tokens: u64, output_tokens: u64) -> Self {
Self {
input_tokens,
output_tokens,
total_tokens: input_tokens + output_tokens,
}
}
}
/// Enum to abstract over different agent types from different providers
#[derive(Clone)]
pub enum LlmAgent {
OpenAI(Arc<Agent<openai::responses_api::ResponsesCompletionModel>>),
Anthropic(Arc<Agent<anthropic::completion::CompletionModel>>),
Ollama(Arc<Agent<ollama::CompletionModel>>),
}
impl LlmAgent {
pub async fn prompt(&self, prompt: &str) -> Result<String, String> {
match self {
LlmAgent::OpenAI(agent) => {
agent.as_ref().prompt(prompt).await.map_err(|e| e.to_string())
}
LlmAgent::Anthropic(agent) => {
agent.as_ref().prompt(prompt).await.map_err(|e| e.to_string())
}
LlmAgent::Ollama(agent) => {
agent.as_ref().prompt(prompt).await.map_err(|e| e.to_string())
}
}
}
}
/// Professional LLM client using Rig framework
pub struct RigLlmClient {
config: LlmClientConfig,
agent_id: AgentId,
token_tracker: Arc<RwLock<TokenUsageTracker>>,
cost_tracker: Arc<RwLock<CostTracker>>,
agent: LlmAgent,
}
impl RigLlmClient {
/// Create a new LLM client
pub async fn new(
config: LlmClientConfig,
agent_id: AgentId,
token_tracker: Arc<RwLock<TokenUsageTracker>>,
cost_tracker: Arc<RwLock<CostTracker>>,
) -> MultiAgentResult<Self> {
let agent = match config.provider.as_str() {
"openai" => {
let client = openai::Client::new(
&config.api_key.clone().ok_or_else(|| anyhow::anyhow!("OpenAI API key required"))?
);
let agent = client
.agent(&config.model)
.preamble("You are a helpful AI assistant.")
.build();
LlmAgent::OpenAI(Arc::new(agent))
},
"anthropic" => {
// Anthropic client in rig-core 0.5.0 requires more parameters
let api_key = config.api_key.clone().ok_or_else(|| anyhow::anyhow!("Anthropic API key required"))?;
// Using defaults for other parameters
let client = anthropic::Client::new(
&api_key,
"https://api.anthropic.com", // base_url
None, // betas
"2023-06-01" // version
);
let agent = client
.agent(&config.model)
.preamble("You are a helpful AI assistant.")
.build();
LlmAgent::Anthropic(Arc::new(agent))
},
"ollama" => {
let client = if let Some(base_url) = config.base_url.clone() {
ollama::Client::from_url(&base_url)
} else {
ollama::Client::new()
};
let agent = client
.agent(&config.model)
.preamble("You are a helpful AI assistant.")
.build();
LlmAgent::Ollama(Arc::new(agent))
},
_ => {
return Err(anyhow::anyhow!(
"Provider '{}' is not supported. Supported providers: openai, anthropic, ollama",
config.provider
).into());
},
};
log::info!(
"Created LLM client for agent {} using {} model {}",
agent_id,
config.provider,
config.model
);
Ok(Self {
config,
agent_id,
token_tracker,
cost_tracker,
agent,
})
}
/// Generate completion using Rig framework
pub async fn complete(&self, request: LlmRequest) -> MultiAgentResult<LlmResponse> {
let start_time = Utc::now();
let request_id = Uuid::new_v4();
log::debug!(
"Starting LLM completion for agent {} (request: {})",
self.agent_id,
request_id
);
// Convert messages to a single prompt
let prompt = request
.messages
.iter()
.map(|msg| format!("{}: {}", msg.role, msg.content))
.collect::<Vec<_>>()
.join("\n");
// Use Rig agent to generate completion
let response_content = self.agent
.prompt(&prompt)
.await
.map_err(|e| anyhow::anyhow!("Rig agent prompt error: {}", e))?;
let end_time = Utc::now();
let duration_ms = (end_time - start_time).num_milliseconds() as u64;
// Estimate token usage (Rig doesn't always provide detailed usage)
let input_tokens = (prompt.len() / 4) as u64; // rough estimation: 4 chars = 1 token
let output_tokens = (response_content.len() / 4) as u64;
let response = LlmResponse {
content: response_content,
model: self.config.model.clone(),
usage: TokenUsage::new(input_tokens, output_tokens),
request_id,
timestamp: start_time,
duration_ms,
finish_reason: "completed".to_string(),
};
// Track usage
self.track_usage(&response).await?;
log::debug!(
"Completed LLM request {} in {}ms (tokens: {} input, {} output)",
request_id,
response.duration_ms,
response.usage.input_tokens,
response.usage.output_tokens
);
Ok(response)
}
/// Generate streaming completion
pub async fn stream_complete(
&self,
request: LlmRequest,
) -> MultiAgentResult<tokio::sync::mpsc::Receiver<String>> {
let (tx, rx) = tokio::sync::mpsc::channel(100);
// Convert messages to prompt
let prompt = request
.messages
.iter()
.map(|msg| format!("{}: {}", msg.role, msg.content))
.collect::<Vec<_>>()
.join("\n");
// Clone what we need for the spawned task
let agent = self.agent.clone();
let agent_id = self.agent_id;
tokio::spawn(async move {
// TODO: Implement streaming when Rig supports it better
// For now, generate full response and send as chunks
match agent.prompt(&prompt).await {
Ok(response) => {
// Split response into chunks for streaming effect
let words: Vec<&str> = response.split_whitespace().collect();
for word in words {
if tx.send(format!("{} ", word)).await.is_err() {
break; // Receiver dropped
}
tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
}
},
Err(e) => {
log::error!("Streaming completion error for agent {}: {}", agent_id, e);
}
}
});
log::debug!(
"Started streaming completion for agent {} using Rig",
self.agent_id
);
Ok(rx)
}
/// Check if model is available
pub async fn check_availability(&self) -> MultiAgentResult<bool> {
// TODO: Implement actual availability check
log::debug!(
"Checking availability for {} model {} (mock: available)",
self.config.provider,
self.config.model
);
Ok(true)
}
/// Get model capabilities
pub fn get_capabilities(&self) -> ModelCapabilities {
// TODO: Get actual capabilities from Rig model info
ModelCapabilities {
max_context_tokens: match self.config.model.as_str() {
model if model.contains("gpt-4") => 128000,
model if model.contains("gpt-3.5") => 16384,
model if model.contains("claude-3") => 200000,
_ => 8192,
},
supports_streaming: true,
supports_function_calling: self.config.model.contains("gpt-4")
|| self.config.model.contains("gpt-3.5"),
supports_vision: self.config.model.contains("vision")
|| self.config.model.contains("gpt-4"),
}
}
// Private methods
/// Track token usage and costs
async fn track_usage(&self, response: &LlmResponse) -> MultiAgentResult<()> {
// Track tokens
{
let mut tracker = self.token_tracker.write().await;
let record = TokenUsageRecord {
request_id: response.request_id,
timestamp: response.timestamp,
agent_id: self.agent_id,
model: response.model.clone(),
input_tokens: response.usage.input_tokens,
output_tokens: response.usage.output_tokens,
total_tokens: response.usage.total_tokens,
cost_usd: self.calculate_cost(&response.usage).await?,
duration_ms: response.duration_ms,
quality_score: None, // TODO: Implement quality scoring
};
tracker.add_record(record)?;
}
// Track costs if enabled
if self.config.track_costs {
let mut cost_tracker = self.cost_tracker.write().await;
let cost = self.calculate_cost(&response.usage).await?;
let cost_record = CostRecord {
timestamp: response.timestamp,
agent_id: self.agent_id,
operation_type: "llm_completion".to_string(),
cost_usd: cost,
metadata: [
("model".to_string(), response.model.clone()),
(
"input_tokens".to_string(),
response.usage.input_tokens.to_string(),
),
(
"output_tokens".to_string(),
response.usage.output_tokens.to_string(),
),
]
.into(),
};
cost_tracker.add_record(cost_record)?;
}
Ok(())
}
/// Calculate cost based on token usage
async fn calculate_cost(&self, usage: &TokenUsage) -> MultiAgentResult<f64> {
// TODO: Get actual pricing from Rig or pricing database
let (input_cost_per_1k, output_cost_per_1k) = match self.config.model.as_str() {
"gpt-4" => (0.03, 0.06),
"gpt-4-turbo" => (0.01, 0.03),
"gpt-3.5-turbo" => (0.0015, 0.002),
"claude-3-opus" => (0.015, 0.075),
"claude-3-sonnet" => (0.003, 0.015),
"claude-3-haiku" => (0.00025, 0.00125),
_ => (0.001, 0.002), // Default fallback
};
let input_cost = (usage.input_tokens as f64 / 1000.0) * input_cost_per_1k;
let output_cost = (usage.output_tokens as f64 / 1000.0) * output_cost_per_1k;
Ok(input_cost + output_cost)
}
}
/// Model capabilities information
#[derive(Debug, Clone)]
pub struct ModelCapabilities {
pub max_context_tokens: u64,
pub supports_streaming: bool,
pub supports_function_calling: bool,
pub supports_vision: bool,
}
/// Extract LLM configuration from role extra parameters
use ahash::AHashMap;
pub fn extract_llm_config(extra: &AHashMap<String, serde_json::Value>) -> LlmClientConfig {
let mut config = LlmClientConfig::default();
if let Some(provider) = extra.get("llm_provider") {
if let Some(provider_str) = provider.as_str() {
config.provider = provider_str.to_string();
}
}
if let Some(model) = extra.get("llm_model") {
if let Some(model_str) = model.as_str() {
config.model = model_str.to_string();
}
}
if let Some(api_key) = extra.get("llm_api_key") {
if let Some(key_str) = api_key.as_str() {
config.api_key = Some(key_str.to_string());
}
}
if let Some(base_url) = extra.get("llm_base_url") {
if let Some(url_str) = base_url.as_str() {
config.base_url = Some(url_str.to_string());
}
}
if let Some(temperature) = extra.get("llm_temperature") {
if let Some(temp_f64) = temperature.as_f64() {
config.temperature = temp_f64 as f32;
}
}
if let Some(max_tokens) = extra.get("llm_max_tokens") {
if let Some(tokens_u64) = max_tokens.as_u64() {
config.max_tokens = tokens_u64 as u32;
}
}
config
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_llm_message_creation() {
let system_msg = LlmMessage::system("You are a helpful assistant".to_string());
assert!(matches!(system_msg.role, MessageRole::System));
assert_eq!(system_msg.content, "You are a helpful assistant");
let user_msg = LlmMessage::user("Hello!".to_string());
assert!(matches!(user_msg.role, MessageRole::User));
assert_eq!(user_msg.content, "Hello!");
}
#[test]
fn test_llm_request_builder() {
let messages = vec![
LlmMessage::system("System prompt".to_string()),
LlmMessage::user("User message".to_string()),
];
let request = LlmRequest::new(messages)
.with_temperature(0.8)
.with_max_tokens(2048)
.with_metadata("test_key".to_string(), "test_value".to_string());
assert_eq!(request.temperature, Some(0.8));
assert_eq!(request.max_tokens, Some(2048));
assert_eq!(
request.metadata.get("test_key"),
Some(&"test_value".to_string())
);
}
#[test]
fn test_extract_llm_config() {
let mut extra = ahash::AHashMap::new();
extra.insert(
"llm_provider".to_string(),
serde_json::Value::String("anthropic".to_string()),
);
extra.insert(
"llm_model".to_string(),
serde_json::Value::String("claude-3-sonnet".to_string()),
);
extra.insert(
"llm_temperature".to_string(),
serde_json::Value::Number(serde_json::Number::from_f64(0.5).unwrap()),
);
let config = extract_llm_config(&extra);
assert_eq!(config.provider, "anthropic");
assert_eq!(config.model, "claude-3-sonnet");
assert_eq!(config.temperature, 0.5);
}
#[test]
fn test_token_usage_calculation() {
let usage = TokenUsage::new(100, 50);
assert_eq!(usage.input_tokens, 100);
assert_eq!(usage.output_tokens, 50);
assert_eq!(usage.total_tokens, 150);
}
}