oxirs_embed/
temporal_embeddings.rs

1//! Temporal Embeddings Module for Time-Aware Knowledge Graph Embeddings
2//!
3//! This module provides temporal embedding capabilities for knowledge graphs
4//! that evolve over time, capturing temporal patterns and dynamics.
5//!
6//! ## Features
7//!
8//! - **Time-Aware Embeddings**: Embed entities/relations with temporal context
9//! - **Temporal Evolution**: Track how embeddings change over time
10//! - **Time Series Analysis**: Analyze temporal patterns in knowledge graphs
11//! - **Temporal Queries**: Support time-based queries and predictions
12//! - **Event Detection**: Detect significant temporal events
13//! - **Forecasting**: Predict future entity/relation states
14//!
15//! ## Temporal Models
16//!
17//! - **TTransE**: Temporal extension of TransE
18//! - **TA-DistMult**: Time-aware DistMult
19//! - **DE-SimplE**: Diachronic embedding model
20//! - **ChronoR**: Recurrent temporal embeddings
21//! - **TeMP**: Temporal message passing
22//!
23//! ## Use Cases
24//!
25//! - Historical data analysis
26//! - Event prediction
27//! - Temporal reasoning
28//! - Dynamic knowledge graphs
29//! - Time-series forecasting
30
31use anyhow::Result;
32use chrono::{DateTime, Duration, Utc};
33use serde::{Deserialize, Serialize};
34use std::collections::{BTreeMap, HashMap};
35use std::sync::Arc;
36use tokio::sync::RwLock;
37use tracing::{debug, info};
38
39use crate::{ModelConfig, TrainingStats, Triple, Vector};
40use uuid::Uuid;
41
42// Type aliases to simplify complex types
43type TemporalEntityEmbeddings = Arc<RwLock<HashMap<String, BTreeMap<DateTime<Utc>, Vector>>>>;
44type TemporalRelationEmbeddings = Arc<RwLock<HashMap<String, BTreeMap<DateTime<Utc>, Vector>>>>;
45
46// Placeholder for time series analysis (will be implemented with scirs2-stats)
47#[derive(Debug, Clone, Default)]
48pub struct TimeSeriesAnalyzer;
49
50#[derive(Debug, Clone)]
51pub enum ForecastMethod {
52    ExponentialSmoothing,
53    Arima,
54    Prophet,
55}
56
57/// Temporal granularity
58#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
59pub enum TemporalGranularity {
60    /// Second-level precision
61    Second,
62    /// Minute-level precision
63    Minute,
64    /// Hour-level precision
65    Hour,
66    /// Day-level precision
67    Day,
68    /// Week-level precision
69    Week,
70    /// Month-level precision
71    Month,
72    /// Year-level precision
73    Year,
74    /// Custom duration
75    Custom(i64), // seconds
76}
77
78/// Temporal scope
79#[derive(Debug, Clone, Serialize, Deserialize)]
80pub enum TemporalScope {
81    /// Point in time
82    Instant(DateTime<Utc>),
83    /// Time interval
84    Interval {
85        start: DateTime<Utc>,
86        end: DateTime<Utc>,
87    },
88    /// Periodic (e.g., every Monday)
89    Periodic {
90        start: DateTime<Utc>,
91        period: Duration,
92        count: Option<usize>,
93    },
94    /// Unbounded (always valid)
95    Unbounded,
96}
97
98/// Temporal triple with time information
99#[derive(Debug, Clone, Serialize, Deserialize)]
100pub struct TemporalTriple {
101    /// The RDF triple
102    pub triple: Triple,
103    /// Temporal scope when this triple is valid
104    pub scope: TemporalScope,
105    /// Confidence score (0.0-1.0)
106    pub confidence: f32,
107    /// Source/provenance information
108    pub source: Option<String>,
109}
110
111impl TemporalTriple {
112    /// Create a new temporal triple
113    pub fn new(triple: Triple, scope: TemporalScope) -> Self {
114        Self {
115            triple,
116            scope,
117            confidence: 1.0,
118            source: None,
119        }
120    }
121
122    /// Check if this triple is valid at the given time
123    pub fn is_valid_at(&self, time: &DateTime<Utc>) -> bool {
124        match &self.scope {
125            TemporalScope::Instant(instant) => instant == time,
126            TemporalScope::Interval { start, end } => time >= start && time <= end,
127            TemporalScope::Periodic {
128                start,
129                period,
130                count,
131            } => {
132                let elapsed = time.signed_duration_since(*start);
133                if elapsed < Duration::zero() {
134                    return false;
135                }
136                if let Some(max_count) = count {
137                    let num_periods = elapsed.num_seconds() / period.num_seconds();
138                    num_periods < *max_count as i64
139                } else {
140                    true
141                }
142            }
143            TemporalScope::Unbounded => true,
144        }
145    }
146}
147
148/// Temporal embedding configuration
149#[derive(Debug, Clone, Serialize, Deserialize)]
150pub struct TemporalEmbeddingConfig {
151    /// Base model configuration
152    pub base_config: ModelConfig,
153    /// Temporal granularity
154    pub granularity: TemporalGranularity,
155    /// Time embedding dimensions
156    pub time_dim: usize,
157    /// Enable temporal decay
158    pub enable_decay: bool,
159    /// Decay rate (for exponential decay)
160    pub decay_rate: f32,
161    /// Enable temporal smoothing
162    pub enable_smoothing: bool,
163    /// Smoothing window size
164    pub smoothing_window: usize,
165    /// Enable forecasting
166    pub enable_forecasting: bool,
167    /// Forecast horizon (number of time steps)
168    pub forecast_horizon: usize,
169    /// Enable event detection
170    pub enable_event_detection: bool,
171    /// Event threshold
172    pub event_threshold: f32,
173}
174
175impl Default for TemporalEmbeddingConfig {
176    fn default() -> Self {
177        Self {
178            base_config: ModelConfig::default(),
179            granularity: TemporalGranularity::Day,
180            time_dim: 32,
181            enable_decay: true,
182            decay_rate: 0.9,
183            enable_smoothing: true,
184            smoothing_window: 7,
185            enable_forecasting: true,
186            forecast_horizon: 30,
187            enable_event_detection: false,
188            event_threshold: 0.7,
189        }
190    }
191}
192
193/// Temporal event
194#[derive(Debug, Clone, Serialize, Deserialize)]
195pub struct TemporalEvent {
196    /// Event ID
197    pub event_id: String,
198    /// Event type
199    pub event_type: String,
200    /// Timestamp
201    pub timestamp: DateTime<Utc>,
202    /// Entities involved
203    pub entities: Vec<String>,
204    /// Relations involved
205    pub relations: Vec<String>,
206    /// Event significance score
207    pub significance: f32,
208    /// Event description
209    pub description: Option<String>,
210}
211
212/// Temporal forecast result
213#[derive(Debug, Clone, Serialize, Deserialize)]
214pub struct TemporalForecast {
215    /// Entity or relation being forecasted
216    pub target: String,
217    /// Forecast timestamps
218    pub timestamps: Vec<DateTime<Utc>>,
219    /// Predicted embeddings
220    pub predictions: Vec<Vector>,
221    /// Confidence intervals (lower, upper)
222    pub confidence_intervals: Vec<(Vector, Vector)>,
223    /// Forecast accuracy (if validation data available)
224    pub accuracy: Option<f32>,
225}
226
227/// Temporal embedding model
228pub struct TemporalEmbeddingModel {
229    config: TemporalEmbeddingConfig,
230    model_id: Uuid,
231
232    // Entity embeddings over time: entity -> time -> embedding
233    entity_embeddings: TemporalEntityEmbeddings,
234
235    // Relation embeddings over time: relation -> time -> embedding
236    relation_embeddings: TemporalRelationEmbeddings,
237
238    // Time embeddings: timestamp -> time embedding
239    time_embeddings: Arc<RwLock<BTreeMap<DateTime<Utc>, Vector>>>,
240
241    // Temporal triples
242    temporal_triples: Arc<RwLock<Vec<TemporalTriple>>>,
243
244    // Detected events
245    events: Arc<RwLock<Vec<TemporalEvent>>>,
246
247    // Time series analyzer
248    time_series_analyzer: Option<TimeSeriesAnalyzer>,
249
250    // Training state
251    is_trained: Arc<RwLock<bool>>,
252}
253
254impl TemporalEmbeddingModel {
255    /// Create a new temporal embedding model
256    pub fn new(config: TemporalEmbeddingConfig) -> Self {
257        info!(
258            "Creating temporal embedding model with time_dim={}",
259            config.time_dim
260        );
261
262        Self {
263            model_id: Uuid::new_v4(),
264            time_series_analyzer: Some(TimeSeriesAnalyzer),
265            config,
266            entity_embeddings: Arc::new(RwLock::new(HashMap::new())),
267            relation_embeddings: Arc::new(RwLock::new(HashMap::new())),
268            time_embeddings: Arc::new(RwLock::new(BTreeMap::new())),
269            temporal_triples: Arc::new(RwLock::new(Vec::new())),
270            events: Arc::new(RwLock::new(Vec::new())),
271            is_trained: Arc::new(RwLock::new(false)),
272        }
273    }
274
275    /// Add a temporal triple
276    pub async fn add_temporal_triple(&mut self, temporal_triple: TemporalTriple) -> Result<()> {
277        let mut triples = self.temporal_triples.write().await;
278        triples.push(temporal_triple);
279        Ok(())
280    }
281
282    /// Get entity embedding at a specific time
283    pub async fn get_entity_embedding_at_time(
284        &self,
285        entity: &str,
286        time: &DateTime<Utc>,
287    ) -> Result<Vector> {
288        let embeddings = self.entity_embeddings.read().await;
289
290        if let Some(time_series) = embeddings.get(entity) {
291            // Find the closest time point
292            if let Some((_, embedding)) = time_series.range(..=time).next_back() {
293                return Ok(embedding.clone());
294            }
295        }
296
297        Err(anyhow::anyhow!(
298            "Entity '{}' not found at time {}",
299            entity,
300            time
301        ))
302    }
303
304    /// Get relation embedding at a specific time
305    pub async fn get_relation_embedding_at_time(
306        &self,
307        relation: &str,
308        time: &DateTime<Utc>,
309    ) -> Result<Vector> {
310        let embeddings = self.relation_embeddings.read().await;
311
312        if let Some(time_series) = embeddings.get(relation) {
313            // Find the closest time point
314            if let Some((_, embedding)) = time_series.range(..=time).next_back() {
315                return Ok(embedding.clone());
316            }
317        }
318
319        Err(anyhow::anyhow!(
320            "Relation '{}' not found at time {}",
321            relation,
322            time
323        ))
324    }
325
326    /// Train temporal embeddings
327    pub async fn train_temporal(&mut self, epochs: usize) -> Result<TrainingStats> {
328        info!("Training temporal embeddings for {} epochs", epochs);
329
330        let start_time = std::time::Instant::now();
331        let mut loss_history = Vec::new();
332
333        for epoch in 0..epochs {
334            let loss = self.train_epoch(epoch).await?;
335            loss_history.push(loss);
336
337            if epoch % 10 == 0 {
338                debug!("Epoch {}/{}: loss={:.6}", epoch + 1, epochs, loss);
339            }
340        }
341
342        *self.is_trained.write().await = true;
343
344        let elapsed = start_time.elapsed().as_secs_f64();
345        let final_loss = *loss_history.last().unwrap_or(&0.0);
346
347        info!(
348            "Temporal training completed in {:.2}s, final loss: {:.6}",
349            elapsed, final_loss
350        );
351
352        Ok(TrainingStats {
353            epochs_completed: epochs,
354            final_loss,
355            training_time_seconds: elapsed,
356            convergence_achieved: final_loss < 0.01,
357            loss_history,
358        })
359    }
360
361    /// Train a single epoch
362    async fn train_epoch(&mut self, _epoch: usize) -> Result<f64> {
363        // Simplified training - in a real implementation, this would:
364        // 1. Sample temporal triples
365        // 2. Compute time-aware embeddings
366        // 3. Calculate temporal loss (considering time decay)
367        // 4. Update embeddings with gradient descent
368
369        // Initialize some sample embeddings for testing
370        let triples = self.temporal_triples.read().await;
371        let dim = self.config.base_config.dimensions;
372
373        use scirs2_core::random::Random;
374        let mut rng = Random::default();
375
376        for temporal_triple in triples.iter() {
377            let embedding = Vector::new(
378                (0..dim)
379                    .map(|_| rng.random_range(-1.0, 1.0) as f32)
380                    .collect(),
381            );
382
383            // Store embedding with timestamp
384            let timestamp = match &temporal_triple.scope {
385                TemporalScope::Instant(t) => *t,
386                TemporalScope::Interval { start, .. } => *start,
387                _ => Utc::now(),
388            };
389
390            let entity = temporal_triple.triple.subject.iri.clone();
391            let mut entity_embs = self.entity_embeddings.write().await;
392            entity_embs
393                .entry(entity)
394                .or_insert_with(BTreeMap::new)
395                .insert(timestamp, embedding);
396        }
397
398        Ok(0.1) // Simplified loss
399    }
400
401    /// Forecast future embeddings
402    pub async fn forecast(&self, entity: &str, horizon: usize) -> Result<TemporalForecast> {
403        info!(
404            "Forecasting {} time steps ahead for entity: {}",
405            horizon, entity
406        );
407
408        let embeddings = self.entity_embeddings.read().await;
409
410        if let Some(time_series) = embeddings.get(entity) {
411            let timestamps: Vec<DateTime<Utc>> = time_series.keys().cloned().collect();
412            let last_time = timestamps
413                .last()
414                .ok_or_else(|| anyhow::anyhow!("No temporal data for entity: {}", entity))?;
415
416            // Generate future timestamps based on granularity
417            let time_step = match self.config.granularity {
418                TemporalGranularity::Second => Duration::seconds(1),
419                TemporalGranularity::Minute => Duration::minutes(1),
420                TemporalGranularity::Hour => Duration::hours(1),
421                TemporalGranularity::Day => Duration::days(1),
422                TemporalGranularity::Week => Duration::weeks(1),
423                TemporalGranularity::Month => Duration::days(30),
424                TemporalGranularity::Year => Duration::days(365),
425                TemporalGranularity::Custom(secs) => Duration::seconds(secs),
426            };
427
428            let mut future_timestamps = Vec::new();
429            let mut predictions = Vec::new();
430            let mut confidence_intervals = Vec::new();
431
432            for i in 1..=horizon {
433                let future_time = *last_time + time_step * i as i32;
434                future_timestamps.push(future_time);
435
436                // Simple forecasting: use last known embedding with decay
437                let last_embedding = time_series.values().last().unwrap();
438                let decay_factor = self.config.decay_rate.powi(i as i32);
439
440                let prediction = last_embedding.mapv(|v| v * decay_factor);
441                let std_dev = 0.1 * (1.0 - decay_factor);
442
443                let lower = last_embedding.mapv(|v| (v * decay_factor) - std_dev);
444                let upper = last_embedding.mapv(|v| (v * decay_factor) + std_dev);
445
446                predictions.push(prediction);
447                confidence_intervals.push((lower, upper));
448            }
449
450            Ok(TemporalForecast {
451                target: entity.to_string(),
452                timestamps: future_timestamps,
453                predictions,
454                confidence_intervals,
455                accuracy: None,
456            })
457        } else {
458            Err(anyhow::anyhow!("Entity '{}' not found", entity))
459        }
460    }
461
462    /// Detect temporal events
463    pub async fn detect_events(&mut self, threshold: f32) -> Result<Vec<TemporalEvent>> {
464        info!("Detecting temporal events with threshold: {}", threshold);
465
466        let entity_embeddings = self.entity_embeddings.read().await;
467        let mut detected_events = Vec::new();
468
469        // Detect significant changes in embeddings over time
470        for (entity, time_series) in entity_embeddings.iter() {
471            let mut prev_embedding: Option<&Vector> = None;
472            let mut prev_time: Option<&DateTime<Utc>> = None;
473
474            for (time, embedding) in time_series.iter() {
475                if let (Some(prev_emb), Some(prev_t)) = (prev_embedding, prev_time) {
476                    // Calculate change magnitude
477                    let diff: Vec<f32> = embedding
478                        .values
479                        .iter()
480                        .zip(prev_emb.values.iter())
481                        .map(|(a, b)| (a - b).abs())
482                        .collect();
483                    let change_magnitude = diff.iter().sum::<f32>() / diff.len() as f32;
484
485                    if change_magnitude > threshold {
486                        let event = TemporalEvent {
487                            event_id: format!("event_{}_{}", entity, time.timestamp()),
488                            event_type: "embedding_shift".to_string(),
489                            timestamp: *time,
490                            entities: vec![entity.clone()],
491                            relations: Vec::new(),
492                            significance: change_magnitude,
493                            description: Some(format!(
494                                "Significant embedding change detected for '{}' between {} and {}",
495                                entity, prev_t, time
496                            )),
497                        };
498                        detected_events.push(event);
499                    }
500                }
501
502                prev_embedding = Some(embedding);
503                prev_time = Some(time);
504            }
505        }
506
507        // Store detected events
508        let mut events = self.events.write().await;
509        events.extend(detected_events.clone());
510
511        info!("Detected {} temporal events", detected_events.len());
512        Ok(detected_events)
513    }
514
515    /// Get all detected events
516    pub async fn get_events(&self) -> Vec<TemporalEvent> {
517        self.events.read().await.clone()
518    }
519
520    /// Query triples valid at a specific time
521    pub async fn query_at_time(&self, time: &DateTime<Utc>) -> Vec<Triple> {
522        let triples = self.temporal_triples.read().await;
523        triples
524            .iter()
525            .filter(|tt| tt.is_valid_at(time))
526            .map(|tt| tt.triple.clone())
527            .collect()
528    }
529
530    /// Get temporal statistics
531    pub async fn get_temporal_stats(&self) -> TemporalStats {
532        let entity_embeddings = self.entity_embeddings.read().await;
533        let relation_embeddings = self.relation_embeddings.read().await;
534        let triples = self.temporal_triples.read().await;
535        let events = self.events.read().await;
536
537        // Calculate time span
538        let all_times: Vec<DateTime<Utc>> = entity_embeddings
539            .values()
540            .flat_map(|ts| ts.keys().cloned())
541            .collect();
542
543        let (min_time, max_time) = if all_times.is_empty() {
544            (None, None)
545        } else {
546            (
547                all_times.iter().min().cloned(),
548                all_times.iter().max().cloned(),
549            )
550        };
551
552        TemporalStats {
553            num_temporal_triples: triples.len(),
554            num_entities: entity_embeddings.len(),
555            num_relations: relation_embeddings.len(),
556            num_time_points: all_times.len(),
557            num_events: events.len(),
558            time_span_start: min_time,
559            time_span_end: max_time,
560            granularity: self.config.granularity.clone(),
561        }
562    }
563}
564
565/// Temporal statistics
566#[derive(Debug, Clone, Serialize, Deserialize)]
567pub struct TemporalStats {
568    pub num_temporal_triples: usize,
569    pub num_entities: usize,
570    pub num_relations: usize,
571    pub num_time_points: usize,
572    pub num_events: usize,
573    pub time_span_start: Option<DateTime<Utc>>,
574    pub time_span_end: Option<DateTime<Utc>>,
575    pub granularity: TemporalGranularity,
576}
577
578#[cfg(test)]
579mod tests {
580    use super::*;
581    use crate::NamedNode;
582
583    #[tokio::test]
584    async fn test_temporal_model_creation() {
585        let config = TemporalEmbeddingConfig::default();
586        let model = TemporalEmbeddingModel::new(config);
587        assert_eq!(model.config.time_dim, 32);
588    }
589
590    #[tokio::test]
591    async fn test_temporal_triple_validity() {
592        let triple = Triple::new(
593            NamedNode::new("http://example.org/alice").unwrap(),
594            NamedNode::new("http://example.org/worksFor").unwrap(),
595            NamedNode::new("http://example.org/company").unwrap(),
596        );
597
598        let start = Utc::now();
599        let end = start + Duration::days(365);
600
601        let temporal_triple = TemporalTriple::new(triple, TemporalScope::Interval { start, end });
602
603        let now = Utc::now();
604        assert!(temporal_triple.is_valid_at(&now));
605
606        let future = now + Duration::days(400);
607        assert!(!temporal_triple.is_valid_at(&future));
608    }
609
610    #[tokio::test]
611    async fn test_temporal_embedding_add_triple() {
612        let config = TemporalEmbeddingConfig::default();
613        let mut model = TemporalEmbeddingModel::new(config);
614
615        let triple = Triple::new(
616            NamedNode::new("http://example.org/alice").unwrap(),
617            NamedNode::new("http://example.org/knows").unwrap(),
618            NamedNode::new("http://example.org/bob").unwrap(),
619        );
620
621        let temporal_triple = TemporalTriple::new(triple, TemporalScope::Instant(Utc::now()));
622
623        model.add_temporal_triple(temporal_triple).await.unwrap();
624
625        let stats = model.get_temporal_stats().await;
626        assert_eq!(stats.num_temporal_triples, 1);
627    }
628
629    #[tokio::test]
630    async fn test_temporal_training() {
631        let config = TemporalEmbeddingConfig::default();
632        let mut model = TemporalEmbeddingModel::new(config);
633
634        // Add some temporal triples
635        for i in 0..5 {
636            let triple = Triple::new(
637                NamedNode::new(&format!("http://example.org/entity_{}", i)).unwrap(),
638                NamedNode::new("http://example.org/relation").unwrap(),
639                NamedNode::new("http://example.org/target").unwrap(),
640            );
641
642            let temporal_triple = TemporalTriple::new(
643                triple,
644                TemporalScope::Instant(Utc::now() + Duration::days(i)),
645            );
646
647            model.add_temporal_triple(temporal_triple).await.unwrap();
648        }
649
650        let stats = model.train_temporal(10).await.unwrap();
651        assert_eq!(stats.epochs_completed, 10);
652        assert!(stats.final_loss >= 0.0);
653    }
654
655    #[tokio::test]
656    async fn test_temporal_forecasting() {
657        let config = TemporalEmbeddingConfig::default();
658        let mut model = TemporalEmbeddingModel::new(config);
659
660        // Add temporal data
661        let triple = Triple::new(
662            NamedNode::new("http://example.org/entity").unwrap(),
663            NamedNode::new("http://example.org/relation").unwrap(),
664            NamedNode::new("http://example.org/target").unwrap(),
665        );
666
667        let temporal_triple = TemporalTriple::new(triple, TemporalScope::Instant(Utc::now()));
668
669        model.add_temporal_triple(temporal_triple).await.unwrap();
670        model.train_temporal(5).await.unwrap();
671
672        let forecast = model
673            .forecast("http://example.org/entity", 10)
674            .await
675            .unwrap();
676        assert_eq!(forecast.predictions.len(), 10);
677        assert_eq!(forecast.timestamps.len(), 10);
678    }
679
680    #[tokio::test]
681    async fn test_event_detection() {
682        let config = TemporalEmbeddingConfig {
683            event_threshold: 0.3,
684            ..Default::default()
685        };
686        let mut model = TemporalEmbeddingModel::new(config);
687
688        // Add temporal triples and train
689        for i in 0..3 {
690            let triple = Triple::new(
691                NamedNode::new("http://example.org/entity").unwrap(),
692                NamedNode::new("http://example.org/relation").unwrap(),
693                NamedNode::new("http://example.org/target").unwrap(),
694            );
695
696            let temporal_triple = TemporalTriple::new(
697                triple,
698                TemporalScope::Instant(Utc::now() + Duration::days(i)),
699            );
700
701            model.add_temporal_triple(temporal_triple).await.unwrap();
702        }
703
704        model.train_temporal(5).await.unwrap();
705        let _events = model.detect_events(0.3).await.unwrap();
706
707        // Events may or may not be detected depending on random initialization
708        // Just verify the function executes without error (verified by unwrap)
709    }
710}