1use anyhow::Result;
32use chrono::{DateTime, Duration, Utc};
33use serde::{Deserialize, Serialize};
34use std::collections::{BTreeMap, HashMap};
35use std::sync::Arc;
36use tokio::sync::RwLock;
37use tracing::{debug, info};
38
39use crate::{ModelConfig, TrainingStats, Triple, Vector};
40use uuid::Uuid;
41
42type TemporalEntityEmbeddings = Arc<RwLock<HashMap<String, BTreeMap<DateTime<Utc>, Vector>>>>;
44type TemporalRelationEmbeddings = Arc<RwLock<HashMap<String, BTreeMap<DateTime<Utc>, Vector>>>>;
45
46#[derive(Debug, Clone, Default)]
48pub struct TimeSeriesAnalyzer;
49
50#[derive(Debug, Clone)]
51pub enum ForecastMethod {
52 ExponentialSmoothing,
53 Arima,
54 Prophet,
55}
56
57#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
59pub enum TemporalGranularity {
60 Second,
62 Minute,
64 Hour,
66 Day,
68 Week,
70 Month,
72 Year,
74 Custom(i64), }
77
78#[derive(Debug, Clone, Serialize, Deserialize)]
80pub enum TemporalScope {
81 Instant(DateTime<Utc>),
83 Interval {
85 start: DateTime<Utc>,
86 end: DateTime<Utc>,
87 },
88 Periodic {
90 start: DateTime<Utc>,
91 period: Duration,
92 count: Option<usize>,
93 },
94 Unbounded,
96}
97
98#[derive(Debug, Clone, Serialize, Deserialize)]
100pub struct TemporalTriple {
101 pub triple: Triple,
103 pub scope: TemporalScope,
105 pub confidence: f32,
107 pub source: Option<String>,
109}
110
111impl TemporalTriple {
112 pub fn new(triple: Triple, scope: TemporalScope) -> Self {
114 Self {
115 triple,
116 scope,
117 confidence: 1.0,
118 source: None,
119 }
120 }
121
122 pub fn is_valid_at(&self, time: &DateTime<Utc>) -> bool {
124 match &self.scope {
125 TemporalScope::Instant(instant) => instant == time,
126 TemporalScope::Interval { start, end } => time >= start && time <= end,
127 TemporalScope::Periodic {
128 start,
129 period,
130 count,
131 } => {
132 let elapsed = time.signed_duration_since(*start);
133 if elapsed < Duration::zero() {
134 return false;
135 }
136 if let Some(max_count) = count {
137 let num_periods = elapsed.num_seconds() / period.num_seconds();
138 num_periods < *max_count as i64
139 } else {
140 true
141 }
142 }
143 TemporalScope::Unbounded => true,
144 }
145 }
146}
147
148#[derive(Debug, Clone, Serialize, Deserialize)]
150pub struct TemporalEmbeddingConfig {
151 pub base_config: ModelConfig,
153 pub granularity: TemporalGranularity,
155 pub time_dim: usize,
157 pub enable_decay: bool,
159 pub decay_rate: f32,
161 pub enable_smoothing: bool,
163 pub smoothing_window: usize,
165 pub enable_forecasting: bool,
167 pub forecast_horizon: usize,
169 pub enable_event_detection: bool,
171 pub event_threshold: f32,
173}
174
175impl Default for TemporalEmbeddingConfig {
176 fn default() -> Self {
177 Self {
178 base_config: ModelConfig::default(),
179 granularity: TemporalGranularity::Day,
180 time_dim: 32,
181 enable_decay: true,
182 decay_rate: 0.9,
183 enable_smoothing: true,
184 smoothing_window: 7,
185 enable_forecasting: true,
186 forecast_horizon: 30,
187 enable_event_detection: false,
188 event_threshold: 0.7,
189 }
190 }
191}
192
193#[derive(Debug, Clone, Serialize, Deserialize)]
195pub struct TemporalEvent {
196 pub event_id: String,
198 pub event_type: String,
200 pub timestamp: DateTime<Utc>,
202 pub entities: Vec<String>,
204 pub relations: Vec<String>,
206 pub significance: f32,
208 pub description: Option<String>,
210}
211
212#[derive(Debug, Clone, Serialize, Deserialize)]
214pub struct TemporalForecast {
215 pub target: String,
217 pub timestamps: Vec<DateTime<Utc>>,
219 pub predictions: Vec<Vector>,
221 pub confidence_intervals: Vec<(Vector, Vector)>,
223 pub accuracy: Option<f32>,
225}
226
227pub struct TemporalEmbeddingModel {
229 config: TemporalEmbeddingConfig,
230 model_id: Uuid,
231
232 entity_embeddings: TemporalEntityEmbeddings,
234
235 relation_embeddings: TemporalRelationEmbeddings,
237
238 time_embeddings: Arc<RwLock<BTreeMap<DateTime<Utc>, Vector>>>,
240
241 temporal_triples: Arc<RwLock<Vec<TemporalTriple>>>,
243
244 events: Arc<RwLock<Vec<TemporalEvent>>>,
246
247 time_series_analyzer: Option<TimeSeriesAnalyzer>,
249
250 is_trained: Arc<RwLock<bool>>,
252}
253
254impl TemporalEmbeddingModel {
255 pub fn new(config: TemporalEmbeddingConfig) -> Self {
257 info!(
258 "Creating temporal embedding model with time_dim={}",
259 config.time_dim
260 );
261
262 Self {
263 model_id: Uuid::new_v4(),
264 time_series_analyzer: Some(TimeSeriesAnalyzer),
265 config,
266 entity_embeddings: Arc::new(RwLock::new(HashMap::new())),
267 relation_embeddings: Arc::new(RwLock::new(HashMap::new())),
268 time_embeddings: Arc::new(RwLock::new(BTreeMap::new())),
269 temporal_triples: Arc::new(RwLock::new(Vec::new())),
270 events: Arc::new(RwLock::new(Vec::new())),
271 is_trained: Arc::new(RwLock::new(false)),
272 }
273 }
274
275 pub async fn add_temporal_triple(&mut self, temporal_triple: TemporalTriple) -> Result<()> {
277 let mut triples = self.temporal_triples.write().await;
278 triples.push(temporal_triple);
279 Ok(())
280 }
281
282 pub async fn get_entity_embedding_at_time(
284 &self,
285 entity: &str,
286 time: &DateTime<Utc>,
287 ) -> Result<Vector> {
288 let embeddings = self.entity_embeddings.read().await;
289
290 if let Some(time_series) = embeddings.get(entity) {
291 if let Some((_, embedding)) = time_series.range(..=time).next_back() {
293 return Ok(embedding.clone());
294 }
295 }
296
297 Err(anyhow::anyhow!(
298 "Entity '{}' not found at time {}",
299 entity,
300 time
301 ))
302 }
303
304 pub async fn get_relation_embedding_at_time(
306 &self,
307 relation: &str,
308 time: &DateTime<Utc>,
309 ) -> Result<Vector> {
310 let embeddings = self.relation_embeddings.read().await;
311
312 if let Some(time_series) = embeddings.get(relation) {
313 if let Some((_, embedding)) = time_series.range(..=time).next_back() {
315 return Ok(embedding.clone());
316 }
317 }
318
319 Err(anyhow::anyhow!(
320 "Relation '{}' not found at time {}",
321 relation,
322 time
323 ))
324 }
325
326 pub async fn train_temporal(&mut self, epochs: usize) -> Result<TrainingStats> {
328 info!("Training temporal embeddings for {} epochs", epochs);
329
330 let start_time = std::time::Instant::now();
331 let mut loss_history = Vec::new();
332
333 for epoch in 0..epochs {
334 let loss = self.train_epoch(epoch).await?;
335 loss_history.push(loss);
336
337 if epoch % 10 == 0 {
338 debug!("Epoch {}/{}: loss={:.6}", epoch + 1, epochs, loss);
339 }
340 }
341
342 *self.is_trained.write().await = true;
343
344 let elapsed = start_time.elapsed().as_secs_f64();
345 let final_loss = *loss_history.last().unwrap_or(&0.0);
346
347 info!(
348 "Temporal training completed in {:.2}s, final loss: {:.6}",
349 elapsed, final_loss
350 );
351
352 Ok(TrainingStats {
353 epochs_completed: epochs,
354 final_loss,
355 training_time_seconds: elapsed,
356 convergence_achieved: final_loss < 0.01,
357 loss_history,
358 })
359 }
360
361 async fn train_epoch(&mut self, _epoch: usize) -> Result<f64> {
363 let triples = self.temporal_triples.read().await;
371 let dim = self.config.base_config.dimensions;
372
373 use scirs2_core::random::Random;
374 let mut rng = Random::default();
375
376 for temporal_triple in triples.iter() {
377 let embedding = Vector::new(
378 (0..dim)
379 .map(|_| rng.random_range(-1.0..1.0) as f32)
380 .collect(),
381 );
382
383 let timestamp = match &temporal_triple.scope {
385 TemporalScope::Instant(t) => *t,
386 TemporalScope::Interval { start, .. } => *start,
387 _ => Utc::now(),
388 };
389
390 let entity = temporal_triple.triple.subject.iri.clone();
391 let mut entity_embs = self.entity_embeddings.write().await;
392 entity_embs
393 .entry(entity)
394 .or_insert_with(BTreeMap::new)
395 .insert(timestamp, embedding);
396 }
397
398 Ok(0.1) }
400
401 pub async fn forecast(&self, entity: &str, horizon: usize) -> Result<TemporalForecast> {
403 info!(
404 "Forecasting {} time steps ahead for entity: {}",
405 horizon, entity
406 );
407
408 let embeddings = self.entity_embeddings.read().await;
409
410 if let Some(time_series) = embeddings.get(entity) {
411 let timestamps: Vec<DateTime<Utc>> = time_series.keys().cloned().collect();
412 let last_time = timestamps
413 .last()
414 .ok_or_else(|| anyhow::anyhow!("No temporal data for entity: {}", entity))?;
415
416 let time_step = match self.config.granularity {
418 TemporalGranularity::Second => Duration::seconds(1),
419 TemporalGranularity::Minute => Duration::minutes(1),
420 TemporalGranularity::Hour => Duration::hours(1),
421 TemporalGranularity::Day => Duration::days(1),
422 TemporalGranularity::Week => Duration::weeks(1),
423 TemporalGranularity::Month => Duration::days(30),
424 TemporalGranularity::Year => Duration::days(365),
425 TemporalGranularity::Custom(secs) => Duration::seconds(secs),
426 };
427
428 let mut future_timestamps = Vec::new();
429 let mut predictions = Vec::new();
430 let mut confidence_intervals = Vec::new();
431
432 for i in 1..=horizon {
433 let future_time = *last_time + time_step * i as i32;
434 future_timestamps.push(future_time);
435
436 let last_embedding = time_series
438 .values()
439 .last()
440 .expect("time_series should have at least one embedding");
441 let decay_factor = self.config.decay_rate.powi(i as i32);
442
443 let prediction = last_embedding.mapv(|v| v * decay_factor);
444 let std_dev = 0.1 * (1.0 - decay_factor);
445
446 let lower = last_embedding.mapv(|v| (v * decay_factor) - std_dev);
447 let upper = last_embedding.mapv(|v| (v * decay_factor) + std_dev);
448
449 predictions.push(prediction);
450 confidence_intervals.push((lower, upper));
451 }
452
453 Ok(TemporalForecast {
454 target: entity.to_string(),
455 timestamps: future_timestamps,
456 predictions,
457 confidence_intervals,
458 accuracy: None,
459 })
460 } else {
461 Err(anyhow::anyhow!("Entity '{}' not found", entity))
462 }
463 }
464
465 pub async fn detect_events(&mut self, threshold: f32) -> Result<Vec<TemporalEvent>> {
467 info!("Detecting temporal events with threshold: {}", threshold);
468
469 let entity_embeddings = self.entity_embeddings.read().await;
470 let mut detected_events = Vec::new();
471
472 for (entity, time_series) in entity_embeddings.iter() {
474 let mut prev_embedding: Option<&Vector> = None;
475 let mut prev_time: Option<&DateTime<Utc>> = None;
476
477 for (time, embedding) in time_series.iter() {
478 if let (Some(prev_emb), Some(prev_t)) = (prev_embedding, prev_time) {
479 let diff: Vec<f32> = embedding
481 .values
482 .iter()
483 .zip(prev_emb.values.iter())
484 .map(|(a, b)| (a - b).abs())
485 .collect();
486 let change_magnitude = diff.iter().sum::<f32>() / diff.len() as f32;
487
488 if change_magnitude > threshold {
489 let event = TemporalEvent {
490 event_id: format!("event_{}_{}", entity, time.timestamp()),
491 event_type: "embedding_shift".to_string(),
492 timestamp: *time,
493 entities: vec![entity.clone()],
494 relations: Vec::new(),
495 significance: change_magnitude,
496 description: Some(format!(
497 "Significant embedding change detected for '{}' between {} and {}",
498 entity, prev_t, time
499 )),
500 };
501 detected_events.push(event);
502 }
503 }
504
505 prev_embedding = Some(embedding);
506 prev_time = Some(time);
507 }
508 }
509
510 let mut events = self.events.write().await;
512 events.extend(detected_events.clone());
513
514 info!("Detected {} temporal events", detected_events.len());
515 Ok(detected_events)
516 }
517
518 pub async fn get_events(&self) -> Vec<TemporalEvent> {
520 self.events.read().await.clone()
521 }
522
523 pub async fn query_at_time(&self, time: &DateTime<Utc>) -> Vec<Triple> {
525 let triples = self.temporal_triples.read().await;
526 triples
527 .iter()
528 .filter(|tt| tt.is_valid_at(time))
529 .map(|tt| tt.triple.clone())
530 .collect()
531 }
532
533 pub async fn get_temporal_stats(&self) -> TemporalStats {
535 let entity_embeddings = self.entity_embeddings.read().await;
536 let relation_embeddings = self.relation_embeddings.read().await;
537 let triples = self.temporal_triples.read().await;
538 let events = self.events.read().await;
539
540 let all_times: Vec<DateTime<Utc>> = entity_embeddings
542 .values()
543 .flat_map(|ts| ts.keys().cloned())
544 .collect();
545
546 let (min_time, max_time) = if all_times.is_empty() {
547 (None, None)
548 } else {
549 (
550 all_times.iter().min().cloned(),
551 all_times.iter().max().cloned(),
552 )
553 };
554
555 TemporalStats {
556 num_temporal_triples: triples.len(),
557 num_entities: entity_embeddings.len(),
558 num_relations: relation_embeddings.len(),
559 num_time_points: all_times.len(),
560 num_events: events.len(),
561 time_span_start: min_time,
562 time_span_end: max_time,
563 granularity: self.config.granularity.clone(),
564 }
565 }
566}
567
568#[derive(Debug, Clone, Serialize, Deserialize)]
570pub struct TemporalStats {
571 pub num_temporal_triples: usize,
572 pub num_entities: usize,
573 pub num_relations: usize,
574 pub num_time_points: usize,
575 pub num_events: usize,
576 pub time_span_start: Option<DateTime<Utc>>,
577 pub time_span_end: Option<DateTime<Utc>>,
578 pub granularity: TemporalGranularity,
579}
580
581#[cfg(test)]
582mod tests {
583 use super::*;
584 use crate::NamedNode;
585
586 #[tokio::test]
587 async fn test_temporal_model_creation() {
588 let config = TemporalEmbeddingConfig::default();
589 let model = TemporalEmbeddingModel::new(config);
590 assert_eq!(model.config.time_dim, 32);
591 }
592
593 #[tokio::test]
594 async fn test_temporal_triple_validity() {
595 let triple = Triple::new(
596 NamedNode::new("http://example.org/alice").unwrap(),
597 NamedNode::new("http://example.org/worksFor").unwrap(),
598 NamedNode::new("http://example.org/company").unwrap(),
599 );
600
601 let start = Utc::now();
602 let end = start + Duration::days(365);
603
604 let temporal_triple = TemporalTriple::new(triple, TemporalScope::Interval { start, end });
605
606 let now = Utc::now();
607 assert!(temporal_triple.is_valid_at(&now));
608
609 let future = now + Duration::days(400);
610 assert!(!temporal_triple.is_valid_at(&future));
611 }
612
613 #[tokio::test]
614 async fn test_temporal_embedding_add_triple() {
615 let config = TemporalEmbeddingConfig::default();
616 let mut model = TemporalEmbeddingModel::new(config);
617
618 let triple = Triple::new(
619 NamedNode::new("http://example.org/alice").unwrap(),
620 NamedNode::new("http://example.org/knows").unwrap(),
621 NamedNode::new("http://example.org/bob").unwrap(),
622 );
623
624 let temporal_triple = TemporalTriple::new(triple, TemporalScope::Instant(Utc::now()));
625
626 model.add_temporal_triple(temporal_triple).await.unwrap();
627
628 let stats = model.get_temporal_stats().await;
629 assert_eq!(stats.num_temporal_triples, 1);
630 }
631
632 #[tokio::test]
633 async fn test_temporal_training() {
634 let config = TemporalEmbeddingConfig::default();
635 let mut model = TemporalEmbeddingModel::new(config);
636
637 for i in 0..5 {
639 let triple = Triple::new(
640 NamedNode::new(&format!("http://example.org/entity_{}", i)).unwrap(),
641 NamedNode::new("http://example.org/relation").unwrap(),
642 NamedNode::new("http://example.org/target").unwrap(),
643 );
644
645 let temporal_triple = TemporalTriple::new(
646 triple,
647 TemporalScope::Instant(Utc::now() + Duration::days(i)),
648 );
649
650 model.add_temporal_triple(temporal_triple).await.unwrap();
651 }
652
653 let stats = model.train_temporal(10).await.unwrap();
654 assert_eq!(stats.epochs_completed, 10);
655 assert!(stats.final_loss >= 0.0);
656 }
657
658 #[tokio::test]
659 async fn test_temporal_forecasting() {
660 let config = TemporalEmbeddingConfig::default();
661 let mut model = TemporalEmbeddingModel::new(config);
662
663 let triple = Triple::new(
665 NamedNode::new("http://example.org/entity").unwrap(),
666 NamedNode::new("http://example.org/relation").unwrap(),
667 NamedNode::new("http://example.org/target").unwrap(),
668 );
669
670 let temporal_triple = TemporalTriple::new(triple, TemporalScope::Instant(Utc::now()));
671
672 model.add_temporal_triple(temporal_triple).await.unwrap();
673 model.train_temporal(5).await.unwrap();
674
675 let forecast = model
676 .forecast("http://example.org/entity", 10)
677 .await
678 .unwrap();
679 assert_eq!(forecast.predictions.len(), 10);
680 assert_eq!(forecast.timestamps.len(), 10);
681 }
682
683 #[tokio::test]
684 async fn test_event_detection() {
685 let config = TemporalEmbeddingConfig {
686 event_threshold: 0.3,
687 ..Default::default()
688 };
689 let mut model = TemporalEmbeddingModel::new(config);
690
691 for i in 0..3 {
693 let triple = Triple::new(
694 NamedNode::new("http://example.org/entity").unwrap(),
695 NamedNode::new("http://example.org/relation").unwrap(),
696 NamedNode::new("http://example.org/target").unwrap(),
697 );
698
699 let temporal_triple = TemporalTriple::new(
700 triple,
701 TemporalScope::Instant(Utc::now() + Duration::days(i)),
702 );
703
704 model.add_temporal_triple(temporal_triple).await.unwrap();
705 }
706
707 model.train_temporal(5).await.unwrap();
708 let _events = model.detect_events(0.3).await.unwrap();
709
710 }
713}