#![deny(clippy::all)]
use napi::bindgen_prelude::*;
use napi_derive::napi;
use ruvector_gnn::{
compress::{
CompressedTensor as RustCompressedTensor, CompressionLevel as RustCompressionLevel,
TensorCompress as RustTensorCompress,
},
layer::RuvectorLayer as RustRuvectorLayer,
search::{
differentiable_search as rust_differentiable_search,
hierarchical_forward as rust_hierarchical_forward,
},
};
#[napi]
pub struct RuvectorLayer {
inner: RustRuvectorLayer,
}
#[napi]
impl RuvectorLayer {
#[napi(constructor)]
pub fn new(input_dim: u32, hidden_dim: u32, heads: u32, dropout: f64) -> Result<Self> {
let inner = RustRuvectorLayer::new(
input_dim as usize,
hidden_dim as usize,
heads as usize,
dropout as f32,
)
.map_err(|e| Error::new(Status::InvalidArg, e.to_string()))?;
Ok(Self { inner })
}
#[napi]
pub fn forward(
&self,
node_embedding: Float32Array,
neighbor_embeddings: Vec<Float32Array>,
edge_weights: Float32Array,
) -> Result<Float32Array> {
let node_slice = node_embedding.as_ref();
let neighbors_vec: Vec<Vec<f32>> = neighbor_embeddings
.into_iter()
.map(|arr| arr.to_vec())
.collect();
let weights_slice = edge_weights.as_ref();
let result = self
.inner
.forward(node_slice, &neighbors_vec, weights_slice);
Ok(Float32Array::new(result))
}
#[napi]
pub fn to_json(&self) -> Result<String> {
serde_json::to_string(&self.inner).map_err(|e| {
Error::new(
Status::GenericFailure,
format!("Serialization error: {}", e),
)
})
}
#[napi(factory)]
pub fn from_json(json: String) -> Result<Self> {
let inner: RustRuvectorLayer = serde_json::from_str(&json).map_err(|e| {
Error::new(
Status::GenericFailure,
format!("Deserialization error: {}", e),
)
})?;
Ok(Self { inner })
}
}
#[napi(object)]
pub struct CompressionLevelConfig {
pub level_type: String,
pub scale: Option<f64>,
pub subvectors: Option<u32>,
pub centroids: Option<u32>,
pub outlier_threshold: Option<f64>,
pub threshold: Option<f64>,
}
impl CompressionLevelConfig {
fn to_rust(&self) -> Result<RustCompressionLevel> {
match self.level_type.as_str() {
"none" => Ok(RustCompressionLevel::None),
"half" => Ok(RustCompressionLevel::Half {
scale: self.scale.unwrap_or(1.0) as f32,
}),
"pq8" => Ok(RustCompressionLevel::PQ8 {
subvectors: self.subvectors.unwrap_or(8) as u8,
centroids: self.centroids.unwrap_or(16) as u8,
}),
"pq4" => Ok(RustCompressionLevel::PQ4 {
subvectors: self.subvectors.unwrap_or(8) as u8,
outlier_threshold: self.outlier_threshold.unwrap_or(3.0) as f32,
}),
"binary" => Ok(RustCompressionLevel::Binary {
threshold: self.threshold.unwrap_or(0.0) as f32,
}),
_ => Err(Error::new(
Status::InvalidArg,
format!("Invalid compression level: {}", self.level_type),
)),
}
}
}
#[napi]
pub struct TensorCompress {
inner: RustTensorCompress,
}
#[napi]
impl TensorCompress {
#[napi(constructor)]
pub fn new() -> Self {
Self {
inner: RustTensorCompress::new(),
}
}
#[napi]
pub fn compress(&self, embedding: Float32Array, access_freq: f64) -> Result<String> {
let embedding_slice = embedding.as_ref();
let compressed = self
.inner
.compress(embedding_slice, access_freq as f32)
.map_err(|e| Error::new(Status::GenericFailure, format!("Compression error: {}", e)))?;
serde_json::to_string(&compressed).map_err(|e| {
Error::new(
Status::GenericFailure,
format!("Serialization error: {}", e),
)
})
}
#[napi]
pub fn compress_with_level(
&self,
embedding: Float32Array,
level: CompressionLevelConfig,
) -> Result<String> {
let embedding_slice = embedding.as_ref();
let rust_level = level.to_rust()?;
let compressed = self
.inner
.compress_with_level(embedding_slice, &rust_level)
.map_err(|e| Error::new(Status::GenericFailure, format!("Compression error: {}", e)))?;
serde_json::to_string(&compressed).map_err(|e| {
Error::new(
Status::GenericFailure,
format!("Serialization error: {}", e),
)
})
}
#[napi]
pub fn decompress(&self, compressed_json: String) -> Result<Float32Array> {
let compressed: RustCompressedTensor =
serde_json::from_str(&compressed_json).map_err(|e| {
Error::new(
Status::GenericFailure,
format!("Deserialization error: {}", e),
)
})?;
let result = self.inner.decompress(&compressed).map_err(|e| {
Error::new(
Status::GenericFailure,
format!("Decompression error: {}", e),
)
})?;
Ok(Float32Array::new(result))
}
}
#[napi(object)]
pub struct SearchResult {
pub indices: Vec<u32>,
pub weights: Vec<f64>,
}
#[napi]
pub fn differentiable_search(
query: Float32Array,
candidate_embeddings: Vec<Float32Array>,
k: u32,
temperature: f64,
) -> Result<SearchResult> {
let query_slice = query.as_ref();
let candidates_vec: Vec<Vec<f32>> = candidate_embeddings
.into_iter()
.map(|arr| arr.to_vec())
.collect();
let (indices, weights) =
rust_differentiable_search(query_slice, &candidates_vec, k as usize, temperature as f32);
Ok(SearchResult {
indices: indices.iter().map(|&i| i as u32).collect(),
weights: weights.iter().map(|&w| w as f64).collect(),
})
}
#[napi]
pub fn hierarchical_forward(
query: Float32Array,
layer_embeddings: Vec<Vec<Float32Array>>,
gnn_layers_json: Vec<String>,
) -> Result<Float32Array> {
let query_slice = query.as_ref();
let embeddings_f32: Vec<Vec<Vec<f32>>> = layer_embeddings
.into_iter()
.map(|layer| layer.into_iter().map(|arr| arr.to_vec()).collect())
.collect();
let gnn_layers: Vec<RustRuvectorLayer> = gnn_layers_json
.iter()
.map(|json| {
serde_json::from_str(json).map_err(|e| {
Error::new(
Status::GenericFailure,
format!("Layer deserialization error: {}", e),
)
})
})
.collect::<Result<Vec<_>>>()?;
let result = rust_hierarchical_forward(query_slice, &embeddings_f32, &gnn_layers);
Ok(Float32Array::new(result))
}
#[napi]
pub fn get_compression_level(access_freq: f64) -> String {
if access_freq > 0.8 {
"none".to_string()
} else if access_freq > 0.4 {
"half".to_string()
} else if access_freq > 0.1 {
"pq8".to_string()
} else if access_freq > 0.01 {
"pq4".to_string()
} else {
"binary".to_string()
}
}
#[napi]
pub fn init() -> String {
"Ruvector GNN Node.js bindings initialized".to_string()
}