use std::collections::HashMap;
use crate::common::metadata::Metadata;
use crate::graph::types::{CausalPath, Community, Entity, Relation};
#[derive(Debug, Clone)]
pub struct ExtractedEntity {
pub name: String,
pub entity_type: String,
pub description: Option<String>,
pub properties: Metadata,
}
#[derive(Debug, Clone)]
pub struct ExtractedRelation {
pub from_name: String,
pub to_name: String,
pub relation_type: String,
pub description: Option<String>,
pub confidence: Option<f32>,
pub properties: Metadata,
}
#[derive(Debug, Clone, Default)]
pub struct ExtractionResult {
pub entities: Vec<ExtractedEntity>,
pub relations: Vec<ExtractedRelation>,
}
pub trait LlmClient: Send + Sync {
fn complete(&self, prompt: &str) -> crate::graphrag::error::Result<String>;
}
pub trait EntityExtractor: Send + Sync {
fn extract(&self, text: &str) -> crate::graphrag::error::Result<Vec<ExtractedEntity>>;
}
pub trait RelationExtractor: Send + Sync {
fn extract(
&self,
text: &str,
entities: &[ExtractedEntity],
) -> crate::graphrag::error::Result<Vec<ExtractedRelation>>;
}
#[derive(Debug, Clone)]
pub struct SearchContext {
pub entity: Option<Entity>,
pub relations: Vec<Relation>,
pub text: Option<String>,
pub score: Option<f32>,
pub metadata: Metadata,
}
#[derive(Debug, Clone, Default)]
pub struct SearchResult {
pub entities: Vec<Entity>,
pub contexts: Vec<SearchContext>,
pub communities: Vec<Community>,
pub paths: Vec<CausalPath>,
pub provenance: Vec<HashMap<String, serde_json::Value>>,
pub answer: Option<String>,
pub avg_confidence: Option<f32>,
pub min_confidence: Option<f32>,
}
impl SearchResult {
pub fn empty() -> Self {
Self::default()
}
pub fn is_empty(&self) -> bool {
self.entities.is_empty() && self.contexts.is_empty()
}
pub fn compute_confidence_metrics(entities: &[Entity]) -> (Option<f32>, Option<f32>) {
let confidences: Vec<f32> = entities
.iter()
.filter_map(|e| e.confidence.overall)
.collect();
if confidences.is_empty() {
return (None, None);
}
let avg = confidences.iter().sum::<f32>() / confidences.len() as f32;
let min = confidences.iter().cloned().fold(f32::INFINITY, f32::min);
(Some(avg), Some(min))
}
}
#[derive(Debug, Clone)]
pub struct InformationGainWeights {
pub expansion: f32,
pub confidence: f32,
pub contradiction: f32,
}
impl Default for InformationGainWeights {
fn default() -> Self {
Self { expansion: 0.7, confidence: 0.3, contradiction: 0.0 }
}
}
#[derive(Debug, Clone)]
pub struct IngestObservation {
pub batch_id: Option<String>,
pub observed_at: std::time::SystemTime,
pub surprisal: Option<f32>,
pub entities_created: usize,
pub relations_created: usize,
pub entities_merged: usize,
pub confidence_delta: f32,
pub contradictions_flagged: usize,
pub is_stressor: bool,
}
impl IngestObservation {
pub fn information_gain(&self, weights: &InformationGainWeights) -> f32 {
let expansion = ((1 + self.entities_created + self.relations_created) as f32).ln();
let confidence_gain = self.confidence_delta.max(0.0);
let contradiction_signal = if self.contradictions_flagged > 0 { 1.0 } else { 0.0 };
(weights.expansion * expansion
+ weights.confidence * confidence_gain
+ weights.contradiction * contradiction_signal)
.clamp(0.0, 1.0)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::SystemTime;
fn observation(created: usize, conf_delta: f32, contradictions: usize) -> IngestObservation {
IngestObservation {
batch_id: None,
observed_at: SystemTime::now(),
surprisal: None,
entities_created: created,
relations_created: 0,
entities_merged: 0,
confidence_delta: conf_delta,
contradictions_flagged: contradictions,
is_stressor: false,
}
}
#[test]
fn empty_result_is_empty() {
assert!(SearchResult::empty().is_empty());
}
#[test]
fn confidence_metrics_empty_entities() {
let (avg, min) = SearchResult::compute_confidence_metrics(&[]);
assert!(avg.is_none());
assert!(min.is_none());
}
#[test]
fn information_gain_zero_creation_is_low() {
let obs = observation(0, 0.0, 0);
let ig = obs.information_gain(&InformationGainWeights::default());
assert!(ig.abs() < 1e-5);
}
#[test]
fn information_gain_positive_creation_is_positive() {
let obs = observation(5, 0.0, 0);
let ig = obs.information_gain(&InformationGainWeights::default());
assert!(ig > 0.0, "ig={}", ig);
}
#[test]
fn information_gain_is_clamped_to_one() {
let obs = observation(10_000, 1.0, 1);
let ig = obs.information_gain(&InformationGainWeights::default());
assert!(ig <= 1.0 + 1e-6);
}
#[test]
fn confidence_delta_contributes_positively() {
let low = observation(1, 0.0, 0).information_gain(&InformationGainWeights::default());
let high = observation(1, 0.5, 0).information_gain(&InformationGainWeights::default());
assert!(high > low, "high={} low={}", high, low);
}
#[test]
fn negative_confidence_delta_is_ignored() {
let obs = observation(1, -0.5, 0);
let ig = obs.information_gain(&InformationGainWeights::default());
let baseline = observation(1, 0.0, 0).information_gain(&InformationGainWeights::default());
assert!((ig - baseline).abs() < 1e-5);
}
#[test]
fn contradiction_signal_contributes_with_nonzero_weight() {
let weights = InformationGainWeights { expansion: 0.5, confidence: 0.3, contradiction: 0.2 };
let without = observation(1, 0.0, 0).information_gain(&weights);
let with = observation(1, 0.0, 1).information_gain(&weights);
assert!(with > without, "with={} without={}", with, without);
}
#[test]
fn default_weights_sum_to_one() {
let w = InformationGainWeights::default();
assert!((w.expansion + w.confidence + w.contradiction - 1.0).abs() < 1e-6);
}
}