use crate::config::{Config, ConfigLoader, RetrievalConfig};
use crate::memo::MemoStore;
use crate::retrieval::PipelineRetriever;
use crate::storage::Workspace;
use super::engine::Engine;
use crate::events::EventEmitter;
#[derive(Debug)]
pub struct EngineBuilder {
config_path: Option<std::path::PathBuf>,
config: Option<Config>,
retrieval_config: Option<RetrievalConfig>,
events: Option<EventEmitter>,
api_key: Option<String>,
model: Option<String>,
endpoint: Option<String>,
top_k: Option<usize>,
fast_mode: bool,
precise_mode: bool,
memo_store: Option<MemoStore>,
}
impl EngineBuilder {
#[must_use]
pub fn new() -> Self {
Self {
config_path: None,
config: None,
retrieval_config: None,
events: None,
api_key: None,
model: None,
endpoint: None,
top_k: None,
fast_mode: false,
precise_mode: false,
memo_store: None,
}
}
#[must_use]
pub fn with_config_path(mut self, path: impl Into<std::path::PathBuf>) -> Self {
self.config_path = Some(path.into());
self
}
#[must_use]
pub fn with_config(mut self, config: Config) -> Self {
self.config = Some(config);
self
}
#[must_use]
pub fn with_retrieval_config(mut self, config: RetrievalConfig) -> Self {
self.retrieval_config = Some(config);
self
}
#[must_use]
pub fn with_events(mut self, events: EventEmitter) -> Self {
self.events = Some(events);
self
}
#[must_use]
pub fn with_memo_store(mut self, store: MemoStore) -> Self {
self.memo_store = Some(store);
self
}
#[must_use]
pub fn with_key(mut self, key: impl Into<String>) -> Self {
self.api_key = Some(key.into());
self
}
#[must_use]
pub fn with_model(mut self, model: impl Into<String>) -> Self {
self.model = Some(model.into());
self
}
#[must_use]
pub fn with_endpoint(mut self, url: impl Into<String>) -> Self {
self.endpoint = Some(url.into());
self
}
#[must_use]
pub fn with_top_k(mut self, k: usize) -> Self {
self.top_k = Some(k);
self
}
#[must_use]
pub fn fast(mut self) -> Self {
self.fast_mode = true;
self.precise_mode = false;
self
}
#[must_use]
pub fn precise(mut self) -> Self {
self.precise_mode = true;
self.fast_mode = false;
self
}
pub async fn build(self) -> Result<Engine, BuildError> {
let mut config = if let Some(config) = self.config {
config
} else if let Some(path) = self.config_path {
ConfigLoader::new()
.file(&path)
.load()
.map_err(|e| BuildError::Config(e.to_string()))?
} else {
Config::default()
};
if let Some(retrieval_config) = self.retrieval_config {
config.retrieval = retrieval_config;
}
if let Some(api_key) = self.api_key {
config.llm.api_key = Some(api_key.clone());
config.retrieval.api_key = Some(api_key.clone());
config.summary.api_key = Some(api_key);
}
if let Some(model) = self.model {
if config.llm.index.model.is_empty() {
config.llm.index.model = model.clone();
}
if config.llm.retrieval.model.is_empty() {
config.llm.retrieval.model = model.clone();
}
if config.llm.pilot.model.is_empty() {
config.llm.pilot.model = model.clone();
}
config.retrieval.model = model.clone();
config.summary.model = model;
}
if let Some(endpoint) = self.endpoint {
config.llm.endpoint = Some(endpoint.clone());
config.retrieval.endpoint = endpoint.clone();
config.summary.endpoint = endpoint;
}
if let Some(top_k) = self.top_k {
config.retrieval.top_k = top_k;
}
if self.fast_mode {
config.retrieval.search.max_iterations = 5;
}
if self.precise_mode {
config.retrieval.search.max_iterations = 100;
}
let resolved_key = config
.llm
.api_key
.as_ref()
.or_else(|| config.llm.retrieval.api_key.as_ref())
.or_else(|| config.summary.api_key.as_ref())
.or_else(|| config.retrieval.api_key.as_ref());
if resolved_key.is_none() {
return Err(BuildError::MissingApiKey);
}
let retrieval_model = if config.llm.retrieval.model.is_empty() {
&config.retrieval.model
} else {
&config.llm.retrieval.model
};
if retrieval_model.is_empty() {
return Err(BuildError::MissingModel);
}
let workspace = Workspace::new(&config.storage.workspace_dir)
.await
.map_err(|e| BuildError::Workspace(e.to_string()))?;
let llm_configs: crate::llm::LlmConfigs = config.llm.clone().into();
let pool = {
let controller = crate::throttle::ConcurrencyController::new(
crate::throttle::ConcurrencyConfig::new()
.with_max_concurrent_requests(config.concurrency.max_concurrent_requests)
.with_requests_per_minute(config.concurrency.requests_per_minute)
.with_enabled(config.concurrency.enabled),
);
crate::llm::LlmPool::new(llm_configs).with_concurrency(controller)
};
let indexer =
crate::client::indexer::IndexerClient::with_llm(pool.index().clone());
let retrieval_config = config.retrieval.clone();
let mut retriever =
PipelineRetriever::new().with_max_iterations(retrieval_config.search.max_iterations);
retriever = retriever.with_llm_client(pool.retrieval().clone());
if retrieval_config.content.enabled {
retriever =
retriever.with_content_config(retrieval_config.content.to_aggregator_config());
}
if let Some(memo_store) = self.memo_store {
retriever = retriever.with_memo_store(memo_store);
} else {
let memo_store = MemoStore::new()
.with_model(retrieval_model)
.with_version(1);
retriever = retriever.with_memo_store(memo_store);
}
let events = self.events.unwrap_or_default();
Engine::with_components(config, workspace, retriever, indexer, events)
.await
.map_err(|e| BuildError::Other(e.to_string()))
}
}
impl Default for EngineBuilder {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, thiserror::Error)]
pub enum BuildError {
#[error("Configuration error: {0}")]
Config(String),
#[error("Workspace error: {0}")]
Workspace(String),
#[error("Missing API key: call .with_key(\"sk-...\") or set api_key in config file")]
MissingApiKey,
#[error("Missing model: call .with_model(\"gpt-4o\") or set model in config file")]
MissingModel,
#[error("{0}")]
Other(String),
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_builder_defaults() {
let builder = EngineBuilder::new();
assert!(!builder.fast_mode);
assert!(!builder.precise_mode);
}
#[test]
fn test_builder_with_key() {
let builder = EngineBuilder::new().with_key("sk-test-key");
assert_eq!(builder.api_key, Some("sk-test-key".to_string()));
}
#[test]
fn test_builder_with_model() {
let builder = EngineBuilder::new().with_model("gpt-4o-mini");
assert_eq!(builder.model, Some("gpt-4o-mini".to_string()));
}
#[test]
fn test_builder_with_key_and_model() {
let builder = EngineBuilder::new()
.with_model("gpt-4o-mini")
.with_key("sk-test");
assert_eq!(builder.model, Some("gpt-4o-mini".to_string()));
assert_eq!(builder.api_key, Some("sk-test".to_string()));
}
#[test]
fn test_builder_fast_mode() {
let builder = EngineBuilder::new().fast();
assert!(builder.fast_mode);
assert!(!builder.precise_mode);
}
#[test]
fn test_builder_precise_mode() {
let builder = EngineBuilder::new().precise();
assert!(builder.precise_mode);
assert!(!builder.fast_mode);
}
#[test]
fn test_builder_top_k() {
let builder = EngineBuilder::new().with_top_k(10);
assert_eq!(builder.top_k, Some(10));
}
}