use std::sync::Arc;
use super::client::LlmClient;
use super::config::LlmConfig;
use super::fallback::{FallbackChain, FallbackConfig};
use super::throttle::ConcurrencyController;
use crate::metrics::MetricsHub;
#[derive(Debug, Clone)]
pub struct LlmPool {
index: Arc<LlmClient>,
retrieval: Arc<LlmClient>,
}
impl LlmPool {
pub fn from_config(
config: &crate::config::LlmConfig,
metrics: Option<Arc<MetricsHub>>,
) -> Self {
let api_key = config.api_key.clone();
let endpoint = config.endpoint.clone().unwrap_or_default();
let retry = config.retry.to_runtime_config();
let make_config = |slot: &crate::config::SlotConfig| -> LlmConfig {
LlmConfig {
model: config.resolve_model(slot),
endpoint: endpoint.clone(),
api_key: api_key.clone(),
max_tokens: slot.max_tokens,
temperature: slot.temperature,
retry: retry.clone(),
request_timeout_secs: 0,
}
};
let openai_base = if endpoint.is_empty() {
"https://api.openai.com/v1".to_string()
} else {
endpoint.clone()
};
let openai_client = Arc::new(async_openai::Client::with_config(
async_openai::config::OpenAIConfig::new()
.with_api_key(api_key.clone().unwrap_or_default())
.with_api_base(openai_base),
));
let concurrency_config = config.throttle.to_runtime_config();
let controller = Arc::new(ConcurrencyController::new(concurrency_config));
let fallback_config: FallbackConfig = config.fallback.clone().into();
let fallback_chain = Arc::new(FallbackChain::new(fallback_config));
let build_client = |slot_config: &crate::config::SlotConfig| {
let mut client = LlmClient::new(make_config(slot_config))
.with_shared_concurrency(controller.clone())
.with_shared_openai_client(openai_client.clone())
.with_shared_fallback(fallback_chain.clone());
if let Some(ref hub) = metrics {
client = client.with_shared_metrics(hub.clone());
}
Arc::new(client)
};
Self {
index: build_client(&config.index),
retrieval: build_client(&config.retrieval),
}
}
pub fn from_defaults() -> Self {
Self::from_config(&crate::config::LlmConfig::default(), None)
}
pub fn index(&self) -> &LlmClient {
&self.index
}
pub fn retrieval(&self) -> &LlmClient {
&self.retrieval
}
}
impl Default for LlmPool {
fn default() -> Self {
Self::from_defaults()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pool_from_config() {
let config = crate::config::LlmConfig::new("gpt-4o")
.with_api_key("sk-test")
.with_endpoint("https://api.openai.com/v1")
.with_index(crate::config::SlotConfig::fast().with_model("gpt-4o-mini"));
let pool = LlmPool::from_config(&config, None);
assert_eq!(pool.index().config().model, "gpt-4o-mini");
assert_eq!(pool.retrieval().config().model, "gpt-4o");
assert_eq!(pool.index().config().max_tokens, 100);
}
#[test]
fn test_pool_from_config_with_metrics() {
let config = crate::config::LlmConfig::new("gpt-4o")
.with_api_key("sk-test")
.with_endpoint("https://api.openai.com/v1");
let hub = MetricsHub::shared();
let pool = LlmPool::from_config(&config, Some(hub.clone()));
assert!(pool.index().fallback().is_some());
assert!(pool.retrieval().fallback().is_some());
assert_eq!(pool.index().config().model, "gpt-4o");
assert_eq!(pool.retrieval().config().model, "gpt-4o");
}
#[test]
fn test_pool_shared_metrics_hub() {
let config = crate::config::LlmConfig::new("gpt-4o")
.with_api_key("sk-test")
.with_endpoint("https://api.openai.com/v1");
let hub = MetricsHub::shared();
let _pool = LlmPool::from_config(&config, Some(hub.clone()));
assert!(Arc::strong_count(&hub) > 1);
}
}