use super::{KnowledgeEntry, SearchOptions, SearchResult};
use crate::embedding::EmbeddingEngine;
use crate::error::{Error, Result};
use crate::learning::LearningEngine;
use crate::storage::StorageBackend;
use dashmap::DashMap;
use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use std::path::Path;
use std::sync::Arc;
use tracing::{debug, info, instrument};
use uuid::Uuid;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KnowledgeBaseConfig {
pub dimensions: usize,
pub storage_path: String,
pub learning_enabled: bool,
pub learning_rate: f32,
pub hnsw_m: usize,
pub hnsw_ef_construction: usize,
pub hnsw_ef_search: usize,
pub batch_size: usize,
}
impl Default for KnowledgeBaseConfig {
fn default() -> Self {
Self {
dimensions: 384,
storage_path: "./knowledge.db".to_string(),
learning_enabled: true,
learning_rate: 0.01,
hnsw_m: 16,
hnsw_ef_construction: 200,
hnsw_ef_search: 100,
batch_size: 1000,
}
}
}
impl KnowledgeBaseConfig {
pub fn with_path(mut self, path: impl Into<String>) -> Self {
self.storage_path = path.into();
self
}
pub fn with_dimensions(mut self, dims: usize) -> Self {
self.dimensions = dims;
self
}
pub fn without_learning(mut self) -> Self {
self.learning_enabled = false;
self
}
}
pub struct KnowledgeBase {
config: KnowledgeBaseConfig,
storage: Arc<StorageBackend>,
embeddings: Arc<EmbeddingEngine>,
learning: Option<Arc<RwLock<LearningEngine>>>,
entries: DashMap<Uuid, KnowledgeEntry>,
vectors: DashMap<Uuid, Vec<f32>>,
count: Arc<RwLock<usize>>,
}
impl KnowledgeBase {
#[instrument(skip_all)]
pub async fn open(path: impl AsRef<Path>) -> Result<Self> {
let config = KnowledgeBaseConfig::default().with_path(path.as_ref().to_string_lossy());
Self::with_config(config).await
}
#[instrument(skip_all, fields(path = %config.storage_path))]
pub async fn with_config(config: KnowledgeBaseConfig) -> Result<Self> {
info!("Initializing knowledge base at {}", config.storage_path);
let storage = Arc::new(StorageBackend::open(&config.storage_path).await?);
let embeddings = Arc::new(EmbeddingEngine::new(config.dimensions));
let learning = if config.learning_enabled {
Some(Arc::new(RwLock::new(LearningEngine::new(
config.dimensions,
config.learning_rate,
))))
} else {
None
};
let kb = Self {
config,
storage,
embeddings,
learning,
entries: DashMap::new(),
vectors: DashMap::new(),
count: Arc::new(RwLock::new(0)),
};
kb.load_entries().await?;
info!("Knowledge base initialized with {} entries", kb.len());
Ok(kb)
}
async fn load_entries(&self) -> Result<()> {
let stored = self.storage.load_all().await?;
for (entry, embedding) in stored {
self.entries.insert(entry.id, entry.clone());
self.vectors.insert(entry.id, embedding);
}
*self.count.write() = self.entries.len();
Ok(())
}
pub fn len(&self) -> usize {
*self.count.read()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn config(&self) -> &KnowledgeBaseConfig {
&self.config
}
#[instrument(skip(self, entry), fields(title = %entry.title))]
pub async fn add_entry(&self, entry: KnowledgeEntry) -> Result<Uuid> {
let id = entry.id;
let text = entry.embedding_text();
let embedding = self.embeddings.embed(&text).await?;
self.entries.insert(id, entry.clone());
self.vectors.insert(id, embedding.clone());
self.storage.save_entry(&entry, &embedding).await?;
*self.count.write() += 1;
debug!("Added entry {}", id);
Ok(id)
}
#[instrument(skip(self, entries), fields(count = entries.len()))]
pub async fn add_entries(&self, entries: Vec<KnowledgeEntry>) -> Result<Vec<Uuid>> {
let mut ids = Vec::with_capacity(entries.len());
for chunk in entries.chunks(self.config.batch_size) {
let mut batch = Vec::with_capacity(chunk.len());
for entry in chunk {
let text = entry.embedding_text();
let embedding = self.embeddings.embed(&text).await?;
batch.push((entry.clone(), embedding));
}
for (entry, embedding) in &batch {
self.entries.insert(entry.id, entry.clone());
self.vectors.insert(entry.id, embedding.clone());
ids.push(entry.id);
}
self.storage.save_batch(&batch).await?;
}
*self.count.write() += ids.len();
info!("Added {} entries in batch", ids.len());
Ok(ids)
}
pub fn get(&self, id: Uuid) -> Option<KnowledgeEntry> {
self.entries.get(&id).map(|e| e.clone())
}
#[instrument(skip(self, entry), fields(id = %entry.id))]
pub async fn update_entry(&self, entry: KnowledgeEntry) -> Result<()> {
let id = entry.id;
if !self.entries.contains_key(&id) {
return Err(Error::not_found(id.to_string()));
}
let text = entry.embedding_text();
let embedding = self.embeddings.embed(&text).await?;
self.entries.insert(id, entry.clone());
self.vectors.insert(id, embedding.clone());
self.storage.save_entry(&entry, &embedding).await?;
debug!("Updated entry {}", id);
Ok(())
}
#[instrument(skip(self), fields(id = %id))]
pub async fn delete_entry(&self, id: Uuid) -> Result<()> {
if self.entries.remove(&id).is_none() {
return Err(Error::not_found(id.to_string()));
}
self.vectors.remove(&id);
self.storage.delete_entry(id).await?;
*self.count.write() -= 1;
debug!("Deleted entry {}", id);
Ok(())
}
#[instrument(skip(self), fields(k = options.limit))]
pub async fn search(&self, query: &str, options: SearchOptions) -> Result<Vec<SearchResult>> {
let query_embedding = self.embeddings.embed(query).await?;
let mut candidates: Vec<(Uuid, f32)> = self
.vectors
.iter()
.map(|entry| {
let id = *entry.key();
let distance = cosine_distance(&query_embedding, entry.value());
(id, distance)
})
.collect();
candidates.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
if options.use_learning
&& let Some(learning) = &self.learning
{
let learning = learning.read();
candidates = learning.rerank(&query_embedding, candidates, &self.vectors);
}
let mut results = Vec::new();
for (id, distance) in candidates.into_iter().take(options.limit * 2) {
if let Some(entry) = self.entries.get(&id) {
let entry = entry.clone();
if let Some(ref cat) = options.category
&& entry.category.as_ref() != Some(cat)
{
continue;
}
if !options.tags.is_empty()
&& !options
.tags
.iter()
.any(|t| entry.tags.iter().any(|et| et == t))
{
continue;
}
let similarity = 1.0 - distance;
if similarity < options.min_similarity {
continue;
}
results.push(SearchResult::new(entry, similarity, distance));
if results.len() >= options.limit {
break;
}
}
}
if options.diversity > 0.0 {
results = apply_mmr(results, options.diversity);
}
if let Some(learning) = &self.learning {
let mut learning = learning.write();
learning.record_query(&query_embedding, &results);
}
debug!("Search returned {} results", results.len());
Ok(results)
}
pub async fn search_simple(&self, query: &str, limit: usize) -> Result<Vec<SearchResult>> {
self.search(query, SearchOptions::new(limit)).await
}
#[instrument(skip(self))]
pub async fn record_feedback(&self, entry_id: Uuid, positive: bool) -> Result<()> {
if let Some(mut entry) = self.entries.get_mut(&entry_id) {
let boost = if positive { 0.1 } else { -0.05 };
entry.record_access(1.0 + boost);
if let Some(learning) = &self.learning {
let mut learning = learning.write();
if let Some(embedding) = self.vectors.get(&entry_id) {
learning.record_feedback(&embedding, positive);
}
}
let entry = entry.clone();
if let Some(embedding) = self.vectors.get(&entry_id) {
self.storage.save_entry(&entry, &embedding).await?;
}
}
Ok(())
}
pub fn get_related(&self, id: Uuid, limit: usize) -> Vec<KnowledgeEntry> {
if let Some(entry) = self.entries.get(&id) {
entry
.related_entries
.iter()
.take(limit)
.filter_map(|rel_id| self.entries.get(rel_id).map(|e| e.clone()))
.collect()
} else {
Vec::new()
}
}
#[allow(clippy::unused_async)]
pub async fn link_entries(&self, id1: Uuid, id2: Uuid) -> Result<()> {
if let Some(mut entry1) = self.entries.get_mut(&id1) {
if !entry1.related_entries.contains(&id2) {
entry1.related_entries.push(id2);
}
} else {
return Err(Error::not_found(id1.to_string()));
}
if let Some(mut entry2) = self.entries.get_mut(&id2)
&& !entry2.related_entries.contains(&id1)
{
entry2.related_entries.push(id1);
}
Ok(())
}
pub fn all_entries(&self) -> Vec<KnowledgeEntry> {
self.entries.iter().map(|e| e.value().clone()).collect()
}
pub fn stats(&self) -> KnowledgeBaseStats {
let total = self.len();
let categories: std::collections::HashSet<_> = self
.entries
.iter()
.filter_map(|e| e.category.clone())
.collect();
let tags: std::collections::HashSet<_> =
self.entries.iter().flat_map(|e| e.tags.clone()).collect();
let total_access: u64 = self.entries.iter().map(|e| e.access_count).sum();
KnowledgeBaseStats {
total_entries: total,
unique_categories: categories.len(),
unique_tags: tags.len(),
total_access_count: total_access,
dimensions: self.config.dimensions,
learning_enabled: self.config.learning_enabled,
}
}
pub async fn flush(&self) -> Result<()> {
self.storage.flush().await
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KnowledgeBaseStats {
pub total_entries: usize,
pub unique_categories: usize,
pub unique_tags: usize,
pub total_access_count: u64,
pub dimensions: usize,
pub learning_enabled: bool,
}
fn cosine_distance(a: &[f32], b: &[f32]) -> f32 {
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm_a == 0.0 || norm_b == 0.0 {
1.0
} else {
1.0 - (dot / (norm_a * norm_b))
}
}
fn apply_mmr(mut results: Vec<SearchResult>, lambda: f32) -> Vec<SearchResult> {
if results.len() <= 1 {
return results;
}
let mut selected = vec![results.remove(0)];
while !results.is_empty() && selected.len() < results.len() + selected.len() {
let mut best_idx = 0;
let mut best_score = f32::NEG_INFINITY;
for (i, candidate) in results.iter().enumerate() {
let relevance = candidate.similarity;
let max_sim = selected
.iter()
.map(|s| {
1.0 - (s.score - candidate.score).abs()
})
.max_by(|a, b| a.partial_cmp(b).unwrap())
.unwrap_or(0.0);
let mmr = lambda * relevance - (1.0 - lambda) * max_sim;
if mmr > best_score {
best_score = mmr;
best_idx = i;
}
}
selected.push(results.remove(best_idx));
}
selected
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::KnowledgeEntry;
use tempfile::tempdir;
fn small_config(path: &Path) -> KnowledgeBaseConfig {
KnowledgeBaseConfig::default()
.with_path(path.to_string_lossy())
.with_dimensions(32)
}
#[test]
fn test_cosine_distance() {
let a = vec![1.0, 0.0, 0.0];
let b = vec![1.0, 0.0, 0.0];
assert!((cosine_distance(&a, &b) - 0.0).abs() < 1e-6);
let c = vec![0.0, 1.0, 0.0];
assert!((cosine_distance(&a, &c) - 1.0).abs() < 1e-6);
let z = vec![0.0, 0.0, 0.0];
assert!((cosine_distance(&a, &z) - 1.0).abs() < 1e-6);
}
#[test]
fn config_builder_sets_fields() {
let cfg = KnowledgeBaseConfig::default()
.with_path("/tmp/x.db")
.with_dimensions(64)
.without_learning();
assert_eq!(cfg.storage_path, "/tmp/x.db");
assert_eq!(cfg.dimensions, 64);
assert!(!cfg.learning_enabled);
}
#[tokio::test]
async fn open_creates_empty_kb() {
let dir = tempdir().unwrap();
let kb = KnowledgeBase::open(dir.path().join("kb.db")).await.unwrap();
assert_eq!(kb.len(), 0);
assert!(kb.is_empty());
assert_eq!(kb.config().dimensions, 384);
}
#[tokio::test]
async fn add_get_update_delete_roundtrip() {
let dir = tempdir().unwrap();
let kb = KnowledgeBase::with_config(small_config(&dir.path().join("kb.db")))
.await
.unwrap();
let entry = KnowledgeEntry::new("Title", "body text").with_category("docs");
let id = kb.add_entry(entry.clone()).await.unwrap();
assert_eq!(kb.len(), 1);
assert!(!kb.is_empty());
let fetched = kb.get(id).expect("entry should exist");
assert_eq!(fetched.title, "Title");
let mut updated = fetched;
updated.content = "new body".into();
kb.update_entry(updated.clone()).await.unwrap();
assert_eq!(kb.get(id).unwrap().content, "new body");
kb.delete_entry(id).await.unwrap();
assert_eq!(kb.len(), 0);
assert!(kb.get(id).is_none());
}
#[tokio::test]
async fn update_missing_entry_errors() {
let dir = tempdir().unwrap();
let kb = KnowledgeBase::with_config(small_config(&dir.path().join("kb.db")))
.await
.unwrap();
let stranger = KnowledgeEntry::new("ghost", "body");
let err = kb.update_entry(stranger).await.unwrap_err();
assert!(matches!(err, Error::NotFound(_)));
}
#[tokio::test]
async fn delete_missing_entry_errors() {
let dir = tempdir().unwrap();
let kb = KnowledgeBase::with_config(small_config(&dir.path().join("kb.db")))
.await
.unwrap();
let err = kb.delete_entry(Uuid::new_v4()).await.unwrap_err();
assert!(matches!(err, Error::NotFound(_)));
}
#[tokio::test]
async fn add_entries_batch_persists() {
let dir = tempdir().unwrap();
let kb = KnowledgeBase::with_config(small_config(&dir.path().join("kb.db")))
.await
.unwrap();
let batch: Vec<_> = (0..5)
.map(|i| KnowledgeEntry::new(format!("t{i}"), format!("body {i}")))
.collect();
let ids = kb.add_entries(batch).await.unwrap();
assert_eq!(ids.len(), 5);
assert_eq!(kb.len(), 5);
kb.flush().await.unwrap();
}
#[tokio::test]
async fn search_filters_and_results() {
let dir = tempdir().unwrap();
let cfg = KnowledgeBaseConfig::default()
.with_path(dir.path().join("kb.db").to_string_lossy())
.with_dimensions(128);
let kb = KnowledgeBase::with_config(cfg).await.unwrap();
kb.add_entry(
KnowledgeEntry::new("rust ownership", "borrow checker introduction")
.with_category("rust")
.with_tags(["ownership"]),
)
.await
.unwrap();
kb.add_entry(
KnowledgeEntry::new("python decorators", "functions wrapping functions")
.with_category("python")
.with_tags(["meta"]),
)
.await
.unwrap();
let _ = kb.search_simple("borrow", 10).await.unwrap();
let only_rust = kb
.search(
"wrapping",
SearchOptions::new(10)
.with_category("rust")
.without_learning(),
)
.await
.unwrap();
for r in &only_rust {
assert_eq!(r.entry.category.as_deref(), Some("rust"));
}
let by_tag = kb
.search("anything", SearchOptions::new(10).with_tags(["ownership"]))
.await
.unwrap();
for r in &by_tag {
assert!(r.entry.tags.iter().any(|t| t == "ownership"));
}
let _ = kb
.search("functions", SearchOptions::new(5).with_diversity(0.5))
.await
.unwrap();
let none = kb
.search("borrow", SearchOptions::new(10).with_min_similarity(1.0))
.await
.unwrap();
assert!(none.is_empty());
}
#[tokio::test]
async fn record_feedback_and_stats() {
let dir = tempdir().unwrap();
let kb = KnowledgeBase::with_config(small_config(&dir.path().join("kb.db")))
.await
.unwrap();
let id = kb
.add_entry(
KnowledgeEntry::new("a", "alpha")
.with_category("c")
.with_tags(["t"]),
)
.await
.unwrap();
kb.record_feedback(id, true).await.unwrap();
kb.record_feedback(id, false).await.unwrap();
kb.record_feedback(Uuid::new_v4(), true).await.unwrap();
let stats = kb.stats();
assert_eq!(stats.total_entries, 1);
assert_eq!(stats.unique_categories, 1);
assert_eq!(stats.unique_tags, 1);
assert!(stats.learning_enabled);
assert_eq!(stats.dimensions, 32);
assert!(stats.total_access_count >= 2);
}
#[tokio::test]
async fn linking_and_related() {
let dir = tempdir().unwrap();
let kb = KnowledgeBase::with_config(small_config(&dir.path().join("kb.db")))
.await
.unwrap();
let a = kb.add_entry(KnowledgeEntry::new("a", "x")).await.unwrap();
let b = kb.add_entry(KnowledgeEntry::new("b", "y")).await.unwrap();
kb.link_entries(a, b).await.unwrap();
kb.link_entries(a, b).await.unwrap();
let related = kb.get_related(a, 5);
assert_eq!(related.len(), 1);
assert_eq!(related[0].id, b);
let err = kb.link_entries(Uuid::new_v4(), b).await.unwrap_err();
assert!(matches!(err, Error::NotFound(_)));
assert!(kb.get_related(Uuid::new_v4(), 5).is_empty());
assert_eq!(kb.all_entries().len(), 2);
}
#[tokio::test]
async fn reopens_with_existing_entries() {
let dir = tempdir().unwrap();
let path = dir.path().join("kb.db");
let kb = KnowledgeBase::with_config(small_config(&path))
.await
.unwrap();
kb.add_entry(KnowledgeEntry::new("persist", "me"))
.await
.unwrap();
kb.flush().await.unwrap();
drop(kb);
let kb2 = KnowledgeBase::with_config(small_config(&path))
.await
.unwrap();
assert_eq!(kb2.len(), 1);
assert_eq!(kb2.all_entries()[0].title, "persist");
}
#[tokio::test]
async fn learning_disabled_skips_engine() {
let dir = tempdir().unwrap();
let cfg = small_config(&dir.path().join("kb.db")).without_learning();
let kb = KnowledgeBase::with_config(cfg).await.unwrap();
let id = kb.add_entry(KnowledgeEntry::new("t", "c")).await.unwrap();
let _ = kb.search_simple("t", 5).await.unwrap();
kb.record_feedback(id, true).await.unwrap();
assert!(!kb.stats().learning_enabled);
}
#[test]
fn mmr_short_circuits_short_lists() {
let entry = KnowledgeEntry::new("t", "c");
let r = SearchResult::new(entry, 0.5, 0.5);
let one = apply_mmr(vec![r.clone()], 0.5);
assert_eq!(one.len(), 1);
let empty: Vec<SearchResult> = apply_mmr(Vec::new(), 0.5);
assert!(empty.is_empty());
let mut many = Vec::new();
for i in 0..3 {
let e = KnowledgeEntry::new(format!("t{i}"), "c");
many.push(SearchResult::new(e, 0.9 - i as f32 * 0.1, 0.1 * i as f32));
}
let picked = apply_mmr(many, 0.7);
assert!(!picked.is_empty());
}
}