use serde::{Deserialize, Serialize};
use crate::brain::Atlas;
use crate::error::{Result, RuvNeuralError};
use crate::topology::CognitiveState;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NeuralEmbedding {
pub vector: Vec<f64>,
pub dimension: usize,
pub timestamp: f64,
pub metadata: EmbeddingMetadata,
}
impl NeuralEmbedding {
pub fn new(vector: Vec<f64>, timestamp: f64, metadata: EmbeddingMetadata) -> Result<Self> {
let dimension = vector.len();
if dimension == 0 {
return Err(RuvNeuralError::Embedding(
"Embedding vector must not be empty".into(),
));
}
Ok(Self {
vector,
dimension,
timestamp,
metadata,
})
}
pub fn norm(&self) -> f64 {
self.vector.iter().map(|x| x * x).sum::<f64>().sqrt()
}
pub fn cosine_similarity(&self, other: &NeuralEmbedding) -> Result<f64> {
if self.dimension != other.dimension {
return Err(RuvNeuralError::DimensionMismatch {
expected: self.dimension,
got: other.dimension,
});
}
let dot: f64 = self
.vector
.iter()
.zip(other.vector.iter())
.map(|(a, b)| a * b)
.sum();
let norm_a = self.norm();
let norm_b = other.norm();
if norm_a == 0.0 || norm_b == 0.0 {
return Ok(0.0);
}
Ok(dot / (norm_a * norm_b))
}
pub fn euclidean_distance(&self, other: &NeuralEmbedding) -> Result<f64> {
if self.dimension != other.dimension {
return Err(RuvNeuralError::DimensionMismatch {
expected: self.dimension,
got: other.dimension,
});
}
let sum_sq: f64 = self
.vector
.iter()
.zip(other.vector.iter())
.map(|(a, b)| (a - b) * (a - b))
.sum();
Ok(sum_sq.sqrt())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddingMetadata {
pub subject_id: Option<String>,
pub session_id: Option<String>,
pub cognitive_state: Option<CognitiveState>,
pub source_atlas: Atlas,
pub embedding_method: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddingTrajectory {
pub embeddings: Vec<NeuralEmbedding>,
pub timestamps: Vec<f64>,
}
impl EmbeddingTrajectory {
pub fn len(&self) -> usize {
self.embeddings.len()
}
pub fn is_empty(&self) -> bool {
self.embeddings.is_empty()
}
pub fn duration_s(&self) -> f64 {
if self.timestamps.len() < 2 {
return 0.0;
}
self.timestamps.last().unwrap() - self.timestamps.first().unwrap()
}
}