use std::sync::Arc;
use super::client::LlmClient;
use super::config::LlmConfigs;
use crate::throttle::ConcurrencyController;
#[derive(Debug, Clone)]
pub struct LlmPool {
summary: Arc<LlmClient>,
retrieval: Arc<LlmClient>,
toc: Arc<LlmClient>,
concurrency: Option<Arc<ConcurrencyController>>,
}
impl LlmPool {
pub fn new(configs: LlmConfigs) -> Self {
Self {
summary: Arc::new(LlmClient::new(configs.summary)),
retrieval: Arc::new(LlmClient::new(configs.retrieval)),
toc: Arc::new(LlmClient::new(configs.toc)),
concurrency: None,
}
}
pub fn from_defaults() -> Self {
Self::new(LlmConfigs::default())
}
pub fn with_concurrency(mut self, controller: ConcurrencyController) -> Self {
let arc = Arc::new(controller);
self.concurrency = Some(arc.clone());
self.summary = Arc::new(
LlmClient::new(self.summary.config().clone()).with_shared_concurrency(arc.clone()),
);
self.retrieval = Arc::new(
LlmClient::new(self.retrieval.config().clone()).with_shared_concurrency(arc.clone()),
);
self.toc = Arc::new(
LlmClient::new(self.toc.config().clone()).with_shared_concurrency(arc.clone()),
);
self
}
pub fn with_shared_concurrency(mut self, controller: Arc<ConcurrencyController>) -> Self {
self.concurrency = Some(controller.clone());
self.summary = Arc::new(
LlmClient::new(self.summary.config().clone())
.with_shared_concurrency(controller.clone()),
);
self.retrieval = Arc::new(
LlmClient::new(self.retrieval.config().clone())
.with_shared_concurrency(controller.clone()),
);
self.toc = Arc::new(
LlmClient::new(self.toc.config().clone()).with_shared_concurrency(controller.clone()),
);
self
}
pub fn concurrency(&self) -> Option<&ConcurrencyController> {
self.concurrency.as_deref()
}
pub fn summary(&self) -> &LlmClient {
&self.summary
}
pub fn retrieval(&self) -> &LlmClient {
&self.retrieval
}
pub fn toc(&self) -> &LlmClient {
&self.toc
}
pub fn get(&self, purpose: &str) -> Option<&LlmClient> {
match purpose {
"summary" | "summarize" => Some(&self.summary),
"retrieval" | "retrieve" | "navigate" => Some(&self.retrieval),
"toc" => Some(&self.toc),
_ => None,
}
}
pub fn single_model(model: impl Into<String>) -> Self {
let config = super::config::LlmConfig::new(model);
let client = Arc::new(LlmClient::new(config));
Self {
summary: client.clone(),
retrieval: client.clone(),
toc: client,
concurrency: None,
}
}
}
impl Default for LlmPool {
fn default() -> Self {
Self::from_defaults()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pool_creation() {
let pool = LlmPool::from_defaults();
assert!(pool.get("summary").is_some());
assert!(pool.get("retrieval").is_some());
assert!(pool.get("toc").is_some());
assert!(pool.get("unknown").is_none());
}
#[test]
fn test_pool_get_aliases() {
let pool = LlmPool::from_defaults();
assert!(pool.get("summarize").is_some());
assert!(pool.get("retrieve").is_some());
assert!(pool.get("navigate").is_some());
}
#[test]
fn test_single_model_pool() {
let pool = LlmPool::single_model("gpt-4o-mini");
assert_eq!(pool.summary().config().model, "gpt-4o-mini");
assert_eq!(pool.retrieval().config().model, "gpt-4o-mini");
assert_eq!(pool.toc().config().model, "gpt-4o-mini");
}
#[test]
fn test_pool_with_concurrency() {
use crate::throttle::ConcurrencyConfig;
let controller = ConcurrencyController::new(ConcurrencyConfig::conservative());
let pool = LlmPool::from_defaults().with_concurrency(controller);
assert!(pool.concurrency().is_some());
assert!(pool.summary().concurrency().is_some());
assert!(pool.retrieval().concurrency().is_some());
assert!(pool.toc().concurrency().is_some());
}
}