use crate::AletheiaDB;
use crate::core::error::{Error, Result, VectorError};
use crate::core::id::NodeId;
pub struct ConceptAlgebra<'a> {
db: &'a AletheiaDB,
property_name: Option<String>,
}
impl<'a> ConceptAlgebra<'a> {
pub fn new(db: &'a AletheiaDB) -> Self {
Self {
db,
property_name: None,
}
}
pub fn with_property(mut self, name: impl Into<String>) -> Self {
self.property_name = Some(name.into());
self
}
fn get_property_name(&self) -> Result<String> {
if let Some(ref name) = self.property_name {
return Ok(name.clone());
}
let indexes = self.db.list_vector_indexes();
if let Some(idx_info) = indexes.first() {
Ok(idx_info.property_name.clone())
} else {
Err(Error::Vector(VectorError::IndexError(
"No vector indexes configured. Specify a property name or enable an index."
.to_string(),
)))
}
}
fn get_vector(&self, node_id: NodeId, property: &str) -> Result<Vec<f32>> {
let node = self.db.get_node(node_id)?;
let val = node.properties.get(property).ok_or_else(|| {
Error::Vector(VectorError::IndexError(format!(
"Node {} does not have vector property '{}'",
node_id, property
)))
})?;
if let Some(vec) = val.as_vector() {
Ok(vec.to_vec())
} else {
Err(Error::Vector(VectorError::IndexError(format!(
"Property '{}' on node {} is not a vector",
property, node_id
))))
}
}
pub fn add(&self, a: NodeId, b: NodeId, k: usize) -> Result<Vec<(NodeId, f32)>> {
let prop = self.get_property_name()?;
let vec_a = self.get_vector(a, &prop)?;
let vec_b = self.get_vector(b, &prop)?;
if vec_a.len() != vec_b.len() {
return Err(Error::Vector(VectorError::DimensionMismatch {
expected: vec_a.len(),
actual: vec_b.len(),
}));
}
let sum: Vec<f32> = vec_a.iter().zip(vec_b.iter()).map(|(x, y)| x + y).collect();
self.db.search_vectors_in(&prop, &sum, k)
}
pub fn subtract(&self, a: NodeId, b: NodeId, k: usize) -> Result<Vec<(NodeId, f32)>> {
let prop = self.get_property_name()?;
let vec_a = self.get_vector(a, &prop)?;
let vec_b = self.get_vector(b, &prop)?;
if vec_a.len() != vec_b.len() {
return Err(Error::Vector(VectorError::DimensionMismatch {
expected: vec_a.len(),
actual: vec_b.len(),
}));
}
let diff: Vec<f32> = vec_a.iter().zip(vec_b.iter()).map(|(x, y)| x - y).collect();
self.db.search_vectors_in(&prop, &diff, k)
}
pub fn analogy(&self, a: NodeId, b: NodeId, c: NodeId, k: usize) -> Result<Vec<(NodeId, f32)>> {
let prop = self.get_property_name()?;
let vec_a = self.get_vector(a, &prop)?;
let vec_b = self.get_vector(b, &prop)?;
let vec_c = self.get_vector(c, &prop)?;
if vec_a.len() != vec_b.len() || vec_a.len() != vec_c.len() {
return Err(Error::Vector(VectorError::DimensionMismatch {
expected: vec_a.len(),
actual: if vec_a.len() != vec_b.len() {
vec_b.len()
} else {
vec_c.len()
},
}));
}
let result: Vec<f32> = vec_a
.iter()
.zip(vec_b.iter())
.zip(vec_c.iter())
.map(|((x, y), z)| x - y + z)
.collect();
self.db.search_vectors_in(&prop, &result, k)
}
pub fn mean(&self, nodes: &[NodeId], k: usize) -> Result<Vec<(NodeId, f32)>> {
if nodes.is_empty() {
return Ok(Vec::new());
}
let prop = self.get_property_name()?;
let mut sum_vec: Option<Vec<f32>> = None;
let count = nodes.len() as f32;
for &node_id in nodes {
let vec = self.get_vector(node_id, &prop)?;
if let Some(ref mut sum) = sum_vec {
if sum.len() != vec.len() {
return Err(Error::Vector(VectorError::DimensionMismatch {
expected: sum.len(),
actual: vec.len(),
}));
}
for (s, v) in sum.iter_mut().zip(vec.iter()) {
*s += v;
}
} else {
sum_vec = Some(vec);
}
}
let centroid: Vec<f32> = sum_vec.unwrap().into_iter().map(|x| x / count).collect();
self.db.search_vectors_in(&prop, ¢roid, k)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::property::PropertyMapBuilder;
use crate::index::vector::{DistanceMetric, HnswConfig};
fn setup_db() -> AletheiaDB {
let db = AletheiaDB::new().unwrap();
let config = HnswConfig::new(2, DistanceMetric::Cosine);
db.enable_vector_index("embedding", config).unwrap();
db
}
#[test]
fn test_concept_addition() {
let db = setup_db();
let props_a = PropertyMapBuilder::new()
.insert("name", "A")
.insert_vector("embedding", &[1.0, 0.0])
.build();
let a = db.create_node("Node", props_a).unwrap();
let props_b = PropertyMapBuilder::new()
.insert("name", "B")
.insert_vector("embedding", &[0.0, 1.0])
.build();
let b = db.create_node("Node", props_b).unwrap();
let props_c = PropertyMapBuilder::new()
.insert("name", "C")
.insert_vector("embedding", &[1.0, 1.0])
.build();
let c = db.create_node("Node", props_c).unwrap();
let algebra = ConceptAlgebra::new(&db).with_property("embedding");
let results = algebra.add(a, b, 1).unwrap();
assert!(!results.is_empty());
assert_eq!(results[0].0, c);
}
#[test]
fn test_concept_analogy() {
let db = setup_db();
let man = db
.create_node(
"Concept",
PropertyMapBuilder::new()
.insert_vector("embedding", &[1.0, 0.0])
.build(),
)
.unwrap();
let woman = db
.create_node(
"Concept",
PropertyMapBuilder::new()
.insert_vector("embedding", &[2.0, 0.0])
.build(),
)
.unwrap();
let king = db
.create_node(
"Concept",
PropertyMapBuilder::new()
.insert_vector("embedding", &[1.0, 1.0])
.build(),
)
.unwrap();
let queen = db
.create_node(
"Concept",
PropertyMapBuilder::new()
.insert_vector("embedding", &[2.0, 1.0])
.build(),
)
.unwrap();
let algebra = ConceptAlgebra::new(&db);
let results = algebra.analogy(king, man, woman, 1).unwrap();
assert!(!results.is_empty());
assert_eq!(results[0].0, queen);
}
#[test]
fn test_concept_mean() {
let db = setup_db();
let n1 = db
.create_node(
"P",
PropertyMapBuilder::new()
.insert_vector("embedding", &[1.0, 0.0])
.build(),
)
.unwrap();
let n2 = db
.create_node(
"P",
PropertyMapBuilder::new()
.insert_vector("embedding", &[0.0, 1.0])
.build(),
)
.unwrap();
let n3 = db
.create_node(
"P",
PropertyMapBuilder::new()
.insert_vector("embedding", &[1.0, 1.0])
.build(),
)
.unwrap();
let target = n3;
let algebra = ConceptAlgebra::new(&db);
let results = algebra.mean(&[n1, n2], 1).unwrap();
assert!(!results.is_empty());
assert_eq!(results[0].0, target);
}
#[test]
fn test_error_handling() {
let db = AletheiaDB::new().unwrap();
let algebra = ConceptAlgebra::new(&db);
let n1 = db
.create_node("N", PropertyMapBuilder::new().build())
.unwrap();
let n2 = db
.create_node("N", PropertyMapBuilder::new().build())
.unwrap();
let res = algebra.add(n1, n2, 1);
assert!(res.is_err());
let config = HnswConfig::new(2, DistanceMetric::Cosine);
db.enable_vector_index("embedding", config).unwrap();
let n3 = db
.create_node("N", PropertyMapBuilder::new().build())
.unwrap();
let res = algebra.add(n3, n3, 1);
assert!(res.is_err());
if let Err(Error::Vector(VectorError::IndexError(msg))) = res {
assert!(msg.contains("does not have vector property"));
} else {
panic!("Expected IndexError");
}
}
}