use crate::graph::pdg::{NodeId, ProgramDependenceGraph};
use std::collections::HashMap;
pub type Trigram = u32;
const TRIGRAM_INDEX_VERSION: u32 = 1;
#[derive(Debug, Clone, Default)]
pub struct TrigramIndex {
postings: HashMap<Trigram, Vec<u32>>,
node_count: usize,
}
impl TrigramIndex {
pub fn new() -> Self {
Self::default()
}
pub fn build_from_pdg(pdg: &ProgramDependenceGraph) -> Self {
let mut index = Self::new();
index.node_count = pdg.node_count();
for node_id in pdg.node_indices() {
if let Some(node) = pdg.get_node(node_id) {
let node_idx = node_id.index() as u32;
let name_lower = node.name.to_lowercase();
for trigram in extract_trigrams(&name_lower) {
index.postings.entry(trigram).or_default().push(node_idx);
}
let id_lower = node.id.to_lowercase();
for trigram in extract_trigrams(&id_lower) {
index.postings.entry(trigram).or_default().push(node_idx);
}
let file_lower = node.file_path.to_lowercase();
for trigram in extract_trigrams(&file_lower) {
index.postings.entry(trigram).or_default().push(node_idx);
}
}
}
for posting_list in index.postings.values_mut() {
posting_list.sort_unstable();
posting_list.dedup();
}
index
}
pub fn query(&self, query_lower: &str) -> Option<Vec<u32>> {
let trigrams = extract_trigrams(query_lower);
if trigrams.is_empty() {
return None; }
let mut sorted_trigrams: Vec<&Vec<u32>> = trigrams
.iter()
.filter_map(|t| self.postings.get(t))
.collect();
if sorted_trigrams.is_empty() {
return Some(Vec::new());
}
if sorted_trigrams.len() < trigrams.len() {
return Some(Vec::new());
}
sorted_trigrams.sort_by_key(|list| list.len());
let mut result = sorted_trigrams[0].clone();
for posting_list in sorted_trigrams.iter().skip(1) {
result = intersect_sorted(&result, posting_list);
if result.is_empty() {
return Some(Vec::new());
}
}
Some(result)
}
pub fn add_node(
&mut self,
node_id: NodeId,
name: &str,
node_id_str: &str,
file_path: &str,
) {
let node_idx = node_id.index() as u32;
self.node_count += 1;
for trigram in extract_trigrams(&name.to_lowercase()) {
let list = self.postings.entry(trigram).or_default();
if let Err(pos) = list.binary_search(&node_idx) {
list.insert(pos, node_idx);
}
}
for trigram in extract_trigrams(&node_id_str.to_lowercase()) {
let list = self.postings.entry(trigram).or_default();
if let Err(pos) = list.binary_search(&node_idx) {
list.insert(pos, node_idx);
}
}
for trigram in extract_trigrams(&file_path.to_lowercase()) {
let list = self.postings.entry(trigram).or_default();
if let Err(pos) = list.binary_search(&node_idx) {
list.insert(pos, node_idx);
}
}
}
pub fn remove_node(
&mut self,
node_id: NodeId,
name: &str,
node_id_str: &str,
file_path: &str,
) {
let node_idx = node_id.index() as u32;
if self.node_count > 0 {
self.node_count -= 1;
}
let mut trigrams_to_clean: Vec<Trigram> = Vec::new();
trigrams_to_clean.extend(extract_trigrams(&name.to_lowercase()));
trigrams_to_clean.extend(extract_trigrams(&node_id_str.to_lowercase()));
trigrams_to_clean.extend(extract_trigrams(&file_path.to_lowercase()));
trigrams_to_clean.sort_unstable();
trigrams_to_clean.dedup();
for trigram in &trigrams_to_clean {
if let Some(posting_list) = self.postings.get_mut(trigram) {
if let Ok(pos) = posting_list.binary_search(&node_idx) {
posting_list.remove(pos);
}
}
}
self.postings.retain(|_, list| !list.is_empty());
}
pub fn trigram_count(&self) -> usize {
self.postings.len()
}
pub fn node_count(&self) -> usize {
self.node_count
}
pub fn is_empty(&self) -> bool {
self.postings.is_empty()
}
pub fn serialize(&self) -> Vec<u8> {
let entry_count = self.postings.len() as u32;
let total_size: usize = 4 + 4 + self.postings.iter().map(|(_, list)| 8 + list.len() * 4).sum::<usize>();
let mut buf = Vec::with_capacity(total_size);
buf.extend_from_slice(&TRIGRAM_INDEX_VERSION.to_le_bytes());
buf.extend_from_slice(&entry_count.to_le_bytes());
for (&trigram, posting_list) in &self.postings {
buf.extend_from_slice(&trigram.to_le_bytes());
let len = posting_list.len() as u32;
buf.extend_from_slice(&len.to_le_bytes());
for &node_idx in posting_list {
buf.extend_from_slice(&node_idx.to_le_bytes());
}
}
buf
}
pub fn deserialize(data: &[u8]) -> Option<Self> {
if data.len() < 8 {
return None;
}
let version = u32::from_le_bytes(data[0..4].try_into().ok()?);
if version != TRIGRAM_INDEX_VERSION {
return None;
}
let entry_count = u32::from_le_bytes(data[4..8].try_into().ok()?) as usize;
let mut offset = 8;
let mut postings = HashMap::with_capacity_and_hasher(
entry_count,
std::collections::hash_map::RandomState::default(),
);
for _ in 0..entry_count {
if offset + 8 > data.len() {
return None;
}
let trigram = u32::from_le_bytes(data[offset..offset + 4].try_into().ok()?);
offset += 4;
let list_len = u32::from_le_bytes(data[offset..offset + 4].try_into().ok()?) as usize;
offset += 4;
if offset + list_len * 4 > data.len() {
return None;
}
let mut posting_list = Vec::with_capacity(list_len);
for _ in 0..list_len {
let node_idx = u32::from_le_bytes(data[offset..offset + 4].try_into().ok()?);
offset += 4;
posting_list.push(node_idx);
}
postings.insert(trigram, posting_list);
}
let max_idx = postings
.values()
.flat_map(|v| v.iter().copied())
.max()
.unwrap_or(0);
let node_count = max_idx as usize + 1;
Some(Self {
postings,
node_count,
})
}
}
pub fn extract_trigrams(s: &str) -> Vec<Trigram> {
let mut trigrams = Vec::new();
let mut chars = s.chars();
let mut c1 = match chars.next() {
Some(c) => c,
None => return trigrams,
};
let mut c2 = match chars.next() {
Some(c) => c,
None => return trigrams,
};
for c3 in chars {
let mut h: u32 = 2166136261; for &c in &[c1, c2, c3] {
h ^= c as u32;
h = h.wrapping_mul(16777619); }
trigrams.push(h);
c1 = c2;
c2 = c3;
}
trigrams
}
fn intersect_sorted(a: &[u32], b: &[u32]) -> Vec<u32> {
let mut result = Vec::with_capacity(a.len().min(b.len()));
let mut i = 0;
let mut j = 0;
while i < a.len() && j < b.len() {
match a[i].cmp(&b[j]) {
std::cmp::Ordering::Less => i += 1,
std::cmp::Ordering::Greater => j += 1,
std::cmp::Ordering::Equal => {
result.push(a[i]);
i += 1;
j += 1;
}
}
}
result
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_extract_trigrams() {
let trigrams = extract_trigrams("hello");
assert_eq!(trigrams.len(), 3); }
#[test]
fn test_extract_trigrams_short() {
let trigrams = extract_trigrams("ab");
assert!(trigrams.is_empty());
}
#[test]
fn test_intersect_sorted() {
let a = vec![1, 3, 5, 7, 9];
let b = vec![3, 4, 5, 8, 9];
let result = intersect_sorted(&a, &b);
assert_eq!(result, vec![3, 5, 9]);
}
#[test]
fn test_intersect_empty() {
let a = vec![1, 2, 3];
let b: Vec<u32> = vec![];
let result = intersect_sorted(&a, &b);
assert!(result.is_empty());
}
#[test]
fn test_trigram_index_query() {
let mut index = TrigramIndex::new();
for t in extract_trigrams("hello") {
index.postings.entry(t).or_default().push(0u32);
}
for t in extract_trigrams("world") {
index.postings.entry(t).or_default().push(1u32);
}
for t in extract_trigrams("help") {
index.postings.entry(t).or_default().push(2u32);
}
let result = index.query("hello").unwrap();
assert!(result.contains(&0));
let result = index.query("hel").unwrap();
assert!(result.contains(&0));
assert!(result.contains(&2));
let result = index.query("xyz").unwrap();
assert!(result.is_empty());
assert!(index.query("ab").is_none());
}
#[test]
fn test_serialize_deserialize_roundtrip() {
let mut index = TrigramIndex::new();
index.node_count = 2;
for t in extract_trigrams("hello") {
index.postings.entry(t).or_default().push(0u32);
}
for t in extract_trigrams("world") {
index.postings.entry(t).or_default().push(1u32);
}
let serialized = index.serialize();
let deserialized = TrigramIndex::deserialize(&serialized).unwrap();
assert_eq!(deserialized.postings.len(), index.postings.len());
let result = deserialized.query("hello").unwrap();
assert!(result.contains(&0));
let result = deserialized.query("world").unwrap();
assert!(result.contains(&1));
assert!(deserialized.query("ab").is_none());
}
}