use serde::{Deserialize, Serialize};
use wasm_bindgen::prelude::*;
use super::helpers;
#[wasm_bindgen(start)]
pub fn init() {
#[cfg(feature = "console_error_panic_hook")]
console_error_panic_hook::set_once();
}
struct WasmEmbedding {
id: String,
vector: Vec<f32>,
metadata: String, timestamp: f64, }
#[derive(Serialize, Deserialize)]
struct SearchHit {
id: String,
score: f64,
metadata: String,
timestamp: f64,
}
#[wasm_bindgen]
pub struct OsPipeWasm {
dimension: usize,
embeddings: Vec<WasmEmbedding>,
}
#[wasm_bindgen]
impl OsPipeWasm {
#[wasm_bindgen(constructor)]
pub fn new(dimension: usize) -> Self {
Self {
dimension,
embeddings: Vec::new(),
}
}
pub fn insert(
&mut self,
id: &str,
embedding: &[f32],
metadata: &str,
timestamp: f64,
) -> Result<(), JsValue> {
if embedding.len() != self.dimension {
return Err(JsValue::from_str(&format!(
"Embedding dimension mismatch: expected {}, got {}",
self.dimension,
embedding.len()
)));
}
self.embeddings.push(WasmEmbedding {
id: id.to_string(),
vector: embedding.to_vec(),
metadata: metadata.to_string(),
timestamp,
});
Ok(())
}
pub fn search(
&self,
query_embedding: &[f32],
k: usize,
) -> Result<JsValue, JsValue> {
if query_embedding.len() != self.dimension {
return Err(JsValue::from_str(&format!(
"Query dimension mismatch: expected {}, got {}",
self.dimension,
query_embedding.len()
)));
}
let mut scored: Vec<(usize, f32)> = self
.embeddings
.iter()
.enumerate()
.map(|(i, e)| (i, helpers::cosine_similarity(query_embedding, &e.vector)))
.collect();
scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let hits: Vec<SearchHit> = scored
.into_iter()
.take(k)
.map(|(i, score)| {
let e = &self.embeddings[i];
SearchHit {
id: e.id.clone(),
score: score as f64,
metadata: e.metadata.clone(),
timestamp: e.timestamp,
}
})
.collect();
serde_wasm_bindgen::to_value(&hits).map_err(|e| JsValue::from_str(&e.to_string()))
}
pub fn search_filtered(
&self,
query_embedding: &[f32],
k: usize,
start_time: f64,
end_time: f64,
) -> Result<JsValue, JsValue> {
if query_embedding.len() != self.dimension {
return Err(JsValue::from_str(&format!(
"Query dimension mismatch: expected {}, got {}",
self.dimension,
query_embedding.len()
)));
}
let mut scored: Vec<(usize, f32)> = self
.embeddings
.iter()
.enumerate()
.filter(|(_, e)| e.timestamp >= start_time && e.timestamp <= end_time)
.map(|(i, e)| (i, helpers::cosine_similarity(query_embedding, &e.vector)))
.collect();
scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let hits: Vec<SearchHit> = scored
.into_iter()
.take(k)
.map(|(i, score)| {
let e = &self.embeddings[i];
SearchHit {
id: e.id.clone(),
score: score as f64,
metadata: e.metadata.clone(),
timestamp: e.timestamp,
}
})
.collect();
serde_wasm_bindgen::to_value(&hits).map_err(|e| JsValue::from_str(&e.to_string()))
}
pub fn is_duplicate(&self, embedding: &[f32], threshold: f32) -> bool {
self.embeddings
.iter()
.any(|e| helpers::cosine_similarity(embedding, &e.vector) >= threshold)
}
pub fn len(&self) -> usize {
self.embeddings.len()
}
pub fn is_empty(&self) -> bool {
self.embeddings.is_empty()
}
pub fn stats(&self) -> String {
serde_json::json!({
"dimension": self.dimension,
"total_embeddings": self.embeddings.len(),
"memory_estimate_bytes": self.embeddings.len() * (self.dimension * 4 + 128),
})
.to_string()
}
pub fn embed_text(&self, text: &str) -> Vec<f32> {
helpers::hash_embed(text, self.dimension)
}
pub fn batch_embed(&self, texts: JsValue) -> Result<JsValue, JsValue> {
let text_list: Vec<String> = serde_wasm_bindgen::from_value(texts)
.map_err(|e| JsValue::from_str(&format!("Failed to deserialize texts: {e}")))?;
let results: Vec<Vec<f32>> = text_list
.iter()
.map(|t| helpers::hash_embed(t, self.dimension))
.collect();
serde_wasm_bindgen::to_value(&results)
.map_err(|e| JsValue::from_str(&e.to_string()))
}
pub fn safety_check(&self, content: &str) -> String {
helpers::safety_classify(content).to_string()
}
pub fn route_query(&self, query: &str) -> String {
helpers::route_query(query).to_string()
}
}