use std::path::PathBuf;
use crate::config::{Config, ConfigLoader, RetrievalConfig};
use crate::memo::MemoStore;
use crate::retrieval::PipelineRetriever;
use crate::storage::Workspace;
use super::engine::Engine;
use super::events::EventEmitter;
#[derive(Debug)]
pub struct EngineBuilder {
workspace: Option<PathBuf>,
config_path: Option<PathBuf>,
config: Option<Config>,
retrieval_config: Option<RetrievalConfig>,
events: Option<EventEmitter>,
api_key: Option<String>,
model: Option<String>,
endpoint: Option<String>,
top_k: Option<usize>,
fast_mode: bool,
precise_mode: bool,
memo_store: Option<MemoStore>,
}
impl EngineBuilder {
#[must_use]
pub fn new() -> Self {
Self {
workspace: None,
config_path: None,
config: None,
retrieval_config: None,
events: None,
api_key: None,
model: None,
endpoint: None,
top_k: None,
fast_mode: false,
precise_mode: false,
memo_store: None,
}
}
#[must_use]
pub fn with_workspace(mut self, path: impl Into<PathBuf>) -> Self {
self.workspace = Some(path.into());
self
}
#[must_use]
pub fn with_config_path(mut self, path: impl Into<PathBuf>) -> Self {
self.config_path = Some(path.into());
self
}
#[must_use]
pub fn with_config(mut self, config: Config) -> Self {
self.config = Some(config);
self
}
#[must_use]
pub fn with_retrieval_config(mut self, config: RetrievalConfig) -> Self {
self.retrieval_config = Some(config);
self
}
#[must_use]
pub fn with_events(mut self, events: EventEmitter) -> Self {
self.events = Some(events);
self
}
#[must_use]
pub fn with_memo_store(mut self, store: MemoStore) -> Self {
self.memo_store = Some(store);
self
}
#[must_use]
pub fn with_key(mut self, key: impl Into<String>) -> Self {
self.api_key = Some(key.into());
self
}
#[must_use]
pub fn with_model(mut self, model: impl Into<String>) -> Self {
self.model = Some(model.into());
self
}
#[must_use]
pub fn with_endpoint(mut self, url: impl Into<String>) -> Self {
self.endpoint = Some(url.into());
self
}
#[must_use]
pub fn with_top_k(mut self, k: usize) -> Self {
self.top_k = Some(k);
self
}
#[must_use]
pub fn fast(mut self) -> Self {
self.fast_mode = true;
self.precise_mode = false;
self
}
#[must_use]
pub fn precise(mut self) -> Self {
self.precise_mode = true;
self.fast_mode = false;
self
}
pub async fn build(self) -> Result<Engine, BuildError> {
let mut config = if let Some(config) = self.config {
config
} else if let Some(path) = self.config_path {
ConfigLoader::new()
.file(&path)
.load()
.map_err(|e| BuildError::Config(e.to_string()))?
} else {
Config::default()
};
if let Some(retrieval_config) = self.retrieval_config {
config.retrieval = retrieval_config;
}
if let Some(api_key) = self.api_key {
config.retrieval.api_key = Some(api_key.clone());
config.summary.api_key = Some(api_key);
if config.llm.summary.api_key.is_none() {
config.llm.summary.api_key = config.summary.api_key.clone();
}
if config.llm.retrieval.api_key.is_none() {
config.llm.retrieval.api_key = config.summary.api_key.clone();
}
}
if let Some(model) = self.model {
config.retrieval.model = model.clone();
config.summary.model = model;
}
if let Some(endpoint) = self.endpoint {
config.retrieval.endpoint = endpoint.clone();
config.summary.endpoint = endpoint;
}
if let Some(top_k) = self.top_k {
config.retrieval.top_k = top_k;
}
if self.fast_mode {
config.retrieval.search.max_iterations = 5;
}
if self.precise_mode {
config.retrieval.search.max_iterations = 100;
}
if config.summary.api_key.is_none() && config.retrieval.api_key.is_none() {
return Err(BuildError::MissingApiKey);
}
if config.retrieval.model.is_empty() {
return Err(BuildError::MissingModel);
}
let workspace_path = self
.workspace
.as_ref()
.unwrap_or(&config.storage.workspace_dir);
let workspace = Workspace::new(workspace_path)
.await
.map_err(|e| BuildError::Workspace(e.to_string()))?;
let indexer = if let Some(api_key) = config.summary.api_key.clone() {
let llm_config = crate::llm::LlmConfig::new(&config.summary.model)
.with_endpoint(config.summary.endpoint.clone())
.with_api_key(api_key)
.with_max_tokens(config.summary.max_tokens)
.with_temperature(config.summary.temperature);
let llm_client = crate::llm::LlmClient::new(llm_config);
crate::client::indexer::IndexerClient::with_llm(llm_client)
} else {
crate::client::indexer::IndexerClient::new(crate::index::PipelineExecutor::new())
};
let retrieval_config = config.retrieval.clone();
let mut retriever =
PipelineRetriever::new().with_max_iterations(retrieval_config.search.max_iterations);
let retrieval_api_key = retrieval_config
.api_key
.clone()
.or_else(|| config.summary.api_key.clone())
.ok_or(BuildError::MissingApiKey)?;
let llm_config = crate::llm::LlmConfig::new(&retrieval_config.model)
.with_endpoint(retrieval_config.endpoint.clone())
.with_api_key(retrieval_api_key)
.with_temperature(retrieval_config.temperature);
let llm_client = crate::llm::LlmClient::new(llm_config);
retriever = retriever.with_llm_client(llm_client);
if retrieval_config.content.enabled {
retriever =
retriever.with_content_config(retrieval_config.content.to_aggregator_config());
}
if let Some(memo_store) = self.memo_store {
retriever = retriever.with_memo_store(memo_store);
} else {
let memo_store = MemoStore::new()
.with_model(&retrieval_config.model)
.with_version(1);
retriever = retriever.with_memo_store(memo_store);
}
let events = self.events.unwrap_or_default();
Engine::with_components(config, workspace, retriever, indexer, events)
.await
.map_err(|e| BuildError::Other(e.to_string()))
}
}
impl Default for EngineBuilder {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, thiserror::Error)]
pub enum BuildError {
#[error("Configuration error: {0}")]
Config(String),
#[error("Workspace error: {0}")]
Workspace(String),
#[error("Missing API key: call .with_key(\"sk-...\") or set api_key in config file")]
MissingApiKey,
#[error("Missing model: call .with_model(\"gpt-4o\") or set model in config file")]
MissingModel,
#[error("{0}")]
Other(String),
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_builder_defaults() {
let builder = EngineBuilder::new();
assert!(builder.workspace.is_none());
assert!(!builder.fast_mode);
assert!(!builder.precise_mode);
}
#[test]
fn test_builder_with_workspace() {
let builder = EngineBuilder::new().with_workspace("./test_workspace");
assert_eq!(builder.workspace, Some(PathBuf::from("./test_workspace")));
}
#[test]
fn test_builder_with_key() {
let builder = EngineBuilder::new().with_key("sk-test-key");
assert_eq!(builder.api_key, Some("sk-test-key".to_string()));
}
#[test]
fn test_builder_with_model() {
let builder = EngineBuilder::new().with_model("gpt-4o-mini");
assert_eq!(builder.model, Some("gpt-4o-mini".to_string()));
}
#[test]
fn test_builder_with_key_and_model() {
let builder = EngineBuilder::new()
.with_model("gpt-4o-mini")
.with_key("sk-test");
assert_eq!(builder.model, Some("gpt-4o-mini".to_string()));
assert_eq!(builder.api_key, Some("sk-test".to_string()));
}
#[test]
fn test_builder_fast_mode() {
let builder = EngineBuilder::new().fast();
assert!(builder.fast_mode);
assert!(!builder.precise_mode);
}
#[test]
fn test_builder_precise_mode() {
let builder = EngineBuilder::new().precise();
assert!(builder.precise_mode);
assert!(!builder.fast_mode);
}
#[test]
fn test_builder_top_k() {
let builder = EngineBuilder::new().with_top_k(10);
assert_eq!(builder.top_k, Some(10));
}
}