use std::collections::{HashMap, HashSet};
use crate::search::scoring::{Bm25Params, bm25_score, elbow_cutoff};
use crate::search::tokenizer;
use crate::agent::swarm::knowledge::types::{
Announcement, BlackboardEntry, ELBOW_THRESHOLD, EntryKind, FileSummary, IndexedEntry,
KnowledgeFact, PARAMS_B, PARAMS_K1, PreparedEntry,
};
#[derive(Default)]
pub(crate) struct KnowledgeInner {
pub facts: HashMap<String, KnowledgeFact>,
pub files_read: HashMap<String, FileSummary>,
pub files_modified:
HashMap<String, Vec<crate::agent::swarm::knowledge::types::FileModification>>,
pub announcements: Vec<Announcement>,
pub next_id: u64,
pub index: Vec<IndexedEntry>,
pub doc_freq: HashMap<String, u32>,
pub total_tokens: u64,
pub blackboard: std::collections::BTreeMap<String, BlackboardEntry>,
pub actual_file_access: HashMap<String, HashSet<String>>,
}
impl KnowledgeInner {
pub fn insert_prepared_entry(&mut self, prepared: PreparedEntry) -> u64 {
let id = self.next_id;
self.next_id += 1;
tracing::trace!(entry_id = id, "Knowledge: indexed new entry");
self.total_tokens += prepared.token_count as u64;
for key in prepared.tf.keys() {
*self.doc_freq.entry(key.clone()).or_default() += 1;
}
self.index.push(IndexedEntry {
id,
kind: prepared.kind,
tf: prepared.tf,
token_count: prepared.token_count,
display: prepared.display,
paths: prepared.paths,
});
id
}
fn avg_doc_len(&self) -> f64 {
let n = self.index.len() as f64;
if n > 0.0 {
self.total_tokens as f64 / n
} else {
1.0
}
}
pub fn rank_entries(
&self,
query_tokens: &[String],
target_files: &[String],
max_k: usize,
) -> Vec<(usize, f64)> {
if self.index.is_empty() || query_tokens.is_empty() {
return Vec::new();
}
let n = self.index.len();
let avg_dl = self.avg_doc_len();
let params = Bm25Params {
k1: PARAMS_K1,
b: PARAMS_B,
};
let mut scored: Vec<(usize, f64)> = self
.index
.iter()
.enumerate()
.filter_map(|(i, entry)| {
let mut score = bm25_score(
&entry.tf,
entry.token_count,
query_tokens,
&self.doc_freq,
n,
avg_dl,
¶ms,
);
if score <= 0.0 {
return None;
}
if !target_files.is_empty() && !entry.paths.is_empty() {
let has_overlap = entry.paths.iter().any(|p| {
target_files
.iter()
.any(|tf| p.contains(tf) || tf.contains(p))
});
if has_overlap {
score *= 2.0;
}
}
Some((i, score))
})
.collect();
scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
elbow_cutoff(&scored, max_k, ELBOW_THRESHOLD)
}
}
pub(crate) fn prepare_index_entry(
kind: EntryKind,
index_text: &str,
display: String,
paths: Vec<String>,
) -> PreparedEntry {
let tf = tokenizer::tokenize_text(index_text);
let token_count: u32 = tf.values().sum();
PreparedEntry {
kind,
display,
paths,
tf,
token_count,
}
}