use crate::{
auth::middleware::AuthUser,
db::{AresVectorStore, VectorStore},
rag::{
chunker::{ChunkingStrategy, TextChunker},
embeddings::{EmbeddingModelType, EmbeddingService},
reranker::{Reranker, RerankerConfig, RerankerModelType},
search::{HybridWeights, SearchEngine, SearchStrategy},
},
types::{
AppError, Document, DocumentMetadata, RagDeleteCollectionRequest,
RagDeleteCollectionResponse, RagIngestRequest, RagIngestResponse, RagSearchRequest,
RagSearchResponse, RagSearchResult, Result,
},
AppState,
};
use axum::{extract::State, Json};
use chrono::Utc;
use std::sync::Arc;
use std::time::Instant;
use tokio::sync::OnceCell;
use uuid::Uuid;
fn user_scoped_collection(user_id: &str, collection: &str) -> String {
format!("user_{}_{}", user_id, collection)
}
fn extract_user_collection(user_id: &str, scoped_name: &str) -> Option<String> {
let prefix = format!("user_{}_", user_id);
scoped_name.strip_prefix(&prefix).map(|s| s.to_string())
}
static EMBEDDING_SERVICE: OnceCell<Arc<EmbeddingService>> = OnceCell::const_new();
async fn get_embedding_service() -> Result<Arc<EmbeddingService>> {
EMBEDDING_SERVICE
.get_or_try_init(|| async {
let model = EmbeddingModelType::default();
let cache_dir = std::env::var("FASTEMBED_CACHE_DIR")
.unwrap_or_else(|_| ".fastembed_cache".to_string());
let cache_path = std::path::PathBuf::from(&cache_dir);
match lancor::hub::HubClient::with_cache_dir(cache_path.clone()) {
Err(e) => tracing::error!("Failed to create lancor HubClient: {}", e),
Ok(hub) => {
let repo_id = model.hf_repo_id();
for filename in &["onnx/model.onnx", "tokenizer.json", "config.json", "tokenizer_config.json"] {
let folder = format!("models--{}", repo_id.replace('/', "--"));
let snapshot_dir = cache_path.join(&folder).join("snapshots").join("lancor");
let target = snapshot_dir.join(filename);
if target.exists() && std::fs::metadata(&target).map(|m| m.len() > 0).unwrap_or(false) {
tracing::debug!("Model file cached: {}", target.display());
continue;
}
tracing::info!("Downloading {}/{} via lancor...", repo_id, filename);
if let Some(parent) = target.parent() {
std::fs::create_dir_all(parent).ok();
}
match hub.download(repo_id, filename, None).await {
Ok(dl_path) => {
if dl_path != target {
std::fs::copy(&dl_path, &target).ok();
}
tracing::info!("Downloaded: {} ({} bytes)", filename,
std::fs::metadata(&target).map(|m| m.len()).unwrap_or(0));
}
Err(e) => tracing::warn!("Could not download {}: {}", filename, e),
}
}
let refs_dir = cache_path.join(format!("models--{}", repo_id.replace('/', "--"))).join("refs");
std::fs::create_dir_all(&refs_dir).ok();
std::fs::write(refs_dir.join("main"), "lancor").ok();
}
}
let service = EmbeddingService::with_model(model)
.map_err(|e| AppError::Internal(format!("Failed to init embeddings: {}", e)))?;
Ok::<_, AppError>(Arc::new(service))
})
.await
.cloned()
}
static VECTOR_STORE: OnceCell<Arc<AresVectorStore>> = OnceCell::const_new();
async fn get_vector_store(vector_path: &str) -> Result<Arc<AresVectorStore>> {
VECTOR_STORE
.get_or_try_init(|| async {
let store = AresVectorStore::new(Some(vector_path.to_string())).await?;
Ok::<_, AppError>(Arc::new(store))
})
.await
.cloned()
}
#[utoipa::path(
post,
path = "/api/rag/ingest",
request_body = RagIngestRequest,
responses(
(status = 200, description = "Document ingested successfully", body = RagIngestResponse),
(status = 400, description = "Invalid request"),
(status = 401, description = "Unauthorized"),
(status = 500, description = "Internal server error")
),
tag = "rag",
security(("bearer" = []))
)]
pub async fn ingest(
State(state): State<AppState>,
AuthUser(claims): AuthUser,
Json(payload): Json<RagIngestRequest>,
) -> Result<Json<RagIngestResponse>> {
let start = Instant::now();
if payload.collection.is_empty() {
return Err(AppError::InvalidInput("Collection name required".into()));
}
if payload.content.is_empty() {
return Err(AppError::InvalidInput("Content required".into()));
}
let scoped_collection = user_scoped_collection(&claims.sub, &payload.collection);
let embedding_service = get_embedding_service().await?;
let vector_path = &state.config_manager.config().rag.vector_path;
let vector_store = get_vector_store(vector_path).await?;
let strategy: ChunkingStrategy = payload
.chunking_strategy
.as_ref()
.map(|s| s.parse())
.transpose()?
.unwrap_or_default();
let chunker = match strategy {
ChunkingStrategy::Word => TextChunker::with_word_chunking(200, 50),
ChunkingStrategy::Semantic => TextChunker::with_semantic_chunking(500),
ChunkingStrategy::Character => TextChunker::with_character_chunking(500, 100),
};
let chunks = chunker.chunk_with_metadata(&payload.content);
if chunks.is_empty() {
return Err(AppError::InvalidInput("Content too small to chunk".into()));
}
let dimensions = embedding_service.dimensions();
if !vector_store.collection_exists(&scoped_collection).await? {
vector_store
.create_collection(&scoped_collection, dimensions)
.await?;
}
let chunk_texts: Vec<String> = chunks.iter().map(|c| c.content.clone()).collect();
let embeddings = embedding_service.embed_texts(&chunk_texts).await?;
let base_id = Uuid::new_v4().to_string();
let mut documents = Vec::with_capacity(chunks.len());
let mut document_ids = Vec::with_capacity(chunks.len());
for (i, (chunk, embedding)) in chunks.iter().zip(embeddings.into_iter()).enumerate() {
let doc_id = format!("{}_{}", base_id, i);
document_ids.push(doc_id.clone());
documents.push(Document {
id: doc_id,
content: chunk.content.clone(),
metadata: DocumentMetadata {
title: payload.title.clone().unwrap_or_default(),
source: payload.source.clone().unwrap_or_default(),
created_at: Utc::now(),
tags: payload.tags.clone(),
},
embedding: Some(embedding),
});
}
let count = vector_store.upsert(&scoped_collection, &documents).await?;
tracing::info!(
user_id = %claims.sub,
collection = %payload.collection,
scoped_collection = %scoped_collection,
chunks = count,
duration_ms = start.elapsed().as_millis() as u64,
"Document ingested"
);
Ok(Json(RagIngestResponse {
chunks_created: count,
document_ids,
collection: payload.collection, }))
}
#[utoipa::path(
post,
path = "/api/rag/search",
request_body = RagSearchRequest,
responses(
(status = 200, description = "Search completed", body = RagSearchResponse),
(status = 400, description = "Invalid request"),
(status = 401, description = "Unauthorized"),
(status = 404, description = "Collection not found"),
(status = 500, description = "Internal server error")
),
tag = "rag",
security(("bearer" = []))
)]
pub async fn search(
State(state): State<AppState>,
AuthUser(claims): AuthUser,
Json(payload): Json<RagSearchRequest>,
) -> Result<Json<RagSearchResponse>> {
let start = Instant::now();
if payload.collection.is_empty() {
return Err(AppError::InvalidInput("Collection name required".into()));
}
if payload.query.is_empty() {
return Err(AppError::InvalidInput("Query required".into()));
}
let scoped_collection = user_scoped_collection(&claims.sub, &payload.collection);
let embedding_service = get_embedding_service().await?;
let vector_path = &state.config_manager.config().rag.vector_path;
let vector_store = get_vector_store(vector_path).await?;
if !vector_store.collection_exists(&scoped_collection).await? {
return Err(AppError::NotFound(format!(
"Collection '{}' not found",
payload.collection
)));
}
let strategy: SearchStrategy = payload
.strategy
.as_ref()
.map(|s| s.parse())
.transpose()?
.unwrap_or(SearchStrategy::Semantic);
let query_embedding = embedding_service.embed_text(&payload.query).await?;
let vector_results = vector_store
.search(
&scoped_collection,
&query_embedding,
payload.limit * 2, payload.threshold,
)
.await?;
let mut results: Vec<RagSearchResult> = match strategy {
SearchStrategy::Semantic => {
vector_results
.iter()
.take(payload.limit)
.map(|r| RagSearchResult {
id: r.document.id.clone(),
content: r.document.content.clone(),
score: r.score,
metadata: r.document.metadata.clone(),
})
.collect()
}
SearchStrategy::Bm25 | SearchStrategy::Fuzzy | SearchStrategy::Hybrid => {
let mut search_engine = SearchEngine::new();
for r in &vector_results {
search_engine.index_document(&r.document);
}
let strategy_results = match strategy {
SearchStrategy::Bm25 => search_engine.search_bm25(&payload.query, payload.limit),
SearchStrategy::Fuzzy => search_engine.search_fuzzy(&payload.query, payload.limit),
SearchStrategy::Hybrid => {
let semantic_scores: Vec<_> = vector_results
.iter()
.map(|r| (r.document.id.clone(), r.score))
.collect();
let weights = HybridWeights::default();
search_engine.search_hybrid(
&payload.query,
&semantic_scores,
&weights,
payload.limit,
)
}
_ => vec![], };
strategy_results
.iter()
.filter_map(|(id, score)| {
vector_results
.iter()
.find(|r| r.document.id == *id)
.map(|r| RagSearchResult {
id: r.document.id.clone(),
content: r.document.content.clone(),
score: *score,
metadata: r.document.metadata.clone(),
})
})
.collect()
}
};
let reranked = if payload.rerank && !results.is_empty() {
let model_type: RerankerModelType = payload
.reranker_model
.as_ref()
.map(|s| s.parse())
.transpose()?
.unwrap_or_default();
let config = RerankerConfig {
model: model_type,
..Default::default()
};
let reranker = Reranker::new(config);
let rerank_input: Vec<_> = results
.iter()
.map(|r| (r.id.clone(), r.content.clone(), r.score))
.collect();
let reranked_results = reranker
.rerank(&payload.query, &rerank_input, Some(payload.limit))
.await
.map_err(|e| AppError::Internal(format!("Reranking failed: {}", e)))?;
results = reranked_results
.into_iter()
.filter_map(|rr| {
results
.iter()
.find(|r| r.id == rr.id)
.map(|r| RagSearchResult {
id: r.id.clone(),
content: r.content.clone(),
score: rr.final_score,
metadata: r.metadata.clone(),
})
})
.collect();
true
} else {
false
};
let total = results.len();
let strategy_name = format!("{:?}", strategy).to_lowercase();
tracing::info!(
user_id = %claims.sub,
collection = %payload.collection,
strategy = %strategy_name,
results = total,
reranked = reranked,
duration_ms = start.elapsed().as_millis() as u64,
"Search completed"
);
Ok(Json(RagSearchResponse {
results,
total,
strategy: strategy_name,
reranked,
duration_ms: start.elapsed().as_millis() as u64,
}))
}
#[utoipa::path(
delete,
path = "/api/rag/collection",
request_body = RagDeleteCollectionRequest,
responses(
(status = 200, description = "Collection deleted", body = RagDeleteCollectionResponse),
(status = 400, description = "Invalid request"),
(status = 401, description = "Unauthorized"),
(status = 404, description = "Collection not found"),
(status = 500, description = "Internal server error")
),
tag = "rag",
security(("bearer" = []))
)]
pub async fn delete_collection(
State(state): State<AppState>,
AuthUser(claims): AuthUser,
Json(payload): Json<RagDeleteCollectionRequest>,
) -> Result<Json<RagDeleteCollectionResponse>> {
if payload.collection.is_empty() {
return Err(AppError::InvalidInput("Collection name required".into()));
}
let scoped_collection = user_scoped_collection(&claims.sub, &payload.collection);
let vector_path = &state.config_manager.config().rag.vector_path;
let vector_store = get_vector_store(vector_path).await?;
if !vector_store.collection_exists(&scoped_collection).await? {
return Err(AppError::NotFound(format!(
"Collection '{}' not found",
payload.collection
)));
}
let stats = vector_store.collection_stats(&scoped_collection).await?;
let doc_count = stats.document_count;
vector_store.delete_collection(&scoped_collection).await?;
tracing::info!(
user_id = %claims.sub,
collection = %payload.collection,
documents = doc_count,
"Collection deleted"
);
Ok(Json(RagDeleteCollectionResponse {
success: true,
collection: payload.collection, documents_deleted: doc_count,
}))
}
#[utoipa::path(
get,
path = "/api/rag/collections",
responses(
(status = 200, description = "Collections listed", body = Vec<String>),
(status = 401, description = "Unauthorized"),
(status = 500, description = "Internal server error")
),
tag = "rag",
security(("bearer" = []))
)]
pub async fn list_collections(
State(state): State<AppState>,
AuthUser(claims): AuthUser,
) -> Result<Json<Vec<crate::db::CollectionInfo>>> {
let vector_path = &state.config_manager.config().rag.vector_path;
let vector_store = get_vector_store(vector_path).await?;
let all_collections = vector_store.list_collections().await?;
let user_collections: Vec<_> = all_collections
.into_iter()
.filter_map(|mut info| {
extract_user_collection(&claims.sub, &info.name).map(|user_name| {
info.name = user_name;
info
})
})
.collect();
Ok(Json(user_collections))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_search_strategy() {
let strategy: SearchStrategy = "semantic".parse().unwrap();
assert_eq!(strategy, SearchStrategy::Semantic);
let strategy: SearchStrategy = "bm25".parse().unwrap();
assert_eq!(strategy, SearchStrategy::Bm25);
let strategy: SearchStrategy = "hybrid".parse().unwrap();
assert_eq!(strategy, SearchStrategy::Hybrid);
}
#[test]
fn test_default_chunking_strategy() {
let strategy: ChunkingStrategy = "word".parse().unwrap();
assert_eq!(strategy, ChunkingStrategy::Word);
let strategy: ChunkingStrategy = "semantic".parse().unwrap();
assert_eq!(strategy, ChunkingStrategy::Semantic);
}
}