use std::sync::Arc;
use std::time::Duration;
use tracing::{debug, info, warn};
use async_openai::types::chat::{
ChatCompletionRequestSystemMessage, ChatCompletionRequestUserMessage,
CreateChatCompletionRequestArgs,
};
use super::config::LlmConfig;
use super::error::{LlmError, LlmResult};
use super::fallback::{FallbackChain, FallbackStep};
use super::throttle::ConcurrencyController;
use crate::metrics::MetricsHub;
#[derive(Clone)]
pub struct LlmExecutor {
config: LlmConfig,
openai_client: Arc<async_openai::Client<async_openai::config::OpenAIConfig>>,
throttle: Option<Arc<ConcurrencyController>>,
fallback: Option<Arc<FallbackChain>>,
metrics: Option<Arc<MetricsHub>>,
}
impl std::fmt::Debug for LlmExecutor {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("LlmExecutor")
.field("model", &self.config.model)
.field("endpoint", &self.config.endpoint)
.field("has_throttle", &self.throttle.is_some())
.field("has_fallback", &self.fallback.is_some())
.field("has_openai_client", &true)
.field("has_metrics", &self.metrics.is_some())
.finish()
}
}
impl LlmExecutor {
pub fn new(config: LlmConfig) -> Self {
let openai_client = Self::build_openai_client(&config);
Self {
config,
openai_client: Arc::new(openai_client),
throttle: None,
fallback: None,
metrics: None,
}
}
fn build_openai_client(
config: &LlmConfig,
) -> async_openai::Client<async_openai::config::OpenAIConfig> {
let api_key = config.api_key.clone().unwrap_or_default();
let endpoint = if config.endpoint.is_empty() {
"https://api.openai.com/v1".to_string()
} else {
config.endpoint.clone()
};
let openai_config = async_openai::config::OpenAIConfig::new()
.with_api_key(api_key)
.with_api_base(endpoint);
async_openai::Client::with_config(openai_config)
}
pub fn with_defaults() -> Self {
Self::new(LlmConfig::default())
}
pub fn for_model(model: impl Into<String>) -> Self {
Self::new(LlmConfig::new(model))
}
pub fn with_throttle(mut self, controller: ConcurrencyController) -> Self {
self.throttle = Some(Arc::new(controller));
self
}
pub fn with_shared_throttle(mut self, controller: Arc<ConcurrencyController>) -> Self {
self.throttle = Some(controller);
self
}
pub fn with_fallback(mut self, chain: FallbackChain) -> Self {
self.fallback = Some(Arc::new(chain));
self
}
pub fn with_shared_fallback(mut self, chain: Arc<FallbackChain>) -> Self {
self.fallback = Some(chain);
self
}
pub fn with_shared_metrics(mut self, hub: Arc<MetricsHub>) -> Self {
self.metrics = Some(hub);
self
}
pub fn with_openai_client(
mut self,
client: Arc<async_openai::Client<async_openai::config::OpenAIConfig>>,
) -> Self {
self.openai_client = client;
self
}
pub fn config(&self) -> &LlmConfig {
&self.config
}
pub fn throttle(&self) -> Option<&ConcurrencyController> {
self.throttle.as_deref()
}
pub fn fallback(&self) -> Option<&FallbackChain> {
self.fallback.as_deref()
}
pub async fn complete(&self, system: &str, user: &str) -> LlmResult<String> {
self.execute_with_context(system, user, None).await
}
pub async fn complete_with_max_tokens(
&self,
system: &str,
user: &str,
max_tokens: u16,
) -> LlmResult<String> {
self.execute_with_context(system, user, Some(max_tokens))
.await
}
async fn execute_with_context(
&self,
system: &str,
user: &str,
max_tokens: Option<u16>,
) -> LlmResult<String> {
let mut attempts = 0;
let mut current_model = self.config.model.clone();
let mut fallback_history: Vec<FallbackStep> = vec![];
let mut total_attempts_including_fallback = 0;
loop {
attempts += 1;
total_attempts_including_fallback += 1;
const MAX_TOTAL_ATTEMPTS: usize = 20;
if total_attempts_including_fallback > MAX_TOTAL_ATTEMPTS {
warn!(
total_attempts = total_attempts_including_fallback,
"Exceeded maximum total attempts, aborting"
);
return Err(LlmError::RetryExhausted {
attempts: total_attempts_including_fallback,
last_error: "Exceeded maximum total attempts including fallbacks".to_string(),
});
}
let _permit = self.acquire_throttle_permit().await;
debug!(
attempt = attempts,
model = %current_model,
"Executing LLM request"
);
let request_future = self.do_request(¤t_model, system, user, max_tokens);
let result = if self.config.request_timeout_secs > 0 {
let timeout = Duration::from_secs(self.config.request_timeout_secs);
match tokio::time::timeout(timeout, request_future).await {
Ok(r) => r,
Err(_) => {
warn!(
timeout_secs = self.config.request_timeout_secs,
model = %current_model,
"LLM request timed out"
);
if let Some(ref metrics) = self.metrics {
metrics.record_llm_timeout();
}
Err(LlmError::Timeout(format!(
"Request timed out after {}s",
self.config.request_timeout_secs
)))
}
}
} else {
request_future.await
};
match result {
Ok(response) => {
if fallback_history.is_empty() {
debug!(
attempts = attempts,
"LLM request succeeded without fallback"
);
} else {
info!(
attempts = attempts,
fallback_steps = fallback_history.len(),
"LLM request succeeded after fallback"
);
}
return Ok(response);
}
Err(error) => {
if let Some(ref metrics) = self.metrics {
match &error {
LlmError::RateLimit(_) => metrics.record_llm_rate_limit(),
LlmError::Timeout(_) => metrics.record_llm_timeout(),
_ => {}
}
}
if self.should_retry(&error, attempts) {
let delay = self.retry_delay(attempts);
warn!(
attempt = attempts,
max_attempts = self.config.retry.max_attempts,
delay_ms = delay.as_millis() as u64,
error = %error,
"LLM call failed, retrying..."
);
tokio::time::sleep(delay).await;
continue;
}
if let Some(ref fallback) = self.fallback {
if fallback.should_fallback(&error) {
let mut fell_back = false;
if let Some(next_model) = fallback.next_model(¤t_model) {
info!(
from_model = %current_model,
to_model = %next_model,
"Falling back to next model"
);
if let Some(ref metrics) = self.metrics {
metrics.record_llm_fallback();
}
fallback.record_fallback(
&mut fallback_history,
current_model.clone(),
Some(next_model.clone()),
self.config.endpoint.clone(),
None,
error.to_string(),
);
current_model = next_model;
attempts = 0; fell_back = true;
}
if fell_back {
continue;
}
}
}
warn!(
attempts = attempts,
fallback_steps = fallback_history.len(),
error = %error,
"LLM call failed, no more retries or fallbacks available"
);
return Err(error);
}
}
}
}
async fn acquire_throttle_permit(&self) -> Option<tokio::sync::SemaphorePermit<'_>> {
if let Some(ref throttle) = self.throttle {
throttle.acquire().await
} else {
None
}
}
fn should_retry(&self, error: &LlmError, attempts: usize) -> bool {
if attempts >= self.config.retry.max_attempts {
return false;
}
if matches!(error, LlmError::RateLimit(_)) {
self.config.retry.retry_on_rate_limit
} else {
error.is_retryable()
}
}
fn retry_delay(&self, attempt: usize) -> Duration {
self.config.retry.delay_for_attempt(attempt - 1)
}
async fn do_request(
&self,
model: &str,
system: &str,
user: &str,
_max_tokens: Option<u16>,
) -> LlmResult<String> {
let request = CreateChatCompletionRequestArgs::default()
.model(model)
.messages([
ChatCompletionRequestSystemMessage::from(system).into(),
ChatCompletionRequestUserMessage::from(user).into(),
])
.temperature(self.config.temperature)
.build()
.map_err(|e| LlmError::Request(format!("Failed to build request: {}", e)))?;
info!(
"LLM request → endpoint: {}, model: {}, system: {} chars, user: {} chars",
self.config.endpoint,
model,
system.len(),
user.len()
);
let request_start = std::time::Instant::now();
let response = match self.openai_client.chat().create(request).await {
Ok(r) => r,
Err(e) => {
let elapsed = request_start.elapsed();
if let Some(ref metrics) = self.metrics {
metrics.record_llm_call(0, 0, elapsed.as_millis() as u64, false);
}
let msg = e.to_string();
return Err(LlmError::from_api_message(&msg));
}
};
let request_elapsed = request_start.elapsed();
let usage = response.usage.as_ref();
let prompt_tokens = usage.map(|u| u.prompt_tokens).unwrap_or(0);
let completion_tokens = usage.map(|u| u.completion_tokens).unwrap_or(0);
let first_choice = response.choices.first();
if first_choice.is_none() {
if let Some(ref metrics) = self.metrics {
metrics.record_llm_call(
prompt_tokens as u64,
completion_tokens as u64,
request_elapsed.as_millis() as u64,
false,
);
}
return Err(LlmError::NoContent);
}
let choice = first_choice.unwrap();
let content = choice.message.content.clone().unwrap_or_default();
if content.is_empty() {
let has_tool_calls = choice
.message
.tool_calls
.as_ref()
.map_or(false, |t| !t.is_empty());
let finish_reason = format!("{:?}", choice.finish_reason);
warn!(
elapsed_ms = request_elapsed.as_millis(),
prompt_tokens,
completion_tokens,
has_tool_calls,
finish_reason,
"LLM returned empty content field"
);
}
if let Some(ref metrics) = self.metrics {
metrics.record_llm_call(
prompt_tokens as u64,
completion_tokens as u64,
request_elapsed.as_millis() as u64,
true,
);
}
if content.is_empty() {
warn!(
elapsed_ms = request_elapsed.as_millis(),
prompt_tokens, completion_tokens, "LLM returned empty response"
);
} else {
info!(
"LLM response ← {}ms, tokens: {} prompt + {} completion, content: {} chars",
request_elapsed.as_millis(),
prompt_tokens,
completion_tokens,
content.len()
);
}
Ok(content)
}
}
impl Default for LlmExecutor {
fn default() -> Self {
Self::with_defaults()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_executor_creation() {
let executor = LlmExecutor::for_model("gpt-4o");
assert_eq!(executor.config().model, "gpt-4o");
assert!(executor.throttle().is_none());
assert!(executor.fallback().is_none());
}
#[test]
fn test_executor_with_throttle() {
use crate::llm::throttle::ConcurrencyConfig;
let controller = ConcurrencyController::new(ConcurrencyConfig::conservative());
let executor = LlmExecutor::for_model("gpt-4o-mini").with_throttle(controller);
assert!(executor.throttle().is_some());
}
#[test]
fn test_should_retry() {
let executor = LlmExecutor::with_defaults();
assert!(executor.should_retry(&LlmError::Timeout("test".to_string()), 1));
assert!(executor.should_retry(&LlmError::RateLimit("test".to_string()), 1));
assert!(!executor.should_retry(&LlmError::Config("test".to_string()), 1));
assert!(!executor.should_retry(&LlmError::Timeout("test".to_string()), 100));
}
#[test]
fn test_retry_delay() {
let executor = LlmExecutor::with_defaults();
let delay = executor.retry_delay(1);
assert_eq!(delay, Duration::from_millis(500));
}
#[test]
fn test_executor_with_metrics() {
let hub = MetricsHub::shared();
let executor = LlmExecutor::for_model("gpt-4o").with_shared_metrics(hub);
assert!(executor.metrics.is_some());
}
#[test]
fn test_executor_without_metrics() {
let executor = LlmExecutor::for_model("gpt-4o");
assert!(executor.metrics.is_none());
}
}