use serde::{Deserialize, Serialize};
use std::fmt;
use uuid::Uuid;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct EmbeddingId(Uuid);
impl EmbeddingId {
#[inline]
pub fn new() -> Self {
Self(Uuid::new_v4())
}
#[inline]
pub const fn from_uuid(uuid: Uuid) -> Self {
Self(uuid)
}
pub fn parse(s: &str) -> Result<Self, uuid::Error> {
Ok(Self(Uuid::parse_str(s)?))
}
#[inline]
pub const fn as_uuid(&self) -> &Uuid {
&self.0
}
#[inline]
pub fn as_bytes(&self) -> &[u8; 16] {
self.0.as_bytes()
}
#[inline]
pub fn from_bytes(bytes: [u8; 16]) -> Self {
Self(Uuid::from_bytes(bytes))
}
}
impl Default for EmbeddingId {
fn default() -> Self {
Self::new()
}
}
impl fmt::Display for EmbeddingId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.0)
}
}
impl From<Uuid> for EmbeddingId {
fn from(uuid: Uuid) -> Self {
Self(uuid)
}
}
impl From<EmbeddingId> for Uuid {
fn from(id: EmbeddingId) -> Self {
id.0
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
pub struct Timestamp(i64);
impl Timestamp {
pub fn now() -> Self {
Self(chrono::Utc::now().timestamp_millis())
}
#[inline]
pub const fn from_millis(millis: i64) -> Self {
Self(millis)
}
#[inline]
pub const fn as_millis(&self) -> i64 {
self.0
}
pub fn to_datetime(&self) -> chrono::DateTime<chrono::Utc> {
chrono::DateTime::from_timestamp_millis(self.0)
.unwrap_or_else(|| chrono::DateTime::UNIX_EPOCH)
}
}
impl Default for Timestamp {
fn default() -> Self {
Self::now()
}
}
impl fmt::Display for Timestamp {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.to_datetime().format("%Y-%m-%d %H:%M:%S%.3f UTC"))
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HnswConfig {
pub m: usize,
pub ef_construction: usize,
pub ef_search: usize,
pub max_elements: usize,
pub dimensions: usize,
pub normalize: bool,
pub distance_metric: DistanceMetric,
}
impl HnswConfig {
pub fn for_dimension(dim: usize) -> Self {
Self {
m: if dim >= 1024 { 32 } else { 16 },
ef_construction: 200,
ef_search: 128,
max_elements: 1_000_000,
dimensions: dim,
normalize: true,
distance_metric: DistanceMetric::Cosine,
}
}
pub fn for_openai_embeddings() -> Self {
Self::for_dimension(1536)
}
pub fn for_sentence_transformers() -> Self {
Self::for_dimension(384)
}
pub fn with_m(mut self, m: usize) -> Self {
self.m = m;
self
}
pub fn with_ef_construction(mut self, ef: usize) -> Self {
self.ef_construction = ef;
self
}
pub fn with_ef_search(mut self, ef: usize) -> Self {
self.ef_search = ef;
self
}
pub fn with_max_elements(mut self, max: usize) -> Self {
self.max_elements = max;
self
}
pub fn with_distance_metric(mut self, metric: DistanceMetric) -> Self {
self.distance_metric = metric;
self
}
pub fn with_normalize(mut self, normalize: bool) -> Self {
self.normalize = normalize;
self
}
pub fn validate(&self) -> Result<(), ConfigValidationError> {
if self.m < 2 {
return Err(ConfigValidationError::InvalidM(self.m));
}
if self.ef_construction < self.m {
return Err(ConfigValidationError::EfTooSmall {
ef: self.ef_construction,
m: self.m,
});
}
if self.dimensions == 0 {
return Err(ConfigValidationError::ZeroDimensions);
}
if self.max_elements == 0 {
return Err(ConfigValidationError::ZeroMaxElements);
}
Ok(())
}
}
impl Default for HnswConfig {
fn default() -> Self {
Self::for_openai_embeddings()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum DistanceMetric {
Cosine,
Euclidean,
DotProduct,
Poincare,
}
impl Default for DistanceMetric {
fn default() -> Self {
Self::Cosine
}
}
#[derive(Debug, Clone, thiserror::Error)]
pub enum ConfigValidationError {
#[error("M parameter must be >= 2, got {0}")]
InvalidM(usize),
#[error("ef_construction ({ef}) must be >= M ({m})")]
EfTooSmall { ef: usize, m: usize },
#[error("dimensions cannot be zero")]
ZeroDimensions,
#[error("max_elements cannot be zero")]
ZeroMaxElements,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VectorIndex {
pub id: String,
pub name: String,
pub dimensions: usize,
pub size: usize,
pub config: HnswConfig,
pub created_at: Timestamp,
pub updated_at: Timestamp,
pub description: Option<String>,
}
impl VectorIndex {
pub fn new(id: impl Into<String>, name: impl Into<String>, config: HnswConfig) -> Self {
let now = Timestamp::now();
Self {
id: id.into(),
name: name.into(),
dimensions: config.dimensions,
size: 0,
config,
created_at: now,
updated_at: now,
description: None,
}
}
pub fn update_size(&mut self, size: usize) {
self.size = size;
self.updated_at = Timestamp::now();
}
pub fn with_description(mut self, desc: impl Into<String>) -> Self {
self.description = Some(desc.into());
self
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum EdgeType {
Similar,
Sequential,
SameCluster,
SameSource,
Custom,
}
impl Default for EdgeType {
fn default() -> Self {
Self::Similar
}
}
impl fmt::Display for EdgeType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Similar => write!(f, "similar"),
Self::Sequential => write!(f, "sequential"),
Self::SameCluster => write!(f, "same_cluster"),
Self::SameSource => write!(f, "same_source"),
Self::Custom => write!(f, "custom"),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SimilarityEdge {
pub from_id: EmbeddingId,
pub to_id: EmbeddingId,
pub distance: f32,
pub edge_type: EdgeType,
pub created_at: Timestamp,
pub weight: Option<f32>,
pub metadata: Option<EdgeMetadata>,
}
impl SimilarityEdge {
pub fn new(from_id: EmbeddingId, to_id: EmbeddingId, distance: f32) -> Self {
Self {
from_id,
to_id,
distance,
edge_type: EdgeType::Similar,
created_at: Timestamp::now(),
weight: None,
metadata: None,
}
}
pub fn sequential(from_id: EmbeddingId, to_id: EmbeddingId) -> Self {
Self {
from_id,
to_id,
distance: 0.0,
edge_type: EdgeType::Sequential,
created_at: Timestamp::now(),
weight: None,
metadata: None,
}
}
pub fn with_type(mut self, edge_type: EdgeType) -> Self {
self.edge_type = edge_type;
self
}
pub fn with_weight(mut self, weight: f32) -> Self {
self.weight = Some(weight);
self
}
pub fn with_metadata(mut self, metadata: EdgeMetadata) -> Self {
self.metadata = Some(metadata);
self
}
#[inline]
pub fn similarity(&self) -> f32 {
1.0 - self.distance.clamp(0.0, 1.0)
}
#[inline]
pub fn is_strong(&self, threshold: f32) -> bool {
self.similarity() >= threshold
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct EdgeMetadata {
pub source: Option<String>,
pub confidence: Option<f32>,
pub attributes: hashbrown::HashMap<String, String>,
}
impl EdgeMetadata {
pub fn new() -> Self {
Self::default()
}
pub fn with_source(mut self, source: impl Into<String>) -> Self {
self.source = Some(source.into());
self
}
pub fn with_confidence(mut self, confidence: f32) -> Self {
self.confidence = Some(confidence);
self
}
pub fn with_attribute(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.attributes.insert(key.into(), value.into());
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StoredVector {
pub id: EmbeddingId,
pub vector: Vec<f32>,
pub created_at: Timestamp,
pub metadata: Option<VectorMetadata>,
}
impl StoredVector {
pub fn new(id: EmbeddingId, vector: Vec<f32>) -> Self {
Self {
id,
vector,
created_at: Timestamp::now(),
metadata: None,
}
}
pub fn with_metadata(mut self, metadata: VectorMetadata) -> Self {
self.metadata = Some(metadata);
self
}
#[inline]
pub fn dimensions(&self) -> usize {
self.vector.len()
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct VectorMetadata {
pub source_id: Option<String>,
pub source_timestamp: Option<f64>,
pub labels: Vec<String>,
pub attributes: hashbrown::HashMap<String, serde_json::Value>,
}
impl VectorMetadata {
pub fn new() -> Self {
Self::default()
}
pub fn with_source_id(mut self, id: impl Into<String>) -> Self {
self.source_id = Some(id.into());
self
}
pub fn with_source_timestamp(mut self, ts: f64) -> Self {
self.source_timestamp = Some(ts);
self
}
pub fn with_label(mut self, label: impl Into<String>) -> Self {
self.labels.push(label.into());
self
}
pub fn with_attribute(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
self.attributes.insert(key.into(), value);
self
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_embedding_id_creation() {
let id1 = EmbeddingId::new();
let id2 = EmbeddingId::new();
assert_ne!(id1, id2);
}
#[test]
fn test_embedding_id_parse() {
let id = EmbeddingId::new();
let s = id.to_string();
let parsed = EmbeddingId::parse(&s).unwrap();
assert_eq!(id, parsed);
}
#[test]
fn test_hnsw_config_default() {
let config = HnswConfig::default();
assert_eq!(config.dimensions, 1536);
assert_eq!(config.m, 32);
assert!(config.validate().is_ok());
}
#[test]
fn test_hnsw_config_validation() {
let config = HnswConfig::default().with_m(1);
assert!(config.validate().is_err());
let config = HnswConfig::default().with_ef_construction(10);
assert!(config.validate().is_err());
}
#[test]
fn test_similarity_edge() {
let from = EmbeddingId::new();
let to = EmbeddingId::new();
let edge = SimilarityEdge::new(from, to, 0.2);
assert_eq!(edge.similarity(), 0.8);
assert!(edge.is_strong(0.7));
assert!(!edge.is_strong(0.9));
}
#[test]
fn test_timestamp() {
let ts1 = Timestamp::now();
std::thread::sleep(std::time::Duration::from_millis(10));
let ts2 = Timestamp::now();
assert!(ts2 > ts1);
}
}