Skip to main content

graphrag_core/graph/incremental/
helpers.rs

1#![allow(unused_imports)]
2
3use crate::core::{
4    DocumentId, Entity, EntityId, GraphRAGError, KnowledgeGraph, Relationship, Result, TextChunk,
5};
6use chrono::{DateTime, Utc};
7use serde::{Deserialize, Serialize};
8use std::collections::{HashMap, HashSet};
9use std::time::{Duration, Instant};
10
11#[cfg(feature = "incremental")]
12use std::sync::Arc;
13
14#[cfg(feature = "incremental")]
15use {
16    dashmap::DashMap,
17    parking_lot::{Mutex, RwLock},
18    tokio::sync::{broadcast, Semaphore},
19    uuid::Uuid,
20};
21
22use super::*;
23
24/// Selective cache invalidation manager
25#[cfg(feature = "incremental")]
26pub struct SelectiveInvalidation {
27    cache_regions: DashMap<String, CacheRegion>,
28    entity_to_regions: DashMap<EntityId, HashSet<String>>,
29    invalidation_log: Mutex<Vec<(DateTime<Utc>, InvalidationStrategy)>>,
30}
31
32#[cfg(feature = "incremental")]
33impl Default for SelectiveInvalidation {
34    fn default() -> Self {
35        Self::new()
36    }
37}
38
39#[cfg(feature = "incremental")]
40impl SelectiveInvalidation {
41    /// Creates a new selective invalidation manager
42    pub fn new() -> Self {
43        Self {
44            cache_regions: DashMap::new(),
45            entity_to_regions: DashMap::new(),
46            invalidation_log: Mutex::new(Vec::new()),
47        }
48    }
49
50    /// Registers a cache region for invalidation tracking
51    pub fn register_cache_region(&self, region: CacheRegion) {
52        let region_id = region.region_id.clone();
53
54        // Update entity mappings
55        for entity_id in &region.entity_ids {
56            self.entity_to_regions
57                .entry(entity_id.clone())
58                .or_default()
59                .insert(region_id.clone());
60        }
61
62        self.cache_regions.insert(region_id, region);
63    }
64
65    /// Determines invalidation strategies for a set of changes
66    pub fn invalidate_for_changes(&self, changes: &[ChangeRecord]) -> Vec<InvalidationStrategy> {
67        let mut strategies = Vec::new();
68        let mut affected_regions = HashSet::new();
69
70        for change in changes {
71            match &change.change_type {
72                ChangeType::EntityAdded | ChangeType::EntityUpdated | ChangeType::EntityRemoved => {
73                    if let Some(entity_id) = &change.entity_id {
74                        if let Some(regions) = self.entity_to_regions.get(entity_id) {
75                            affected_regions.extend(regions.clone());
76                        }
77                        strategies.push(InvalidationStrategy::Relational(entity_id.clone(), 2));
78                    }
79                },
80                ChangeType::RelationshipAdded
81                | ChangeType::RelationshipUpdated
82                | ChangeType::RelationshipRemoved => {
83                    // Invalidate based on relationship endpoints
84                    if let ChangeData::Relationship(rel) = &change.data {
85                        strategies.push(InvalidationStrategy::Relational(rel.source.clone(), 1));
86                        strategies.push(InvalidationStrategy::Relational(rel.target.clone(), 1));
87                    }
88                },
89                _ => {
90                    // For other changes, use selective invalidation
91                    let cache_keys = self.generate_cache_keys_for_change(change);
92                    if !cache_keys.is_empty() {
93                        strategies.push(InvalidationStrategy::Selective(cache_keys));
94                    }
95                },
96            }
97        }
98
99        // Add regional invalidation for affected regions
100        for region_id in affected_regions {
101            strategies.push(InvalidationStrategy::Regional(region_id));
102        }
103
104        // Log invalidation
105        let mut log = self.invalidation_log.lock();
106        for strategy in &strategies {
107            log.push((Utc::now(), strategy.clone()));
108        }
109
110        strategies
111    }
112
113    fn generate_cache_keys_for_change(&self, change: &ChangeRecord) -> Vec<String> {
114        let mut keys = Vec::new();
115
116        // Generate cache keys based on change type and data
117        match &change.change_type {
118            ChangeType::EntityAdded | ChangeType::EntityUpdated => {
119                if let Some(entity_id) = &change.entity_id {
120                    keys.push(format!("entity:{entity_id}"));
121                    keys.push(format!("entity_neighbors:{entity_id}"));
122                }
123            },
124            ChangeType::DocumentAdded | ChangeType::DocumentUpdated => {
125                if let Some(doc_id) = &change.document_id {
126                    keys.push(format!("document:{doc_id}"));
127                    keys.push(format!("document_chunks:{doc_id}"));
128                }
129            },
130            ChangeType::EmbeddingAdded | ChangeType::EmbeddingUpdated => {
131                if let Some(entity_id) = &change.entity_id {
132                    keys.push(format!("embedding:{entity_id}"));
133                    keys.push(format!("similarity:{entity_id}"));
134                }
135            },
136            _ => {},
137        }
138
139        keys
140    }
141
142    /// Gets statistics about cache invalidations
143    pub fn get_invalidation_stats(&self) -> InvalidationStats {
144        let log = self.invalidation_log.lock();
145
146        InvalidationStats {
147            total_invalidations: log.len(),
148            cache_regions: self.cache_regions.len(),
149            entity_mappings: self.entity_to_regions.len(),
150            last_invalidation: log.last().map(|(time, _)| *time),
151        }
152    }
153}
154
155/// Statistics about cache invalidations
156#[derive(Debug, Clone)]
157pub struct InvalidationStats {
158    /// Total number of invalidations performed
159    pub total_invalidations: usize,
160    /// Number of cache regions registered
161    pub cache_regions: usize,
162    /// Number of entity-to-region mappings
163    pub entity_mappings: usize,
164    /// Timestamp of last invalidation
165    pub last_invalidation: Option<DateTime<Utc>>,
166}
167
168// ============================================================================
169// Conflict Resolution
170// ============================================================================
171
172/// Conflict resolver with multiple strategies
173pub struct ConflictResolver {
174    pub(super) strategy: ConflictStrategy,
175    custom_resolvers: HashMap<String, ConflictResolverFn>,
176}
177
178// Reduce type complexity for custom resolver function type
179type ConflictResolverFn = Box<dyn Fn(&Conflict) -> Result<ConflictResolution> + Send + Sync>;
180
181impl ConflictResolver {
182    /// Creates a new conflict resolver with the given strategy
183    pub fn new(strategy: ConflictStrategy) -> Self {
184        Self {
185            strategy,
186            custom_resolvers: HashMap::new(),
187        }
188    }
189
190    /// Adds a custom resolver function by name
191    pub fn with_custom_resolver<F>(mut self, name: String, resolver: F) -> Self
192    where
193        F: Fn(&Conflict) -> Result<ConflictResolution> + Send + Sync + 'static,
194    {
195        self.custom_resolvers.insert(name, Box::new(resolver));
196        self
197    }
198
199    /// Resolves a conflict using the configured strategy
200    pub async fn resolve_conflict(&self, conflict: &Conflict) -> Result<ConflictResolution> {
201        match &self.strategy {
202            ConflictStrategy::KeepExisting => Ok(ConflictResolution {
203                strategy: ConflictStrategy::KeepExisting,
204                resolved_data: conflict.existing_data.clone(),
205                metadata: HashMap::new(),
206            }),
207            ConflictStrategy::KeepNew => Ok(ConflictResolution {
208                strategy: ConflictStrategy::KeepNew,
209                resolved_data: conflict.new_data.clone(),
210                metadata: HashMap::new(),
211            }),
212            ConflictStrategy::Merge => self.merge_conflict_data(conflict).await,
213            ConflictStrategy::Custom(resolver_name) => {
214                if let Some(resolver) = self.custom_resolvers.get(resolver_name) {
215                    resolver(conflict)
216                } else {
217                    Err(GraphRAGError::ConflictResolution {
218                        message: format!("Custom resolver '{resolver_name}' not found"),
219                    })
220                }
221            },
222            _ => Err(GraphRAGError::ConflictResolution {
223                message: "Conflict resolution strategy not implemented".to_string(),
224            }),
225        }
226    }
227
228    async fn merge_conflict_data(&self, conflict: &Conflict) -> Result<ConflictResolution> {
229        match (&conflict.existing_data, &conflict.new_data) {
230            (ChangeData::Entity(existing), ChangeData::Entity(new)) => {
231                let merged = self.merge_entities(existing, new)?;
232                Ok(ConflictResolution {
233                    strategy: ConflictStrategy::Merge,
234                    resolved_data: ChangeData::Entity(merged),
235                    metadata: [("merge_strategy".to_string(), "entity_merge".to_string())]
236                        .into_iter()
237                        .collect(),
238                })
239            },
240            (ChangeData::Relationship(existing), ChangeData::Relationship(new)) => {
241                let merged = self.merge_relationships(existing, new)?;
242                Ok(ConflictResolution {
243                    strategy: ConflictStrategy::Merge,
244                    resolved_data: ChangeData::Relationship(merged),
245                    metadata: [(
246                        "merge_strategy".to_string(),
247                        "relationship_merge".to_string(),
248                    )]
249                    .into_iter()
250                    .collect(),
251                })
252            },
253            _ => Err(GraphRAGError::ConflictResolution {
254                message: "Cannot merge incompatible data types".to_string(),
255            }),
256        }
257    }
258
259    pub(super) fn merge_entities(&self, existing: &Entity, new: &Entity) -> Result<Entity> {
260        let mut merged = existing.clone();
261
262        // Use higher confidence
263        if new.confidence > existing.confidence {
264            merged.confidence = new.confidence;
265            merged.name = new.name.clone();
266            merged.entity_type = new.entity_type.clone();
267        }
268
269        // Merge mentions
270        let mut all_mentions = existing.mentions.clone();
271        for new_mention in &new.mentions {
272            if !all_mentions.iter().any(|m| {
273                m.chunk_id == new_mention.chunk_id && m.start_offset == new_mention.start_offset
274            }) {
275                all_mentions.push(new_mention.clone());
276            }
277        }
278        merged.mentions = all_mentions;
279
280        // Prefer new embedding if available
281        if new.embedding.is_some() {
282            merged.embedding = new.embedding.clone();
283        }
284
285        Ok(merged)
286    }
287
288    fn merge_relationships(
289        &self,
290        existing: &Relationship,
291        new: &Relationship,
292    ) -> Result<Relationship> {
293        let mut merged = existing.clone();
294
295        // Use higher confidence
296        if new.confidence > existing.confidence {
297            merged.confidence = new.confidence;
298            merged.relation_type = new.relation_type.clone();
299        }
300
301        // Merge contexts
302        let mut all_contexts = existing.context.clone();
303        for new_context in &new.context {
304            if !all_contexts.contains(new_context) {
305                all_contexts.push(new_context.clone());
306            }
307        }
308        merged.context = all_contexts;
309
310        Ok(merged)
311    }
312}
313
314// ============================================================================
315// Update Monitor and Metrics
316// ============================================================================
317
318/// Monitor for tracking update operations and performance
319#[cfg(feature = "incremental")]
320pub struct UpdateMonitor {
321    metrics: DashMap<String, UpdateMetric>,
322    operations_log: Mutex<Vec<OperationLog>>,
323    performance_stats: RwLock<PerformanceStats>,
324}
325
326#[cfg(feature = "incremental")]
327impl Default for UpdateMonitor {
328    fn default() -> Self {
329        Self::new()
330    }
331}
332
333/// Metric for tracking update operations
334#[derive(Debug, Clone)]
335pub struct UpdateMetric {
336    /// Name of the metric
337    pub name: String,
338    /// Metric value
339    pub value: f64,
340    /// When the metric was recorded
341    pub timestamp: DateTime<Utc>,
342    /// Tags for categorizing the metric
343    pub tags: HashMap<String, String>,
344}
345
346/// Log entry for an operation
347#[derive(Debug, Clone)]
348pub struct OperationLog {
349    /// Unique operation identifier
350    pub operation_id: UpdateId,
351    /// Type of operation performed
352    pub operation_type: String,
353    /// When the operation started
354    pub start_time: Instant,
355    /// When the operation ended
356    pub end_time: Option<Instant>,
357    /// Whether the operation succeeded
358    pub success: Option<bool>,
359    /// Error message if failed
360    pub error_message: Option<String>,
361    /// Number of entities affected
362    pub affected_entities: usize,
363    /// Number of relationships affected
364    pub affected_relationships: usize,
365}
366
367/// Performance statistics for monitoring
368#[derive(Debug, Clone)]
369pub struct PerformanceStats {
370    /// Total number of operations performed
371    pub total_operations: u64,
372    /// Number of successful operations
373    pub successful_operations: u64,
374    /// Number of failed operations
375    pub failed_operations: u64,
376    /// Average time per operation
377    pub average_operation_time: Duration,
378    /// Peak throughput in operations per second
379    pub peak_operations_per_second: f64,
380    /// Cache hit rate (0.0 to 1.0)
381    pub cache_hit_rate: f64,
382    /// Conflict resolution rate (0.0 to 1.0)
383    pub conflict_resolution_rate: f64,
384}
385
386#[cfg(feature = "incremental")]
387impl UpdateMonitor {
388    /// Creates a new update monitor
389    pub fn new() -> Self {
390        Self {
391            metrics: DashMap::new(),
392            operations_log: Mutex::new(Vec::new()),
393            performance_stats: RwLock::new(PerformanceStats {
394                total_operations: 0,
395                successful_operations: 0,
396                failed_operations: 0,
397                average_operation_time: Duration::from_millis(0),
398                peak_operations_per_second: 0.0,
399                cache_hit_rate: 0.0,
400                conflict_resolution_rate: 0.0,
401            }),
402        }
403    }
404
405    /// Starts tracking a new operation and returns its ID
406    pub fn start_operation(&self, operation_type: &str) -> UpdateId {
407        let operation_id = UpdateId::new();
408        let log_entry = OperationLog {
409            operation_id: operation_id.clone(),
410            operation_type: operation_type.to_string(),
411            start_time: Instant::now(),
412            end_time: None,
413            success: None,
414            error_message: None,
415            affected_entities: 0,
416            affected_relationships: 0,
417        };
418
419        self.operations_log.lock().push(log_entry);
420        operation_id
421    }
422
423    /// Marks an operation as complete with results
424    pub fn complete_operation(
425        &self,
426        operation_id: &UpdateId,
427        success: bool,
428        error: Option<String>,
429        affected_entities: usize,
430        affected_relationships: usize,
431    ) {
432        let mut log = self.operations_log.lock();
433        if let Some(entry) = log.iter_mut().find(|e| &e.operation_id == operation_id) {
434            entry.end_time = Some(Instant::now());
435            entry.success = Some(success);
436            entry.error_message = error;
437            entry.affected_entities = affected_entities;
438            entry.affected_relationships = affected_relationships;
439        }
440
441        // Update performance stats
442        self.update_performance_stats();
443    }
444
445    fn update_performance_stats(&self) {
446        let log = self.operations_log.lock();
447        let completed_ops: Vec<_> = log
448            .iter()
449            .filter(|op| op.end_time.is_some() && op.success.is_some())
450            .collect();
451
452        if completed_ops.is_empty() {
453            return;
454        }
455
456        let mut stats = self.performance_stats.write();
457        stats.total_operations = completed_ops.len() as u64;
458        stats.successful_operations = completed_ops
459            .iter()
460            .filter(|op| op.success == Some(true))
461            .count() as u64;
462        stats.failed_operations = stats.total_operations - stats.successful_operations;
463
464        // Calculate average operation time
465        let total_time: Duration = completed_ops
466            .iter()
467            .filter_map(|op| op.end_time.map(|end| end.duration_since(op.start_time)))
468            .sum();
469
470        if !completed_ops.is_empty() {
471            stats.average_operation_time = total_time / completed_ops.len() as u32;
472        }
473    }
474
475    /// Records a metric with tags
476    pub fn record_metric(&self, name: &str, value: f64, tags: HashMap<String, String>) {
477        let metric = UpdateMetric {
478            name: name.to_string(),
479            value,
480            timestamp: Utc::now(),
481            tags,
482        };
483        self.metrics.insert(name.to_string(), metric);
484    }
485
486    /// Gets the current performance statistics
487    pub fn get_performance_stats(&self) -> PerformanceStats {
488        self.performance_stats.read().clone()
489    }
490
491    /// Gets the most recent operations up to the specified limit
492    pub fn get_recent_operations(&self, limit: usize) -> Vec<OperationLog> {
493        let log = self.operations_log.lock();
494        log.iter().rev().take(limit).cloned().collect()
495    }
496}