#[cfg(not(target_arch = "wasm32"))]
mod native {
use ruvector_gnn::compress::TensorCompress;
use ruvector_gnn::ewc::ElasticWeightConsolidation;
use ruvector_gnn::replay::ReplayBuffer;
const MIN_FEEDBACK_ENTRIES: usize = 32;
pub struct SearchLearner {
replay_buffer: ReplayBuffer,
ewc: ElasticWeightConsolidation,
weights: Vec<f32>,
}
impl SearchLearner {
pub fn new(embedding_dim: usize, replay_capacity: usize) -> Self {
Self {
replay_buffer: ReplayBuffer::new(replay_capacity),
ewc: ElasticWeightConsolidation::new(100.0),
weights: vec![1.0; embedding_dim],
}
}
pub fn record_feedback(
&mut self,
query_embedding: Vec<f32>,
result_embedding: Vec<f32>,
relevant: bool,
) {
let mut combined = query_embedding;
combined.extend_from_slice(&result_embedding);
let positive_id: usize = if relevant { 1 } else { 0 };
self.replay_buffer.add(&combined, &[positive_id]);
}
pub fn replay_buffer_len(&self) -> usize {
self.replay_buffer.len()
}
pub fn has_sufficient_data(&self) -> bool {
self.replay_buffer.len() >= MIN_FEEDBACK_ENTRIES
}
pub fn consolidate(&mut self) {
if self.replay_buffer.is_empty() {
return;
}
let samples = self.replay_buffer.sample(
self.replay_buffer.len().min(64),
);
let dim = self.weights.len();
let gradients: Vec<Vec<f32>> = samples
.iter()
.filter_map(|entry| {
if entry.query.len() >= dim * 2 {
let query_part = &entry.query[..dim];
let result_part = &entry.query[dim..dim * 2];
let grad: Vec<f32> = query_part
.iter()
.zip(result_part.iter())
.map(|(q, r)| q - r)
.collect();
Some(grad)
} else {
None
}
})
.collect();
if gradients.is_empty() {
return;
}
let grad_refs: Vec<&[f32]> = gradients.iter().map(|g| g.as_slice()).collect();
let sample_count = grad_refs.len();
self.ewc.compute_fisher(&grad_refs, sample_count);
self.ewc.consolidate(&self.weights);
}
pub fn ewc_penalty(&self) -> f32 {
self.ewc.penalty(&self.weights)
}
}
pub struct EmbeddingQuantizer {
compressor: TensorCompress,
}
impl Default for EmbeddingQuantizer {
fn default() -> Self {
Self::new()
}
}
impl EmbeddingQuantizer {
pub fn new() -> Self {
Self {
compressor: TensorCompress::new(),
}
}
pub fn quantize_by_age(&self, embedding: &[f32], age_hours: u64) -> Vec<u8> {
let access_freq = Self::age_to_freq(age_hours);
match self.compressor.compress(embedding, access_freq) {
Ok(compressed) => {
serde_json::to_vec(&compressed).unwrap_or_else(|_| {
embedding.iter().flat_map(|f| f.to_le_bytes()).collect()
})
}
Err(_) => {
embedding.iter().flat_map(|f| f.to_le_bytes()).collect()
}
}
}
pub fn dequantize(&self, data: &[u8], original_dim: usize) -> Vec<f32> {
if let Ok(compressed) =
serde_json::from_slice::<ruvector_gnn::compress::CompressedTensor>(data)
{
if let Ok(decompressed) = self.compressor.decompress(&compressed) {
if decompressed.len() == original_dim {
return decompressed;
}
}
}
if data.len() == original_dim * 4 {
return data
.chunks_exact(4)
.map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
.collect();
}
vec![0.0; original_dim]
}
fn age_to_freq(age_hours: u64) -> f32 {
match age_hours {
0 => 1.0, 1..=24 => 0.5, 25..=168 => 0.2, _ => 0.005, }
}
}
}
#[cfg(target_arch = "wasm32")]
mod wasm_stub {
pub struct SearchLearner {
buffer_len: usize,
}
impl SearchLearner {
pub fn new(_embedding_dim: usize, _replay_capacity: usize) -> Self {
Self { buffer_len: 0 }
}
pub fn record_feedback(
&mut self,
_query_embedding: Vec<f32>,
_result_embedding: Vec<f32>,
_relevant: bool,
) {
self.buffer_len += 1;
}
pub fn replay_buffer_len(&self) -> usize {
self.buffer_len
}
pub fn has_sufficient_data(&self) -> bool {
self.buffer_len >= 32
}
pub fn consolidate(&mut self) {}
pub fn ewc_penalty(&self) -> f32 {
0.0
}
}
pub struct EmbeddingQuantizer;
impl EmbeddingQuantizer {
pub fn new() -> Self {
Self
}
pub fn quantize_by_age(&self, embedding: &[f32], _age_hours: u64) -> Vec<u8> {
embedding.iter().flat_map(|f| f.to_le_bytes()).collect()
}
pub fn dequantize(&self, data: &[u8], original_dim: usize) -> Vec<f32> {
if data.len() == original_dim * 4 {
data.chunks_exact(4)
.map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
.collect()
} else {
vec![0.0; original_dim]
}
}
}
}
#[cfg(not(target_arch = "wasm32"))]
pub use native::{EmbeddingQuantizer, SearchLearner};
#[cfg(target_arch = "wasm32")]
pub use wasm_stub::{EmbeddingQuantizer, SearchLearner};