use serde::{Deserialize, Serialize};
use std::sync::Arc;
use std::time::Instant;
use crate::coherence::{CoherenceEngine, CoherenceEnergy};
use crate::execution::ComputeLane;
use crate::governance::PolicyBundle;
use super::config::LlmCoherenceConfig;
use super::error::{Result, RuvLlmIntegrationError};
pub struct LlmCoherenceGate {
engine: Arc<CoherenceEngine>,
policy: PolicyBundle,
config: LlmCoherenceConfig,
}
impl std::fmt::Debug for LlmCoherenceGate {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("LlmCoherenceGate")
.field("policy", &self.policy)
.field("config", &self.config)
.finish_non_exhaustive()
}
}
impl Clone for LlmCoherenceGate {
fn clone(&self) -> Self {
Self {
engine: Arc::clone(&self.engine),
policy: self.policy.clone(),
config: self.config.clone(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LlmGateDecision {
pub allowed: bool,
pub energy: f64,
pub lane: ComputeLane,
pub reason: LlmGateReason,
pub analysis: CoherenceAnalysis,
pub processing_time_us: u64,
pub timestamp: chrono::DateTime<chrono::Utc>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum LlmGateReason {
Coherent,
BelowThreshold {
energy: f64,
threshold: f64,
},
HallucinationDetected {
confidence: f64,
description: String,
},
SemanticInconsistency {
description: String,
},
CitationFailure {
citations: Vec<String>,
},
HumanEscalation {
reason: String,
},
LengthExceeded {
actual: usize,
maximum: usize,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CoherenceAnalysis {
pub semantic_score: f64,
pub factual_score: f64,
pub citation_score: f64,
pub hallucination_prob: f64,
pub affected_nodes: usize,
pub max_residual: f64,
pub subgraph_energy: f64,
}
impl Default for CoherenceAnalysis {
fn default() -> Self {
Self {
semantic_score: 1.0,
factual_score: 1.0,
citation_score: 1.0,
hallucination_prob: 0.0,
affected_nodes: 0,
max_residual: 0.0,
subgraph_energy: 0.0,
}
}
}
#[derive(Debug, Clone)]
pub struct ResponseCoherence {
pub response: String,
pub context_embedding: Vec<f32>,
pub response_embedding: Vec<f32>,
pub related_nodes: Vec<crate::NodeId>,
pub session_id: Option<String>,
}
impl LlmCoherenceGate {
pub fn new(
engine: Arc<CoherenceEngine>,
policy: PolicyBundle,
config: LlmCoherenceConfig,
) -> Result<Self> {
Ok(Self {
engine,
policy,
config,
})
}
pub fn from_engine(
engine: CoherenceEngine,
policy: PolicyBundle,
config: LlmCoherenceConfig,
) -> Result<Self> {
Self::new(Arc::new(engine), policy, config)
}
pub fn evaluate(&self, response: &ResponseCoherence) -> Result<LlmGateDecision> {
let start = Instant::now();
if response.response.len() > self.config.max_response_length {
return Ok(self.create_decision(
false,
0.0,
ComputeLane::Human,
LlmGateReason::LengthExceeded {
actual: response.response.len(),
maximum: self.config.max_response_length,
},
CoherenceAnalysis::default(),
start.elapsed().as_micros() as u64,
));
}
let analysis = self.analyze_coherence(response)?;
let (allowed, lane, reason) = self.determine_decision(&analysis);
Ok(self.create_decision(
allowed,
analysis.subgraph_energy,
lane,
reason,
analysis,
start.elapsed().as_micros() as u64,
))
}
fn analyze_coherence(&self, response: &ResponseCoherence) -> Result<CoherenceAnalysis> {
let mut analysis = CoherenceAnalysis::default();
if response.related_nodes.is_empty() {
return Ok(analysis);
}
if self.config.semantic_consistency {
analysis.semantic_score = self.compute_semantic_score(response);
}
if self.config.factual_grounding {
analysis.factual_score = self.compute_factual_score(response);
}
if self.config.citation_verification {
analysis.citation_score = self.compute_citation_score(response);
}
analysis.hallucination_prob = self.estimate_hallucination_prob(&analysis);
analysis.affected_nodes = response.related_nodes.len();
Ok(analysis)
}
fn compute_semantic_score(&self, response: &ResponseCoherence) -> f64 {
if response.context_embedding.is_empty() || response.response_embedding.is_empty() {
return 1.0;
}
let dot: f32 = response
.context_embedding
.iter()
.zip(&response.response_embedding)
.map(|(a, b)| a * b)
.sum();
let mag_a: f32 = response.context_embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
let mag_b: f32 = response.response_embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
if mag_a == 0.0 || mag_b == 0.0 {
return 1.0;
}
(dot / (mag_a * mag_b)).max(0.0) as f64
}
fn compute_factual_score(&self, _response: &ResponseCoherence) -> f64 {
1.0
}
fn compute_citation_score(&self, _response: &ResponseCoherence) -> f64 {
1.0
}
fn estimate_hallucination_prob(&self, analysis: &CoherenceAnalysis) -> f64 {
let combined = (analysis.semantic_score + analysis.factual_score + analysis.citation_score) / 3.0;
(1.0 - combined) * self.config.hallucination_sensitivity
}
fn determine_decision(&self, analysis: &CoherenceAnalysis) -> (bool, ComputeLane, LlmGateReason) {
if analysis.hallucination_prob > self.config.hallucination_sensitivity {
return (
false,
ComputeLane::Human,
LlmGateReason::HallucinationDetected {
confidence: analysis.hallucination_prob,
description: "Response may contain hallucinated content".to_string(),
},
);
}
if analysis.semantic_score < self.config.coherence_threshold {
return (
false,
ComputeLane::Heavy,
LlmGateReason::SemanticInconsistency {
description: format!(
"Semantic score {:.2} below threshold {:.2}",
analysis.semantic_score, self.config.coherence_threshold
),
},
);
}
let lane = self.determine_lane(analysis.subgraph_energy);
(true, lane, LlmGateReason::Coherent)
}
fn determine_lane(&self, energy: f64) -> ComputeLane {
if energy < self.config.lane_thresholds.reflex {
ComputeLane::Reflex
} else if energy < self.config.lane_thresholds.retrieval {
ComputeLane::Retrieval
} else if energy < self.config.lane_thresholds.heavy {
ComputeLane::Heavy
} else {
ComputeLane::Human
}
}
fn create_decision(
&self,
allowed: bool,
energy: f64,
lane: ComputeLane,
reason: LlmGateReason,
analysis: CoherenceAnalysis,
processing_time_us: u64,
) -> LlmGateDecision {
LlmGateDecision {
allowed,
energy,
lane,
reason,
analysis,
processing_time_us,
timestamp: chrono::Utc::now(),
}
}
pub fn config(&self) -> &LlmCoherenceConfig {
&self.config
}
pub fn policy(&self) -> &PolicyBundle {
&self.policy
}
pub fn engine(&self) -> &CoherenceEngine {
&self.engine
}
}
impl LlmGateDecision {
pub fn is_allowed(&self) -> bool {
self.allowed
}
pub fn requires_escalation(&self) -> bool {
matches!(self.lane, ComputeLane::Human)
}
}