use std::sync::Arc;
use tokio::sync::Mutex;
use crate::config::{self, Config};
use crate::embedding::{self, Embedder};
use crate::error::{MemeError, Result};
use crate::llm::LlmClient;
use crate::meme::Meme;
use crate::pipeline::{Extractor, HybridRetriever};
use crate::store::{HistoryStore, VectorStore};
#[derive(Debug, Clone, Default)]
pub struct MemeBuilder {
config: Config,
http_client: Option<reqwest::Client>,
clear_db: bool,
namespace: Option<String>,
}
impl MemeBuilder {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn config(mut self, config: Config) -> Self {
self.config = config;
self
}
#[must_use]
pub fn api_key(mut self, key: impl Into<String>) -> Self {
self.config.llm.api_key = Some(key.into());
self
}
#[must_use]
pub fn model(mut self, model: impl Into<String>) -> Self {
self.config.llm.model = model.into();
self
}
#[must_use]
pub fn base_url(mut self, url: impl Into<String>) -> Self {
self.config.llm.base_url = url.into();
self
}
#[must_use]
pub const fn embedding_provider(mut self, provider: config::EmbeddingProviderKind) -> Self {
self.config.embedding.provider = provider;
self
}
#[must_use]
pub fn embedding_model(mut self, model: impl Into<String>) -> Self {
self.config.embedding.model = model.into();
self
}
#[must_use]
pub const fn embedding_dimension(mut self, dim: usize) -> Self {
self.config.embedding.dimension = dim;
self
}
#[must_use]
pub fn embedding_api_key(mut self, key: impl Into<String>) -> Self {
self.config.embedding.api_key = Some(key.into());
self
}
#[must_use]
pub fn embedding_base_url(mut self, url: impl Into<String>) -> Self {
self.config.embedding.base_url = Some(url.into());
self
}
#[must_use]
pub fn lancedb_path(mut self, path: impl Into<std::path::PathBuf>) -> Self {
self.config.store.lancedb_path = path.into();
self
}
#[must_use]
pub fn history_db_path(mut self, path: impl Into<std::path::PathBuf>) -> Self {
self.config.store.history_db_path = path.into();
self
}
#[must_use]
pub const fn enable_planning(mut self, enable: bool) -> Self {
self.config.pipeline.enable_planning = enable;
self
}
#[must_use]
pub const fn enable_reflection(mut self, enable: bool) -> Self {
self.config.pipeline.enable_reflection = enable;
self
}
#[must_use]
pub const fn semantic_top_k(mut self, k: usize) -> Self {
self.config.pipeline.semantic_top_k = k;
self
}
#[must_use]
pub fn reranker(mut self, model: impl Into<String>) -> Self {
self.config.pipeline.reranker_model = Some(model.into());
self
}
#[must_use]
pub const fn rerank_top_n(mut self, n: usize) -> Self {
self.config.pipeline.rerank_top_n = n;
self
}
#[must_use]
pub fn http_client(mut self, client: reqwest::Client) -> Self {
self.http_client = Some(client);
self
}
#[must_use]
pub const fn clear_db(mut self, clear: bool) -> Self {
self.clear_db = clear;
self
}
#[must_use]
pub fn namespace(mut self, ns: impl Into<String>) -> Self {
self.namespace = Some(ns.into());
self
}
pub async fn build(self) -> Result<Meme> {
let config = self.config;
config.validate()?;
let http = self.http_client.map_or_else(build_http_client, Ok)?;
let llm = Arc::new(LlmClient::new(http.clone(), &config.llm)?);
let embedder = Arc::new(match config.embedding.provider {
config::EmbeddingProviderKind::Api => Embedder::Api(embedding::ApiEmbedding::new(
http,
&config.embedding,
&config.llm,
)?),
#[cfg(feature = "onnx")]
config::EmbeddingProviderKind::Onnx => {
Embedder::Onnx(embedding::OnnxEmbedding::new(&config.embedding.model)?)
}
#[cfg(not(feature = "onnx"))]
config::EmbeddingProviderKind::Onnx => {
return Err(MemeError::Config(
"ONNX provider requires the 'onnx' feature flag".into(),
));
}
});
let store = Arc::new(
VectorStore::open(
&config.store.lancedb_path.to_string_lossy(),
&config.store.table_name,
embedder.dimension(),
)
.await?,
);
let history = Arc::new(HistoryStore::open(&config.store.history_db_path)?);
if self.clear_db {
store.clear_all().await?;
}
let extractor = Extractor::new(
Arc::clone(&llm),
&config.pipeline,
config.pipeline.max_build_workers,
);
#[cfg(feature = "onnx")]
let reranker = config
.pipeline
.reranker_model
.as_deref()
.map(crate::reranking::OnnxReranker::new)
.transpose()?
.map(Arc::new);
#[cfg(not(feature = "onnx"))]
if config.pipeline.reranker_model.is_some() {
return Err(MemeError::Config(
"reranker requires the 'onnx' feature flag".into(),
));
}
let retriever = HybridRetriever::new(
Arc::clone(&llm),
Arc::clone(&store),
Arc::clone(&embedder),
config.pipeline.clone(),
self.namespace.clone(),
#[cfg(feature = "onnx")]
reranker,
);
tracing::info!("meme system initialized");
Ok(Meme {
llm,
store,
embedder,
history,
extractor: Mutex::new(extractor),
retriever,
config,
namespace: self.namespace,
})
}
}
const DEFAULT_TIMEOUT: std::time::Duration = std::time::Duration::from_mins(1);
const DEFAULT_CONNECT_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10);
const DEFAULT_POOL_IDLE_PER_HOST: usize = 10;
fn build_http_client() -> Result<reqwest::Client> {
reqwest::Client::builder()
.timeout(DEFAULT_TIMEOUT)
.connect_timeout(DEFAULT_CONNECT_TIMEOUT)
.pool_max_idle_per_host(DEFAULT_POOL_IDLE_PER_HOST)
.user_agent(concat!("meme/", env!("CARGO_PKG_VERSION")))
.build()
.map_err(|e| MemeError::Internal(format!("failed to build HTTP client: {e}")))
}