use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::Arc;
use crate::backend::Backend;
use crate::model::LlamaModel;
use crate::tokenizer::Tokenizer;
use super::{
Document, EmbeddingGenerator, MetadataFilter, NewDocument, RagConfig, RagError, RagResult,
RagStore, SearchType, TextChunker,
};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KnowledgeBaseConfig {
pub name: String,
#[serde(default)]
pub description: Option<String>,
#[serde(default)]
pub storage: RagConfig,
#[serde(default)]
pub chunking: ChunkingStrategy,
#[serde(default)]
pub retrieval: RetrievalConfig,
#[serde(default)]
pub hybrid_search: bool,
#[serde(default)]
pub reranking: Option<RerankingConfig>,
}
impl Default for KnowledgeBaseConfig {
fn default() -> Self {
Self {
name: "default".into(),
description: None,
storage: RagConfig::default(),
chunking: ChunkingStrategy::default(),
retrieval: RetrievalConfig::default(),
hybrid_search: false,
reranking: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ChunkingStrategy {
None,
FixedSize {
max_tokens: usize,
overlap_percentage: u8,
},
Semantic {
max_tokens: usize,
buffer_size: usize,
},
Hierarchical {
parent_max_tokens: usize,
child_max_tokens: usize,
child_overlap_percentage: u8,
},
}
impl Default for ChunkingStrategy {
fn default() -> Self {
Self::FixedSize {
max_tokens: 300,
overlap_percentage: 20,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RetrievalConfig {
#[serde(default = "default_max_results")]
pub max_results: usize,
#[serde(default = "default_min_score")]
pub min_score: f32,
#[serde(default)]
pub search_type: SearchType,
#[serde(skip)]
pub filter: Option<MetadataFilter>,
#[serde(default)]
pub prompt_template: Option<String>,
}
fn default_max_results() -> usize {
5
}
fn default_min_score() -> f32 {
0.5
}
impl Default for RetrievalConfig {
fn default() -> Self {
Self {
max_results: 5,
min_score: 0.5,
search_type: SearchType::Semantic,
filter: None,
prompt_template: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RerankingConfig {
pub num_candidates: usize,
pub method: RerankingMethod,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum RerankingMethod {
ScoreBased,
CrossEncoder { model_path: String },
RRF { k: usize },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum DataSource {
File {
path: PathBuf,
},
Directory {
path: PathBuf,
pattern: Option<String>,
#[serde(default = "default_true")]
recursive: bool,
},
Text {
content: String,
source_id: String,
metadata: Option<serde_json::Value>,
},
Url {
url: String,
#[serde(default)]
depth: usize,
},
ObjectStorage {
bucket: String,
prefix: String,
endpoint: Option<String>,
},
}
fn default_true() -> bool {
true
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct IngestionResult {
pub documents_processed: usize,
pub chunks_created: usize,
pub failures: HashMap<String, String>,
pub metadata: IngestionMetadata,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct IngestionMetadata {
pub source_id: String,
pub timestamp: String,
pub chunking_strategy: String,
pub total_characters: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RetrievedChunk {
pub content: String,
pub score: f32,
pub source: SourceLocation,
pub metadata: Option<serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SourceLocation {
pub source_type: String,
pub uri: String,
pub location: Option<TextLocation>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TextLocation {
pub start: usize,
pub end: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Citation {
pub generated_text_span: Option<TextLocation>,
pub source: SourceLocation,
pub score: f32,
pub content: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RetrievalResponse {
pub chunks: Vec<RetrievedChunk>,
pub query: String,
pub next_token: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RetrieveAndGenerateResponse {
pub output: String,
pub citations: Vec<Citation>,
pub retrieved_chunks: Vec<RetrievedChunk>,
pub guardrail_action: Option<String>,
}
pub fn rerank(mut chunks: Vec<RetrievedChunk>, config: &RerankingConfig) -> Vec<RetrievedChunk> {
match &config.method {
RerankingMethod::ScoreBased => {
chunks.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal));
}
RerankingMethod::RRF { k: _ } => {
chunks.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal));
}
RerankingMethod::CrossEncoder { model_path } => {
tracing::warn!(
"CrossEncoder reranking with model '{}' is not yet implemented; falling back to score-based sort",
model_path
);
chunks.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal));
}
}
chunks
}
pub struct KnowledgeBase {
config: KnowledgeBaseConfig,
store: RagStore,
embedding_gen: Option<EmbeddingGenerator>,
}
impl KnowledgeBase {
pub async fn create(config: KnowledgeBaseConfig) -> RagResult<Self> {
let store = RagStore::connect(config.storage.clone()).await?;
store.create_table().await?;
Ok(Self {
config,
store,
embedding_gen: None,
})
}
pub async fn connect(config: KnowledgeBaseConfig) -> RagResult<Self> {
let store = RagStore::connect(config.storage.clone()).await?;
Ok(Self {
config,
store,
embedding_gen: None,
})
}
pub fn with_embedding_generator(mut self, emb_gen: EmbeddingGenerator) -> Self {
self.embedding_gen = Some(emb_gen);
self
}
pub fn name(&self) -> &str {
&self.config.name
}
pub fn config(&self) -> &KnowledgeBaseConfig {
&self.config
}
pub async fn ingest(&self, source: DataSource) -> RagResult<IngestionResult> {
let mut result = IngestionResult {
documents_processed: 0,
chunks_created: 0,
failures: HashMap::new(),
metadata: IngestionMetadata {
source_id: self.source_id(&source),
timestamp: chrono_now(),
chunking_strategy: format!("{:?}", self.config.chunking),
total_characters: 0,
},
};
match source {
DataSource::File { path } => {
self.ingest_file(&path, &mut result).await?;
}
DataSource::Directory {
path,
pattern,
recursive,
} => {
self.ingest_directory(&path, pattern.as_deref(), recursive, &mut result)
.await?;
}
DataSource::Text {
content,
source_id,
metadata,
} => {
self.ingest_text(&content, &source_id, metadata, &mut result)
.await?;
}
DataSource::Url { url, depth: _ } => {
result.failures.insert(
url,
"URL ingestion not yet implemented".into(),
);
}
DataSource::ObjectStorage {
bucket,
prefix,
endpoint: _,
} => {
result.failures.insert(
format!("{}:{}", bucket, prefix),
"Object storage ingestion not yet implemented".into(),
);
}
}
Ok(result)
}
pub async fn retrieve(
&self,
query: &str,
config: Option<RetrievalConfig>,
) -> RagResult<RetrievalResponse> {
let config = config.unwrap_or_else(|| self.config.retrieval.clone());
let query_embedding = self.embed_query(query)?;
let docs = if self.config.storage.search_type() == SearchType::Hybrid {
self.store
.search_hybrid(
&query_embedding,
query,
Some(config.max_results),
config.filter,
)
.await?
} else {
self.store
.search_with_filter(&query_embedding, Some(config.max_results), config.filter)
.await?
};
let mut chunks: Vec<RetrievedChunk> = docs
.into_iter()
.filter(|d| d.score.unwrap_or(0.0) >= config.min_score)
.map(|d| self.doc_to_chunk(d))
.collect();
if let Some(reranking_config) = &self.config.reranking {
chunks = rerank(chunks, reranking_config);
}
Ok(RetrievalResponse {
chunks,
query: query.to_string(),
next_token: None,
})
}
pub async fn retrieve_and_generate(
&self,
query: &str,
config: Option<RetrievalConfig>,
) -> RagResult<RetrieveAndGenerateResponse> {
let config = config.unwrap_or_else(|| self.config.retrieval.clone());
let retrieval = self.retrieve(query, Some(config.clone())).await?;
let context = self.build_context(&retrieval.chunks);
let prompt = if let Some(template) = &config.prompt_template {
template
.replace("{context}", &context)
.replace("{query}", query)
.replace("{question}", query)
} else {
self.default_prompt(&context, query)
};
let citations: Vec<Citation> = retrieval
.chunks
.iter()
.map(|chunk| Citation {
generated_text_span: None, source: chunk.source.clone(),
score: chunk.score,
content: chunk.content.clone(),
})
.collect();
Ok(RetrieveAndGenerateResponse {
output: prompt,
citations,
retrieved_chunks: retrieval.chunks,
guardrail_action: None,
})
}
pub async fn sync(&self) -> RagResult<()> {
Ok(())
}
pub async fn delete(&self) -> RagResult<()> {
self.store.clear().await?;
Ok(())
}
pub async fn stats(&self) -> RagResult<KnowledgeBaseStats> {
let document_count = self.store.count().await? as usize;
Ok(KnowledgeBaseStats {
name: self.config.name.clone(),
document_count,
embedding_dimension: self.config.storage.embedding_dim(),
chunking_strategy: format!("{:?}", self.config.chunking),
hybrid_search_enabled: self.config.hybrid_search,
})
}
fn source_id(&self, source: &DataSource) -> String {
match source {
DataSource::File { path } => path.to_string_lossy().to_string(),
DataSource::Directory { path, .. } => path.to_string_lossy().to_string(),
DataSource::Text { source_id, .. } => source_id.clone(),
DataSource::Url { url, .. } => url.clone(),
DataSource::ObjectStorage { bucket, prefix, .. } => {
format!("s3://{}/{}", bucket, prefix)
}
}
}
async fn ingest_file(
&self,
path: &std::path::Path,
result: &mut IngestionResult,
) -> RagResult<()> {
match std::fs::read_to_string(path) {
Ok(content) => {
let source_id = path.to_string_lossy().to_string();
let metadata = serde_json::json!({
"source": source_id,
"source_type": "file",
"filename": path.file_name().map(|n| n.to_string_lossy().to_string()),
});
self.ingest_text(&content, &source_id, Some(metadata), result)
.await?;
result.documents_processed += 1;
}
Err(e) => {
result
.failures
.insert(path.to_string_lossy().to_string(), e.to_string());
}
}
Ok(())
}
async fn ingest_directory(
&self,
path: &std::path::Path,
pattern: Option<&str>,
recursive: bool,
result: &mut IngestionResult,
) -> RagResult<()> {
let entries = if recursive {
self.walk_directory_recursive(path, pattern)?
} else {
self.walk_directory_flat(path, pattern)?
};
for entry in entries {
self.ingest_file(&entry, result).await?;
}
Ok(())
}
fn walk_directory_recursive(
&self,
path: &std::path::Path,
pattern: Option<&str>,
) -> RagResult<Vec<PathBuf>> {
let mut files = Vec::new();
fn visit_dir(dir: &std::path::Path, files: &mut Vec<PathBuf>) -> std::io::Result<()> {
for entry in std::fs::read_dir(dir)? {
let entry = entry?;
let path = entry.path();
if path.is_dir() {
visit_dir(&path, files)?;
} else if path.is_file() {
files.push(path);
}
}
Ok(())
}
visit_dir(path, &mut files)
.map_err(|e| RagError::ConfigError(format!("Failed to read directory: {}", e)))?;
if let Some(pattern) = pattern {
files.retain(|f| {
let path_str = f.to_string_lossy();
matches_glob_pattern(&path_str, pattern)
});
}
Ok(files)
}
fn walk_directory_flat(
&self,
path: &std::path::Path,
pattern: Option<&str>,
) -> RagResult<Vec<PathBuf>> {
let mut files = Vec::new();
let entries = std::fs::read_dir(path)
.map_err(|e| RagError::ConfigError(format!("Failed to read directory: {}", e)))?;
for entry in entries.flatten() {
let path = entry.path();
if path.is_file() {
files.push(path);
}
}
if let Some(pattern) = pattern {
files.retain(|f| {
let path_str = f.to_string_lossy();
matches_glob_pattern(&path_str, pattern)
});
}
Ok(files)
}
async fn ingest_text(
&self,
content: &str,
source_id: &str,
metadata: Option<serde_json::Value>,
result: &mut IngestionResult,
) -> RagResult<()> {
result.metadata.total_characters += content.len();
let chunks = self.chunk_text(content);
for (i, chunk_text) in chunks.iter().enumerate() {
let chunk_metadata = serde_json::json!({
"source": source_id,
"chunk_index": i,
"total_chunks": chunks.len(),
"parent": metadata.clone(),
});
let embedding = self.embed_text(chunk_text)?;
let doc = NewDocument {
content: chunk_text.clone(),
embedding,
metadata: Some(chunk_metadata),
};
self.store.insert(doc).await?;
result.chunks_created += 1;
}
Ok(())
}
fn chunk_text(&self, text: &str) -> Vec<String> {
match &self.config.chunking {
ChunkingStrategy::None => vec![text.to_string()],
ChunkingStrategy::FixedSize {
max_tokens,
overlap_percentage,
} => {
let char_size = max_tokens * 4; let overlap = (char_size * *overlap_percentage as usize) / 100;
let chunker = TextChunker::new(char_size).with_overlap(overlap);
chunker.chunk(text)
}
ChunkingStrategy::Semantic {
max_tokens,
buffer_size: _,
} => {
let char_size = max_tokens * 4;
let sentences: Vec<&str> = text
.split(['.', '!', '?'])
.filter(|s| !s.trim().is_empty())
.collect();
let mut chunks = Vec::new();
let mut current_chunk = String::new();
for sentence in sentences {
let sentence = sentence.trim().to_string() + ".";
if current_chunk.len() + sentence.len() > char_size {
if !current_chunk.is_empty() {
chunks.push(current_chunk.trim().to_string());
}
current_chunk = sentence;
} else {
if !current_chunk.is_empty() {
current_chunk.push(' ');
}
current_chunk.push_str(&sentence);
}
}
if !current_chunk.is_empty() {
chunks.push(current_chunk.trim().to_string());
}
chunks
}
ChunkingStrategy::Hierarchical {
parent_max_tokens,
child_max_tokens,
child_overlap_percentage,
} => {
let parent_char_size = parent_max_tokens * 4;
let child_char_size = child_max_tokens * 4;
let child_overlap = (child_char_size * *child_overlap_percentage as usize) / 100;
let parent_chunker = TextChunker::new(parent_char_size);
let child_chunker = TextChunker::new(child_char_size).with_overlap(child_overlap);
let parents = parent_chunker.chunk(text);
let mut all_chunks = Vec::new();
for parent in parents {
let children = child_chunker.chunk(&parent);
all_chunks.extend(children);
}
all_chunks
}
}
}
fn embed_text(&self, text: &str) -> RagResult<Vec<f32>> {
if let Some(emb) = &self.embedding_gen {
emb.embed(text)
} else {
Ok(vec![0.0f32; self.config.storage.embedding_dim()])
}
}
fn embed_query(&self, query: &str) -> RagResult<Vec<f32>> {
if let Some(emb) = &self.embedding_gen {
emb.embed(query)
} else {
Ok(vec![0.0f32; self.config.storage.embedding_dim()])
}
}
fn doc_to_chunk(&self, doc: Document) -> RetrievedChunk {
let source_uri = doc
.metadata
.as_ref()
.and_then(|m| m.get("source"))
.and_then(|v| v.as_str())
.unwrap_or("unknown")
.to_string();
RetrievedChunk {
content: doc.content,
score: doc.score.unwrap_or(0.0),
source: SourceLocation {
source_type: "document".into(),
uri: source_uri,
location: None,
},
metadata: doc.metadata,
}
}
fn build_context(&self, chunks: &[RetrievedChunk]) -> String {
chunks
.iter()
.enumerate()
.map(|(i, c)| format!("[{}] {}", i + 1, c.content))
.collect::<Vec<_>>()
.join("\n\n")
}
fn default_prompt(&self, context: &str, query: &str) -> String {
format!(
r#"Use the following pieces of context to answer the question at the end. If you don't know the answer, just say that you don't know, don't try to make up an answer.
Context:
{context}
Question: {query}
Helpful Answer:"#
)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KnowledgeBaseStats {
pub name: String,
pub document_count: usize,
pub embedding_dimension: usize,
pub chunking_strategy: String,
pub hybrid_search_enabled: bool,
}
fn chrono_now() -> String {
use std::time::{SystemTime, UNIX_EPOCH};
let duration = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default();
format!("{}s", duration.as_secs())
}
fn matches_glob_pattern(path: &str, pattern: &str) -> bool {
glob::Pattern::new(pattern)
.map(|p| p.matches(path))
.unwrap_or(false)
}
pub struct KnowledgeBaseBuilder {
config: KnowledgeBaseConfig,
model: Option<Arc<LlamaModel>>,
tokenizer: Option<Arc<Tokenizer>>,
backend: Option<Arc<dyn Backend>>,
}
impl KnowledgeBaseBuilder {
pub fn new(name: impl Into<String>) -> Self {
Self {
config: KnowledgeBaseConfig {
name: name.into(),
..Default::default()
},
model: None,
tokenizer: None,
backend: None,
}
}
pub fn with_model(
mut self,
model: Arc<LlamaModel>,
tokenizer: Arc<Tokenizer>,
backend: Arc<dyn Backend>,
) -> Self {
self.model = Some(model);
self.tokenizer = Some(tokenizer);
self.backend = Some(backend);
self
}
pub fn description(mut self, desc: impl Into<String>) -> Self {
self.config.description = Some(desc.into());
self
}
pub fn storage(mut self, storage: RagConfig) -> Self {
self.config.storage = storage;
self
}
pub fn chunking(mut self, strategy: ChunkingStrategy) -> Self {
self.config.chunking = strategy;
self
}
pub fn fixed_size_chunking(mut self, max_tokens: usize, overlap_pct: u8) -> Self {
self.config.chunking = ChunkingStrategy::FixedSize {
max_tokens,
overlap_percentage: overlap_pct.min(50),
};
self
}
pub fn semantic_chunking(mut self, max_tokens: usize) -> Self {
self.config.chunking = ChunkingStrategy::Semantic {
max_tokens,
buffer_size: 100,
};
self
}
pub fn hierarchical_chunking(
mut self,
parent_tokens: usize,
child_tokens: usize,
overlap_pct: u8,
) -> Self {
self.config.chunking = ChunkingStrategy::Hierarchical {
parent_max_tokens: parent_tokens,
child_max_tokens: child_tokens,
child_overlap_percentage: overlap_pct.min(50),
};
self
}
pub fn retrieval(mut self, retrieval: RetrievalConfig) -> Self {
self.config.retrieval = retrieval;
self
}
pub fn max_results(mut self, max: usize) -> Self {
self.config.retrieval.max_results = max;
self
}
pub fn min_score(mut self, min: f32) -> Self {
self.config.retrieval.min_score = min.clamp(0.0, 1.0);
self
}
pub fn hybrid_search(mut self, enabled: bool) -> Self {
self.config.hybrid_search = enabled;
self
}
pub fn reranking(mut self, config: RerankingConfig) -> Self {
self.config.reranking = Some(config);
self
}
pub fn build(self) -> KnowledgeBaseConfig {
self.config
}
pub async fn create(self) -> RagResult<KnowledgeBase> {
let mut kb = KnowledgeBase::create(self.config).await?;
if let (Some(model), Some(tokenizer), Some(backend)) =
(self.model, self.tokenizer, self.backend)
{
kb.embedding_gen = Some(EmbeddingGenerator::new(model, tokenizer, backend));
}
Ok(kb)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_rerank_score_based() {
let low = RetrievedChunk {
content: "low score chunk".into(),
score: 0.3,
source: SourceLocation {
source_type: "document".into(),
uri: "low.txt".into(),
location: None,
},
metadata: None,
};
let high = RetrievedChunk {
content: "high score chunk".into(),
score: 0.9,
source: SourceLocation {
source_type: "document".into(),
uri: "high.txt".into(),
location: None,
},
metadata: None,
};
let chunks = vec![low, high];
let config = RerankingConfig {
num_candidates: 10,
method: RerankingMethod::ScoreBased,
};
let result = rerank(chunks, &config);
assert_eq!(result.len(), 2);
assert_eq!(result[0].content, "high score chunk");
assert_eq!(result[1].content, "low score chunk");
assert!(result[0].score > result[1].score);
}
#[test]
fn test_glob_pattern_matching() {
assert!(matches_glob_pattern("docs/readme.md", "**/*.md"));
assert!(matches_glob_pattern("src/lib.rs", "**/*.rs"));
assert!(!matches_glob_pattern("image.png", "**/*.md"));
assert!(matches_glob_pattern("a.md", "**/*.md"));
assert!(!matches_glob_pattern("docs/readme.txt", "**/*.md"));
}
}