1use super::{KnowledgeEntry, SearchOptions, SearchResult};
4use crate::embedding::EmbeddingEngine;
5use crate::error::{Error, Result};
6use crate::learning::LearningEngine;
7use crate::storage::StorageBackend;
8
9use dashmap::DashMap;
10use parking_lot::RwLock;
11use serde::{Deserialize, Serialize};
12use std::path::Path;
13use std::sync::Arc;
14use tracing::{debug, info, instrument};
15use uuid::Uuid;
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct KnowledgeBaseConfig {
20 pub dimensions: usize,
22
23 pub storage_path: String,
25
26 pub learning_enabled: bool,
28
29 pub learning_rate: f32,
31
32 pub hnsw_m: usize,
34
35 pub hnsw_ef_construction: usize,
37
38 pub hnsw_ef_search: usize,
40
41 pub batch_size: usize,
43}
44
45impl Default for KnowledgeBaseConfig {
46 fn default() -> Self {
47 Self {
48 dimensions: 384,
49 storage_path: "./knowledge.db".to_string(),
50 learning_enabled: true,
51 learning_rate: 0.01,
52 hnsw_m: 16,
53 hnsw_ef_construction: 200,
54 hnsw_ef_search: 100,
55 batch_size: 1000,
56 }
57 }
58}
59
60impl KnowledgeBaseConfig {
61 pub fn with_path(mut self, path: impl Into<String>) -> Self {
63 self.storage_path = path.into();
64 self
65 }
66
67 pub fn with_dimensions(mut self, dims: usize) -> Self {
69 self.dimensions = dims;
70 self
71 }
72
73 pub fn without_learning(mut self) -> Self {
75 self.learning_enabled = false;
76 self
77 }
78}
79
80pub struct KnowledgeBase {
82 config: KnowledgeBaseConfig,
84
85 storage: Arc<StorageBackend>,
87
88 embeddings: Arc<EmbeddingEngine>,
90
91 learning: Option<Arc<RwLock<LearningEngine>>>,
93
94 entries: DashMap<Uuid, KnowledgeEntry>,
96
97 vectors: DashMap<Uuid, Vec<f32>>,
99
100 count: Arc<RwLock<usize>>,
102}
103
104impl KnowledgeBase {
105 #[instrument(skip_all)]
107 pub async fn open(path: impl AsRef<Path>) -> Result<Self> {
108 let config = KnowledgeBaseConfig::default().with_path(path.as_ref().to_string_lossy());
109 Self::with_config(config).await
110 }
111
112 #[instrument(skip_all, fields(path = %config.storage_path))]
114 pub async fn with_config(config: KnowledgeBaseConfig) -> Result<Self> {
115 info!("Initializing knowledge base at {}", config.storage_path);
116
117 let storage = Arc::new(StorageBackend::open(&config.storage_path).await?);
118 let embeddings = Arc::new(EmbeddingEngine::new(config.dimensions));
119
120 let learning = if config.learning_enabled {
121 Some(Arc::new(RwLock::new(LearningEngine::new(
122 config.dimensions,
123 config.learning_rate,
124 ))))
125 } else {
126 None
127 };
128
129 let kb = Self {
130 config,
131 storage,
132 embeddings,
133 learning,
134 entries: DashMap::new(),
135 vectors: DashMap::new(),
136 count: Arc::new(RwLock::new(0)),
137 };
138
139 kb.load_entries().await?;
141
142 info!("Knowledge base initialized with {} entries", kb.len());
143 Ok(kb)
144 }
145
146 async fn load_entries(&self) -> Result<()> {
148 let stored = self.storage.load_all().await?;
149
150 for (entry, embedding) in stored {
151 self.entries.insert(entry.id, entry.clone());
152 self.vectors.insert(entry.id, embedding);
153 }
154
155 *self.count.write() = self.entries.len();
156 Ok(())
157 }
158
159 pub fn len(&self) -> usize {
161 *self.count.read()
162 }
163
164 pub fn is_empty(&self) -> bool {
166 self.len() == 0
167 }
168
169 pub fn config(&self) -> &KnowledgeBaseConfig {
171 &self.config
172 }
173
174 #[instrument(skip(self, entry), fields(title = %entry.title))]
176 pub async fn add_entry(&self, entry: KnowledgeEntry) -> Result<Uuid> {
177 let id = entry.id;
178
179 let text = entry.embedding_text();
181 let embedding = self.embeddings.embed(&text)?;
182
183 self.entries.insert(id, entry.clone());
185 self.vectors.insert(id, embedding.clone());
186
187 self.storage.save_entry(&entry, &embedding).await?;
189
190 *self.count.write() += 1;
191 debug!("Added entry {}", id);
192
193 Ok(id)
194 }
195
196 #[instrument(skip(self, entries), fields(count = entries.len()))]
198 pub async fn add_entries(&self, entries: Vec<KnowledgeEntry>) -> Result<Vec<Uuid>> {
199 let mut ids = Vec::with_capacity(entries.len());
200
201 for chunk in entries.chunks(self.config.batch_size) {
202 let batch: Vec<_> = chunk
203 .iter()
204 .map(|entry| {
205 let text = entry.embedding_text();
206 let embedding = self.embeddings.embed(&text)?;
207 Ok((entry.clone(), embedding))
208 })
209 .collect::<Result<Vec<_>>>()?;
210
211 for (entry, embedding) in &batch {
212 self.entries.insert(entry.id, entry.clone());
213 self.vectors.insert(entry.id, embedding.clone());
214 ids.push(entry.id);
215 }
216
217 self.storage.save_batch(&batch).await?;
218 }
219
220 *self.count.write() += ids.len();
221 info!("Added {} entries in batch", ids.len());
222
223 Ok(ids)
224 }
225
226 pub fn get(&self, id: Uuid) -> Option<KnowledgeEntry> {
228 self.entries.get(&id).map(|e| e.clone())
229 }
230
231 #[instrument(skip(self, entry), fields(id = %entry.id))]
233 pub async fn update_entry(&self, entry: KnowledgeEntry) -> Result<()> {
234 let id = entry.id;
235
236 if !self.entries.contains_key(&id) {
237 return Err(Error::not_found(id.to_string()));
238 }
239
240 let text = entry.embedding_text();
242 let embedding = self.embeddings.embed(&text)?;
243
244 self.entries.insert(id, entry.clone());
246 self.vectors.insert(id, embedding.clone());
247
248 self.storage.save_entry(&entry, &embedding).await?;
250
251 debug!("Updated entry {}", id);
252 Ok(())
253 }
254
255 #[instrument(skip(self), fields(id = %id))]
257 pub async fn delete_entry(&self, id: Uuid) -> Result<()> {
258 if self.entries.remove(&id).is_none() {
259 return Err(Error::not_found(id.to_string()));
260 }
261
262 self.vectors.remove(&id);
263 self.storage.delete_entry(id).await?;
264
265 *self.count.write() -= 1;
266 debug!("Deleted entry {}", id);
267
268 Ok(())
269 }
270
271 #[instrument(skip(self), fields(k = options.limit))]
273 pub async fn search(&self, query: &str, options: SearchOptions) -> Result<Vec<SearchResult>> {
274 let query_embedding = self.embeddings.embed(query)?;
276
277 let mut candidates: Vec<(Uuid, f32)> = self
280 .vectors
281 .iter()
282 .map(|entry| {
283 let id = *entry.key();
284 let distance = cosine_distance(&query_embedding, entry.value());
285 (id, distance)
286 })
287 .collect();
288
289 candidates.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
291
292 if options.use_learning {
294 if let Some(learning) = &self.learning {
295 let learning = learning.read();
296 candidates = learning.rerank(&query_embedding, candidates, &self.vectors);
297 }
298 }
299
300 let mut results = Vec::new();
302
303 for (id, distance) in candidates.into_iter().take(options.limit * 2) {
304 if let Some(entry) = self.entries.get(&id) {
305 let entry = entry.clone();
306
307 if let Some(ref cat) = options.category {
309 if entry.category.as_ref() != Some(cat) {
310 continue;
311 }
312 }
313
314 if !options.tags.is_empty()
315 && !options
316 .tags
317 .iter()
318 .any(|t| entry.tags.iter().any(|et| et == t))
319 {
320 continue;
321 }
322
323 let similarity = 1.0 - distance;
324 if similarity < options.min_similarity {
325 continue;
326 }
327
328 results.push(SearchResult::new(entry, similarity, distance));
329
330 if results.len() >= options.limit {
331 break;
332 }
333 }
334 }
335
336 if options.diversity > 0.0 {
338 results = apply_mmr(results, options.diversity);
339 }
340
341 if let Some(learning) = &self.learning {
343 let mut learning = learning.write();
344 learning.record_query(&query_embedding, &results);
345 }
346
347 debug!("Search returned {} results", results.len());
348 Ok(results)
349 }
350
351 pub async fn search_simple(&self, query: &str, limit: usize) -> Result<Vec<SearchResult>> {
353 self.search(query, SearchOptions::new(limit)).await
354 }
355
356 #[instrument(skip(self))]
358 pub async fn record_feedback(&self, entry_id: Uuid, positive: bool) -> Result<()> {
359 if let Some(mut entry) = self.entries.get_mut(&entry_id) {
360 let boost = if positive { 0.1 } else { -0.05 };
361 entry.record_access(1.0 + boost);
362
363 if let Some(learning) = &self.learning {
365 let mut learning = learning.write();
366 if let Some(embedding) = self.vectors.get(&entry_id) {
367 learning.record_feedback(&embedding, positive);
368 }
369 }
370
371 let entry = entry.clone();
373 if let Some(embedding) = self.vectors.get(&entry_id) {
374 self.storage.save_entry(&entry, &embedding).await?;
375 }
376 }
377
378 Ok(())
379 }
380
381 pub fn get_related(&self, id: Uuid, limit: usize) -> Vec<KnowledgeEntry> {
383 if let Some(entry) = self.entries.get(&id) {
384 entry
385 .related_entries
386 .iter()
387 .take(limit)
388 .filter_map(|rel_id| self.entries.get(rel_id).map(|e| e.clone()))
389 .collect()
390 } else {
391 Vec::new()
392 }
393 }
394
395 pub async fn link_entries(&self, id1: Uuid, id2: Uuid) -> Result<()> {
397 if let Some(mut entry1) = self.entries.get_mut(&id1) {
398 if !entry1.related_entries.contains(&id2) {
399 entry1.related_entries.push(id2);
400 }
401 } else {
402 return Err(Error::not_found(id1.to_string()));
403 }
404
405 if let Some(mut entry2) = self.entries.get_mut(&id2) {
406 if !entry2.related_entries.contains(&id1) {
407 entry2.related_entries.push(id1);
408 }
409 }
410
411 Ok(())
412 }
413
414 pub fn all_entries(&self) -> Vec<KnowledgeEntry> {
416 self.entries.iter().map(|e| e.value().clone()).collect()
417 }
418
419 pub fn stats(&self) -> KnowledgeBaseStats {
421 let total = self.len();
422 let categories: std::collections::HashSet<_> = self
423 .entries
424 .iter()
425 .filter_map(|e| e.category.clone())
426 .collect();
427
428 let tags: std::collections::HashSet<_> =
429 self.entries.iter().flat_map(|e| e.tags.clone()).collect();
430
431 let total_access: u64 = self.entries.iter().map(|e| e.access_count).sum();
432
433 KnowledgeBaseStats {
434 total_entries: total,
435 unique_categories: categories.len(),
436 unique_tags: tags.len(),
437 total_access_count: total_access,
438 dimensions: self.config.dimensions,
439 learning_enabled: self.config.learning_enabled,
440 }
441 }
442
443 pub async fn flush(&self) -> Result<()> {
445 self.storage.flush().await
446 }
447}
448
449#[derive(Debug, Clone, Serialize, Deserialize)]
451pub struct KnowledgeBaseStats {
452 pub total_entries: usize,
453 pub unique_categories: usize,
454 pub unique_tags: usize,
455 pub total_access_count: u64,
456 pub dimensions: usize,
457 pub learning_enabled: bool,
458}
459
460fn cosine_distance(a: &[f32], b: &[f32]) -> f32 {
462 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
463 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
464 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
465
466 if norm_a == 0.0 || norm_b == 0.0 {
467 1.0
468 } else {
469 1.0 - (dot / (norm_a * norm_b))
470 }
471}
472
473fn apply_mmr(mut results: Vec<SearchResult>, lambda: f32) -> Vec<SearchResult> {
475 if results.len() <= 1 {
476 return results;
477 }
478
479 let mut selected = vec![results.remove(0)];
480
481 while !results.is_empty() && selected.len() < results.len() + selected.len() {
482 let mut best_idx = 0;
483 let mut best_score = f32::NEG_INFINITY;
484
485 for (i, candidate) in results.iter().enumerate() {
486 let relevance = candidate.similarity;
488
489 let max_sim = selected
491 .iter()
492 .map(|s| {
493 1.0 - (s.score - candidate.score).abs()
495 })
496 .max_by(|a, b| a.partial_cmp(b).unwrap())
497 .unwrap_or(0.0);
498
499 let mmr = lambda * relevance - (1.0 - lambda) * max_sim;
501
502 if mmr > best_score {
503 best_score = mmr;
504 best_idx = i;
505 }
506 }
507
508 selected.push(results.remove(best_idx));
509 }
510
511 selected
512}
513
514#[cfg(test)]
515mod tests {
516 use super::*;
517
518 #[test]
519 fn test_cosine_distance() {
520 let a = vec![1.0, 0.0, 0.0];
521 let b = vec![1.0, 0.0, 0.0];
522 assert!((cosine_distance(&a, &b) - 0.0).abs() < 1e-6);
523
524 let c = vec![0.0, 1.0, 0.0];
525 assert!((cosine_distance(&a, &c) - 1.0).abs() < 1e-6);
526 }
527}