use crate::{
core::{GraphRAGError, KnowledgeGraph, Result},
ollama::OllamaClient,
};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OptimizationStep {
pub iteration: usize,
pub relevance_score: f32,
pub faithfulness_score: f32,
pub conciseness_score: f32,
pub combined_score: f32,
pub weights_snapshot: HashMap<String, f32>,
}
impl OptimizationStep {
pub fn new(iteration: usize) -> Self {
Self {
iteration,
relevance_score: 0.0,
faithfulness_score: 0.0,
conciseness_score: 0.0,
combined_score: 0.0,
weights_snapshot: HashMap::new(),
}
}
pub fn calculate_combined(&mut self, weights: &ObjectiveWeights) {
self.combined_score = self.relevance_score * weights.relevance
+ self.faithfulness_score * weights.faithfulness
+ self.conciseness_score * weights.conciseness;
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ObjectiveWeights {
pub relevance: f32,
pub faithfulness: f32,
pub conciseness: f32,
}
impl Default for ObjectiveWeights {
fn default() -> Self {
Self {
relevance: 0.4,
faithfulness: 0.4,
conciseness: 0.2,
}
}
}
impl ObjectiveWeights {
pub fn normalize(&mut self) {
let sum = self.relevance + self.faithfulness + self.conciseness;
if sum > 0.0 {
self.relevance /= sum;
self.faithfulness /= sum;
self.conciseness /= sum;
}
}
pub fn boost_objective(&mut self, objective: &str, boost: f32) {
match objective {
"relevance" => self.relevance += boost,
"faithfulness" => self.faithfulness += boost,
"conciseness" => self.conciseness += boost,
_ => {},
}
self.normalize();
}
}
#[derive(Debug, Clone)]
pub struct TestQuery {
pub query: String,
pub expected_answer: String,
pub weight: f32,
}
impl TestQuery {
pub fn new(query: String, expected_answer: String) -> Self {
Self {
query,
expected_answer,
weight: 1.0,
}
}
pub fn with_weight(mut self, weight: f32) -> Self {
self.weight = weight;
self
}
}
#[derive(Debug, Clone)]
pub struct OptimizerConfig {
pub learning_rate: f32,
pub max_iterations: usize,
pub slope_window: usize,
pub stagnation_threshold: f32,
pub objective_weights: ObjectiveWeights,
pub use_llm_eval: bool,
}
impl Default for OptimizerConfig {
fn default() -> Self {
Self {
learning_rate: 0.1,
max_iterations: 20,
slope_window: 3,
stagnation_threshold: 0.01,
objective_weights: ObjectiveWeights::default(),
use_llm_eval: true,
}
}
}
pub struct GraphWeightOptimizer {
config: OptimizerConfig,
history: Vec<OptimizationStep>,
ollama_client: Option<OllamaClient>,
current_weights: ObjectiveWeights,
}
impl GraphWeightOptimizer {
pub fn new() -> Self {
Self {
config: OptimizerConfig::default(),
history: Vec::new(),
ollama_client: None,
current_weights: ObjectiveWeights::default(),
}
}
pub fn with_config(config: OptimizerConfig) -> Self {
let current_weights = config.objective_weights.clone();
Self {
config,
history: Vec::new(),
ollama_client: None,
current_weights,
}
}
pub fn with_ollama_client(mut self, client: OllamaClient) -> Self {
self.ollama_client = Some(client);
self
}
#[cfg(feature = "async")]
pub async fn optimize_weights(
&mut self,
graph: &mut KnowledgeGraph,
test_queries: &[TestQuery],
) -> Result<()> {
if test_queries.is_empty() {
return Err(GraphRAGError::Config {
message: "No test queries provided for optimization".to_string(),
});
}
#[cfg(feature = "tracing")]
tracing::info!(
max_iterations = self.config.max_iterations,
num_queries = test_queries.len(),
"Starting graph weight optimization"
);
for iteration in 0..self.config.max_iterations {
let mut step = OptimizationStep::new(iteration);
let metrics = self.evaluate_graph_quality(graph, test_queries).await?;
step.relevance_score = metrics.0;
step.faithfulness_score = metrics.1;
step.conciseness_score = metrics.2;
step.calculate_combined(&self.current_weights);
step.weights_snapshot = self.snapshot_weights(graph);
self.history.push(step.clone());
#[cfg(feature = "tracing")]
tracing::info!(
iteration = iteration,
relevance = step.relevance_score,
faithfulness = step.faithfulness_score,
conciseness = step.conciseness_score,
combined = step.combined_score,
"Optimization iteration complete"
);
if iteration >= self.config.slope_window {
self.detect_and_adjust_stagnation();
}
if step.relevance_score > 0.95
&& step.faithfulness_score > 0.95
&& step.conciseness_score > 0.95
{
#[cfg(feature = "tracing")]
tracing::info!("Early stopping: all metrics excellent");
break;
}
if iteration < self.config.max_iterations - 1 {
self.adjust_graph_weights(graph, test_queries, &step)
.await?;
}
}
#[cfg(feature = "tracing")]
tracing::info!(
iterations = self.history.len(),
final_score = self.history.last().map(|s| s.combined_score).unwrap_or(0.0),
"Optimization complete"
);
Ok(())
}
#[cfg(feature = "async")]
async fn evaluate_graph_quality(
&self,
graph: &KnowledgeGraph,
test_queries: &[TestQuery],
) -> Result<(f32, f32, f32)> {
let mut total_relevance = 0.0;
let mut total_faithfulness = 0.0;
let mut total_conciseness = 0.0;
let mut total_weight = 0.0;
for test_query in test_queries {
let (relevance, faithfulness, conciseness) =
if self.config.use_llm_eval && self.ollama_client.is_some() {
self.evaluate_with_llm(graph, test_query).await?
} else {
self.evaluate_with_heuristics(graph, test_query)?
};
total_relevance += relevance * test_query.weight;
total_faithfulness += faithfulness * test_query.weight;
total_conciseness += conciseness * test_query.weight;
total_weight += test_query.weight;
}
if total_weight > 0.0 {
Ok((
total_relevance / total_weight,
total_faithfulness / total_weight,
total_conciseness / total_weight,
))
} else {
Ok((0.0, 0.0, 0.0))
}
}
fn evaluate_with_heuristics(
&self,
graph: &KnowledgeGraph,
test_query: &TestQuery,
) -> Result<(f32, f32, f32)> {
let query_tokens: Vec<String> = test_query
.query
.to_lowercase()
.split_whitespace()
.filter(|t| t.len() > 2) .map(|s| s.to_string())
.collect();
let answer_tokens: Vec<String> = test_query
.expected_answer
.to_lowercase()
.split_whitespace()
.map(|s| s.to_string())
.collect();
let mut matching_entities = 0;
let mut total_entities = 0;
for entity in graph.entities() {
total_entities += 1;
let entity_name_lower = entity.name.to_lowercase();
if query_tokens
.iter()
.any(|token| entity_name_lower.contains(token))
{
matching_entities += 1;
}
}
let relevance = if total_entities > 0 {
(matching_entities as f32 / total_entities.min(10) as f32).min(1.0)
} else {
0.0
};
let mut answer_token_found = 0;
for token in &answer_tokens {
let found_in_graph = graph.entities().any(|e| {
e.name.to_lowercase().contains(token)
|| e.entity_type.to_lowercase().contains(token)
}) || graph
.get_all_relationships()
.iter()
.any(|r| r.relation_type.to_lowercase().contains(token));
if found_in_graph {
answer_token_found += 1;
}
}
let faithfulness = if !answer_tokens.is_empty() {
answer_token_found as f32 / answer_tokens.len() as f32
} else {
0.5 };
let avg_confidence: f32 = graph
.get_all_relationships()
.iter()
.map(|r| r.confidence)
.sum::<f32>()
/ graph.get_all_relationships().len().max(1) as f32;
let complexity_penalty = (graph.get_all_relationships().len() as f32 / 100.0).min(1.0);
let conciseness = (avg_confidence * 0.7) + ((1.0 - complexity_penalty) * 0.3);
Ok((relevance, faithfulness, conciseness))
}
#[cfg(feature = "async")]
async fn evaluate_with_llm(
&self,
graph: &KnowledgeGraph,
test_query: &TestQuery,
) -> Result<(f32, f32, f32)> {
let ollama_client = self
.ollama_client
.as_ref()
.ok_or_else(|| GraphRAGError::Config {
message: "LLM evaluation requested but no Ollama client available".to_string(),
})?;
let context = self.build_graph_context(graph, &test_query.query, 5);
let prompt = format!(
"Evaluate the quality of information retrieval for this query.\n\n\
Query: {}\n\
Expected Answer: {}\n\n\
Retrieved Information:\n{}\n\n\
Please evaluate on three dimensions (0.0-1.0 scale):\n\
1. Relevance: How well does the retrieved information match the query?\n\
2. Faithfulness: How accurate is the information compared to the expected answer?\n\
3. Conciseness: How focused and non-redundant is the information?\n\n\
Respond with JSON format:\n\
{{\"relevance\": 0.8, \"faithfulness\": 0.7, \"conciseness\": 0.9}}",
test_query.query, test_query.expected_answer, context
);
let response =
ollama_client
.generate(&prompt)
.await
.map_err(|e| GraphRAGError::LanguageModel {
message: format!("LLM evaluation failed: {}", e),
})?;
self.parse_llm_evaluation(&response)
}
fn build_graph_context(&self, graph: &KnowledgeGraph, query: &str, top_k: usize) -> String {
let query_lower = query.to_lowercase();
let query_tokens: Vec<_> = query_lower.split_whitespace().collect();
let mut entity_scores: Vec<_> = graph
.entities()
.map(|e| {
let name_lower = e.name.to_lowercase();
let score = query_tokens
.iter()
.filter(|&&token| name_lower.contains(token))
.count();
(e, score)
})
.filter(|(_, score)| *score > 0)
.collect();
entity_scores.sort_by_key(|item| std::cmp::Reverse(item.1));
let mut context = String::new();
context.push_str("Entities:\n");
for (entity, _) in entity_scores.iter().take(top_k) {
context.push_str(&format!("- {} ({})\n", entity.name, entity.entity_type));
}
context.push_str("\nRelationships:\n");
for rel in graph.get_all_relationships().iter().take(top_k) {
context.push_str(&format!(
"- {} --[{}]--> {}\n",
rel.source.0, rel.relation_type, rel.target.0
));
}
context
}
fn parse_llm_evaluation(&self, response: &str) -> Result<(f32, f32, f32)> {
let json_start = response.find('{');
let json_end = response.rfind('}');
if let (Some(start), Some(end)) = (json_start, json_end) {
if end > start {
let json_str = &response[start..=end];
if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(json_str) {
let relevance = parsed["relevance"].as_f64().unwrap_or(0.5) as f32;
let faithfulness = parsed["faithfulness"].as_f64().unwrap_or(0.5) as f32;
let conciseness = parsed["conciseness"].as_f64().unwrap_or(0.5) as f32;
return Ok((
relevance.clamp(0.0, 1.0),
faithfulness.clamp(0.0, 1.0),
conciseness.clamp(0.0, 1.0),
));
}
}
}
#[cfg(feature = "tracing")]
tracing::warn!("Failed to parse LLM evaluation, using default scores");
Ok((0.5, 0.5, 0.5))
}
fn snapshot_weights(&self, graph: &KnowledgeGraph) -> HashMap<String, f32> {
let mut weights = HashMap::new();
for rel in graph.get_all_relationships() {
let key = format!("{}_{}", rel.source.0, rel.target.0);
weights.insert(key, rel.confidence);
}
weights
}
fn detect_and_adjust_stagnation(&mut self) {
let window_size = self.config.slope_window;
let history_len = self.history.len();
if history_len < window_size + 1 {
return;
}
let relevance_slope = self.calculate_slope(window_size, |s| s.relevance_score);
let faithfulness_slope = self.calculate_slope(window_size, |s| s.faithfulness_score);
let conciseness_slope = self.calculate_slope(window_size, |s| s.conciseness_score);
#[cfg(feature = "tracing")]
tracing::debug!(
relevance_slope = relevance_slope,
faithfulness_slope = faithfulness_slope,
conciseness_slope = conciseness_slope,
threshold = self.config.stagnation_threshold,
"Stagnation detection"
);
if relevance_slope.abs() < self.config.stagnation_threshold {
self.current_weights.boost_objective("relevance", 0.05);
#[cfg(feature = "tracing")]
tracing::info!("Boosting relevance weight due to stagnation");
}
if faithfulness_slope.abs() < self.config.stagnation_threshold {
self.current_weights.boost_objective("faithfulness", 0.05);
#[cfg(feature = "tracing")]
tracing::info!("Boosting faithfulness weight due to stagnation");
}
if conciseness_slope.abs() < self.config.stagnation_threshold {
self.current_weights.boost_objective("conciseness", 0.05);
#[cfg(feature = "tracing")]
tracing::info!("Boosting conciseness weight due to stagnation");
}
}
fn calculate_slope<F>(&self, window_size: usize, metric_fn: F) -> f32
where
F: Fn(&OptimizationStep) -> f32,
{
let history_len = self.history.len();
if history_len < window_size + 1 {
return 0.0;
}
let recent_steps = &self.history[history_len - window_size - 1..];
let first_value = metric_fn(&recent_steps[0]);
let last_value = metric_fn(&recent_steps[window_size]);
(last_value - first_value) / window_size as f32
}
#[cfg(feature = "async")]
async fn adjust_graph_weights(
&self,
graph: &mut KnowledgeGraph,
_test_queries: &[TestQuery],
current_step: &OptimizationStep,
) -> Result<()> {
let needs_relevance = current_step.relevance_score < 0.8;
let needs_faithfulness = current_step.faithfulness_score < 0.8;
let needs_conciseness = current_step.conciseness_score < 0.8;
let relationships = graph.get_all_relationships().to_vec();
for rel in relationships {
let mut new_confidence = rel.confidence;
if needs_relevance {
if rel.embedding.is_some() {
new_confidence *= 1.0 + self.config.learning_rate * 0.5;
}
}
if needs_faithfulness {
if rel.temporal_type.is_some() || rel.causal_strength.is_some() {
new_confidence *= 1.0 + self.config.learning_rate * 0.3;
}
}
if needs_conciseness {
new_confidence *= 1.0 - self.config.learning_rate * 0.1;
}
new_confidence = new_confidence.clamp(0.1, 1.0);
let _ = new_confidence; }
Ok(())
}
pub fn history(&self) -> &[OptimizationStep] {
&self.history
}
pub fn final_metrics(&self) -> Option<(f32, f32, f32, f32)> {
self.history.last().map(|step| {
(
step.relevance_score,
step.faithfulness_score,
step.conciseness_score,
step.combined_score,
)
})
}
pub fn total_improvement(&self) -> f32 {
if self.history.len() < 2 {
return 0.0;
}
let first = self.history.first().expect("non-empty").combined_score;
let last = self.history.last().expect("non-empty").combined_score;
last - first
}
}
impl Default for GraphWeightOptimizer {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_optimization_step_creation() {
let step = OptimizationStep::new(0);
assert_eq!(step.iteration, 0);
assert_eq!(step.relevance_score, 0.0);
}
#[test]
fn test_objective_weights_normalization() {
let mut weights = ObjectiveWeights {
relevance: 2.0,
faithfulness: 2.0,
conciseness: 2.0,
};
weights.normalize();
let sum = weights.relevance + weights.faithfulness + weights.conciseness;
assert!((sum - 1.0).abs() < 0.001);
}
#[test]
fn test_objective_weights_boost() {
let mut weights = ObjectiveWeights::default();
let original_relevance = weights.relevance;
weights.boost_objective("relevance", 0.1);
assert!(weights.relevance > original_relevance);
let sum = weights.relevance + weights.faithfulness + weights.conciseness;
assert!((sum - 1.0).abs() < 0.001);
}
#[test]
fn test_test_query_creation() {
let query = TestQuery::new("test query".to_string(), "expected".to_string());
assert_eq!(query.weight, 1.0);
let weighted = TestQuery::new("test".to_string(), "expected".to_string()).with_weight(2.0);
assert_eq!(weighted.weight, 2.0);
}
#[test]
fn test_optimizer_initialization() {
let optimizer = GraphWeightOptimizer::new();
assert_eq!(optimizer.history.len(), 0);
assert_eq!(optimizer.config.max_iterations, 20);
}
#[test]
fn test_slope_calculation() {
let mut optimizer = GraphWeightOptimizer::new();
for i in 0..5 {
let mut step = OptimizationStep::new(i);
step.relevance_score = 0.5 + (i as f32 * 0.1);
optimizer.history.push(step);
}
let slope = optimizer.calculate_slope(3, |s| s.relevance_score);
assert!(slope > 0.0);
}
#[test]
fn test_combined_score_calculation() {
let weights = ObjectiveWeights {
relevance: 0.5,
faithfulness: 0.3,
conciseness: 0.2,
};
let mut step = OptimizationStep::new(0);
step.relevance_score = 0.8;
step.faithfulness_score = 0.6;
step.conciseness_score = 0.9;
step.calculate_combined(&weights);
let expected = 0.8 * 0.5 + 0.6 * 0.3 + 0.9 * 0.2;
assert!((step.combined_score - expected).abs() < 0.001);
}
#[test]
fn test_heuristic_evaluation() {
use crate::core::{Entity, EntityId, Relationship};
let mut graph = KnowledgeGraph::new();
let socrates = Entity {
id: EntityId("socrates".to_string()),
name: "Socrates".to_string(),
entity_type: "PERSON".to_string(),
confidence: 0.95,
mentions: vec![],
embedding: None,
first_mentioned: None,
last_mentioned: None,
temporal_validity: None,
};
let philosophy = Entity {
id: EntityId("philosophy".to_string()),
name: "Philosophy".to_string(),
entity_type: "CONCEPT".to_string(),
confidence: 0.9,
mentions: vec![],
embedding: None,
first_mentioned: None,
last_mentioned: None,
temporal_validity: None,
};
graph.add_entity(socrates).unwrap();
graph.add_entity(philosophy).unwrap();
let rel = Relationship::new(
EntityId("socrates".to_string()),
EntityId("philosophy".to_string()),
"FOUNDED".to_string(),
0.9,
);
graph.add_relationship(rel).unwrap();
let optimizer = GraphWeightOptimizer::new();
let query = TestQuery::new(
"Tell me about Socrates and philosophy".to_string(),
"Socrates founded philosophy".to_string(),
);
let (relevance, faithfulness, conciseness) =
optimizer.evaluate_with_heuristics(&graph, &query).unwrap();
assert!(
(0.0..=1.0).contains(&relevance),
"Relevance out of range: {}",
relevance
);
assert!(
(0.0..=1.0).contains(&faithfulness),
"Faithfulness out of range: {}",
faithfulness
);
assert!(
(0.0..=1.0).contains(&conciseness),
"Conciseness out of range: {}",
conciseness
);
assert!(
relevance > 0.0,
"Should find some relevant entities (relevance={})",
relevance
);
assert!(
faithfulness > 0.0,
"Should match expected answer (faithfulness={})",
faithfulness
);
}
#[test]
fn test_heuristic_evaluation_empty_graph() {
let graph = KnowledgeGraph::new();
let optimizer = GraphWeightOptimizer::new();
let query = TestQuery::new("test query".to_string(), "test answer".to_string());
let (relevance, faithfulness, conciseness) =
optimizer.evaluate_with_heuristics(&graph, &query).unwrap();
assert_eq!(relevance, 0.0, "Empty graph should have zero relevance");
assert!(faithfulness >= 0.0, "Faithfulness should be non-negative");
assert!(conciseness >= 0.0, "Conciseness should be non-negative");
}
#[test]
fn test_graph_context_building() {
use crate::core::{Entity, EntityId, Relationship};
let mut graph = KnowledgeGraph::new();
for i in 0..5 {
let entity = Entity {
id: EntityId(format!("entity_{}", i)),
name: format!("Entity {}", i),
entity_type: "TEST".to_string(),
confidence: 0.9,
mentions: vec![],
embedding: None,
first_mentioned: None,
last_mentioned: None,
temporal_validity: None,
};
graph.add_entity(entity).unwrap();
}
for i in 0..4 {
let rel = Relationship::new(
EntityId(format!("entity_{}", i)),
EntityId(format!("entity_{}", i + 1)),
"RELATES_TO".to_string(),
0.8,
);
graph.add_relationship(rel).unwrap();
}
let optimizer = GraphWeightOptimizer::new();
let context = optimizer.build_graph_context(&graph, "entity 0", 3);
assert!(
context.contains("Entities:"),
"Context should include entities"
);
assert!(
context.contains("Relationships:"),
"Context should include relationships"
);
assert!(!context.is_empty(), "Context should not be empty");
}
#[test]
fn test_llm_evaluation_parse_json() {
let optimizer = GraphWeightOptimizer::new();
let response = r#"Here is my evaluation:
{"relevance": 0.8, "faithfulness": 0.7, "conciseness": 0.9}
That's my assessment."#;
let (relevance, faithfulness, conciseness) =
optimizer.parse_llm_evaluation(response).unwrap();
assert!((relevance - 0.8).abs() < 0.001);
assert!((faithfulness - 0.7).abs() < 0.001);
assert!((conciseness - 0.9).abs() < 0.001);
}
#[test]
fn test_llm_evaluation_parse_fallback() {
let optimizer = GraphWeightOptimizer::new();
let response = "This is not JSON at all";
let (relevance, faithfulness, conciseness) =
optimizer.parse_llm_evaluation(response).unwrap();
assert_eq!(relevance, 0.5);
assert_eq!(faithfulness, 0.5);
assert_eq!(conciseness, 0.5);
}
}