oxirs_embed/
temporal_embeddings.rs

1//! Temporal Embeddings Module for Time-Aware Knowledge Graph Embeddings
2//!
3//! This module provides temporal embedding capabilities for knowledge graphs
4//! that evolve over time, capturing temporal patterns and dynamics.
5//!
6//! ## Features
7//!
8//! - **Time-Aware Embeddings**: Embed entities/relations with temporal context
9//! - **Temporal Evolution**: Track how embeddings change over time
10//! - **Time Series Analysis**: Analyze temporal patterns in knowledge graphs
11//! - **Temporal Queries**: Support time-based queries and predictions
12//! - **Event Detection**: Detect significant temporal events
13//! - **Forecasting**: Predict future entity/relation states
14//!
15//! ## Temporal Models
16//!
17//! - **TTransE**: Temporal extension of TransE
18//! - **TA-DistMult**: Time-aware DistMult
19//! - **DE-SimplE**: Diachronic embedding model
20//! - **ChronoR**: Recurrent temporal embeddings
21//! - **TeMP**: Temporal message passing
22//!
23//! ## Use Cases
24//!
25//! - Historical data analysis
26//! - Event prediction
27//! - Temporal reasoning
28//! - Dynamic knowledge graphs
29//! - Time-series forecasting
30
31use anyhow::Result;
32use chrono::{DateTime, Duration, Utc};
33use serde::{Deserialize, Serialize};
34use std::collections::{BTreeMap, HashMap};
35use std::sync::Arc;
36use tokio::sync::RwLock;
37use tracing::{debug, info};
38
39use crate::{ModelConfig, TrainingStats, Triple, Vector};
40use uuid::Uuid;
41
42// Type aliases to simplify complex types
43type TemporalEntityEmbeddings = Arc<RwLock<HashMap<String, BTreeMap<DateTime<Utc>, Vector>>>>;
44type TemporalRelationEmbeddings = Arc<RwLock<HashMap<String, BTreeMap<DateTime<Utc>, Vector>>>>;
45
46// Placeholder for time series analysis (will be implemented with scirs2-stats)
47#[derive(Debug, Clone, Default)]
48pub struct TimeSeriesAnalyzer;
49
50#[derive(Debug, Clone)]
51pub enum ForecastMethod {
52    ExponentialSmoothing,
53    Arima,
54    Prophet,
55}
56
57/// Temporal granularity
58#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
59pub enum TemporalGranularity {
60    /// Second-level precision
61    Second,
62    /// Minute-level precision
63    Minute,
64    /// Hour-level precision
65    Hour,
66    /// Day-level precision
67    Day,
68    /// Week-level precision
69    Week,
70    /// Month-level precision
71    Month,
72    /// Year-level precision
73    Year,
74    /// Custom duration
75    Custom(i64), // seconds
76}
77
78/// Temporal scope
79#[derive(Debug, Clone, Serialize, Deserialize)]
80pub enum TemporalScope {
81    /// Point in time
82    Instant(DateTime<Utc>),
83    /// Time interval
84    Interval {
85        start: DateTime<Utc>,
86        end: DateTime<Utc>,
87    },
88    /// Periodic (e.g., every Monday)
89    Periodic {
90        start: DateTime<Utc>,
91        period: Duration,
92        count: Option<usize>,
93    },
94    /// Unbounded (always valid)
95    Unbounded,
96}
97
98/// Temporal triple with time information
99#[derive(Debug, Clone, Serialize, Deserialize)]
100pub struct TemporalTriple {
101    /// The RDF triple
102    pub triple: Triple,
103    /// Temporal scope when this triple is valid
104    pub scope: TemporalScope,
105    /// Confidence score (0.0-1.0)
106    pub confidence: f32,
107    /// Source/provenance information
108    pub source: Option<String>,
109}
110
111impl TemporalTriple {
112    /// Create a new temporal triple
113    pub fn new(triple: Triple, scope: TemporalScope) -> Self {
114        Self {
115            triple,
116            scope,
117            confidence: 1.0,
118            source: None,
119        }
120    }
121
122    /// Check if this triple is valid at the given time
123    pub fn is_valid_at(&self, time: &DateTime<Utc>) -> bool {
124        match &self.scope {
125            TemporalScope::Instant(instant) => instant == time,
126            TemporalScope::Interval { start, end } => time >= start && time <= end,
127            TemporalScope::Periodic {
128                start,
129                period,
130                count,
131            } => {
132                let elapsed = time.signed_duration_since(*start);
133                if elapsed < Duration::zero() {
134                    return false;
135                }
136                if let Some(max_count) = count {
137                    let num_periods = elapsed.num_seconds() / period.num_seconds();
138                    num_periods < *max_count as i64
139                } else {
140                    true
141                }
142            }
143            TemporalScope::Unbounded => true,
144        }
145    }
146}
147
148/// Temporal embedding configuration
149#[derive(Debug, Clone, Serialize, Deserialize)]
150pub struct TemporalEmbeddingConfig {
151    /// Base model configuration
152    pub base_config: ModelConfig,
153    /// Temporal granularity
154    pub granularity: TemporalGranularity,
155    /// Time embedding dimensions
156    pub time_dim: usize,
157    /// Enable temporal decay
158    pub enable_decay: bool,
159    /// Decay rate (for exponential decay)
160    pub decay_rate: f32,
161    /// Enable temporal smoothing
162    pub enable_smoothing: bool,
163    /// Smoothing window size
164    pub smoothing_window: usize,
165    /// Enable forecasting
166    pub enable_forecasting: bool,
167    /// Forecast horizon (number of time steps)
168    pub forecast_horizon: usize,
169    /// Enable event detection
170    pub enable_event_detection: bool,
171    /// Event threshold
172    pub event_threshold: f32,
173}
174
175impl Default for TemporalEmbeddingConfig {
176    fn default() -> Self {
177        Self {
178            base_config: ModelConfig::default(),
179            granularity: TemporalGranularity::Day,
180            time_dim: 32,
181            enable_decay: true,
182            decay_rate: 0.9,
183            enable_smoothing: true,
184            smoothing_window: 7,
185            enable_forecasting: true,
186            forecast_horizon: 30,
187            enable_event_detection: false,
188            event_threshold: 0.7,
189        }
190    }
191}
192
193/// Temporal event
194#[derive(Debug, Clone, Serialize, Deserialize)]
195pub struct TemporalEvent {
196    /// Event ID
197    pub event_id: String,
198    /// Event type
199    pub event_type: String,
200    /// Timestamp
201    pub timestamp: DateTime<Utc>,
202    /// Entities involved
203    pub entities: Vec<String>,
204    /// Relations involved
205    pub relations: Vec<String>,
206    /// Event significance score
207    pub significance: f32,
208    /// Event description
209    pub description: Option<String>,
210}
211
212/// Temporal forecast result
213#[derive(Debug, Clone, Serialize, Deserialize)]
214pub struct TemporalForecast {
215    /// Entity or relation being forecasted
216    pub target: String,
217    /// Forecast timestamps
218    pub timestamps: Vec<DateTime<Utc>>,
219    /// Predicted embeddings
220    pub predictions: Vec<Vector>,
221    /// Confidence intervals (lower, upper)
222    pub confidence_intervals: Vec<(Vector, Vector)>,
223    /// Forecast accuracy (if validation data available)
224    pub accuracy: Option<f32>,
225}
226
227/// Temporal embedding model
228pub struct TemporalEmbeddingModel {
229    config: TemporalEmbeddingConfig,
230    model_id: Uuid,
231
232    // Entity embeddings over time: entity -> time -> embedding
233    entity_embeddings: TemporalEntityEmbeddings,
234
235    // Relation embeddings over time: relation -> time -> embedding
236    relation_embeddings: TemporalRelationEmbeddings,
237
238    // Time embeddings: timestamp -> time embedding
239    time_embeddings: Arc<RwLock<BTreeMap<DateTime<Utc>, Vector>>>,
240
241    // Temporal triples
242    temporal_triples: Arc<RwLock<Vec<TemporalTriple>>>,
243
244    // Detected events
245    events: Arc<RwLock<Vec<TemporalEvent>>>,
246
247    // Time series analyzer
248    time_series_analyzer: Option<TimeSeriesAnalyzer>,
249
250    // Training state
251    is_trained: Arc<RwLock<bool>>,
252}
253
254impl TemporalEmbeddingModel {
255    /// Create a new temporal embedding model
256    pub fn new(config: TemporalEmbeddingConfig) -> Self {
257        info!(
258            "Creating temporal embedding model with time_dim={}",
259            config.time_dim
260        );
261
262        Self {
263            model_id: Uuid::new_v4(),
264            time_series_analyzer: Some(TimeSeriesAnalyzer),
265            config,
266            entity_embeddings: Arc::new(RwLock::new(HashMap::new())),
267            relation_embeddings: Arc::new(RwLock::new(HashMap::new())),
268            time_embeddings: Arc::new(RwLock::new(BTreeMap::new())),
269            temporal_triples: Arc::new(RwLock::new(Vec::new())),
270            events: Arc::new(RwLock::new(Vec::new())),
271            is_trained: Arc::new(RwLock::new(false)),
272        }
273    }
274
275    /// Add a temporal triple
276    pub async fn add_temporal_triple(&mut self, temporal_triple: TemporalTriple) -> Result<()> {
277        let mut triples = self.temporal_triples.write().await;
278        triples.push(temporal_triple);
279        Ok(())
280    }
281
282    /// Get entity embedding at a specific time
283    pub async fn get_entity_embedding_at_time(
284        &self,
285        entity: &str,
286        time: &DateTime<Utc>,
287    ) -> Result<Vector> {
288        let embeddings = self.entity_embeddings.read().await;
289
290        if let Some(time_series) = embeddings.get(entity) {
291            // Find the closest time point
292            if let Some((_, embedding)) = time_series.range(..=time).next_back() {
293                return Ok(embedding.clone());
294            }
295        }
296
297        Err(anyhow::anyhow!(
298            "Entity '{}' not found at time {}",
299            entity,
300            time
301        ))
302    }
303
304    /// Get relation embedding at a specific time
305    pub async fn get_relation_embedding_at_time(
306        &self,
307        relation: &str,
308        time: &DateTime<Utc>,
309    ) -> Result<Vector> {
310        let embeddings = self.relation_embeddings.read().await;
311
312        if let Some(time_series) = embeddings.get(relation) {
313            // Find the closest time point
314            if let Some((_, embedding)) = time_series.range(..=time).next_back() {
315                return Ok(embedding.clone());
316            }
317        }
318
319        Err(anyhow::anyhow!(
320            "Relation '{}' not found at time {}",
321            relation,
322            time
323        ))
324    }
325
326    /// Train temporal embeddings
327    pub async fn train_temporal(&mut self, epochs: usize) -> Result<TrainingStats> {
328        info!("Training temporal embeddings for {} epochs", epochs);
329
330        let start_time = std::time::Instant::now();
331        let mut loss_history = Vec::new();
332
333        for epoch in 0..epochs {
334            let loss = self.train_epoch(epoch).await?;
335            loss_history.push(loss);
336
337            if epoch % 10 == 0 {
338                debug!("Epoch {}/{}: loss={:.6}", epoch + 1, epochs, loss);
339            }
340        }
341
342        *self.is_trained.write().await = true;
343
344        let elapsed = start_time.elapsed().as_secs_f64();
345        let final_loss = *loss_history.last().unwrap_or(&0.0);
346
347        info!(
348            "Temporal training completed in {:.2}s, final loss: {:.6}",
349            elapsed, final_loss
350        );
351
352        Ok(TrainingStats {
353            epochs_completed: epochs,
354            final_loss,
355            training_time_seconds: elapsed,
356            convergence_achieved: final_loss < 0.01,
357            loss_history,
358        })
359    }
360
361    /// Train a single epoch
362    async fn train_epoch(&mut self, _epoch: usize) -> Result<f64> {
363        // Simplified training - in a real implementation, this would:
364        // 1. Sample temporal triples
365        // 2. Compute time-aware embeddings
366        // 3. Calculate temporal loss (considering time decay)
367        // 4. Update embeddings with gradient descent
368
369        // Initialize some sample embeddings for testing
370        let triples = self.temporal_triples.read().await;
371        let dim = self.config.base_config.dimensions;
372
373        use scirs2_core::random::Random;
374        let mut rng = Random::default();
375
376        for temporal_triple in triples.iter() {
377            let embedding = Vector::new(
378                (0..dim)
379                    .map(|_| rng.random_range(-1.0..1.0) as f32)
380                    .collect(),
381            );
382
383            // Store embedding with timestamp
384            let timestamp = match &temporal_triple.scope {
385                TemporalScope::Instant(t) => *t,
386                TemporalScope::Interval { start, .. } => *start,
387                _ => Utc::now(),
388            };
389
390            let entity = temporal_triple.triple.subject.iri.clone();
391            let mut entity_embs = self.entity_embeddings.write().await;
392            entity_embs
393                .entry(entity)
394                .or_insert_with(BTreeMap::new)
395                .insert(timestamp, embedding);
396        }
397
398        Ok(0.1) // Simplified loss
399    }
400
401    /// Forecast future embeddings
402    pub async fn forecast(&self, entity: &str, horizon: usize) -> Result<TemporalForecast> {
403        info!(
404            "Forecasting {} time steps ahead for entity: {}",
405            horizon, entity
406        );
407
408        let embeddings = self.entity_embeddings.read().await;
409
410        if let Some(time_series) = embeddings.get(entity) {
411            let timestamps: Vec<DateTime<Utc>> = time_series.keys().cloned().collect();
412            let last_time = timestamps
413                .last()
414                .ok_or_else(|| anyhow::anyhow!("No temporal data for entity: {}", entity))?;
415
416            // Generate future timestamps based on granularity
417            let time_step = match self.config.granularity {
418                TemporalGranularity::Second => Duration::seconds(1),
419                TemporalGranularity::Minute => Duration::minutes(1),
420                TemporalGranularity::Hour => Duration::hours(1),
421                TemporalGranularity::Day => Duration::days(1),
422                TemporalGranularity::Week => Duration::weeks(1),
423                TemporalGranularity::Month => Duration::days(30),
424                TemporalGranularity::Year => Duration::days(365),
425                TemporalGranularity::Custom(secs) => Duration::seconds(secs),
426            };
427
428            let mut future_timestamps = Vec::new();
429            let mut predictions = Vec::new();
430            let mut confidence_intervals = Vec::new();
431
432            for i in 1..=horizon {
433                let future_time = *last_time + time_step * i as i32;
434                future_timestamps.push(future_time);
435
436                // Simple forecasting: use last known embedding with decay
437                let last_embedding = time_series
438                    .values()
439                    .last()
440                    .expect("time_series should have at least one embedding");
441                let decay_factor = self.config.decay_rate.powi(i as i32);
442
443                let prediction = last_embedding.mapv(|v| v * decay_factor);
444                let std_dev = 0.1 * (1.0 - decay_factor);
445
446                let lower = last_embedding.mapv(|v| (v * decay_factor) - std_dev);
447                let upper = last_embedding.mapv(|v| (v * decay_factor) + std_dev);
448
449                predictions.push(prediction);
450                confidence_intervals.push((lower, upper));
451            }
452
453            Ok(TemporalForecast {
454                target: entity.to_string(),
455                timestamps: future_timestamps,
456                predictions,
457                confidence_intervals,
458                accuracy: None,
459            })
460        } else {
461            Err(anyhow::anyhow!("Entity '{}' not found", entity))
462        }
463    }
464
465    /// Detect temporal events
466    pub async fn detect_events(&mut self, threshold: f32) -> Result<Vec<TemporalEvent>> {
467        info!("Detecting temporal events with threshold: {}", threshold);
468
469        let entity_embeddings = self.entity_embeddings.read().await;
470        let mut detected_events = Vec::new();
471
472        // Detect significant changes in embeddings over time
473        for (entity, time_series) in entity_embeddings.iter() {
474            let mut prev_embedding: Option<&Vector> = None;
475            let mut prev_time: Option<&DateTime<Utc>> = None;
476
477            for (time, embedding) in time_series.iter() {
478                if let (Some(prev_emb), Some(prev_t)) = (prev_embedding, prev_time) {
479                    // Calculate change magnitude
480                    let diff: Vec<f32> = embedding
481                        .values
482                        .iter()
483                        .zip(prev_emb.values.iter())
484                        .map(|(a, b)| (a - b).abs())
485                        .collect();
486                    let change_magnitude = diff.iter().sum::<f32>() / diff.len() as f32;
487
488                    if change_magnitude > threshold {
489                        let event = TemporalEvent {
490                            event_id: format!("event_{}_{}", entity, time.timestamp()),
491                            event_type: "embedding_shift".to_string(),
492                            timestamp: *time,
493                            entities: vec![entity.clone()],
494                            relations: Vec::new(),
495                            significance: change_magnitude,
496                            description: Some(format!(
497                                "Significant embedding change detected for '{}' between {} and {}",
498                                entity, prev_t, time
499                            )),
500                        };
501                        detected_events.push(event);
502                    }
503                }
504
505                prev_embedding = Some(embedding);
506                prev_time = Some(time);
507            }
508        }
509
510        // Store detected events
511        let mut events = self.events.write().await;
512        events.extend(detected_events.clone());
513
514        info!("Detected {} temporal events", detected_events.len());
515        Ok(detected_events)
516    }
517
518    /// Get all detected events
519    pub async fn get_events(&self) -> Vec<TemporalEvent> {
520        self.events.read().await.clone()
521    }
522
523    /// Query triples valid at a specific time
524    pub async fn query_at_time(&self, time: &DateTime<Utc>) -> Vec<Triple> {
525        let triples = self.temporal_triples.read().await;
526        triples
527            .iter()
528            .filter(|tt| tt.is_valid_at(time))
529            .map(|tt| tt.triple.clone())
530            .collect()
531    }
532
533    /// Get temporal statistics
534    pub async fn get_temporal_stats(&self) -> TemporalStats {
535        let entity_embeddings = self.entity_embeddings.read().await;
536        let relation_embeddings = self.relation_embeddings.read().await;
537        let triples = self.temporal_triples.read().await;
538        let events = self.events.read().await;
539
540        // Calculate time span
541        let all_times: Vec<DateTime<Utc>> = entity_embeddings
542            .values()
543            .flat_map(|ts| ts.keys().cloned())
544            .collect();
545
546        let (min_time, max_time) = if all_times.is_empty() {
547            (None, None)
548        } else {
549            (
550                all_times.iter().min().cloned(),
551                all_times.iter().max().cloned(),
552            )
553        };
554
555        TemporalStats {
556            num_temporal_triples: triples.len(),
557            num_entities: entity_embeddings.len(),
558            num_relations: relation_embeddings.len(),
559            num_time_points: all_times.len(),
560            num_events: events.len(),
561            time_span_start: min_time,
562            time_span_end: max_time,
563            granularity: self.config.granularity.clone(),
564        }
565    }
566}
567
568/// Temporal statistics
569#[derive(Debug, Clone, Serialize, Deserialize)]
570pub struct TemporalStats {
571    pub num_temporal_triples: usize,
572    pub num_entities: usize,
573    pub num_relations: usize,
574    pub num_time_points: usize,
575    pub num_events: usize,
576    pub time_span_start: Option<DateTime<Utc>>,
577    pub time_span_end: Option<DateTime<Utc>>,
578    pub granularity: TemporalGranularity,
579}
580
581#[cfg(test)]
582mod tests {
583    use super::*;
584    use crate::NamedNode;
585
586    #[tokio::test]
587    async fn test_temporal_model_creation() {
588        let config = TemporalEmbeddingConfig::default();
589        let model = TemporalEmbeddingModel::new(config);
590        assert_eq!(model.config.time_dim, 32);
591    }
592
593    #[tokio::test]
594    async fn test_temporal_triple_validity() {
595        let triple = Triple::new(
596            NamedNode::new("http://example.org/alice").unwrap(),
597            NamedNode::new("http://example.org/worksFor").unwrap(),
598            NamedNode::new("http://example.org/company").unwrap(),
599        );
600
601        let start = Utc::now();
602        let end = start + Duration::days(365);
603
604        let temporal_triple = TemporalTriple::new(triple, TemporalScope::Interval { start, end });
605
606        let now = Utc::now();
607        assert!(temporal_triple.is_valid_at(&now));
608
609        let future = now + Duration::days(400);
610        assert!(!temporal_triple.is_valid_at(&future));
611    }
612
613    #[tokio::test]
614    async fn test_temporal_embedding_add_triple() {
615        let config = TemporalEmbeddingConfig::default();
616        let mut model = TemporalEmbeddingModel::new(config);
617
618        let triple = Triple::new(
619            NamedNode::new("http://example.org/alice").unwrap(),
620            NamedNode::new("http://example.org/knows").unwrap(),
621            NamedNode::new("http://example.org/bob").unwrap(),
622        );
623
624        let temporal_triple = TemporalTriple::new(triple, TemporalScope::Instant(Utc::now()));
625
626        model.add_temporal_triple(temporal_triple).await.unwrap();
627
628        let stats = model.get_temporal_stats().await;
629        assert_eq!(stats.num_temporal_triples, 1);
630    }
631
632    #[tokio::test]
633    async fn test_temporal_training() {
634        let config = TemporalEmbeddingConfig::default();
635        let mut model = TemporalEmbeddingModel::new(config);
636
637        // Add some temporal triples
638        for i in 0..5 {
639            let triple = Triple::new(
640                NamedNode::new(&format!("http://example.org/entity_{}", i)).unwrap(),
641                NamedNode::new("http://example.org/relation").unwrap(),
642                NamedNode::new("http://example.org/target").unwrap(),
643            );
644
645            let temporal_triple = TemporalTriple::new(
646                triple,
647                TemporalScope::Instant(Utc::now() + Duration::days(i)),
648            );
649
650            model.add_temporal_triple(temporal_triple).await.unwrap();
651        }
652
653        let stats = model.train_temporal(10).await.unwrap();
654        assert_eq!(stats.epochs_completed, 10);
655        assert!(stats.final_loss >= 0.0);
656    }
657
658    #[tokio::test]
659    async fn test_temporal_forecasting() {
660        let config = TemporalEmbeddingConfig::default();
661        let mut model = TemporalEmbeddingModel::new(config);
662
663        // Add temporal data
664        let triple = Triple::new(
665            NamedNode::new("http://example.org/entity").unwrap(),
666            NamedNode::new("http://example.org/relation").unwrap(),
667            NamedNode::new("http://example.org/target").unwrap(),
668        );
669
670        let temporal_triple = TemporalTriple::new(triple, TemporalScope::Instant(Utc::now()));
671
672        model.add_temporal_triple(temporal_triple).await.unwrap();
673        model.train_temporal(5).await.unwrap();
674
675        let forecast = model
676            .forecast("http://example.org/entity", 10)
677            .await
678            .unwrap();
679        assert_eq!(forecast.predictions.len(), 10);
680        assert_eq!(forecast.timestamps.len(), 10);
681    }
682
683    #[tokio::test]
684    async fn test_event_detection() {
685        let config = TemporalEmbeddingConfig {
686            event_threshold: 0.3,
687            ..Default::default()
688        };
689        let mut model = TemporalEmbeddingModel::new(config);
690
691        // Add temporal triples and train
692        for i in 0..3 {
693            let triple = Triple::new(
694                NamedNode::new("http://example.org/entity").unwrap(),
695                NamedNode::new("http://example.org/relation").unwrap(),
696                NamedNode::new("http://example.org/target").unwrap(),
697            );
698
699            let temporal_triple = TemporalTriple::new(
700                triple,
701                TemporalScope::Instant(Utc::now() + Duration::days(i)),
702            );
703
704            model.add_temporal_triple(temporal_triple).await.unwrap();
705        }
706
707        model.train_temporal(5).await.unwrap();
708        let _events = model.detect_events(0.3).await.unwrap();
709
710        // Events may or may not be detected depending on random initialization
711        // Just verify the function executes without error (verified by unwrap)
712    }
713}