use std::sync::Arc;
use super::client::LlmClient;
use super::config::LlmConfigs;
use crate::throttle::ConcurrencyController;
#[derive(Debug, Clone)]
pub struct LlmPool {
index: Arc<LlmClient>,
retrieval: Arc<LlmClient>,
pilot: Arc<LlmClient>,
concurrency: Option<Arc<ConcurrencyController>>,
}
impl LlmPool {
pub fn new(configs: LlmConfigs) -> Self {
Self {
index: Arc::new(LlmClient::new(configs.index)),
retrieval: Arc::new(LlmClient::new(configs.retrieval)),
pilot: Arc::new(LlmClient::new(configs.pilot)),
concurrency: None,
}
}
pub fn from_defaults() -> Self {
Self::new(LlmConfigs::default())
}
pub fn with_concurrency(mut self, controller: ConcurrencyController) -> Self {
let arc = Arc::new(controller);
self.concurrency = Some(arc.clone());
self.index = Arc::new(
LlmClient::new(self.index.config().clone()).with_shared_concurrency(arc.clone()),
);
self.retrieval = Arc::new(
LlmClient::new(self.retrieval.config().clone()).with_shared_concurrency(arc.clone()),
);
self.pilot = Arc::new(
LlmClient::new(self.pilot.config().clone()).with_shared_concurrency(arc.clone()),
);
self
}
pub fn with_shared_concurrency(mut self, controller: Arc<ConcurrencyController>) -> Self {
self.concurrency = Some(controller.clone());
self.index = Arc::new(
LlmClient::new(self.index.config().clone()).with_shared_concurrency(controller.clone()),
);
self.retrieval = Arc::new(
LlmClient::new(self.retrieval.config().clone())
.with_shared_concurrency(controller.clone()),
);
self.pilot = Arc::new(
LlmClient::new(self.pilot.config().clone()).with_shared_concurrency(controller.clone()),
);
self
}
pub fn concurrency(&self) -> Option<&ConcurrencyController> {
self.concurrency.as_deref()
}
pub fn index(&self) -> &LlmClient {
&self.index
}
pub fn retrieval(&self) -> &LlmClient {
&self.retrieval
}
pub fn pilot(&self) -> &LlmClient {
&self.pilot
}
pub fn get(&self, purpose: &str) -> Option<&LlmClient> {
match purpose {
"index" | "summary" | "summarize" => Some(&self.index),
"retrieval" | "retrieve" | "navigate" => Some(&self.retrieval),
"pilot" => Some(&self.pilot),
_ => None,
}
}
}
impl Default for LlmPool {
fn default() -> Self {
Self::from_defaults()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pool_creation() {
let pool = LlmPool::from_defaults();
assert!(pool.get("index").is_some());
assert!(pool.get("retrieval").is_some());
assert!(pool.get("pilot").is_some());
assert!(pool.get("unknown").is_none());
}
#[test]
fn test_pool_get_aliases() {
let pool = LlmPool::from_defaults();
assert!(pool.get("summary").is_some());
assert!(pool.get("summarize").is_some());
assert!(pool.get("retrieve").is_some());
assert!(pool.get("navigate").is_some());
}
#[test]
fn test_pool_with_concurrency() {
use crate::throttle::ConcurrencyConfig;
let controller = ConcurrencyController::new(ConcurrencyConfig::conservative());
let pool = LlmPool::from_defaults().with_concurrency(controller);
assert!(pool.concurrency().is_some());
assert!(pool.index().concurrency().is_some());
assert!(pool.retrieval().concurrency().is_some());
assert!(pool.pilot().concurrency().is_some());
}
}