use crate::error::{GraphError, Result};
use std::f64::consts::PI;
#[derive(Debug, Clone)]
pub struct TimeEncode {
pub omega: Vec<f64>,
pub time_dim: usize,
}
impl TimeEncode {
pub fn new(time_dim: usize) -> Result<Self> {
if time_dim == 0 {
return Err(GraphError::InvalidParameter {
param: "time_dim".to_string(),
value: "0".to_string(),
expected: "positive even integer".to_string(),
context: "TimeEncode::new".to_string(),
});
}
if time_dim % 2 != 0 {
return Err(GraphError::InvalidParameter {
param: "time_dim".to_string(),
value: format!("{}", time_dim),
expected: "even integer".to_string(),
context: "TimeEncode::new".to_string(),
});
}
let half = time_dim / 2;
let omega: Vec<f64> = (0..half)
.map(|i| {
let exponent = (2 * i) as f64 / time_dim as f64;
1.0 / 10000_f64.powf(exponent)
})
.collect();
Ok(TimeEncode { omega, time_dim })
}
pub fn encode(&self, t: f64) -> Vec<f64> {
let mut out = Vec::with_capacity(self.time_dim);
for &w in &self.omega {
out.push((w * t).cos());
out.push((w * t).sin());
}
out
}
pub fn encode_delta(&self, t_query: f64, t_edge: f64) -> Vec<f64> {
self.encode(t_query - t_edge)
}
pub fn update_omega(&mut self, grad: &[f64], lr: f64) -> Result<()> {
if grad.len() != self.omega.len() {
return Err(GraphError::InvalidParameter {
param: "grad".to_string(),
value: format!("len={}", grad.len()),
expected: format!("len={}", self.omega.len()),
context: "TimeEncode::update_omega".to_string(),
});
}
for (w, g) in self.omega.iter_mut().zip(grad.iter()) {
*w -= lr * g;
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct PositionalTimeEncoding {
encoder: TimeEncode,
pub t_ref: f64,
}
impl PositionalTimeEncoding {
pub fn new(time_dim: usize, t_ref: f64) -> Result<Self> {
let encoder = TimeEncode::new(time_dim)?;
Ok(PositionalTimeEncoding { encoder, t_ref })
}
pub fn encode(&self, t: f64) -> Vec<f64> {
let mut out = self.encoder.encode(t);
let rel = self.encoder.encode(t - self.t_ref);
out.extend(rel);
out
}
pub fn output_dim(&self) -> usize {
2 * self.encoder.time_dim
}
}
pub(crate) fn scaled_dot_product(q: &[f64], keys: &[Vec<f64>]) -> Vec<f64> {
let d_k = q.len().max(1);
let scale = 1.0 / (d_k as f64).sqrt();
keys.iter()
.map(|k| k.iter().zip(q.iter()).map(|(ki, qi)| ki * qi).sum::<f64>() * scale)
.collect()
}
pub(crate) fn softmax(logits: &[f64]) -> Vec<f64> {
if logits.is_empty() {
return Vec::new();
}
let max_val = logits.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
let exps: Vec<f64> = logits.iter().map(|x| (x - max_val).exp()).collect();
let sum = exps.iter().sum::<f64>().max(1e-12);
exps.iter().map(|e| e / sum).collect()
}
pub(crate) fn matvec(w: &[Vec<f64>], x: &[f64]) -> Vec<f64> {
w.iter()
.map(|row| row.iter().zip(x.iter()).map(|(wi, xi)| wi * xi).sum())
.collect()
}
pub(crate) fn sigmoid(x: f64) -> f64 {
1.0 / (1.0 + (-x).exp())
}
pub(crate) fn sigmoid_vec(v: &[f64]) -> Vec<f64> {
v.iter().map(|&x| sigmoid(x)).collect()
}
pub(crate) fn tanh_vec(v: &[f64]) -> Vec<f64> {
v.iter().map(|&x| x.tanh()).collect()
}
pub(crate) fn relu(x: f64) -> f64 {
x.max(0.0)
}
pub(crate) fn relu_vec(v: &[f64]) -> Vec<f64> {
v.iter().map(|&x| relu(x)).collect()
}
pub(crate) fn xavier_init(rows: usize, cols: usize, seed: u64) -> Vec<Vec<f64>> {
let mut state = seed.wrapping_add(1);
let limit = (6.0 / (rows + cols).max(1) as f64).sqrt();
let mut w = vec![vec![0.0f64; cols]; rows];
for row in w.iter_mut() {
for v in row.iter_mut() {
state = state.wrapping_mul(6_364_136_223_846_793_005).wrapping_add(1_442_695_040_888_963_407);
let frac = (state >> 11) as f64 / (1u64 << 53) as f64; *v = frac * 2.0 * limit - limit; }
}
w
}
pub(crate) fn concat(a: &[f64], b: &[f64]) -> Vec<f64> {
let mut out = Vec::with_capacity(a.len() + b.len());
out.extend_from_slice(a);
out.extend_from_slice(b);
out
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_time_encode_dimension() {
let enc = TimeEncode::new(8).expect("should create encoder");
let v = enc.encode(1.5);
assert_eq!(v.len(), 8, "output must have time_dim elements");
}
#[test]
fn test_time_encode_at_zero() {
let enc = TimeEncode::new(8).expect("encoder");
let v = enc.encode(0.0);
for i in 0..4 {
assert!((v[2 * i] - 1.0).abs() < 1e-12, "cos(0) must be 1");
assert!(v[2 * i + 1].abs() < 1e-12, "sin(0) must be 0");
}
}
#[test]
fn test_time_encode_different_times() {
let enc = TimeEncode::new(8).expect("encoder");
let v1 = enc.encode(1.0);
let v2 = enc.encode(2.0);
let diff: f64 = v1.iter().zip(v2.iter()).map(|(a, b)| (a - b).abs()).sum();
assert!(diff > 1e-6, "encode(1.0) must differ from encode(2.0)");
}
#[test]
fn test_time_encode_odd_dim_returns_error() {
let result = TimeEncode::new(7);
assert!(result.is_err(), "odd time_dim must return error");
}
#[test]
fn test_positional_time_encoding_dim() {
let enc = PositionalTimeEncoding::new(8, 0.0).expect("positional encoder");
let v = enc.encode(3.0);
assert_eq!(v.len(), 16, "positional encoding must be 2*time_dim");
assert_eq!(enc.output_dim(), 16);
}
#[test]
fn test_softmax_sums_to_one() {
let logits = vec![1.0, 2.0, 3.0, 0.5];
let probs = softmax(&logits);
let sum: f64 = probs.iter().sum();
assert!((sum - 1.0).abs() < 1e-10, "softmax must sum to 1");
}
#[test]
fn test_softmax_empty() {
let probs = softmax(&[]);
assert!(probs.is_empty());
}
#[test]
fn test_xavier_init_shape() {
let w = xavier_init(4, 8, 42);
assert_eq!(w.len(), 4);
assert_eq!(w[0].len(), 8);
}
#[test]
fn test_sigmoid_range() {
let vals: Vec<f64> = vec![-100.0, -1.0, 0.0, 1.0, 100.0];
for v in vals {
let s = sigmoid(v);
assert!(s >= 0.0 && s <= 1.0, "sigmoid must be in [0,1]");
}
}
}