graphrag_core/graph/incremental/
helpers.rs1#![allow(unused_imports)]
2
3use crate::core::{
4 DocumentId, Entity, EntityId, GraphRAGError, KnowledgeGraph, Relationship, Result, TextChunk,
5};
6use chrono::{DateTime, Utc};
7use serde::{Deserialize, Serialize};
8use std::collections::{HashMap, HashSet};
9use std::time::{Duration, Instant};
10
11#[cfg(feature = "incremental")]
12use std::sync::Arc;
13
14#[cfg(feature = "incremental")]
15use {
16 dashmap::DashMap,
17 parking_lot::{Mutex, RwLock},
18 tokio::sync::{broadcast, Semaphore},
19 uuid::Uuid,
20};
21
22use super::*;
23
24#[cfg(feature = "incremental")]
26pub struct SelectiveInvalidation {
27 cache_regions: DashMap<String, CacheRegion>,
28 entity_to_regions: DashMap<EntityId, HashSet<String>>,
29 invalidation_log: Mutex<Vec<(DateTime<Utc>, InvalidationStrategy)>>,
30}
31
32#[cfg(feature = "incremental")]
33impl Default for SelectiveInvalidation {
34 fn default() -> Self {
35 Self::new()
36 }
37}
38
39#[cfg(feature = "incremental")]
40impl SelectiveInvalidation {
41 pub fn new() -> Self {
43 Self {
44 cache_regions: DashMap::new(),
45 entity_to_regions: DashMap::new(),
46 invalidation_log: Mutex::new(Vec::new()),
47 }
48 }
49
50 pub fn register_cache_region(&self, region: CacheRegion) {
52 let region_id = region.region_id.clone();
53
54 for entity_id in ®ion.entity_ids {
56 self.entity_to_regions
57 .entry(entity_id.clone())
58 .or_default()
59 .insert(region_id.clone());
60 }
61
62 self.cache_regions.insert(region_id, region);
63 }
64
65 pub fn invalidate_for_changes(&self, changes: &[ChangeRecord]) -> Vec<InvalidationStrategy> {
67 let mut strategies = Vec::new();
68 let mut affected_regions = HashSet::new();
69
70 for change in changes {
71 match &change.change_type {
72 ChangeType::EntityAdded | ChangeType::EntityUpdated | ChangeType::EntityRemoved => {
73 if let Some(entity_id) = &change.entity_id {
74 if let Some(regions) = self.entity_to_regions.get(entity_id) {
75 affected_regions.extend(regions.clone());
76 }
77 strategies.push(InvalidationStrategy::Relational(entity_id.clone(), 2));
78 }
79 },
80 ChangeType::RelationshipAdded
81 | ChangeType::RelationshipUpdated
82 | ChangeType::RelationshipRemoved => {
83 if let ChangeData::Relationship(rel) = &change.data {
85 strategies.push(InvalidationStrategy::Relational(rel.source.clone(), 1));
86 strategies.push(InvalidationStrategy::Relational(rel.target.clone(), 1));
87 }
88 },
89 _ => {
90 let cache_keys = self.generate_cache_keys_for_change(change);
92 if !cache_keys.is_empty() {
93 strategies.push(InvalidationStrategy::Selective(cache_keys));
94 }
95 },
96 }
97 }
98
99 for region_id in affected_regions {
101 strategies.push(InvalidationStrategy::Regional(region_id));
102 }
103
104 let mut log = self.invalidation_log.lock();
106 for strategy in &strategies {
107 log.push((Utc::now(), strategy.clone()));
108 }
109
110 strategies
111 }
112
113 fn generate_cache_keys_for_change(&self, change: &ChangeRecord) -> Vec<String> {
114 let mut keys = Vec::new();
115
116 match &change.change_type {
118 ChangeType::EntityAdded | ChangeType::EntityUpdated => {
119 if let Some(entity_id) = &change.entity_id {
120 keys.push(format!("entity:{entity_id}"));
121 keys.push(format!("entity_neighbors:{entity_id}"));
122 }
123 },
124 ChangeType::DocumentAdded | ChangeType::DocumentUpdated => {
125 if let Some(doc_id) = &change.document_id {
126 keys.push(format!("document:{doc_id}"));
127 keys.push(format!("document_chunks:{doc_id}"));
128 }
129 },
130 ChangeType::EmbeddingAdded | ChangeType::EmbeddingUpdated => {
131 if let Some(entity_id) = &change.entity_id {
132 keys.push(format!("embedding:{entity_id}"));
133 keys.push(format!("similarity:{entity_id}"));
134 }
135 },
136 _ => {},
137 }
138
139 keys
140 }
141
142 pub fn get_invalidation_stats(&self) -> InvalidationStats {
144 let log = self.invalidation_log.lock();
145
146 InvalidationStats {
147 total_invalidations: log.len(),
148 cache_regions: self.cache_regions.len(),
149 entity_mappings: self.entity_to_regions.len(),
150 last_invalidation: log.last().map(|(time, _)| *time),
151 }
152 }
153}
154
155#[derive(Debug, Clone)]
157pub struct InvalidationStats {
158 pub total_invalidations: usize,
160 pub cache_regions: usize,
162 pub entity_mappings: usize,
164 pub last_invalidation: Option<DateTime<Utc>>,
166}
167
168pub struct ConflictResolver {
174 pub(super) strategy: ConflictStrategy,
175 custom_resolvers: HashMap<String, ConflictResolverFn>,
176}
177
178type ConflictResolverFn = Box<dyn Fn(&Conflict) -> Result<ConflictResolution> + Send + Sync>;
180
181impl ConflictResolver {
182 pub fn new(strategy: ConflictStrategy) -> Self {
184 Self {
185 strategy,
186 custom_resolvers: HashMap::new(),
187 }
188 }
189
190 pub fn with_custom_resolver<F>(mut self, name: String, resolver: F) -> Self
192 where
193 F: Fn(&Conflict) -> Result<ConflictResolution> + Send + Sync + 'static,
194 {
195 self.custom_resolvers.insert(name, Box::new(resolver));
196 self
197 }
198
199 pub async fn resolve_conflict(&self, conflict: &Conflict) -> Result<ConflictResolution> {
201 match &self.strategy {
202 ConflictStrategy::KeepExisting => Ok(ConflictResolution {
203 strategy: ConflictStrategy::KeepExisting,
204 resolved_data: conflict.existing_data.clone(),
205 metadata: HashMap::new(),
206 }),
207 ConflictStrategy::KeepNew => Ok(ConflictResolution {
208 strategy: ConflictStrategy::KeepNew,
209 resolved_data: conflict.new_data.clone(),
210 metadata: HashMap::new(),
211 }),
212 ConflictStrategy::Merge => self.merge_conflict_data(conflict).await,
213 ConflictStrategy::Custom(resolver_name) => {
214 if let Some(resolver) = self.custom_resolvers.get(resolver_name) {
215 resolver(conflict)
216 } else {
217 Err(GraphRAGError::ConflictResolution {
218 message: format!("Custom resolver '{resolver_name}' not found"),
219 })
220 }
221 },
222 _ => Err(GraphRAGError::ConflictResolution {
223 message: "Conflict resolution strategy not implemented".to_string(),
224 }),
225 }
226 }
227
228 async fn merge_conflict_data(&self, conflict: &Conflict) -> Result<ConflictResolution> {
229 match (&conflict.existing_data, &conflict.new_data) {
230 (ChangeData::Entity(existing), ChangeData::Entity(new)) => {
231 let merged = self.merge_entities(existing, new)?;
232 Ok(ConflictResolution {
233 strategy: ConflictStrategy::Merge,
234 resolved_data: ChangeData::Entity(merged),
235 metadata: [("merge_strategy".to_string(), "entity_merge".to_string())]
236 .into_iter()
237 .collect(),
238 })
239 },
240 (ChangeData::Relationship(existing), ChangeData::Relationship(new)) => {
241 let merged = self.merge_relationships(existing, new)?;
242 Ok(ConflictResolution {
243 strategy: ConflictStrategy::Merge,
244 resolved_data: ChangeData::Relationship(merged),
245 metadata: [(
246 "merge_strategy".to_string(),
247 "relationship_merge".to_string(),
248 )]
249 .into_iter()
250 .collect(),
251 })
252 },
253 _ => Err(GraphRAGError::ConflictResolution {
254 message: "Cannot merge incompatible data types".to_string(),
255 }),
256 }
257 }
258
259 pub(super) fn merge_entities(&self, existing: &Entity, new: &Entity) -> Result<Entity> {
260 let mut merged = existing.clone();
261
262 if new.confidence > existing.confidence {
264 merged.confidence = new.confidence;
265 merged.name = new.name.clone();
266 merged.entity_type = new.entity_type.clone();
267 }
268
269 let mut all_mentions = existing.mentions.clone();
271 for new_mention in &new.mentions {
272 if !all_mentions.iter().any(|m| {
273 m.chunk_id == new_mention.chunk_id && m.start_offset == new_mention.start_offset
274 }) {
275 all_mentions.push(new_mention.clone());
276 }
277 }
278 merged.mentions = all_mentions;
279
280 if new.embedding.is_some() {
282 merged.embedding = new.embedding.clone();
283 }
284
285 Ok(merged)
286 }
287
288 fn merge_relationships(
289 &self,
290 existing: &Relationship,
291 new: &Relationship,
292 ) -> Result<Relationship> {
293 let mut merged = existing.clone();
294
295 if new.confidence > existing.confidence {
297 merged.confidence = new.confidence;
298 merged.relation_type = new.relation_type.clone();
299 }
300
301 let mut all_contexts = existing.context.clone();
303 for new_context in &new.context {
304 if !all_contexts.contains(new_context) {
305 all_contexts.push(new_context.clone());
306 }
307 }
308 merged.context = all_contexts;
309
310 Ok(merged)
311 }
312}
313
314#[cfg(feature = "incremental")]
320pub struct UpdateMonitor {
321 metrics: DashMap<String, UpdateMetric>,
322 operations_log: Mutex<Vec<OperationLog>>,
323 performance_stats: RwLock<PerformanceStats>,
324}
325
326#[cfg(feature = "incremental")]
327impl Default for UpdateMonitor {
328 fn default() -> Self {
329 Self::new()
330 }
331}
332
333#[derive(Debug, Clone)]
335pub struct UpdateMetric {
336 pub name: String,
338 pub value: f64,
340 pub timestamp: DateTime<Utc>,
342 pub tags: HashMap<String, String>,
344}
345
346#[derive(Debug, Clone)]
348pub struct OperationLog {
349 pub operation_id: UpdateId,
351 pub operation_type: String,
353 pub start_time: Instant,
355 pub end_time: Option<Instant>,
357 pub success: Option<bool>,
359 pub error_message: Option<String>,
361 pub affected_entities: usize,
363 pub affected_relationships: usize,
365}
366
367#[derive(Debug, Clone)]
369pub struct PerformanceStats {
370 pub total_operations: u64,
372 pub successful_operations: u64,
374 pub failed_operations: u64,
376 pub average_operation_time: Duration,
378 pub peak_operations_per_second: f64,
380 pub cache_hit_rate: f64,
382 pub conflict_resolution_rate: f64,
384}
385
386#[cfg(feature = "incremental")]
387impl UpdateMonitor {
388 pub fn new() -> Self {
390 Self {
391 metrics: DashMap::new(),
392 operations_log: Mutex::new(Vec::new()),
393 performance_stats: RwLock::new(PerformanceStats {
394 total_operations: 0,
395 successful_operations: 0,
396 failed_operations: 0,
397 average_operation_time: Duration::from_millis(0),
398 peak_operations_per_second: 0.0,
399 cache_hit_rate: 0.0,
400 conflict_resolution_rate: 0.0,
401 }),
402 }
403 }
404
405 pub fn start_operation(&self, operation_type: &str) -> UpdateId {
407 let operation_id = UpdateId::new();
408 let log_entry = OperationLog {
409 operation_id: operation_id.clone(),
410 operation_type: operation_type.to_string(),
411 start_time: Instant::now(),
412 end_time: None,
413 success: None,
414 error_message: None,
415 affected_entities: 0,
416 affected_relationships: 0,
417 };
418
419 self.operations_log.lock().push(log_entry);
420 operation_id
421 }
422
423 pub fn complete_operation(
425 &self,
426 operation_id: &UpdateId,
427 success: bool,
428 error: Option<String>,
429 affected_entities: usize,
430 affected_relationships: usize,
431 ) {
432 let mut log = self.operations_log.lock();
433 if let Some(entry) = log.iter_mut().find(|e| &e.operation_id == operation_id) {
434 entry.end_time = Some(Instant::now());
435 entry.success = Some(success);
436 entry.error_message = error;
437 entry.affected_entities = affected_entities;
438 entry.affected_relationships = affected_relationships;
439 }
440
441 self.update_performance_stats();
443 }
444
445 fn update_performance_stats(&self) {
446 let log = self.operations_log.lock();
447 let completed_ops: Vec<_> = log
448 .iter()
449 .filter(|op| op.end_time.is_some() && op.success.is_some())
450 .collect();
451
452 if completed_ops.is_empty() {
453 return;
454 }
455
456 let mut stats = self.performance_stats.write();
457 stats.total_operations = completed_ops.len() as u64;
458 stats.successful_operations = completed_ops
459 .iter()
460 .filter(|op| op.success == Some(true))
461 .count() as u64;
462 stats.failed_operations = stats.total_operations - stats.successful_operations;
463
464 let total_time: Duration = completed_ops
466 .iter()
467 .filter_map(|op| op.end_time.map(|end| end.duration_since(op.start_time)))
468 .sum();
469
470 if !completed_ops.is_empty() {
471 stats.average_operation_time = total_time / completed_ops.len() as u32;
472 }
473 }
474
475 pub fn record_metric(&self, name: &str, value: f64, tags: HashMap<String, String>) {
477 let metric = UpdateMetric {
478 name: name.to_string(),
479 value,
480 timestamp: Utc::now(),
481 tags,
482 };
483 self.metrics.insert(name.to_string(), metric);
484 }
485
486 pub fn get_performance_stats(&self) -> PerformanceStats {
488 self.performance_stats.read().clone()
489 }
490
491 pub fn get_recent_operations(&self, limit: usize) -> Vec<OperationLog> {
493 let log = self.operations_log.lock();
494 log.iter().rev().take(limit).cloned().collect()
495 }
496}