use std::sync::Arc;
use tracing::warn;
use crate::config::{MemoryBackend, MemoryConfig};
use crate::providers::LLMProvider;
use super::builtin_searcher::BuiltinSearcher;
use super::traits::MemorySearcher;
pub fn create_searcher(config: &MemoryConfig) -> Arc<dyn MemorySearcher> {
create_searcher_with_provider(config, None)
}
pub fn create_searcher_with_provider(
config: &MemoryConfig,
provider: Option<Arc<dyn LLMProvider>>,
) -> Arc<dyn MemorySearcher> {
match &config.backend {
MemoryBackend::Disabled => Arc::new(BuiltinSearcher),
MemoryBackend::Builtin => Arc::new(BuiltinSearcher),
MemoryBackend::Qmd => {
warn!("Memory backend 'qmd' not implemented; using builtin");
Arc::new(BuiltinSearcher)
}
MemoryBackend::Bm25 => {
#[cfg(feature = "memory-bm25")]
{
Arc::new(super::bm25_searcher::Bm25Searcher::new())
}
#[cfg(not(feature = "memory-bm25"))]
{
warn!("memory-bm25 feature not compiled; falling back to builtin. Rebuild with: cargo build --features memory-bm25");
Arc::new(BuiltinSearcher)
}
}
MemoryBackend::Embedding => {
#[cfg(feature = "memory-embedding")]
{
if let Some(p) = provider {
let path = crate::config::Config::dir()
.join("memory")
.join("embeddings.json");
Arc::new(super::embedding_searcher::EmbeddingSearcher::new(p, path))
} else {
warn!("memory-embedding backend requires a provider; falling back to builtin. Pass a provider via create_searcher_with_provider()");
Arc::new(BuiltinSearcher)
}
}
#[cfg(not(feature = "memory-embedding"))]
{
let _ = provider; warn!("memory-embedding feature not compiled; falling back to builtin. Rebuild with: cargo build --features memory-embedding");
Arc::new(BuiltinSearcher)
}
}
MemoryBackend::Hnsw => {
#[cfg(feature = "memory-hnsw")]
{
if let Some(p) = provider {
let path = crate::config::Config::dir()
.join("memory")
.join("hnsw_vectors.json");
Arc::new(super::hnsw_searcher::HnswSearcher::new(p, path))
} else {
warn!("memory-hnsw backend requires a provider; falling back to builtin. Pass a provider via create_searcher_with_provider()");
Arc::new(BuiltinSearcher)
}
}
#[cfg(not(feature = "memory-hnsw"))]
{
let _ = provider; warn!("memory-hnsw feature not compiled; falling back to builtin. Rebuild with: cargo build --features memory-hnsw");
Arc::new(BuiltinSearcher)
}
}
MemoryBackend::Tantivy => {
warn!("memory-tantivy feature not yet implemented; falling back to builtin");
Arc::new(BuiltinSearcher)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_create_searcher_builtin() {
let config = MemoryConfig::default();
let searcher = create_searcher(&config);
assert_eq!(searcher.name(), "builtin");
}
#[test]
fn test_create_searcher_disabled_returns_builtin() {
let mut config = MemoryConfig::default();
config.backend = MemoryBackend::Disabled;
let searcher = create_searcher(&config);
assert_eq!(searcher.name(), "builtin");
}
#[test]
fn test_create_searcher_qmd_falls_back() {
let mut config = MemoryConfig::default();
config.backend = MemoryBackend::Qmd;
let searcher = create_searcher(&config);
assert_eq!(searcher.name(), "builtin");
}
#[test]
fn test_create_searcher_embedding_falls_back() {
let mut config = MemoryConfig::default();
config.backend = MemoryBackend::Embedding;
let searcher = create_searcher(&config);
assert_eq!(searcher.name(), "builtin");
}
#[cfg(feature = "memory-bm25")]
#[test]
fn test_create_searcher_bm25() {
let mut config = MemoryConfig::default();
config.backend = MemoryBackend::Bm25;
let searcher = create_searcher(&config);
assert_eq!(searcher.name(), "bm25");
}
#[test]
fn test_create_searcher_hnsw_falls_back() {
let mut config = MemoryConfig::default();
config.backend = MemoryBackend::Hnsw;
let searcher = create_searcher(&config);
assert_eq!(searcher.name(), "builtin");
}
#[test]
fn test_create_searcher_with_provider_none_embedding_falls_back() {
let mut config = MemoryConfig::default();
config.backend = MemoryBackend::Embedding;
let searcher = create_searcher_with_provider(&config, None);
assert_eq!(searcher.name(), "builtin");
}
#[test]
fn test_create_searcher_with_provider_builtin_ignores_provider() {
use crate::providers::{ChatOptions, LLMProvider, LLMResponse, ToolDefinition};
use crate::session::Message;
use async_trait::async_trait;
use std::sync::Arc;
struct NoopProvider;
#[async_trait]
impl LLMProvider for NoopProvider {
fn name(&self) -> &str {
"noop"
}
fn default_model(&self) -> &str {
"noop"
}
async fn chat(
&self,
_messages: Vec<Message>,
_tools: Vec<ToolDefinition>,
_model: Option<&str>,
_options: ChatOptions,
) -> crate::error::Result<LLMResponse> {
Ok(LLMResponse::text("ok"))
}
}
let config = MemoryConfig::default(); let provider: Option<Arc<dyn LLMProvider>> = Some(Arc::new(NoopProvider));
let searcher = create_searcher_with_provider(&config, provider);
assert_eq!(searcher.name(), "builtin");
}
#[cfg(feature = "memory-embedding")]
#[test]
fn test_create_searcher_with_provider_embedding_with_provider() {
use crate::providers::{ChatOptions, LLMProvider, LLMResponse, ToolDefinition};
use crate::session::Message;
use async_trait::async_trait;
use std::sync::Arc;
struct FakeProvider;
#[async_trait]
impl LLMProvider for FakeProvider {
fn name(&self) -> &str {
"fake"
}
fn default_model(&self) -> &str {
"fake"
}
async fn chat(
&self,
_messages: Vec<Message>,
_tools: Vec<ToolDefinition>,
_model: Option<&str>,
_options: ChatOptions,
) -> crate::error::Result<LLMResponse> {
Ok(LLMResponse::text("ok"))
}
async fn embed(&self, texts: &[String]) -> crate::error::Result<Vec<Vec<f32>>> {
Ok(texts.iter().map(|_| vec![1.0f32, 0.0]).collect())
}
}
let mut config = MemoryConfig::default();
config.backend = MemoryBackend::Embedding;
let provider: Option<Arc<dyn LLMProvider>> = Some(Arc::new(FakeProvider));
let searcher = create_searcher_with_provider(&config, provider);
assert_eq!(searcher.name(), "embedding");
}
#[cfg(feature = "memory-hnsw")]
#[test]
fn test_create_searcher_with_provider_hnsw_with_provider() {
use crate::providers::{ChatOptions, LLMProvider, LLMResponse, ToolDefinition};
use crate::session::Message;
use async_trait::async_trait;
use std::sync::Arc;
struct FakeProvider;
#[async_trait]
impl LLMProvider for FakeProvider {
fn name(&self) -> &str {
"fake"
}
fn default_model(&self) -> &str {
"fake"
}
async fn chat(
&self,
_messages: Vec<Message>,
_tools: Vec<ToolDefinition>,
_model: Option<&str>,
_options: ChatOptions,
) -> crate::error::Result<LLMResponse> {
Ok(LLMResponse::text("ok"))
}
async fn embed(&self, texts: &[String]) -> crate::error::Result<Vec<Vec<f32>>> {
Ok(texts.iter().map(|_| vec![1.0f32, 0.0]).collect())
}
}
let mut config = MemoryConfig::default();
config.backend = MemoryBackend::Hnsw;
let provider: Option<Arc<dyn LLMProvider>> = Some(Arc::new(FakeProvider));
let searcher = create_searcher_with_provider(&config, provider);
assert_eq!(searcher.name(), "hnsw");
}
#[test]
fn test_create_searcher_hnsw_without_provider_falls_back() {
let mut config = MemoryConfig::default();
config.backend = MemoryBackend::Hnsw;
let searcher = create_searcher_with_provider(&config, None);
assert_eq!(searcher.name(), "builtin");
}
}