use std::sync::Arc;
use std::time::Duration;
use tracing::{debug, info, warn};
use super::config::LlmConfig;
use super::error::{LlmError, LlmResult};
use super::fallback::{FallbackChain, FallbackStep};
use crate::throttle::ConcurrencyController;
#[derive(Clone)]
pub struct LlmExecutor {
config: LlmConfig,
throttle: Option<Arc<ConcurrencyController>>,
fallback: Option<Arc<FallbackChain>>,
}
impl std::fmt::Debug for LlmExecutor {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("LlmExecutor")
.field("model", &self.config.model)
.field("endpoint", &self.config.endpoint)
.field("has_throttle", &self.throttle.is_some())
.field("has_fallback", &self.fallback.is_some())
.finish()
}
}
impl LlmExecutor {
pub fn new(config: LlmConfig) -> Self {
Self {
config,
throttle: None,
fallback: None,
}
}
pub fn with_defaults() -> Self {
Self::new(LlmConfig::default())
}
pub fn for_model(model: impl Into<String>) -> Self {
Self::new(LlmConfig::new(model))
}
pub fn with_throttle(mut self, controller: ConcurrencyController) -> Self {
self.throttle = Some(Arc::new(controller));
self
}
pub fn with_shared_throttle(mut self, controller: Arc<ConcurrencyController>) -> Self {
self.throttle = Some(controller);
self
}
pub fn with_fallback(mut self, chain: FallbackChain) -> Self {
self.fallback = Some(Arc::new(chain));
self
}
pub fn with_shared_fallback(mut self, chain: Arc<FallbackChain>) -> Self {
self.fallback = Some(chain);
self
}
pub fn config(&self) -> &LlmConfig {
&self.config
}
pub fn throttle(&self) -> Option<&ConcurrencyController> {
self.throttle.as_deref()
}
pub fn fallback(&self) -> Option<&FallbackChain> {
self.fallback.as_deref()
}
pub async fn complete(&self, system: &str, user: &str) -> LlmResult<String> {
self.execute_with_context(system, user, None).await
}
pub async fn complete_with_max_tokens(
&self,
system: &str,
user: &str,
max_tokens: u16,
) -> LlmResult<String> {
self.execute_with_context(system, user, Some(max_tokens))
.await
}
async fn execute_with_context(
&self,
system: &str,
user: &str,
max_tokens: Option<u16>,
) -> LlmResult<String> {
let mut attempts = 0;
let mut current_model = self.config.model.clone();
let current_endpoint = self.config.endpoint.clone();
let mut fallback_history: Vec<FallbackStep> = vec![];
let mut total_attempts_including_fallback = 0;
loop {
attempts += 1;
total_attempts_including_fallback += 1;
const MAX_TOTAL_ATTEMPTS: usize = 20;
if total_attempts_including_fallback > MAX_TOTAL_ATTEMPTS {
warn!(
total_attempts = total_attempts_including_fallback,
"Exceeded maximum total attempts, aborting"
);
return Err(LlmError::RetryExhausted {
attempts: total_attempts_including_fallback,
last_error: "Exceeded maximum total attempts including fallbacks".to_string(),
});
}
let _permit = self.acquire_throttle_permit().await;
debug!(
attempt = attempts,
model = %current_model,
endpoint = %current_endpoint,
"Executing LLM request"
);
let result = self
.do_request(¤t_model, ¤t_endpoint, system, user, max_tokens)
.await;
match result {
Ok(response) => {
if fallback_history.is_empty() {
debug!(
attempts = attempts,
"LLM request succeeded without fallback"
);
} else {
info!(
attempts = attempts,
fallback_steps = fallback_history.len(),
"LLM request succeeded after fallback"
);
}
return Ok(response);
}
Err(error) => {
if self.should_retry(&error, attempts) {
let delay = self.retry_delay(attempts);
warn!(
attempt = attempts,
max_attempts = self.config.retry.max_attempts,
delay_ms = delay.as_millis() as u64,
error = %error,
"LLM call failed, retrying..."
);
tokio::time::sleep(delay).await;
continue;
}
if let Some(ref fallback) = self.fallback {
if fallback.should_fallback(&error) {
let mut fell_back = false;
if let Some(next_model) = fallback.next_model(¤t_model) {
info!(
from_model = %current_model,
to_model = %next_model,
"Falling back to next model"
);
fallback.record_fallback(
&mut fallback_history,
current_model.clone(),
Some(next_model.clone()),
current_endpoint.clone(),
None,
error.to_string(),
);
current_model = next_model;
attempts = 0; fell_back = true;
}
if fell_back {
continue;
}
}
}
warn!(
attempts = attempts,
fallback_steps = fallback_history.len(),
error = %error,
"LLM call failed, no more retries or fallbacks available"
);
return Err(error);
}
}
}
}
async fn acquire_throttle_permit(&self) -> Option<tokio::sync::SemaphorePermit<'_>> {
if let Some(ref throttle) = self.throttle {
throttle.acquire().await
} else {
None
}
}
fn should_retry(&self, error: &LlmError, attempts: usize) -> bool {
if attempts >= self.config.retry.max_attempts {
return false;
}
match error {
LlmError::RateLimit(_) => self.config.retry.retry_on_rate_limit,
LlmError::Timeout(_) => true,
LlmError::Api(msg) => {
let msg_lower = msg.to_lowercase();
msg_lower.contains("rate limit")
|| msg_lower.contains("429")
|| msg_lower.contains("503")
|| msg_lower.contains("502")
|| msg_lower.contains("timeout")
|| msg_lower.contains("overloaded")
}
_ => false,
}
}
fn retry_delay(&self, attempt: usize) -> Duration {
self.config.retry.delay_for_attempt(attempt - 1)
}
async fn do_request(
&self,
model: &str,
endpoint: &str,
system: &str,
user: &str,
max_tokens: Option<u16>,
) -> LlmResult<String> {
use async_openai::{
Client,
config::OpenAIConfig,
types::chat::{
ChatCompletionRequestSystemMessage, ChatCompletionRequestUserMessage,
CreateChatCompletionRequestArgs,
},
};
let api_key = self.config.api_key.clone().ok_or_else(|| {
LlmError::Config(
"No API key configured. Call .with_key(\"sk-...\") when building the engine.".to_string(),
)
})?;
let openai_config = OpenAIConfig::new()
.with_api_key(api_key)
.with_api_base(endpoint);
let client = Client::with_config(openai_config);
let truncated = self.truncate_prompt(user);
let request = if let Some(tokens) = max_tokens {
CreateChatCompletionRequestArgs::default()
.model(model)
.messages([
ChatCompletionRequestSystemMessage::from(system).into(),
ChatCompletionRequestUserMessage::from(truncated).into(),
])
.temperature(self.config.temperature)
.build()
} else {
CreateChatCompletionRequestArgs::default()
.model(model)
.messages([
ChatCompletionRequestSystemMessage::from(system).into(),
ChatCompletionRequestUserMessage::from(truncated).into(),
])
.temperature(self.config.temperature)
.build()
};
let request =
request.map_err(|e| LlmError::Request(format!("Failed to build request: {}", e)))?;
info!(
"LLM request → endpoint: {}, model: {}, system: {} chars, user: {} chars",
endpoint,
model,
system.len(),
truncated.len()
);
let request_start = std::time::Instant::now();
let response = client.chat().create(request).await.map_err(|e| {
let msg = e.to_string();
LlmError::from_api_message(&msg)
})?;
let request_elapsed = request_start.elapsed();
let usage = response.usage.as_ref();
let prompt_tokens = usage.map(|u| u.prompt_tokens).unwrap_or(0);
let completion_tokens = usage.map(|u| u.completion_tokens).unwrap_or(0);
let content = response
.choices
.first()
.and_then(|choice| choice.message.content.clone())
.ok_or(LlmError::NoContent)?;
info!(
"LLM response ← {}ms, tokens: {} prompt + {} completion, content: {} chars",
request_elapsed.as_millis(),
prompt_tokens,
completion_tokens,
content.len()
);
Ok(content)
}
fn truncate_prompt<'a>(&self, text: &'a str) -> &'a str {
const MAX_CHARS: usize = 30000;
if text.len() > MAX_CHARS {
&text[..MAX_CHARS]
} else {
text
}
}
}
impl Default for LlmExecutor {
fn default() -> Self {
Self::with_defaults()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_executor_creation() {
let executor = LlmExecutor::for_model("gpt-4o");
assert_eq!(executor.config().model, "gpt-4o");
assert!(executor.throttle().is_none());
assert!(executor.fallback().is_none());
}
#[test]
fn test_executor_with_throttle() {
use crate::throttle::ConcurrencyConfig;
let controller = ConcurrencyController::new(ConcurrencyConfig::conservative());
let executor = LlmExecutor::for_model("gpt-4o-mini").with_throttle(controller);
assert!(executor.throttle().is_some());
}
#[test]
fn test_should_retry() {
let executor = LlmExecutor::with_defaults();
assert!(executor.should_retry(&LlmError::Timeout("test".to_string()), 1));
assert!(executor.should_retry(&LlmError::RateLimit("test".to_string()), 1));
assert!(!executor.should_retry(&LlmError::Config("test".to_string()), 1));
assert!(!executor.should_retry(&LlmError::Timeout("test".to_string()), 100));
}
#[test]
fn test_retry_delay() {
let executor = LlmExecutor::with_defaults();
let delay = executor.retry_delay(1);
assert_eq!(delay, Duration::from_millis(500));
}
}