use crate::config::Config;
use crate::event_bus::{Event, EventBus, EventEmitter};
use crate::impl_event_emitter;
use anyhow::Result;
use async_trait::async_trait;
use std::sync::Arc;
#[async_trait]
pub trait LLMProvider: Send + Sync {
fn name(&self) -> &str;
#[allow(dead_code)]
fn context_size(&self) -> usize;
async fn send_prompt(&self, prompt: &str) -> Result<String>;
fn model_name(&self) -> &str {
"Unknown"
}
fn handles_own_metrics(&self) -> bool {
false
}
}
pub struct LocalProvider;
#[async_trait]
impl LLMProvider for LocalProvider {
fn name(&self) -> &str {
"local"
}
fn context_size(&self) -> usize {
4096
}
async fn send_prompt(&self, prompt: &str) -> Result<String> {
if let Some(task) = prompt.strip_prefix("Plan the following task:") {
let mut steps = Vec::new();
for (i, part) in task.split('.').enumerate() {
let trimmed = part.trim();
if !trimmed.is_empty() {
steps.push(format!("{}. {}", i + 1, trimmed));
}
}
if steps.is_empty() {
Ok("1. No steps generated".to_string())
} else {
Ok(steps.join("\n"))
}
} else if let Some(step) = prompt.strip_prefix("Execute step:") {
Ok(format!("Executed: {}", step.trim()))
} else if prompt.starts_with("Review") {
Ok("All good".to_string())
} else {
Ok(prompt.to_string())
}
}
fn handles_own_metrics(&self) -> bool {
false
}
}
pub struct LLMManager {
providers: Vec<Box<dyn LLMProvider>>,
event_bus: Option<Arc<EventBus>>,
config: Option<Arc<Config>>,
}
impl LLMManager {
pub fn new(
providers: Vec<Box<dyn LLMProvider>>,
event_bus: Arc<EventBus>,
config: Arc<Config>,
) -> Self {
Self {
providers,
event_bus: Some(event_bus),
config: Some(config),
}
}
#[allow(dead_code)]
pub fn provider(&self) -> &dyn LLMProvider {
&*self.providers[0]
}
pub fn get_context_size(&self) -> usize {
if self.providers.is_empty() {
4096 } else {
self.providers[0].context_size()
}
}
pub async fn send_prompt(&self, prompt: &str) -> anyhow::Result<String> {
if self.providers.is_empty() {
return Err(anyhow::anyhow!("No providers available"));
}
let provider = &self.providers[0];
if let Some(bus) = &self.event_bus {
let _ = bus
.emit(Event::APICallStarted {
provider: provider.name().to_string(),
model: provider.model_name().to_string(),
})
.await;
}
let result = provider.send_prompt(prompt).await;
if let Some(bus) = &self.event_bus {
match &result {
Ok(response) => {
if !provider.handles_own_metrics() {
let input_tokens = prompt.len() / 4;
let output_tokens = response.len() / 4;
let total_tokens = input_tokens + output_tokens;
let cost = self.calculate_cost(provider.name(), input_tokens, output_tokens);
let _ = bus
.emit(Event::APICallCompleted {
provider: provider.name().to_string(),
tokens: total_tokens,
cost,
})
.await;
}
}
Err(e) => {
let _ = bus
.emit(Event::APIError {
provider: provider.name().to_string(),
error: e.to_string(),
})
.await;
}
}
}
result
}
fn calculate_cost(
&self,
provider_name: &str,
input_tokens: usize,
output_tokens: usize,
) -> f32 {
if let Some(config) = &self.config {
let provider_config = match provider_name.to_lowercase().as_str() {
"openai" => &config.ai_providers.openai,
"anthropic" => &config.ai_providers.anthropic,
"openrouter" => &config.ai_providers.openrouter,
"gemini" => &config.ai_providers.gemini,
"xai" => &config.ai_providers.xai,
_ => return 0.0,
};
if let Some(provider_config) = provider_config {
let input_cost = provider_config.cost_per_1m_input_tokens.unwrap_or(0.0)
* (input_tokens as f32)
/ 1_000_000.0;
let output_cost = provider_config.cost_per_1m_output_tokens.unwrap_or(0.0)
* (output_tokens as f32)
/ 1_000_000.0;
return input_cost + output_cost;
}
}
0.0
}
}
impl_event_emitter!(LLMManager);