use std::collections::HashMap;
use super::EntityType;
use crate::storage::schema::Value;
use std::cmp::Ordering;
#[derive(Debug, Clone)]
pub struct RetrievalContext {
pub query: String,
pub chunks: Vec<ContextChunk>,
pub overall_relevance: f32,
pub sources_used: Vec<ChunkSource>,
pub retrieval_time_us: u64,
pub explanation: Option<String>,
}
impl RetrievalContext {
pub fn new(query: impl Into<String>) -> Self {
Self {
query: query.into(),
chunks: Vec::new(),
overall_relevance: 0.0,
sources_used: Vec::new(),
retrieval_time_us: 0,
explanation: None,
}
}
pub fn add_chunk(&mut self, chunk: ContextChunk) {
if !self.sources_used.contains(&chunk.source) {
self.sources_used.push(chunk.source.clone());
}
self.chunks.push(chunk);
}
pub fn sort_by_relevance(&mut self) {
self.chunks.sort_by(|a, b| {
b.relevance
.partial_cmp(&a.relevance)
.unwrap_or(Ordering::Equal)
.then_with(|| {
let a_entity = a.entity_id.as_deref().unwrap_or("");
let b_entity = b.entity_id.as_deref().unwrap_or("");
a_entity.cmp(b_entity)
})
.then_with(|| a.source.name().cmp(b.source.name()))
.then_with(|| a.content.cmp(&b.content))
});
}
pub fn calculate_overall_relevance(&mut self) {
if self.chunks.is_empty() {
self.overall_relevance = 0.0;
return;
}
let total_weight: f32 = (1..=self.chunks.len()).map(|i| 1.0 / i as f32).sum();
let weighted_sum: f32 = self
.chunks
.iter()
.enumerate()
.map(|(i, c)| c.relevance * (1.0 / (i + 1) as f32))
.sum();
self.overall_relevance = weighted_sum / total_weight;
}
pub fn limit(&mut self, n: usize) {
self.sort_by_relevance();
self.chunks.truncate(n);
}
pub fn chunks_for_type(&self, entity_type: EntityType) -> Vec<&ContextChunk> {
self.chunks
.iter()
.filter(|c| c.entity_type == Some(entity_type))
.collect()
}
pub fn chunks_from_source(&self, source: &ChunkSource) -> Vec<&ContextChunk> {
self.chunks.iter().filter(|c| &c.source == source).collect()
}
pub fn is_empty(&self) -> bool {
self.chunks.is_empty()
}
pub fn len(&self) -> usize {
self.chunks.len()
}
pub fn top_chunk(&self) -> Option<&ContextChunk> {
self.chunks.first()
}
pub fn to_context_string(&self) -> String {
let mut s = String::new();
for (i, chunk) in self.chunks.iter().enumerate() {
s.push_str(&format!("[{}] ", i + 1));
s.push_str(&chunk.to_text());
s.push('\n');
}
s
}
pub fn entity_ids(&self) -> Vec<&str> {
self.chunks
.iter()
.filter_map(|c| c.entity_id.as_deref())
.collect()
}
pub fn merge(&mut self, other: RetrievalContext) {
for chunk in other.chunks {
self.add_chunk(chunk);
}
self.retrieval_time_us += other.retrieval_time_us;
}
pub fn with_explanation(mut self, explanation: impl Into<String>) -> Self {
self.explanation = Some(explanation.into());
self
}
}
#[derive(Debug, Clone)]
pub struct ContextChunk {
pub content: String,
pub source: ChunkSource,
pub relevance: f32,
pub entity_type: Option<EntityType>,
pub entity_id: Option<String>,
pub metadata: HashMap<String, Value>,
pub vector_distance: Option<f32>,
pub graph_depth: Option<u32>,
}
impl ContextChunk {
pub fn new(content: impl Into<String>, source: ChunkSource, relevance: f32) -> Self {
Self {
content: content.into(),
source,
relevance,
entity_type: None,
entity_id: None,
metadata: HashMap::new(),
vector_distance: None,
graph_depth: None,
}
}
pub fn from_vector(
content: impl Into<String>,
collection: impl Into<String>,
distance: f32,
id: u64,
) -> Self {
let relevance = 1.0 / (1.0 + distance); let mut chunk = Self::new(content, ChunkSource::Vector(collection.into()), relevance);
chunk.vector_distance = Some(distance);
chunk.entity_id = Some(id.to_string());
chunk
}
pub fn from_graph(
content: impl Into<String>,
depth: u32,
entity_type: EntityType,
entity_id: impl Into<String>,
) -> Self {
let relevance = 1.0 / (1.0 + depth as f32);
let mut chunk = Self::new(content, ChunkSource::Graph, relevance);
chunk.graph_depth = Some(depth);
chunk.entity_type = Some(entity_type);
chunk.entity_id = Some(entity_id.into());
chunk
}
pub fn from_table(
content: impl Into<String>,
table: impl Into<String>,
row_id: u64,
relevance: f32,
) -> Self {
let mut chunk = Self::new(content, ChunkSource::Table(table.into()), relevance);
chunk.entity_id = Some(row_id.to_string());
chunk
}
pub fn with_entity_type(mut self, entity_type: EntityType) -> Self {
self.entity_type = Some(entity_type);
self
}
pub fn with_entity_id(mut self, id: impl Into<String>) -> Self {
self.entity_id = Some(id.into());
self
}
pub fn with_metadata(mut self, key: impl Into<String>, value: Value) -> Self {
self.metadata.insert(key.into(), value);
self
}
pub fn to_text(&self) -> String {
let mut parts = Vec::new();
parts.push(format!("[{}]", self.source.name()));
if let Some(ref id) = self.entity_id {
if let Some(entity_type) = self.entity_type {
parts.push(format!("{:?}:{}", entity_type, id));
} else {
parts.push(format!("id:{}", id));
}
}
parts.push(format!("relevance:{:.2}", self.relevance));
format!("{}: {}", parts.join(" "), self.content)
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum ChunkSource {
Vector(String), Graph,
Table(String), CrossRef,
Intelligence,
Cache,
}
impl ChunkSource {
pub fn name(&self) -> &str {
match self {
Self::Vector(_) => "vector",
Self::Graph => "graph",
Self::Table(_) => "table",
Self::CrossRef => "cross-ref",
Self::Intelligence => "intel",
Self::Cache => "cache",
}
}
pub fn collection(&self) -> Option<&str> {
match self {
Self::Vector(c) | Self::Table(c) => Some(c),
_ => None,
}
}
}
pub struct ContextBuilder {
context: RetrievalContext,
}
impl ContextBuilder {
pub fn new(query: impl Into<String>) -> Self {
Self {
context: RetrievalContext::new(query),
}
}
pub fn chunk(mut self, chunk: ContextChunk) -> Self {
self.context.add_chunk(chunk);
self
}
pub fn vector_result(
mut self,
content: impl Into<String>,
collection: impl Into<String>,
distance: f32,
id: u64,
) -> Self {
self.context
.add_chunk(ContextChunk::from_vector(content, collection, distance, id));
self
}
pub fn graph_result(
mut self,
content: impl Into<String>,
depth: u32,
entity_type: EntityType,
entity_id: impl Into<String>,
) -> Self {
self.context.add_chunk(ContextChunk::from_graph(
content,
depth,
entity_type,
entity_id,
));
self
}
pub fn table_result(
mut self,
content: impl Into<String>,
table: impl Into<String>,
row_id: u64,
relevance: f32,
) -> Self {
self.context
.add_chunk(ContextChunk::from_table(content, table, row_id, relevance));
self
}
pub fn time_us(mut self, time: u64) -> Self {
self.context.retrieval_time_us = time;
self
}
pub fn explanation(mut self, explanation: impl Into<String>) -> Self {
self.context.explanation = Some(explanation.into());
self
}
pub fn build(mut self) -> RetrievalContext {
self.context.sort_by_relevance();
self.context.calculate_overall_relevance();
self.context
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_context_builder() {
let context = ContextBuilder::new("test query")
.vector_result(
"CVE-2024-1234: SQL injection vulnerability",
"vulns",
0.1,
1,
)
.vector_result("CVE-2024-5678: XSS vulnerability", "vulns", 0.3, 2)
.graph_result("Host 192.168.1.1 runs nginx", 1, EntityType::Host, "h1")
.time_us(1000)
.build();
assert_eq!(context.len(), 3);
assert!(context.overall_relevance > 0.0);
let top = context.top_chunk().unwrap();
assert!(matches!(top.source, ChunkSource::Vector(_)));
}
#[test]
fn test_relevance_calculation() {
let mut context = RetrievalContext::new("test");
context.add_chunk(ContextChunk::new("A", ChunkSource::Graph, 1.0));
context.add_chunk(ContextChunk::new("B", ChunkSource::Graph, 0.5));
context.add_chunk(ContextChunk::new("C", ChunkSource::Graph, 0.25));
context.calculate_overall_relevance();
assert!(context.overall_relevance > 0.25);
assert!(context.overall_relevance < 1.0);
}
#[test]
fn test_context_filtering() {
let mut context = RetrievalContext::new("test");
context.add_chunk(
ContextChunk::new("Host info", ChunkSource::Graph, 0.9)
.with_entity_type(EntityType::Host),
);
context.add_chunk(
ContextChunk::new("Vuln info", ChunkSource::Graph, 0.8)
.with_entity_type(EntityType::Vulnerability),
);
let hosts = context.chunks_for_type(EntityType::Host);
assert_eq!(hosts.len(), 1);
assert!(hosts[0].content.contains("Host"));
}
#[test]
fn test_context_merge() {
let mut context1 = RetrievalContext::new("test");
context1.add_chunk(ContextChunk::new("A", ChunkSource::Graph, 0.9));
context1.retrieval_time_us = 100;
let mut context2 = RetrievalContext::new("test");
context2.add_chunk(ContextChunk::new(
"B",
ChunkSource::Vector("v".to_string()),
0.8,
));
context2.retrieval_time_us = 200;
context1.merge(context2);
assert_eq!(context1.len(), 2);
assert_eq!(context1.retrieval_time_us, 300);
}
#[test]
fn test_to_context_string() {
let context = ContextBuilder::new("test")
.vector_result("Important finding", "vulns", 0.1, 1)
.build();
let text = context.to_context_string();
assert!(text.contains("[1]"));
assert!(text.contains("vector"));
assert!(text.contains("Important finding"));
}
}