#![deny(missing_docs)]
use crate::{database::*, md2f::filter_where, onnx::*};
use kn0sys_lmdb_rs as lmdb;
use kn0sys_lmdb_rs::MdbError;
use kn0sys_nn::distance::L2Dist;
use kn0sys_nn::*;
use log::*;
use ndarray::*;
use regex::Regex;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::{Arc, LazyLock, RwLock};
use thiserror::Error;
use uuid::Uuid;
use vecpac::HexNode;
use wincode::{SchemaRead, SchemaWrite};
pub struct Valentinus {
db: DatabaseEnvironment,
collections: Arc<RwLock<HashMap<String, Arc<EmbeddingCollection>>>>,
}
#[derive(Clone, Debug, Default, Deserialize, Serialize, SchemaWrite, SchemaRead)]
pub struct EmbeddingCollection {
documents: Vec<String>,
pub hex_index: HashMap<(i32, i32, i32), Vec<usize>>,
data: Vec<f32>,
shape: (usize, usize),
metadata: Vec<Vec<String>>,
model_path: String,
model_type: ModelType,
ids: Vec<String>,
key: String,
view: String,
}
#[derive(Clone, Debug, Default, Deserialize, Serialize, SchemaWrite, SchemaRead)]
pub enum ModelType {
AllMiniLmL12V2,
#[default]
AllMiniLmL6V2,
Custom,
}
#[derive(Debug, Default, Deserialize, Serialize)]
pub struct CosineQueryResult {
documents: Vec<String>,
similarities: Vec<f32>,
metadata: Vec<Vec<String>>,
}
#[derive(Debug, Error)]
pub enum ValentinusError {
#[error("Cache error: {0}")]
CacheError(String),
#[error("Serialization/deserialization error: {0}")]
WincodeError(String),
#[error("Collection '{0}' not found")]
CollectionNotFound(String),
#[error("Cosine query failure: {0}")]
CosineError(String),
#[error("Database error: {0}")]
DatabaseError(#[from] MdbError),
#[error("Invalid view name: {0}")]
InvalidViewName(String),
#[error("Metadata filter error")]
Md2fsError,
#[error("Nearest neighbors query failure: {0}")]
NearestError(String),
#[error("ONNX error")]
OnnxError(OnnxError),
#[error("Not found: {0}")]
NotFound(String),
#[error("Test failure")]
TestError,
}
#[derive(SchemaWrite, SchemaRead)]
struct PreCollection {
serde: EmbeddingCollection,
}
#[derive(Debug, Default, Deserialize, Serialize, SchemaWrite, SchemaRead)]
struct KeyViewIndexer {
values: Vec<String>,
}
#[derive(Default, SchemaWrite, SchemaRead)]
struct KVIndexer {
serde: KeyViewIndexer,
}
static VIEWS_NAMING_CHECK: LazyLock<Regex> =
LazyLock::new(|| Regex::new("^[a-zA-Z0-9_]+$").expect("regex should be valid"));
const VALENTINUS_KEYS: &str = "keys";
const VALENTINUS_VIEWS: &str = "views";
const VALENTINUS_KEY: &str = "key";
const VALENTINUS_VIEW: &str = "view";
impl Valentinus {
pub fn new(env: &str) -> Result<Self, ValentinusError> {
let db = DatabaseEnvironment::open(env)?;
Ok(Valentinus {
db,
collections: Arc::new(RwLock::new(HashMap::new())),
})
}
pub fn create_collection(
&self,
name: String,
documents: Vec<String>,
metadata: Vec<Vec<String>>,
ids: Vec<String>,
model_type: ModelType,
model_path: String,
) -> Result<(), ValentinusError> {
if !VIEWS_NAMING_CHECK.is_match(&name) {
return Err(ValentinusError::InvalidViewName(format!(
"Name '{}' must only contain alphanumerics and underscores.",
name
)));
}
info!("Generating embeddings for new collection '{}'", name);
let array_embeddings: Array2<f32> =
batch_embeddings(&model_path, &documents).map_err(ValentinusError::OnnxError)?;
let shape = (array_embeddings.nrows(), array_embeddings.ncols());
let data = array_embeddings.clone().into_raw_vec_and_offset().0;
info!("Quantizing embeddings to the Seed of Life grid...");
let mut hex_index: HashMap<(i32, i32, i32), Vec<usize>> = HashMap::new();
for (idx, row) in array_embeddings.axis_iter(Axis(0)).enumerate() {
let row_slice = row.to_slice().unwrap();
let (x, y) = Self::project_to_2d(row_slice);
let hex_node = HexNode::from_fractional(x, y);
let hex_tuple = (hex_node.q, hex_node.r, hex_node.s);
hex_index.entry(hex_tuple).or_default().push(idx);
}
let key = format!("{}-{}", VALENTINUS_KEY, Uuid::new_v4());
let view = format!("{}-{}", VALENTINUS_VIEW, name);
let collection = EmbeddingCollection {
documents,
hex_index,
data,
shape,
metadata,
model_path,
model_type,
ids,
key,
view,
};
info!("Saving new collection '{}' to database.", name);
let txn = self.db.env.new_transaction()?;
{
let db_handle = &self.db.handle;
let mut views_indexer = Self::get_indexer_mut(&txn, db_handle, VALENTINUS_VIEWS)?;
if views_indexer.serde.values.contains(&name) {
return Err(ValentinusError::InvalidViewName(format!(
"View name '{}' already exists.",
name
)));
}
views_indexer.serde.values.push(name.clone());
let mut keys_indexer = Self::get_indexer_mut(&txn, db_handle, VALENTINUS_KEYS)?;
keys_indexer.serde.values.push(collection.key.clone());
Self::write_indexer(&txn, db_handle, VALENTINUS_VIEWS, &views_indexer)?;
Self::write_indexer(&txn, db_handle, VALENTINUS_KEYS, &keys_indexer)?;
txn.bind(db_handle)
.set(&collection.view.as_bytes(), &collection.key.as_bytes())?;
let pre_collection = PreCollection {
serde: collection.clone(),
};
let encoded_collection = wincode::serialize(&pre_collection)
.map_err(|e| ValentinusError::WincodeError(e.to_string()))?;
write_chunks_in_txn(
&txn,
db_handle,
collection.key.as_bytes(),
&encoded_collection,
)?;
}
txn.commit()?;
Ok(())
}
pub fn get_collection(
&self,
view_name: &str,
) -> Result<Arc<EmbeddingCollection>, ValentinusError> {
{
let cache = self
.collections
.read()
.map_err(|e| ValentinusError::CacheError(e.to_string()))?;
if let Some(collection) = cache.values().find(|c| c.view.ends_with(view_name)) {
info!("Cache hit for collection '{}'", view_name);
return Ok(Arc::clone(collection));
}
}
let mut cache = self.collections.write().unwrap();
if let Some(collection) = cache.values().find(|c| c.view.ends_with(view_name)) {
info!("Cache hit for collection '{}' (after lock)", view_name);
return Ok(Arc::clone(collection));
}
info!(
"Cache miss. Loading collection '{}' from database.",
view_name
);
let key = self.get_key_for_view(view_name)?;
let collection_data = read(&self.db.env, &self.db.handle, &key.as_bytes().to_vec())?
.ok_or_else(|| ValentinusError::CollectionNotFound(view_name.to_string()))?;
let pre_collection: PreCollection = wincode::deserialize(&collection_data)
.map_err(|e| ValentinusError::WincodeError(e.to_string()))?;
let collection = Arc::new(pre_collection.serde);
cache.insert(key, Arc::clone(&collection));
Ok(collection)
}
pub fn delete_collection(&self, view_name: &str) -> Result<(), ValentinusError> {
info!("Deleting collection '{}'", view_name);
let txn = self.db.env.new_transaction()?;
let key_to_delete: String;
let full_view_name = format!("{}-{}", VALENTINUS_VIEW, view_name);
{
let db_handle = &self.db.handle;
let key_bytes = txn
.bind(db_handle)
.get::<Vec<u8>>(&full_view_name.as_bytes())
.map_err(|_| ValentinusError::CollectionNotFound(view_name.to_string()))?;
key_to_delete = String::from_utf8(key_bytes).unwrap_or_default();
if key_to_delete.is_empty() {
return Err(ValentinusError::CollectionNotFound(view_name.to_string()));
}
let mut views_indexer = Self::get_indexer_mut(&txn, db_handle, VALENTINUS_VIEWS)?;
views_indexer.serde.values.retain(|v| v != view_name);
Self::write_indexer(&txn, db_handle, VALENTINUS_VIEWS, &views_indexer)?;
let mut keys_indexer = Self::get_indexer_mut(&txn, db_handle, VALENTINUS_KEYS)?;
keys_indexer.serde.values.retain(|k| k != &key_to_delete);
Self::write_indexer(&txn, db_handle, VALENTINUS_KEYS, &keys_indexer)?;
delete_in_txn(&txn, db_handle, key_to_delete.as_bytes())?;
txn.bind(db_handle).del(&full_view_name.as_bytes())?;
}
txn.commit()?;
let mut cache = self.collections.write().unwrap();
cache.remove(&key_to_delete);
Ok(())
}
pub fn cosine_query(
&self,
query_string: String,
view_name: String,
num_results: usize,
f_where: Option<Vec<String>>,
) -> Result<CosineQueryResult, ValentinusError> {
info!("Starting cosine query on collection '{}'", view_name);
let collection = self.get_collection(&view_name)?;
let is_filtering = f_where.is_some();
let qv_string = vec![query_string];
let qv = batch_embeddings(&collection.model_path, &qv_string)
.map_err(ValentinusError::OnnxError)?;
let query_embedding = qv.index_axis(Axis(0), 0);
let mut results: Vec<(f32, String, Vec<String>)> = Vec::new();
let collection_embeddings =
Array2::from_shape_vec(collection.shape, collection.data.clone()).unwrap_or_default();
for (index, (cv, sentence)) in collection_embeddings
.axis_iter(Axis(0))
.zip(collection.documents.iter())
.enumerate()
{
let metadata = &collection.metadata[index];
let raw_f = f_where.as_deref().unwrap_or(&[]);
if !is_filtering
|| filter_where(raw_f, metadata).map_err(|_| ValentinusError::Md2fsError)?
{
let dot_product: f32 = query_embedding
.iter()
.zip(cv.iter())
.map(|(a, b)| a * b)
.sum();
results.push((dot_product, sentence.clone(), metadata.clone()));
}
}
results.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
if num_results > 0 && results.len() > num_results {
results.truncate(num_results);
}
let (similarities, documents, metadata) = results.into_iter().fold(
(Vec::new(), Vec::new(), Vec::new()),
|(mut sims, mut docs, mut metas), (sim, doc, meta)| {
sims.push(sim);
docs.push(doc);
metas.push(meta);
(sims, docs, metas)
},
);
Ok(CosineQueryResult {
documents,
similarities,
metadata,
})
}
pub fn hex_nearest_query(
&self,
query_string: String,
view_name: String,
) -> Result<Vec<String>, ValentinusError> {
info!("Starting hex-packed query on collection '{}'", view_name);
let collection = self.get_collection(&view_name)?;
let qv_string = vec![query_string];
let qv = batch_embeddings(&collection.model_path, &qv_string)
.map_err(ValentinusError::OnnxError)?;
let query_embedding = qv.index_axis(Axis(0), 0).to_slice().unwrap();
let (x, y) = Self::project_to_2d(query_embedding);
let target_hex = HexNode::from_fractional(x, y);
let mut candidate_indices: Vec<usize> = Vec::new();
if let Some(indices) = collection
.hex_index
.get(&(target_hex.q, target_hex.r, target_hex.s))
{
candidate_indices.extend(indices);
}
for neighbor in target_hex.neighbors() {
if let Some(indices) = collection
.hex_index
.get(&(neighbor.q, neighbor.r, neighbor.s))
{
candidate_indices.extend(indices);
}
}
candidate_indices.sort_unstable();
candidate_indices.dedup();
if candidate_indices.is_empty() {
info!("Local hex neighborhood is empty. Falling back to global semantic scan...");
for indices in collection.hex_index.values() {
candidate_indices.extend(indices);
}
candidate_indices.sort_unstable();
candidate_indices.dedup();
}
if candidate_indices.is_empty() {
return Err(ValentinusError::NotFound(
"Database is completely empty.".to_string(),
));
}
let cols = collection.shape.1; let mut scored_candidates: Vec<(f32, usize)> = Vec::new();
for &idx in &candidate_indices {
let start = idx * cols;
let end = start + cols;
let candidate_vector = &collection.data[start..end];
let score = Self::lite_cos_sim(query_embedding, candidate_vector);
scored_candidates.push((score, idx));
}
scored_candidates
.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
let results: Vec<String> = scored_candidates
.iter()
.map(|&(_, idx)| collection.documents[idx].clone())
.collect();
Ok(results)
}
pub fn nearest_query(
&self,
query_string: String,
view_name: String,
) -> Result<String, ValentinusError> {
info!("Starting nearest query on collection '{}'", view_name);
let collection = self.get_collection(&view_name)?;
let qv_string = vec![query_string];
let qv = batch_embeddings(&collection.model_path, &qv_string)
.map_err(ValentinusError::OnnxError)?;
let query_embedding = qv.index_axis(Axis(0), 0);
let collection_embeddings =
Array2::from_shape_vec(collection.shape, collection.data.clone()).unwrap_or_default();
let nn = CommonNearestNeighbour::KdTree
.batch(&collection_embeddings, L2Dist)
.map_err(|e| ValentinusError::NearestError(e.to_string()))?;
let nearest = nn
.k_nearest(query_embedding, 1)
.map_err(|e| ValentinusError::NearestError(e.to_string()))?;
if nearest.is_empty() {
return Err(ValentinusError::NotFound(
"No nearest neighbor found.".to_string(),
));
}
let nearest_embedding = nearest[0].0.to_vec();
let position = collection_embeddings
.axis_iter(Axis(0))
.position(|x| x.to_vec() == nearest_embedding);
match position {
Some(idx) => Ok(collection.documents[idx].clone()),
None => Err(ValentinusError::NotFound(
"Could not map nearest embedding back to a document.".to_string(),
)),
}
}
fn lite_cos_sim(a: &[f32], b: &[f32]) -> f32 {
let dot: f32 = a.iter().zip(b).map(|(x, y)| x * y).sum();
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm_a == 0.0 || norm_b == 0.0 {
0.0
} else {
dot / (norm_a * norm_b)
}
}
fn project_to_2d(embedding: &[f32]) -> (f64, f64) {
let mut x = 0.0;
let mut y = 0.0;
for (i, &val) in embedding.iter().enumerate() {
let theta = i as f64 * 0.1375; x += val as f64 * theta.cos();
y += val as f64 * theta.sin();
}
let scale_factor = 2.0;
(x * scale_factor, y * scale_factor)
}
fn get_key_for_view(&self, view_name: &str) -> Result<String, ValentinusError> {
let reader = self.db.env.get_reader()?;
let db = reader.bind(&self.db.handle);
let full_view_name = format!("{}-{}", VALENTINUS_VIEW, view_name);
let key_bytes = db
.get::<Vec<u8>>(&full_view_name.as_bytes())
.map_err(|_| ValentinusError::CollectionNotFound(view_name.to_string()))?;
String::from_utf8(key_bytes)
.map_err(|_| ValentinusError::CollectionNotFound("Invalid key format".to_string()))
}
fn get_indexer_mut(
txn: &lmdb::Transaction,
db_handle: &lmdb::DbHandle,
indexer_name: &str,
) -> Result<KVIndexer, ValentinusError> {
match txn.bind(db_handle).get::<Vec<u8>>(&indexer_name.as_bytes()) {
Ok(bytes) => Ok(wincode::deserialize(&bytes)
.map_err(|e| ValentinusError::WincodeError(e.to_string()))?),
Err(MdbError::NotFound) => Ok(KVIndexer::default()), Err(e) => Err(ValentinusError::DatabaseError(e)),
}
}
fn write_indexer(
txn: &lmdb::Transaction,
db_handle: &lmdb::DbHandle,
indexer_name: &str,
indexer: &KVIndexer,
) -> Result<(), ValentinusError> {
let encoded = wincode::serialize(indexer)
.map_err(|e| ValentinusError::WincodeError(e.to_string()))?;
txn.bind(db_handle)
.set(&indexer_name.as_bytes(), &encoded)?;
Ok(())
}
}
impl CosineQueryResult {
pub fn get_docs(&self) -> &Vec<String> {
&self.documents
}
pub fn get_similarities(&self) -> &Vec<f32> {
&self.similarities
}
pub fn get_metadata(&self) -> &Vec<Vec<String>> {
&self.metadata
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::Value;
use std::{fs, fs::File, path::Path};
#[derive(Default, Deserialize, SchemaWrite, SchemaRead)]
struct Review {
review: Option<String>,
rating: Option<String>,
vehicle_title: Option<String>,
}
fn setup_test_env(env_name: &str) -> Arc<Valentinus> {
let user = std::env::var("USER").unwrap_or_else(|_| "user".to_string());
let db_path = format!("/home/{}/.{}/{}", user, "valentinus", env_name);
if Path::new(&db_path).exists() {
fs::remove_dir_all(&db_path).unwrap();
}
Arc::new(Valentinus::new(env_name).unwrap())
}
#[test]
fn test_hex_nearest_query() -> Result<(), ValentinusError> {
let valentinus = setup_test_env("hex_query_test");
let query_str = "Felines resting on rugs!".to_string();
let collection_name = "hex_nearest_test_coll".to_string();
let (docs, md, ids) = create_nearest_test_data();
valentinus.create_collection(
collection_name.clone(),
docs.clone(),
md,
ids,
ModelType::AllMiniLmL6V2,
"all-MiniLM-L6-v2_onnx".to_string(),
)?;
let nearest_doc = valentinus.hex_nearest_query(query_str, collection_name.clone())?;
assert!(nearest_doc.contains(&docs[5]));
valentinus.delete_collection(&collection_name)?;
let res = valentinus.get_collection(&collection_name);
assert!(matches!(res, Err(ValentinusError::CollectionNotFound(_))));
Ok(())
}
#[test]
fn test_dense_hex_population() -> Result<(), ValentinusError> {
let valentinus = setup_test_env("dense_hex_test");
let collection_name = "dense_hex_coll".to_string();
let (docs, md, ids) = create_dense_test_data();
valentinus.create_collection(
collection_name.clone(),
docs.clone(),
md,
ids,
ModelType::AllMiniLmL6V2,
"all-MiniLM-L6-v2_onnx".to_string(),
)?;
let query_str = "The fluffy cat sat comfortably on the soft mat.".to_string();
let nearest_docs = valentinus.hex_nearest_query(query_str, collection_name.clone())?;
assert!(nearest_docs[0].contains("fluffy cat"));
valentinus.delete_collection(&collection_name)?;
Ok(())
}
#[test]
fn test_full_etl_and_query_workflow() -> Result<(), ValentinusError> {
let valentinus = setup_test_env("full_workflow_test");
let collection_name = "tesla_reviews".to_string();
let (documents, metadata, ids) = load_test_csv_data();
let expected_docs = documents.clone();
valentinus.create_collection(
collection_name.clone(),
documents,
metadata,
ids,
ModelType::AllMiniLmL6V2,
"all-MiniLM-L6-v2_onnx".to_string(),
)?;
let collection = valentinus.get_collection(&collection_name)?;
assert_eq!(collection.documents, expected_docs);
assert!(!collection.data.is_empty());
let query_string = "Find the best reviews.".to_string();
let result = valentinus.cosine_query(
query_string.clone(),
collection_name.clone(),
10,
Some(vec![
r#"{ "Year": {"eq": 2017} }"#.to_string(),
r#"{ "Rating": {"gt": 3} }"#.to_string(),
]),
)?;
assert_eq!(result.get_docs().len(), 10);
let first_meta = &result.get_metadata()[0];
let v_year: Value = serde_json::from_str(&first_meta[0]).unwrap();
let v_rating: Value = serde_json::from_str(&first_meta[1]).unwrap();
assert_eq!(v_year["Year"].as_u64().unwrap(), 2017);
assert!(v_rating["Rating"].as_u64().unwrap() > 3);
let no_filter_result =
valentinus.cosine_query(query_string, collection_name.clone(), 5, None)?;
assert_eq!(no_filter_result.get_docs().len(), 5);
let nearest_query_str = "Find me some delicious pizza!".to_string();
let nearest_collection_name = "nearest_test_coll".to_string();
let (docs, md, ids) = create_nearest_test_data();
valentinus.create_collection(
nearest_collection_name.clone(),
docs.clone(),
md,
ids,
ModelType::AllMiniLmL6V2,
"all-MiniLM-L6-v2_onnx".to_string(),
)?;
let nearest_doc =
valentinus.nearest_query(nearest_query_str, nearest_collection_name.clone())?;
assert_eq!(nearest_doc, docs[3]);
valentinus.delete_collection(&collection_name)?;
valentinus.delete_collection(&nearest_collection_name)?;
let res = valentinus.get_collection(&collection_name);
assert!(matches!(res, Err(ValentinusError::CollectionNotFound(_))));
Ok(())
}
fn load_test_csv_data() -> (Vec<String>, Vec<Vec<String>>, Vec<String>) {
let mut documents = Vec::new();
let mut metadata = Vec::new();
let file_path = Path::new(env!("CARGO_MANIFEST_DIR"))
.join("data")
.join("Scraped_Car_Review_tesla.csv");
let file = File::open(file_path).expect("csv file not found");
let mut rdr = csv::Reader::from_reader(file);
for result in rdr.deserialize() {
let record: Review = result.unwrap_or_default();
documents.push(record.review.unwrap_or_default());
let rating = record
.rating
.unwrap_or_default()
.parse::<u64>()
.unwrap_or(0);
let year_str = record.vehicle_title.unwrap_or_default();
let year = if year_str.len() >= 4 {
year_str[0..4].to_string()
} else {
"0".to_string()
};
metadata.push(vec![
format!(r#"{{"Year": {}}}"#, year),
format!(r#"{{"Rating": {}}}"#, rating),
]);
}
let ids = (0..documents.len()).map(|i| format!("id{}", i)).collect();
(documents, metadata, ids)
}
fn create_nearest_test_data() -> (Vec<String>, Vec<Vec<String>>, Vec<String>) {
let docs = [
"The latest iPhone model comes with impressive features and a powerful camera.",
"Exploring the beautiful beaches and vibrant culture of Bali is a dream for many travelers.",
"Einstein's theory of relativity revolutionized our understanding of space and time.",
"Traditional Italian pizza is famous for its thin crust, fresh ingredients, and wood-fired ovens.",
"The American Revolution had a profound impact on the birth of the United States as a nation.",
"The cat sat on the mat.",
"Dogs make great companions.",
"Sacread geometry is the blueprint of reality."
]
.iter()
.map(|s| s.to_string())
.collect::<Vec<_>>();
let ids = (0..docs.len()).map(|i| format!("id{}", i)).collect();
let metadata = vec![vec![]; docs.len()]; (docs, metadata, ids)
}
fn create_dense_test_data() -> (Vec<String>, Vec<Vec<String>>, Vec<String>) {
let mut docs = Vec::new();
for i in 0..50 {
docs.push(format!(
"The fluffy cat sat comfortably on the soft mat. Variation {}",
i
));
}
for i in 0..50 {
docs.push(format!("The new smartphone features a high-resolution camera and fast processor. Iteration {}", i));
}
for i in 0..50 {
docs.push(format!(
"Black holes possess immense gravitational pull in deep space. Object {}",
i
));
}
let ids = (0..docs.len()).map(|i| format!("dense_id_{}", i)).collect();
let metadata = vec![vec![]; docs.len()];
(docs, metadata, ids)
}
}