use anyhow::Result;
use serde::{Deserialize, Serialize};
use std::collections::HashSet;
use std::sync::Arc;
use tokio::sync::Mutex;
use rust_memex::{
BM25Config, EmbeddingClient, EmbeddingConfig, HybridConfig, HybridSearcher, RAGPipeline,
SearchMode, SearchOptions, SliceLayer, StorageManager,
};
use crate::cli::config::*;
use crate::cli::formatting::*;
fn bm25_path_from_db(db_path: &str) -> String {
let expanded = shellexpand::tilde(db_path).to_string();
std::path::Path::new(&expanded)
.parent()
.map(|p| p.join("bm25").to_string_lossy().to_string())
.unwrap_or_else(|| BM25Config::default().index_path)
}
pub async fn check_and_maybe_optimize(
storage: &StorageManager,
maintenance_config: &Option<MaintenanceFileConfig>,
) -> Result<bool> {
let config = match maintenance_config {
Some(c) if c.auto_optimize => c,
_ => return Ok(false), };
let stats = storage.stats().await?;
if stats.version_count > config.version_threshold {
eprintln!(
"Auto-optimizing: {} versions exceed threshold {}",
stats.version_count, config.version_threshold
);
storage.optimize().await?;
if let Some(days) = config.auto_cleanup_days {
storage.cleanup(Some(days)).await?;
}
eprintln!("Auto-optimization complete");
return Ok(true);
}
Ok(false)
}
pub struct SearchConfig<'a> {
pub namespace: String,
pub query: String,
pub limit: usize,
pub json_output: bool,
pub db_path: String,
pub layer_filter: Option<SliceLayer>,
pub search_mode: SearchMode,
pub embedding_config: &'a EmbeddingConfig,
}
pub async fn run_search(config: SearchConfig<'_>) -> Result<()> {
let SearchConfig {
namespace,
query,
limit,
json_output,
db_path,
layer_filter,
search_mode,
embedding_config,
} = config;
let embedding_client = Arc::new(Mutex::new(EmbeddingClient::new(embedding_config).await?));
let storage = Arc::new(StorageManager::new_lance_only(&db_path).await?);
if search_mode != SearchMode::Vector {
let hybrid_config = HybridConfig {
mode: search_mode,
bm25: BM25Config {
index_path: bm25_path_from_db(&db_path),
read_only: true,
..Default::default()
},
..Default::default()
};
let hybrid_searcher = HybridSearcher::new(storage, hybrid_config).await?;
let query_embedding = embedding_client.lock().await.embed(&query).await?;
let results = hybrid_searcher
.search(
&query,
query_embedding,
Some(&namespace),
limit,
SearchOptions {
layer_filter,
project_filter: None,
},
)
.await?;
if json_output {
let json = json_hybrid_search_results(
&query,
Some(&namespace),
&results,
layer_filter,
search_mode,
);
println!("{}", serde_json::to_string_pretty(&json)?);
} else {
display_hybrid_search_results(
&query,
Some(&namespace),
&results,
layer_filter,
search_mode,
);
}
} else {
let rag = RAGPipeline::new(embedding_client, storage).await?;
let results = rag
.memory_search_with_layer(&namespace, &query, limit, layer_filter)
.await?;
if json_output {
let json = json_search_results(&query, Some(&namespace), &results, layer_filter);
println!("{}", serde_json::to_string_pretty(&json)?);
} else {
display_search_results(&query, Some(&namespace), &results, layer_filter);
}
}
Ok(())
}
pub async fn run_expand(
namespace: String,
id: String,
json_output: bool,
db_path: String,
embedding_config: &EmbeddingConfig,
) -> Result<()> {
let embedding_client = Arc::new(Mutex::new(EmbeddingClient::new(embedding_config).await?));
let storage = Arc::new(StorageManager::new_lance_only(&db_path).await?);
let rag = RAGPipeline::new(embedding_client, storage).await?;
let results = rag.expand_result(&namespace, &id).await?;
if json_output {
let json = serde_json::json!({
"parent_id": id,
"namespace": namespace,
"children_count": results.len(),
"children": results.iter().map(|r| serde_json::json!({
"id": r.id,
"layer": r.layer.map(|l| l.name()),
"text": r.text,
"keywords": r.keywords,
"parent_id": r.parent_id,
"children_ids": r.children_ids,
})).collect::<Vec<_>>()
});
println!("{}", serde_json::to_string_pretty(&json)?);
} else {
eprintln!("\n-> Children of slice \"{id}\" in [{namespace}]\n");
if results.is_empty() {
eprintln!("No children found (this may be a leaf/outer slice).");
} else {
for (i, result) in results.iter().enumerate() {
let layer_str = result.layer.map(|l| l.name()).unwrap_or("flat");
let preview: String = result
.text
.chars()
.take(100)
.collect::<String>()
.replace('\n', " ");
let ellipsis = if result.text.len() > 100 { "..." } else { "" };
eprintln!("{}. [{}] {}", i + 1, layer_str, result.id);
eprintln!(" \"{}{ellipsis}\"", preview);
if !result.keywords.is_empty() {
eprintln!(" Keywords: {}", result.keywords.join(", "));
}
eprintln!();
}
}
}
Ok(())
}
pub async fn run_get(
namespace: String,
id: String,
json_output: bool,
db_path: String,
embedding_config: &EmbeddingConfig,
) -> Result<()> {
let embedding_client = Arc::new(Mutex::new(EmbeddingClient::new(embedding_config).await?));
let storage = Arc::new(StorageManager::new_lance_only(&db_path).await?);
let rag = RAGPipeline::new(embedding_client, storage).await?;
match rag.lookup_memory(&namespace, &id).await? {
Some(result) => {
if json_output {
let json = serde_json::json!({
"found": true,
"id": result.id,
"namespace": result.namespace,
"text": result.text,
"metadata": result.metadata
});
println!("{}", serde_json::to_string_pretty(&json)?);
} else {
eprintln!("\n-> Found chunk in [{namespace}]\n");
eprintln!("ID: {}", result.id);
eprintln!("Namespace: {}", result.namespace);
if !result.metadata.is_null() && result.metadata != serde_json::json!({}) {
eprintln!("Metadata: {}", result.metadata);
}
eprintln!("\n--- Content ---\n");
println!("{}", result.text);
}
}
None => {
if json_output {
let json = serde_json::json!({
"found": false,
"namespace": namespace,
"id": id
});
println!("{}", serde_json::to_string_pretty(&json)?);
} else {
eprintln!("Chunk '{}' not found in namespace '{}'", id, namespace);
}
}
}
Ok(())
}
pub async fn run_rag_search(
query: String,
limit: usize,
namespace: Option<String>,
json_output: bool,
db_path: String,
embedding_config: &EmbeddingConfig,
) -> Result<()> {
let embedding_client = Arc::new(Mutex::new(EmbeddingClient::new(embedding_config).await?));
let storage = Arc::new(StorageManager::new_lance_only(&db_path).await?);
let rag = RAGPipeline::new(embedding_client, storage).await?;
let results = rag
.search_inner(namespace.as_deref(), &query, limit)
.await?;
if json_output {
let json = json_search_results(&query, namespace.as_deref(), &results, None);
println!("{}", serde_json::to_string_pretty(&json)?);
} else {
display_search_results(&query, namespace.as_deref(), &results, None);
}
Ok(())
}
pub async fn run_list_namespaces(stats: bool, json_output: bool, db_path: String) -> Result<()> {
let storage = StorageManager::new_lance_only(&db_path).await?;
let storage = Arc::new(storage);
let all_docs = storage.all_documents(None, 10000).await?;
let mut namespace_counts: std::collections::HashMap<String, usize> =
std::collections::HashMap::new();
for doc in &all_docs {
*namespace_counts.entry(doc.namespace.clone()).or_insert(0) += 1;
}
let mut namespaces: Vec<_> = namespace_counts.into_iter().collect();
namespaces.sort_by(|a, b| a.0.cmp(&b.0));
if json_output {
let json = if stats {
serde_json::json!({
"namespaces": namespaces.iter().map(|(ns, count)| serde_json::json!({
"name": ns,
"document_count": count
})).collect::<Vec<_>>()
})
} else {
serde_json::json!({
"namespaces": namespaces.iter().map(|(ns, _)| ns).collect::<Vec<_>>()
})
};
println!("{}", serde_json::to_string_pretty(&json)?);
} else {
eprintln!("\n-> Namespaces in {}\n", storage.lance_path());
if namespaces.is_empty() {
eprintln!("No namespaces found (database may be empty).");
} else {
for (ns, count) in &namespaces {
if stats {
eprintln!(" {} ({} documents)", ns, count);
} else {
eprintln!(" {}", ns);
}
}
eprintln!();
eprintln!("Total: {} namespace(s)", namespaces.len());
}
}
Ok(())
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CrossSearchResult {
pub id: String,
pub namespace: String,
pub text: String,
pub score: f32,
pub metadata: serde_json::Value,
#[serde(skip_serializing_if = "Option::is_none")]
pub layer: Option<String>,
#[serde(skip_serializing_if = "Vec::is_empty", default)]
pub keywords: Vec<String>,
}
pub async fn run_cross_search(
query: String,
limit_per_ns: usize,
total_limit: usize,
mode: String,
json_output: bool,
db_path: String,
embedding_config: &EmbeddingConfig,
) -> Result<()> {
let embedding_client = Arc::new(Mutex::new(EmbeddingClient::new(embedding_config).await?));
let storage = Arc::new(StorageManager::new_lance_only(&db_path).await?);
let all_docs = storage.all_documents(None, 10000).await?;
let mut namespace_set: HashSet<String> = HashSet::new();
for doc in &all_docs {
namespace_set.insert(doc.namespace.clone());
}
let namespaces: Vec<String> = namespace_set.into_iter().collect();
if namespaces.is_empty() {
if json_output {
println!(
"{}",
serde_json::json!({ "results": [], "total": 0, "namespaces_searched": 0 })
);
} else {
eprintln!("No namespaces found in database.");
}
return Ok(());
}
if !json_output {
eprintln!(
"Searching {} namespaces for: \"{}\"",
namespaces.len(),
query
);
eprintln!(
"Mode: {}, limit per namespace: {}, total limit: {}",
mode, limit_per_ns, total_limit
);
eprintln!();
}
let search_mode = match mode.as_str() {
"vector" => SearchMode::Vector,
"keyword" | "bm25" => SearchMode::Keyword,
_ => SearchMode::Hybrid,
};
let hybrid_config = HybridConfig {
mode: search_mode,
bm25: BM25Config {
index_path: bm25_path_from_db(&db_path),
read_only: true,
..Default::default()
},
..Default::default()
};
let hybrid_searcher = HybridSearcher::new(storage.clone(), hybrid_config).await?;
let query_embedding = embedding_client.lock().await.embed(&query).await?;
let mut all_results: Vec<CrossSearchResult> = Vec::new();
for ns in &namespaces {
let ns_results = hybrid_searcher
.search(
&query,
query_embedding.clone(),
Some(ns.as_str()),
limit_per_ns,
SearchOptions::default(),
)
.await?;
for r in ns_results {
all_results.push(CrossSearchResult {
id: r.id,
namespace: r.namespace,
text: r.document,
score: r.combined_score,
metadata: r.metadata,
layer: r.layer.map(|l| l.to_string()),
keywords: r.keywords,
});
}
}
all_results.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
all_results.truncate(total_limit);
if json_output {
let output = serde_json::json!({
"query": query,
"mode": mode,
"namespaces_searched": namespaces.len(),
"total_results": all_results.len(),
"results": all_results
});
println!("{}", serde_json::to_string_pretty(&output)?);
} else {
eprintln!(
"Found {} results across {} namespaces:\n",
all_results.len(),
namespaces.len()
);
for (idx, r) in all_results.iter().enumerate() {
eprintln!(
"{}. [{}] {} (score: {:.4})",
idx + 1,
r.namespace,
&r.id,
r.score
);
if let Some(ref layer) = r.layer {
eprintln!(" Layer: {}", layer);
}
if !r.keywords.is_empty() {
eprintln!(" Keywords: {}", r.keywords.join(", "));
}
let preview = if r.text.len() > 200 {
format!("{}...", &r.text[..200])
} else {
r.text.clone()
};
eprintln!(" {}\n", preview.replace('\n', " "));
}
}
Ok(())
}