use anyhow::Result;
use chrono::{DateTime, Duration, Utc};
use serde::{Deserialize, Serialize};
use std::collections::{BTreeMap, HashMap};
use std::sync::Arc;
use tokio::sync::RwLock;
use tracing::{debug, info};
use crate::{ModelConfig, TrainingStats, Triple, Vector};
use uuid::Uuid;
type TemporalEntityEmbeddings = Arc<RwLock<HashMap<String, BTreeMap<DateTime<Utc>, Vector>>>>;
type TemporalRelationEmbeddings = Arc<RwLock<HashMap<String, BTreeMap<DateTime<Utc>, Vector>>>>;
#[derive(Debug, Clone, Default)]
pub struct TimeSeriesAnalyzer;
#[derive(Debug, Clone)]
pub enum ForecastMethod {
ExponentialSmoothing,
Arima,
Prophet,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub enum TemporalGranularity {
Second,
Minute,
Hour,
Day,
Week,
Month,
Year,
Custom(i64), }
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum TemporalScope {
Instant(DateTime<Utc>),
Interval {
start: DateTime<Utc>,
end: DateTime<Utc>,
},
Periodic {
start: DateTime<Utc>,
period: Duration,
count: Option<usize>,
},
Unbounded,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TemporalTriple {
pub triple: Triple,
pub scope: TemporalScope,
pub confidence: f32,
pub source: Option<String>,
}
impl TemporalTriple {
pub fn new(triple: Triple, scope: TemporalScope) -> Self {
Self {
triple,
scope,
confidence: 1.0,
source: None,
}
}
pub fn is_valid_at(&self, time: &DateTime<Utc>) -> bool {
match &self.scope {
TemporalScope::Instant(instant) => instant == time,
TemporalScope::Interval { start, end } => time >= start && time <= end,
TemporalScope::Periodic {
start,
period,
count,
} => {
let elapsed = time.signed_duration_since(*start);
if elapsed < Duration::zero() {
return false;
}
if let Some(max_count) = count {
let num_periods = elapsed.num_seconds() / period.num_seconds();
num_periods < *max_count as i64
} else {
true
}
}
TemporalScope::Unbounded => true,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TemporalEmbeddingConfig {
pub base_config: ModelConfig,
pub granularity: TemporalGranularity,
pub time_dim: usize,
pub enable_decay: bool,
pub decay_rate: f32,
pub enable_smoothing: bool,
pub smoothing_window: usize,
pub enable_forecasting: bool,
pub forecast_horizon: usize,
pub enable_event_detection: bool,
pub event_threshold: f32,
}
impl Default for TemporalEmbeddingConfig {
fn default() -> Self {
Self {
base_config: ModelConfig::default(),
granularity: TemporalGranularity::Day,
time_dim: 32,
enable_decay: true,
decay_rate: 0.9,
enable_smoothing: true,
smoothing_window: 7,
enable_forecasting: true,
forecast_horizon: 30,
enable_event_detection: false,
event_threshold: 0.7,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TemporalEvent {
pub event_id: String,
pub event_type: String,
pub timestamp: DateTime<Utc>,
pub entities: Vec<String>,
pub relations: Vec<String>,
pub significance: f32,
pub description: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TemporalForecast {
pub target: String,
pub timestamps: Vec<DateTime<Utc>>,
pub predictions: Vec<Vector>,
pub confidence_intervals: Vec<(Vector, Vector)>,
pub accuracy: Option<f32>,
}
pub struct TemporalEmbeddingModel {
config: TemporalEmbeddingConfig,
model_id: Uuid,
entity_embeddings: TemporalEntityEmbeddings,
relation_embeddings: TemporalRelationEmbeddings,
time_embeddings: Arc<RwLock<BTreeMap<DateTime<Utc>, Vector>>>,
temporal_triples: Arc<RwLock<Vec<TemporalTriple>>>,
events: Arc<RwLock<Vec<TemporalEvent>>>,
time_series_analyzer: Option<TimeSeriesAnalyzer>,
is_trained: Arc<RwLock<bool>>,
}
impl TemporalEmbeddingModel {
pub fn new(config: TemporalEmbeddingConfig) -> Self {
info!(
"Creating temporal embedding model with time_dim={}",
config.time_dim
);
Self {
model_id: Uuid::new_v4(),
time_series_analyzer: Some(TimeSeriesAnalyzer),
config,
entity_embeddings: Arc::new(RwLock::new(HashMap::new())),
relation_embeddings: Arc::new(RwLock::new(HashMap::new())),
time_embeddings: Arc::new(RwLock::new(BTreeMap::new())),
temporal_triples: Arc::new(RwLock::new(Vec::new())),
events: Arc::new(RwLock::new(Vec::new())),
is_trained: Arc::new(RwLock::new(false)),
}
}
pub async fn add_temporal_triple(&mut self, temporal_triple: TemporalTriple) -> Result<()> {
let mut triples = self.temporal_triples.write().await;
triples.push(temporal_triple);
Ok(())
}
pub async fn get_entity_embedding_at_time(
&self,
entity: &str,
time: &DateTime<Utc>,
) -> Result<Vector> {
let embeddings = self.entity_embeddings.read().await;
if let Some(time_series) = embeddings.get(entity) {
if let Some((_, embedding)) = time_series.range(..=time).next_back() {
return Ok(embedding.clone());
}
}
Err(anyhow::anyhow!(
"Entity '{}' not found at time {}",
entity,
time
))
}
pub async fn get_relation_embedding_at_time(
&self,
relation: &str,
time: &DateTime<Utc>,
) -> Result<Vector> {
let embeddings = self.relation_embeddings.read().await;
if let Some(time_series) = embeddings.get(relation) {
if let Some((_, embedding)) = time_series.range(..=time).next_back() {
return Ok(embedding.clone());
}
}
Err(anyhow::anyhow!(
"Relation '{}' not found at time {}",
relation,
time
))
}
pub async fn train_temporal(&mut self, epochs: usize) -> Result<TrainingStats> {
info!("Training temporal embeddings for {} epochs", epochs);
let start_time = std::time::Instant::now();
let mut loss_history = Vec::new();
for epoch in 0..epochs {
let loss = self.train_epoch(epoch).await?;
loss_history.push(loss);
if epoch % 10 == 0 {
debug!("Epoch {}/{}: loss={:.6}", epoch + 1, epochs, loss);
}
}
*self.is_trained.write().await = true;
let elapsed = start_time.elapsed().as_secs_f64();
let final_loss = *loss_history.last().unwrap_or(&0.0);
info!(
"Temporal training completed in {:.2}s, final loss: {:.6}",
elapsed, final_loss
);
Ok(TrainingStats {
epochs_completed: epochs,
final_loss,
training_time_seconds: elapsed,
convergence_achieved: final_loss < 0.01,
loss_history,
})
}
async fn train_epoch(&mut self, _epoch: usize) -> Result<f64> {
let triples = self.temporal_triples.read().await;
let dim = self.config.base_config.dimensions;
use scirs2_core::random::Random;
let mut rng = Random::default();
for temporal_triple in triples.iter() {
let embedding = Vector::new(
(0..dim)
.map(|_| rng.random_range(-1.0..1.0) as f32)
.collect(),
);
let timestamp = match &temporal_triple.scope {
TemporalScope::Instant(t) => *t,
TemporalScope::Interval { start, .. } => *start,
_ => Utc::now(),
};
let entity = temporal_triple.triple.subject.iri.clone();
let mut entity_embs = self.entity_embeddings.write().await;
entity_embs
.entry(entity)
.or_insert_with(BTreeMap::new)
.insert(timestamp, embedding);
}
Ok(0.1) }
pub async fn forecast(&self, entity: &str, horizon: usize) -> Result<TemporalForecast> {
info!(
"Forecasting {} time steps ahead for entity: {}",
horizon, entity
);
let embeddings = self.entity_embeddings.read().await;
if let Some(time_series) = embeddings.get(entity) {
let timestamps: Vec<DateTime<Utc>> = time_series.keys().cloned().collect();
let last_time = timestamps
.last()
.ok_or_else(|| anyhow::anyhow!("No temporal data for entity: {}", entity))?;
let time_step = match self.config.granularity {
TemporalGranularity::Second => Duration::seconds(1),
TemporalGranularity::Minute => Duration::minutes(1),
TemporalGranularity::Hour => Duration::hours(1),
TemporalGranularity::Day => Duration::days(1),
TemporalGranularity::Week => Duration::weeks(1),
TemporalGranularity::Month => Duration::days(30),
TemporalGranularity::Year => Duration::days(365),
TemporalGranularity::Custom(secs) => Duration::seconds(secs),
};
let mut future_timestamps = Vec::new();
let mut predictions = Vec::new();
let mut confidence_intervals = Vec::new();
for i in 1..=horizon {
let future_time = *last_time + time_step * i as i32;
future_timestamps.push(future_time);
let last_embedding = time_series
.values()
.last()
.expect("time_series should have at least one embedding");
let decay_factor = self.config.decay_rate.powi(i as i32);
let prediction = last_embedding.mapv(|v| v * decay_factor);
let std_dev = 0.1 * (1.0 - decay_factor);
let lower = last_embedding.mapv(|v| (v * decay_factor) - std_dev);
let upper = last_embedding.mapv(|v| (v * decay_factor) + std_dev);
predictions.push(prediction);
confidence_intervals.push((lower, upper));
}
Ok(TemporalForecast {
target: entity.to_string(),
timestamps: future_timestamps,
predictions,
confidence_intervals,
accuracy: None,
})
} else {
Err(anyhow::anyhow!("Entity '{}' not found", entity))
}
}
pub async fn detect_events(&mut self, threshold: f32) -> Result<Vec<TemporalEvent>> {
info!("Detecting temporal events with threshold: {}", threshold);
let entity_embeddings = self.entity_embeddings.read().await;
let mut detected_events = Vec::new();
for (entity, time_series) in entity_embeddings.iter() {
let mut prev_embedding: Option<&Vector> = None;
let mut prev_time: Option<&DateTime<Utc>> = None;
for (time, embedding) in time_series.iter() {
if let (Some(prev_emb), Some(prev_t)) = (prev_embedding, prev_time) {
let diff: Vec<f32> = embedding
.values
.iter()
.zip(prev_emb.values.iter())
.map(|(a, b)| (a - b).abs())
.collect();
let change_magnitude = diff.iter().sum::<f32>() / diff.len() as f32;
if change_magnitude > threshold {
let event = TemporalEvent {
event_id: format!("event_{}_{}", entity, time.timestamp()),
event_type: "embedding_shift".to_string(),
timestamp: *time,
entities: vec![entity.clone()],
relations: Vec::new(),
significance: change_magnitude,
description: Some(format!(
"Significant embedding change detected for '{}' between {} and {}",
entity, prev_t, time
)),
};
detected_events.push(event);
}
}
prev_embedding = Some(embedding);
prev_time = Some(time);
}
}
let mut events = self.events.write().await;
events.extend(detected_events.clone());
info!("Detected {} temporal events", detected_events.len());
Ok(detected_events)
}
pub async fn get_events(&self) -> Vec<TemporalEvent> {
self.events.read().await.clone()
}
pub async fn query_at_time(&self, time: &DateTime<Utc>) -> Vec<Triple> {
let triples = self.temporal_triples.read().await;
triples
.iter()
.filter(|tt| tt.is_valid_at(time))
.map(|tt| tt.triple.clone())
.collect()
}
pub async fn get_temporal_stats(&self) -> TemporalStats {
let entity_embeddings = self.entity_embeddings.read().await;
let relation_embeddings = self.relation_embeddings.read().await;
let triples = self.temporal_triples.read().await;
let events = self.events.read().await;
let all_times: Vec<DateTime<Utc>> = entity_embeddings
.values()
.flat_map(|ts| ts.keys().cloned())
.collect();
let (min_time, max_time) = if all_times.is_empty() {
(None, None)
} else {
(
all_times.iter().min().cloned(),
all_times.iter().max().cloned(),
)
};
TemporalStats {
num_temporal_triples: triples.len(),
num_entities: entity_embeddings.len(),
num_relations: relation_embeddings.len(),
num_time_points: all_times.len(),
num_events: events.len(),
time_span_start: min_time,
time_span_end: max_time,
granularity: self.config.granularity.clone(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TemporalStats {
pub num_temporal_triples: usize,
pub num_entities: usize,
pub num_relations: usize,
pub num_time_points: usize,
pub num_events: usize,
pub time_span_start: Option<DateTime<Utc>>,
pub time_span_end: Option<DateTime<Utc>>,
pub granularity: TemporalGranularity,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::NamedNode;
#[tokio::test]
async fn test_temporal_model_creation() {
let config = TemporalEmbeddingConfig::default();
let model = TemporalEmbeddingModel::new(config);
assert_eq!(model.config.time_dim, 32);
}
#[tokio::test]
async fn test_temporal_triple_validity() {
let triple = Triple::new(
NamedNode::new("http://example.org/alice").expect("should succeed"),
NamedNode::new("http://example.org/worksFor").expect("should succeed"),
NamedNode::new("http://example.org/company").expect("should succeed"),
);
let start = Utc::now();
let end = start + Duration::days(365);
let temporal_triple = TemporalTriple::new(triple, TemporalScope::Interval { start, end });
let now = Utc::now();
assert!(temporal_triple.is_valid_at(&now));
let future = now + Duration::days(400);
assert!(!temporal_triple.is_valid_at(&future));
}
#[tokio::test]
async fn test_temporal_embedding_add_triple() {
let config = TemporalEmbeddingConfig::default();
let mut model = TemporalEmbeddingModel::new(config);
let triple = Triple::new(
NamedNode::new("http://example.org/alice").expect("should succeed"),
NamedNode::new("http://example.org/knows").expect("should succeed"),
NamedNode::new("http://example.org/bob").expect("should succeed"),
);
let temporal_triple = TemporalTriple::new(triple, TemporalScope::Instant(Utc::now()));
model
.add_temporal_triple(temporal_triple)
.await
.expect("should succeed");
let stats = model.get_temporal_stats().await;
assert_eq!(stats.num_temporal_triples, 1);
}
#[tokio::test]
async fn test_temporal_training() {
let config = TemporalEmbeddingConfig::default();
let mut model = TemporalEmbeddingModel::new(config);
for i in 0..5 {
let triple = Triple::new(
NamedNode::new(&format!("http://example.org/entity_{}", i))
.expect("should succeed"),
NamedNode::new("http://example.org/relation").expect("should succeed"),
NamedNode::new("http://example.org/target").expect("should succeed"),
);
let temporal_triple = TemporalTriple::new(
triple,
TemporalScope::Instant(Utc::now() + Duration::days(i)),
);
model
.add_temporal_triple(temporal_triple)
.await
.expect("should succeed");
}
let stats = model.train_temporal(10).await.expect("should succeed");
assert_eq!(stats.epochs_completed, 10);
assert!(stats.final_loss >= 0.0);
}
#[tokio::test]
async fn test_temporal_forecasting() {
let config = TemporalEmbeddingConfig::default();
let mut model = TemporalEmbeddingModel::new(config);
let triple = Triple::new(
NamedNode::new("http://example.org/entity").expect("should succeed"),
NamedNode::new("http://example.org/relation").expect("should succeed"),
NamedNode::new("http://example.org/target").expect("should succeed"),
);
let temporal_triple = TemporalTriple::new(triple, TemporalScope::Instant(Utc::now()));
model
.add_temporal_triple(temporal_triple)
.await
.expect("should succeed");
model.train_temporal(5).await.expect("should succeed");
let forecast = model
.forecast("http://example.org/entity", 10)
.await
.expect("should succeed");
assert_eq!(forecast.predictions.len(), 10);
assert_eq!(forecast.timestamps.len(), 10);
}
#[tokio::test]
async fn test_event_detection() {
let config = TemporalEmbeddingConfig {
event_threshold: 0.3,
..Default::default()
};
let mut model = TemporalEmbeddingModel::new(config);
for i in 0..3 {
let triple = Triple::new(
NamedNode::new("http://example.org/entity").expect("should succeed"),
NamedNode::new("http://example.org/relation").expect("should succeed"),
NamedNode::new("http://example.org/target").expect("should succeed"),
);
let temporal_triple = TemporalTriple::new(
triple,
TemporalScope::Instant(Utc::now() + Duration::days(i)),
);
model
.add_temporal_triple(temporal_triple)
.await
.expect("should succeed");
}
model.train_temporal(5).await.expect("should succeed");
let _events = model.detect_events(0.3).await.expect("should succeed");
}
}