#![cfg(feature = "napi")]
use napi::bindgen_prelude::*;
use napi_derive::napi;
use crate::{
SonaEngine as RustSonaEngine,
SonaConfig,
TrajectoryBuilder as RustTrajectoryBuilder,
LearnedPattern,
PatternType,
};
#[napi]
pub struct SonaEngine {
inner: RustSonaEngine,
}
#[napi]
impl SonaEngine {
#[napi(constructor)]
pub fn new(hidden_dim: u32) -> Self {
Self {
inner: RustSonaEngine::new(hidden_dim as usize),
}
}
#[napi(factory)]
pub fn with_config(config: JsSonaConfig) -> Self {
let rust_config = SonaConfig {
hidden_dim: config.hidden_dim as usize,
embedding_dim: config.embedding_dim.unwrap_or(config.hidden_dim) as usize,
micro_lora_rank: config.micro_lora_rank.unwrap_or(1) as usize,
base_lora_rank: config.base_lora_rank.unwrap_or(8) as usize,
micro_lora_lr: config.micro_lora_lr.unwrap_or(0.001) as f32,
base_lora_lr: config.base_lora_lr.unwrap_or(0.0001) as f32,
ewc_lambda: config.ewc_lambda.unwrap_or(1000.0) as f32,
pattern_clusters: config.pattern_clusters.unwrap_or(50) as usize,
trajectory_capacity: config.trajectory_capacity.unwrap_or(10000) as usize,
background_interval_ms: config.background_interval_ms.unwrap_or(3600000) as u64,
quality_threshold: config.quality_threshold.unwrap_or(0.5) as f32,
enable_simd: config.enable_simd.unwrap_or(true),
};
Self {
inner: RustSonaEngine::with_config(rust_config),
}
}
#[napi]
pub fn begin_trajectory(&self, query_embedding: Vec<f64>) -> TrajectoryBuilder {
let embedding: Vec<f32> = query_embedding.iter().map(|&x| x as f32).collect();
let builder = self.inner.begin_trajectory(embedding);
TrajectoryBuilder { inner: builder }
}
#[napi]
pub fn end_trajectory(&self, mut builder: TrajectoryBuilder, quality: f64) {
let trajectory = builder.inner.build(quality as f32);
self.inner.submit_trajectory(trajectory);
}
#[napi]
pub fn apply_micro_lora(&self, input: Vec<f64>) -> Vec<f64> {
let input_f32: Vec<f32> = input.iter().map(|&x| x as f32).collect();
let mut output = vec![0.0f32; input_f32.len()];
self.inner.apply_micro_lora(&input_f32, &mut output);
output.iter().map(|&x| x as f64).collect()
}
#[napi]
pub fn apply_base_lora(&self, layer_idx: u32, input: Vec<f64>) -> Vec<f64> {
let input_f32: Vec<f32> = input.iter().map(|&x| x as f32).collect();
let mut output = vec![0.0f32; input_f32.len()];
self.inner.apply_base_lora(layer_idx as usize, &input_f32, &mut output);
output.iter().map(|&x| x as f64).collect()
}
#[napi]
pub fn tick(&self) -> Option<String> {
self.inner.tick()
}
#[napi]
pub fn force_learn(&self) -> String {
self.inner.force_learn()
}
#[napi]
pub fn flush(&self) {
self.inner.flush();
}
#[napi]
pub fn find_patterns(&self, query_embedding: Vec<f64>, k: u32) -> Vec<JsLearnedPattern> {
let query: Vec<f32> = query_embedding.iter().map(|&x| x as f32).collect();
self.inner.find_patterns(&query, k as usize)
.into_iter()
.map(JsLearnedPattern::from)
.collect()
}
#[napi]
pub fn get_stats(&self) -> String {
serde_json::to_string(&self.inner.stats()).unwrap_or_else(|e| {
format!("{{\"error\": \"{}\"}}", e)
})
}
#[napi]
pub fn save_state(&self) -> String {
self.inner.coordinator().serialize_state()
}
#[napi]
pub fn set_enabled(&mut self, enabled: bool) {
self.inner.set_enabled(enabled);
}
#[napi]
pub fn is_enabled(&self) -> bool {
self.inner.is_enabled()
}
}
#[napi]
pub struct TrajectoryBuilder {
inner: RustTrajectoryBuilder,
}
#[napi]
impl TrajectoryBuilder {
#[napi]
pub fn add_step(&mut self, activations: Vec<f64>, attention_weights: Vec<f64>, reward: f64) {
let act: Vec<f32> = activations.iter().map(|&x| x as f32).collect();
let att: Vec<f32> = attention_weights.iter().map(|&x| x as f32).collect();
self.inner.add_step(act, att, reward as f32);
}
#[napi]
pub fn set_route(&mut self, route: String) {
self.inner.set_model_route(&route);
}
#[napi]
pub fn add_context(&mut self, context_id: String) {
self.inner.add_context(&context_id);
}
}
#[napi(object)]
pub struct JsSonaConfig {
pub hidden_dim: u32,
pub embedding_dim: Option<u32>,
pub micro_lora_rank: Option<u32>,
pub base_lora_rank: Option<u32>,
pub micro_lora_lr: Option<f64>,
pub base_lora_lr: Option<f64>,
pub ewc_lambda: Option<f64>,
pub pattern_clusters: Option<u32>,
pub trajectory_capacity: Option<u32>,
pub background_interval_ms: Option<i64>,
pub quality_threshold: Option<f64>,
pub enable_simd: Option<bool>,
}
#[napi(object)]
pub struct JsLearnedPattern {
pub id: String,
pub centroid: Vec<f64>,
pub cluster_size: u32,
pub total_weight: f64,
pub avg_quality: f64,
pub created_at: String,
pub last_accessed: String,
pub access_count: u32,
pub pattern_type: String,
}
impl From<LearnedPattern> for JsLearnedPattern {
fn from(pattern: LearnedPattern) -> Self {
Self {
id: pattern.id.to_string(),
centroid: pattern.centroid.iter().map(|&x| x as f64).collect(),
cluster_size: pattern.cluster_size as u32,
total_weight: pattern.total_weight as f64,
avg_quality: pattern.avg_quality as f64,
created_at: pattern.created_at.to_string(),
last_accessed: pattern.last_accessed.to_string(),
access_count: pattern.access_count,
pattern_type: format!("{:?}", pattern.pattern_type),
}
}
}
#[napi]
pub enum JsPatternType {
General,
Reasoning,
Factual,
Creative,
CodeGen,
Conversational,
}
impl From<JsPatternType> for PatternType {
fn from(js_type: JsPatternType) -> Self {
match js_type {
JsPatternType::General => PatternType::General,
JsPatternType::Reasoning => PatternType::Reasoning,
JsPatternType::Factual => PatternType::Factual,
JsPatternType::Creative => PatternType::Creative,
JsPatternType::CodeGen => PatternType::CodeGen,
JsPatternType::Conversational => PatternType::Conversational,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_napi_engine_creation() {
let engine = SonaEngine::new(256);
assert!(engine.is_enabled());
}
#[test]
fn test_napi_trajectory() {
let engine = SonaEngine::new(64);
let mut builder = engine.begin_trajectory(vec![0.1; 64]);
builder.add_step(vec![0.5; 64], vec![0.4; 32], 0.8);
engine.end_trajectory(&builder, 0.85);
}
}