#![cfg(feature = "wasm")]
use crate::{LearningSignal, SonaConfig, SonaEngine};
use parking_lot::RwLock;
use std::sync::Arc;
use wasm_bindgen::prelude::*;
#[wasm_bindgen]
pub struct WasmSonaEngine {
inner: Arc<RwLock<SonaEngine>>,
}
#[wasm_bindgen]
impl WasmSonaEngine {
#[wasm_bindgen(constructor)]
pub fn new(hidden_dim: usize) -> Result<WasmSonaEngine, JsValue> {
#[cfg(feature = "console_error_panic_hook")]
console_error_panic_hook::set_once();
Ok(Self {
inner: Arc::new(RwLock::new(SonaEngine::new(hidden_dim))),
})
}
#[wasm_bindgen(js_name = withConfig)]
pub fn with_config(config: JsValue) -> Result<WasmSonaEngine, JsValue> {
#[cfg(feature = "console_error_panic_hook")]
console_error_panic_hook::set_once();
let config: SonaConfig = serde_wasm_bindgen::from_value(config)?;
Ok(Self {
inner: Arc::new(RwLock::new(SonaEngine::with_config(config))),
})
}
#[wasm_bindgen(js_name = startTrajectory)]
pub fn start_trajectory(&self, query_embedding: Vec<f32>) -> u64 {
let engine = self.inner.read();
let builder = engine.begin_trajectory(query_embedding);
use std::sync::atomic::{AtomicU64, Ordering};
static NEXT_ID: AtomicU64 = AtomicU64::new(1);
NEXT_ID.fetch_add(1, Ordering::Relaxed)
}
#[wasm_bindgen(js_name = recordStep)]
pub fn record_step(&self, trajectory_id: u64, node_id: u32, score: f32, latency_us: u64) {
web_sys::console::log_1(
&format!(
"Recording step: traj={}, node={}, score={}, latency={}us",
trajectory_id, node_id, score, latency_us
)
.into(),
);
}
#[wasm_bindgen(js_name = endTrajectory)]
pub fn end_trajectory(&self, trajectory_id: u64, final_score: f32) {
web_sys::console::log_1(
&format!(
"Ending trajectory: traj={}, score={}",
trajectory_id, final_score
)
.into(),
);
}
#[wasm_bindgen(js_name = learnFromFeedback)]
pub fn learn_from_feedback(&self, success: bool, latency_ms: f32, quality: f32) {
let reward = if success { quality } else { -quality };
web_sys::console::log_1(
&format!(
"Feedback: success={}, latency={}ms, quality={}, reward={}",
success, latency_ms, quality, reward
)
.into(),
);
}
#[wasm_bindgen(js_name = applyLora)]
pub fn apply_lora(&self, input: Vec<f32>) -> Vec<f32> {
let mut output = vec![0.0; input.len()];
let engine = self.inner.read();
engine.apply_micro_lora(&input, &mut output);
output
}
#[wasm_bindgen(js_name = applyLoraLayer)]
pub fn apply_lora_layer(&self, layer_idx: usize, input: Vec<f32>) -> Vec<f32> {
let mut output = vec![0.0; input.len()];
let engine = self.inner.read();
engine.apply_base_lora(layer_idx, &input, &mut output);
output
}
#[wasm_bindgen(js_name = runInstantCycle)]
pub fn run_instant_cycle(&self) {
let engine = self.inner.read();
engine.flush();
}
#[wasm_bindgen]
pub fn tick(&self) -> bool {
let engine = self.inner.read();
engine.tick().is_some()
}
#[wasm_bindgen(js_name = forceLearn)]
pub fn force_learn(&self) -> String {
let engine = self.inner.read();
engine.force_learn()
}
#[wasm_bindgen(js_name = getStats)]
pub fn get_stats(&self) -> JsValue {
let engine = self.inner.read();
let stats = engine.stats();
serde_wasm_bindgen::to_value(&stats).unwrap_or(JsValue::NULL)
}
#[wasm_bindgen(js_name = setEnabled)]
pub fn set_enabled(&self, enabled: bool) {
let mut engine = self.inner.write();
engine.set_enabled(enabled);
}
#[wasm_bindgen(js_name = isEnabled)]
pub fn is_enabled(&self) -> bool {
let engine = self.inner.read();
engine.is_enabled()
}
#[wasm_bindgen(js_name = getConfig)]
pub fn get_config(&self) -> JsValue {
let engine = self.inner.read();
let config = engine.config();
serde_wasm_bindgen::to_value(config).unwrap_or(JsValue::NULL)
}
#[wasm_bindgen(js_name = findPatterns)]
pub fn find_patterns(&self, query_embedding: Vec<f32>, k: usize) -> JsValue {
let engine = self.inner.read();
let patterns = engine.find_patterns(&query_embedding, k);
serde_wasm_bindgen::to_value(&patterns).unwrap_or(JsValue::NULL)
}
}
#[wasm_bindgen(start)]
pub fn wasm_init() {
#[cfg(feature = "console_error_panic_hook")]
console_error_panic_hook::set_once();
web_sys::console::log_1(&"SONA WASM module initialized".into());
}
use crate::training::{
EphemeralAgent as RustEphemeralAgent, FederatedCoordinator as RustFederatedCoordinator,
FederatedTopology,
};
#[wasm_bindgen]
pub struct WasmEphemeralAgent {
inner: RustEphemeralAgent,
}
#[wasm_bindgen]
impl WasmEphemeralAgent {
#[wasm_bindgen(constructor)]
pub fn new(agent_id: &str) -> Result<WasmEphemeralAgent, JsValue> {
let config = SonaConfig::for_ephemeral();
Ok(Self {
inner: RustEphemeralAgent::new(agent_id, config),
})
}
#[wasm_bindgen(js_name = withConfig)]
pub fn with_config(agent_id: &str, config: JsValue) -> Result<WasmEphemeralAgent, JsValue> {
let config: SonaConfig = serde_wasm_bindgen::from_value(config)?;
Ok(Self {
inner: RustEphemeralAgent::new(agent_id, config),
})
}
#[wasm_bindgen(js_name = processTask)]
pub fn process_task(&mut self, embedding: Vec<f32>, quality: f32) {
self.inner.process_task(embedding, quality);
}
#[wasm_bindgen(js_name = processTaskWithRoute)]
pub fn process_task_with_route(&mut self, embedding: Vec<f32>, quality: f32, route: &str) {
self.inner
.process_task_with_route(embedding, quality, route);
}
#[wasm_bindgen(js_name = exportState)]
pub fn export_state(&self) -> JsValue {
let export = self.inner.export_state();
serde_wasm_bindgen::to_value(&export).unwrap_or(JsValue::NULL)
}
#[wasm_bindgen(js_name = getStats)]
pub fn get_stats(&self) -> JsValue {
let stats = self.inner.stats();
serde_wasm_bindgen::to_value(&stats).unwrap_or(JsValue::NULL)
}
#[wasm_bindgen(js_name = trajectoryCount)]
pub fn trajectory_count(&self) -> usize {
self.inner.trajectory_count()
}
#[wasm_bindgen(js_name = averageQuality)]
pub fn average_quality(&self) -> f32 {
self.inner.average_quality()
}
#[wasm_bindgen(js_name = uptimeSeconds)]
pub fn uptime_seconds(&self) -> u64 {
self.inner.uptime_seconds()
}
#[wasm_bindgen]
pub fn clear(&mut self) {
self.inner.clear();
}
#[wasm_bindgen(js_name = forceLearn)]
pub fn force_learn(&self) -> String {
self.inner.force_learn()
}
#[wasm_bindgen(js_name = getPatterns)]
pub fn get_patterns(&self) -> JsValue {
let patterns = self.inner.get_patterns();
serde_wasm_bindgen::to_value(&patterns).unwrap_or(JsValue::NULL)
}
}
#[wasm_bindgen]
pub struct WasmFederatedCoordinator {
inner: RustFederatedCoordinator,
}
#[wasm_bindgen]
impl WasmFederatedCoordinator {
#[wasm_bindgen(constructor)]
pub fn new(coordinator_id: &str) -> Result<WasmFederatedCoordinator, JsValue> {
let config = SonaConfig::for_coordinator();
Ok(Self {
inner: RustFederatedCoordinator::new(coordinator_id, config),
})
}
#[wasm_bindgen(js_name = withConfig)]
pub fn with_config(
coordinator_id: &str,
config: JsValue,
) -> Result<WasmFederatedCoordinator, JsValue> {
let config: SonaConfig = serde_wasm_bindgen::from_value(config)?;
Ok(Self {
inner: RustFederatedCoordinator::new(coordinator_id, config),
})
}
#[wasm_bindgen(js_name = setQualityThreshold)]
pub fn set_quality_threshold(&mut self, threshold: f32) {
self.inner.set_quality_threshold(threshold);
}
#[wasm_bindgen]
pub fn aggregate(&mut self, agent_export: JsValue) -> JsValue {
use crate::training::AgentExport;
match serde_wasm_bindgen::from_value::<AgentExport>(agent_export) {
Ok(export) => {
let result = self.inner.aggregate(export);
serde_wasm_bindgen::to_value(&result).unwrap_or(JsValue::NULL)
}
Err(e) => {
web_sys::console::error_1(&format!("Failed to parse agent export: {:?}", e).into());
JsValue::NULL
}
}
}
#[wasm_bindgen]
pub fn consolidate(&self) -> String {
self.inner.consolidate()
}
#[wasm_bindgen(js_name = getStats)]
pub fn get_stats(&self) -> JsValue {
let stats = self.inner.stats();
serde_wasm_bindgen::to_value(&stats).unwrap_or(JsValue::NULL)
}
#[wasm_bindgen(js_name = agentCount)]
pub fn agent_count(&self) -> usize {
self.inner.agent_count()
}
#[wasm_bindgen(js_name = totalTrajectories)]
pub fn total_trajectories(&self) -> usize {
self.inner.total_trajectories()
}
#[wasm_bindgen(js_name = getPatterns)]
pub fn get_patterns(&self) -> JsValue {
let patterns = self.inner.get_all_patterns();
serde_wasm_bindgen::to_value(&patterns).unwrap_or(JsValue::NULL)
}
#[wasm_bindgen(js_name = findPatterns)]
pub fn find_patterns(&self, query_embedding: Vec<f32>, k: usize) -> JsValue {
let patterns = self.inner.find_patterns(&query_embedding, k);
serde_wasm_bindgen::to_value(&patterns).unwrap_or(JsValue::NULL)
}
#[wasm_bindgen(js_name = applyLora)]
pub fn apply_lora(&self, input: Vec<f32>) -> Vec<f32> {
self.inner.apply_lora(&input)
}
#[wasm_bindgen]
pub fn clear(&mut self) {
self.inner.clear();
}
}
#[cfg(feature = "wasm")]
mod serde_wasm_bindgen {
use super::*;
use serde::Serialize;
pub fn to_value<T: Serialize>(value: &T) -> Result<JsValue, JsValue> {
serde_json::to_string(value)
.map(|s| JsValue::from_str(&s))
.map_err(|e| JsValue::from_str(&e.to_string()))
}
pub fn from_value<T: serde::de::DeserializeOwned>(value: JsValue) -> Result<T, JsValue> {
if let Some(s) = value.as_string() {
serde_json::from_str(&s).map_err(|e| JsValue::from_str(&e.to_string()))
} else {
Err(JsValue::from_str("Expected JSON string"))
}
}
}