#![deny(clippy::all)]
mod runtime;
use napi::bindgen_prelude::*;
use napi_derive::napi;
use std::sync::{Arc, Mutex};
use stateset_nsr::nsr::{
GroundedInput, NSRConfig, NSRMachine, NSRMachineBuilder,
Program, Primitive, SemanticValue, TrainingExample,
};
use stateset_nsr::nsr::machine::presets;
use crate::runtime::get_runtime;
#[napi(js_name = "GroundedInput")]
pub struct JsGroundedInput {
inner: GroundedInput,
}
#[napi]
impl JsGroundedInput {
#[napi(factory)]
pub fn text(s: String) -> Self {
Self {
inner: GroundedInput::Text(s),
}
}
#[napi(factory)]
pub fn number(n: f64) -> Self {
Self {
inner: GroundedInput::Number(n),
}
}
#[napi(factory)]
pub fn image(data: Vec<f64>, width: u32, height: u32, channels: Option<u32>) -> Self {
let data_f32: Vec<f32> = data.into_iter().map(|x| x as f32).collect();
Self {
inner: GroundedInput::ImageWithDims {
data: data_f32,
width: width as usize,
height: height as usize,
channels: channels.unwrap_or(3) as usize,
},
}
}
#[napi(factory)]
pub fn embedding(data: Vec<f64>) -> Self {
let data_f32: Vec<f32> = data.into_iter().map(|x| x as f32).collect();
Self {
inner: GroundedInput::Embedding(data_f32),
}
}
#[napi(factory)]
pub fn nil() -> Self {
Self {
inner: GroundedInput::Nil,
}
}
#[napi]
pub fn is_text(&self) -> bool {
matches!(self.inner, GroundedInput::Text(_))
}
#[napi]
pub fn is_number(&self) -> bool {
matches!(self.inner, GroundedInput::Number(_))
}
#[napi]
pub fn is_nil(&self) -> bool {
matches!(self.inner, GroundedInput::Nil)
}
#[napi]
pub fn as_text(&self) -> Option<String> {
if let GroundedInput::Text(s) = &self.inner {
Some(s.clone())
} else {
None
}
}
#[napi]
pub fn as_number(&self) -> Option<f64> {
if let GroundedInput::Number(n) = &self.inner {
Some(*n)
} else {
None
}
}
#[napi]
pub fn to_string(&self) -> String {
format!("{}", self.inner)
}
}
#[napi(js_name = "SemanticValue")]
pub struct JsSemanticValue {
inner: SemanticValue,
}
#[napi]
impl JsSemanticValue {
#[napi(factory)]
pub fn integer(n: i64) -> Self {
Self {
inner: SemanticValue::Integer(n),
}
}
#[napi(factory)]
pub fn float(n: f64) -> Self {
Self {
inner: SemanticValue::Float(n),
}
}
#[napi(factory)]
pub fn boolean(b: bool) -> Self {
Self {
inner: SemanticValue::Boolean(b),
}
}
#[napi(factory)]
pub fn string(s: String) -> Self {
Self {
inner: SemanticValue::String(s),
}
}
#[napi(factory)]
pub fn symbol(s: String) -> Self {
Self {
inner: SemanticValue::Symbol(s),
}
}
#[napi(factory)]
pub fn actions(actions: Vec<String>) -> Self {
Self {
inner: SemanticValue::ActionSequence(actions),
}
}
#[napi(factory)]
pub fn list(items: Vec<&JsSemanticValue>) -> Self {
Self {
inner: SemanticValue::List(items.into_iter().map(|v| v.inner.clone()).collect()),
}
}
#[napi(factory)]
pub fn null() -> Self {
Self {
inner: SemanticValue::Null,
}
}
#[napi]
pub fn as_integer(&self) -> Option<i64> {
self.inner.as_integer()
}
#[napi]
pub fn as_float(&self) -> Option<f64> {
self.inner.as_float()
}
#[napi]
pub fn as_string(&self) -> Option<String> {
self.inner.as_string().map(|s| s.to_string())
}
#[napi]
pub fn as_actions(&self) -> Option<Vec<String>> {
self.inner.as_actions().map(|a| a.to_vec())
}
#[napi]
pub fn is_error(&self) -> bool {
self.inner.is_error()
}
#[napi]
pub fn is_null(&self) -> bool {
matches!(self.inner, SemanticValue::Null)
}
#[napi]
pub fn to_string(&self) -> String {
format!("{}", self.inner)
}
}
#[napi(js_name = "Primitive")]
pub struct JsPrimitive {
inner: Primitive,
}
#[napi]
impl JsPrimitive {
#[napi(factory)]
pub fn add() -> Self {
Self { inner: Primitive::Add }
}
#[napi(factory)]
pub fn sub() -> Self {
Self { inner: Primitive::Sub }
}
#[napi(factory)]
pub fn mul() -> Self {
Self { inner: Primitive::Mul }
}
#[napi(factory)]
pub fn div() -> Self {
Self { inner: Primitive::Div }
}
#[napi(factory)]
pub fn eq() -> Self {
Self { inner: Primitive::Eq }
}
#[napi(factory)]
pub fn lt() -> Self {
Self { inner: Primitive::Lt }
}
#[napi(factory)]
pub fn gt() -> Self {
Self { inner: Primitive::Gt }
}
#[napi(factory)]
pub fn and() -> Self {
Self { inner: Primitive::And }
}
#[napi(factory)]
pub fn or() -> Self {
Self { inner: Primitive::Or }
}
#[napi(factory)]
pub fn not() -> Self {
Self { inner: Primitive::Not }
}
#[napi(factory)]
pub fn cons() -> Self {
Self { inner: Primitive::Cons }
}
#[napi(factory)]
pub fn car() -> Self {
Self { inner: Primitive::Car }
}
#[napi(factory)]
pub fn cdr() -> Self {
Self { inner: Primitive::Cdr }
}
#[napi(factory)]
pub fn identity() -> Self {
Self { inner: Primitive::Identity }
}
#[napi]
pub fn arity(&self) -> u32 {
self.inner.arity() as u32
}
#[napi]
pub fn name(&self) -> String {
format!("{:?}", self.inner)
}
}
#[napi(js_name = "Program")]
pub struct JsProgram {
inner: Program,
}
#[napi]
impl JsProgram {
#[napi(factory)]
pub fn constant(value: &JsSemanticValue) -> Self {
Self {
inner: Program::constant(value.inner.clone()),
}
}
#[napi(factory)]
pub fn var(index: u32) -> Self {
Self {
inner: Program::var(index as usize),
}
}
#[napi(factory)]
pub fn child(index: u32) -> Self {
Self {
inner: Program::child(index as usize),
}
}
#[napi(factory)]
pub fn primitive(prim: &JsPrimitive, args: Vec<&JsProgram>) -> Self {
Self {
inner: Program::primitive(
prim.inner.clone(),
args.into_iter().map(|p| p.inner.clone()).collect(),
),
}
}
#[napi(factory)]
pub fn lambda(arity: u32, body: &JsProgram) -> Self {
Self {
inner: Program::lambda(arity as usize, body.inner.clone()),
}
}
#[napi(factory)]
pub fn apply(func: &JsProgram, args: Vec<&JsProgram>) -> Self {
Self {
inner: Program::apply(
func.inner.clone(),
args.into_iter().map(|p| p.inner.clone()).collect(),
),
}
}
#[napi(factory)]
pub fn if_then_else(cond: &JsProgram, then_branch: &JsProgram, else_branch: &JsProgram) -> Self {
Self {
inner: Program::if_then_else(
cond.inner.clone(),
then_branch.inner.clone(),
else_branch.inner.clone(),
),
}
}
#[napi]
pub fn depth(&self) -> u32 {
self.inner.depth() as u32
}
#[napi]
pub fn size(&self) -> u32 {
self.inner.size() as u32
}
#[napi]
pub fn is_constant(&self) -> bool {
self.inner.is_constant()
}
}
#[napi(object)]
#[derive(Clone)]
pub struct JsNSRConfig {
#[napi(js_name = "embeddingDim")]
pub embedding_dim: u32,
#[napi(js_name = "hiddenSize")]
pub hidden_size: u32,
#[napi(js_name = "maxSeqLen")]
pub max_seq_len: u32,
#[napi(js_name = "beamWidth")]
pub beam_width: u32,
#[napi(js_name = "enableSynthesis")]
pub enable_synthesis: bool,
#[napi(js_name = "maxProgramDepth")]
pub max_program_depth: u32,
}
impl Default for JsNSRConfig {
fn default() -> Self {
let config = NSRConfig::default();
Self {
embedding_dim: config.embedding_dim as u32,
hidden_size: config.hidden_size as u32,
max_seq_len: config.max_seq_len as u32,
beam_width: config.beam_width as u32,
enable_synthesis: config.enable_synthesis,
max_program_depth: config.max_program_depth as u32,
}
}
}
impl From<JsNSRConfig> for NSRConfig {
fn from(js: JsNSRConfig) -> Self {
NSRConfig {
embedding_dim: js.embedding_dim as usize,
hidden_size: js.hidden_size as usize,
max_seq_len: js.max_seq_len as usize,
beam_width: js.beam_width as usize,
enable_synthesis: js.enable_synthesis,
max_program_depth: js.max_program_depth as usize,
..Default::default()
}
}
}
#[napi(js_name = "TrainingExample")]
pub struct JsTrainingExample {
inner: TrainingExample,
}
#[napi]
impl JsTrainingExample {
#[napi(constructor)]
pub fn new(
inputs: Vec<&JsGroundedInput>,
output: &JsSemanticValue,
difficulty: Option<f64>,
) -> Self {
Self {
inner: TrainingExample::new(
inputs.into_iter().map(|i| i.inner.clone()).collect(),
output.inner.clone(),
)
.with_difficulty(difficulty.unwrap_or(0.0) as f32),
}
}
#[napi(factory)]
pub fn from_text(text: String, output: &JsSemanticValue) -> Self {
Self {
inner: TrainingExample::new(
vec![GroundedInput::Text(text)],
output.inner.clone(),
),
}
}
#[napi(factory)]
pub fn from_tokens(tokens: Vec<String>, output: &JsSemanticValue) -> Self {
Self {
inner: TrainingExample::new(
tokens.into_iter().map(GroundedInput::Text).collect(),
output.inner.clone(),
),
}
}
#[napi(getter)]
pub fn difficulty(&self) -> f64 {
self.inner.difficulty as f64
}
#[napi(getter)]
pub fn input_count(&self) -> u32 {
self.inner.inputs.len() as u32
}
#[napi]
pub fn to_string(&self) -> String {
format!(
"TrainingExample(inputs={}, difficulty={:.2})",
self.inner.inputs.len(),
self.inner.difficulty
)
}
}
#[napi(object)]
pub struct JsTrainingStats {
#[napi(js_name = "totalExamples")]
pub total_examples: u32,
#[napi(js_name = "successfulAbductions")]
pub successful_abductions: u32,
#[napi(js_name = "trainingTimeMs")]
pub training_time_ms: u32,
}
#[napi(js_name = "InferenceResult")]
pub struct JsInferenceResult {
output_str: Option<String>,
confidence_val: f64,
symbols_vec: Vec<u32>,
node_count_val: u32,
log_prob_val: f64,
}
#[napi]
impl JsInferenceResult {
#[napi]
pub fn output(&self) -> Option<String> {
self.output_str.clone()
}
#[napi]
pub fn confidence(&self) -> f64 {
self.confidence_val
}
#[napi]
pub fn symbols(&self) -> Vec<u32> {
self.symbols_vec.clone()
}
#[napi]
pub fn node_count(&self) -> u32 {
self.node_count_val
}
#[napi]
pub fn log_probability(&self) -> f64 {
self.log_prob_val
}
#[napi]
pub fn to_string(&self) -> String {
format!(
"InferenceResult(output={:?}, confidence={:.4}, symbols={})",
self.output_str,
self.confidence_val,
self.symbols_vec.len()
)
}
}
#[napi(object)]
pub struct JsEvaluationResult {
pub accuracy: f64,
pub correct: u32,
pub total: u32,
}
#[napi(object)]
pub struct JsNSRStats {
#[napi(js_name = "trainingExamples")]
pub training_examples: u32,
#[napi(js_name = "successfulInferences")]
pub successful_inferences: u32,
#[napi(js_name = "programsLearned")]
pub programs_learned: u32,
#[napi(js_name = "vocabularySize")]
pub vocabulary_size: u32,
}
#[napi(js_name = "NSRMachineBuilder")]
pub struct JsNSRMachineBuilder {
inner: Option<NSRMachineBuilder>,
}
#[napi]
impl JsNSRMachineBuilder {
#[napi(constructor)]
pub fn new() -> Self {
Self {
inner: Some(NSRMachineBuilder::new()),
}
}
#[napi]
pub fn embedding_dim(&mut self, dim: u32) -> &Self {
if let Some(builder) = self.inner.take() {
self.inner = Some(builder.embedding_dim(dim as usize));
}
self
}
#[napi]
pub fn hidden_size(&mut self, size: u32) -> &Self {
if let Some(builder) = self.inner.take() {
self.inner = Some(builder.hidden_size(size as usize));
}
self
}
#[napi]
pub fn max_seq_len(&mut self, len: u32) -> &Self {
if let Some(builder) = self.inner.take() {
self.inner = Some(builder.max_seq_len(len as usize));
}
self
}
#[napi]
pub fn beam_width(&mut self, width: u32) -> &Self {
if let Some(builder) = self.inner.take() {
self.inner = Some(builder.beam_width(width as usize));
}
self
}
#[napi]
pub fn add_symbol(&mut self, name: String) -> &Self {
if let Some(builder) = self.inner.take() {
self.inner = Some(builder.add_symbol(name));
}
self
}
#[napi]
pub fn enable_synthesis(&mut self, enable: bool) -> &Self {
if let Some(builder) = self.inner.take() {
self.inner = Some(builder.enable_synthesis(enable));
}
self
}
#[napi]
pub fn with_explainability(&mut self) -> &Self {
if let Some(builder) = self.inner.take() {
self.inner = Some(builder.with_explainability());
}
self
}
#[napi]
pub fn build(&mut self) -> Result<JsNSRMachine> {
let builder = self.inner.take().ok_or_else(|| {
Error::from_reason("Builder already consumed")
})?;
Ok(JsNSRMachine {
inner: Arc::new(Mutex::new(builder.build())),
})
}
}
#[napi(js_name = "NSRMachine")]
pub struct JsNSRMachine {
inner: Arc<Mutex<NSRMachine>>,
}
#[napi]
impl JsNSRMachine {
#[napi(constructor)]
pub fn new() -> Self {
Self {
inner: Arc::new(Mutex::new(NSRMachine::default())),
}
}
#[napi(factory)]
pub fn with_config(config: JsNSRConfig) -> Self {
Self {
inner: Arc::new(Mutex::new(NSRMachine::with_config(config.into()))),
}
}
#[napi]
pub fn infer(&self, inputs: Vec<&JsGroundedInput>) -> Result<JsInferenceResult> {
let rust_inputs: Vec<GroundedInput> = inputs.into_iter().map(|i| i.inner.clone()).collect();
let machine = self.inner.clone();
let runtime = get_runtime();
runtime.block_on(async {
let mut m = machine.lock().map_err(|e| Error::from_reason(e.to_string()))?;
let result = m.infer(&rust_inputs)
.await
.map_err(|e| Error::from_reason(e.to_string()))?;
Ok(JsInferenceResult {
output_str: result.output().map(|v| format!("{}", v)),
confidence_val: result.confidence(),
symbols_vec: result.symbols().into_iter().map(|s| s as u32).collect(),
node_count_val: result.gss.nodes.len() as u32,
log_prob_val: result.gss.log_probability(),
})
})
}
#[napi]
pub fn train(&self, examples: Vec<&JsTrainingExample>) -> Result<JsTrainingStats> {
let rust_examples: Vec<TrainingExample> =
examples.into_iter().map(|e| e.inner.clone()).collect();
let machine = self.inner.clone();
let total = rust_examples.len() as u32;
let runtime = get_runtime();
let start = std::time::Instant::now();
runtime.block_on(async {
let mut m = machine.lock().map_err(|e| Error::from_reason(e.to_string()))?;
let stats = m.train(&rust_examples)
.await
.map_err(|e| Error::from_reason(e.to_string()))?;
Ok(JsTrainingStats {
total_examples: total,
successful_abductions: stats.successful_abductions as u32,
training_time_ms: start.elapsed().as_millis() as u32,
})
})
}
#[napi]
pub fn evaluate(&self, examples: Vec<&JsTrainingExample>) -> Result<JsEvaluationResult> {
let rust_examples: Vec<TrainingExample> =
examples.into_iter().map(|e| e.inner.clone()).collect();
let machine = self.inner.clone();
let total = rust_examples.len() as u32;
let runtime = get_runtime();
runtime.block_on(async {
let mut m = machine.lock().map_err(|e| Error::from_reason(e.to_string()))?;
let mut correct = 0u32;
for example in &rust_examples {
if let Ok(result) = m.infer(&example.inputs).await {
if let Some(output) = result.output() {
if output == &example.output {
correct += 1;
}
}
}
}
Ok(JsEvaluationResult {
accuracy: if total > 0 { correct as f64 / total as f64 } else { 0.0 },
correct,
total,
})
})
}
#[napi]
pub fn add_symbol(&self, name: String) -> Result<u32> {
let mut m = self.inner.lock().map_err(|e| Error::from_reason(e.to_string()))?;
Ok(m.add_symbol(&name) as u32)
}
#[napi]
pub fn add_symbols(&self, names: Vec<String>) -> Result<Vec<u32>> {
let mut m = self.inner.lock().map_err(|e| Error::from_reason(e.to_string()))?;
Ok(names.iter().map(|n| m.add_symbol(n) as u32).collect())
}
#[napi]
pub fn get_symbol_name(&self, symbol_id: u32) -> Result<Option<String>> {
let m = self.inner.lock().map_err(|e| Error::from_reason(e.to_string()))?;
Ok(m.vocabulary().get_name(symbol_id as usize).map(|s| s.to_string()))
}
#[napi]
pub fn get_symbol_id(&self, name: String) -> Result<Option<u32>> {
let m = self.inner.lock().map_err(|e| Error::from_reason(e.to_string()))?;
Ok(m.vocabulary().get_by_name(&name).map(|id| id as u32))
}
#[napi]
pub fn get_all_symbols(&self) -> Result<Vec<String>> {
let m = self.inner.lock().map_err(|e| Error::from_reason(e.to_string()))?;
let vocab = m.vocabulary();
let mut names = Vec::new();
for i in 0..vocab.len() {
if let Some(name) = vocab.get_name(i) {
names.push(name.to_string());
}
}
Ok(names)
}
#[napi(getter)]
pub fn vocabulary_size(&self) -> Result<u32> {
let m = self.inner.lock().map_err(|e| Error::from_reason(e.to_string()))?;
Ok(m.vocabulary().len() as u32)
}
#[napi(getter)]
pub fn statistics(&self) -> Result<JsNSRStats> {
let m = self.inner.lock().map_err(|e| Error::from_reason(e.to_string()))?;
let stats = m.statistics();
Ok(JsNSRStats {
training_examples: stats.training_examples as u32,
successful_inferences: stats.successful_inferences as u32,
programs_learned: stats.programs_learned as u32,
vocabulary_size: stats.vocabulary_size as u32,
})
}
#[napi(getter)]
pub fn config(&self) -> Result<JsNSRConfig> {
let m = self.inner.lock().map_err(|e| Error::from_reason(e.to_string()))?;
let cfg = m.config();
Ok(JsNSRConfig {
embedding_dim: cfg.embedding_dim as u32,
hidden_size: cfg.hidden_size as u32,
max_seq_len: cfg.max_seq_len as u32,
beam_width: cfg.beam_width as u32,
enable_synthesis: cfg.enable_synthesis,
max_program_depth: cfg.max_program_depth as u32,
})
}
#[napi]
pub fn set_program(&self, symbol_id: u32, program: &JsProgram) -> Result<()> {
let mut m = self.inner.lock().map_err(|e| Error::from_reason(e.to_string()))?;
m.set_symbol_program(symbol_id as usize, program.inner.clone());
Ok(())
}
#[napi]
pub fn set_constant_program(&self, symbol_id: u32, value: &JsSemanticValue) -> Result<()> {
let mut m = self.inner.lock().map_err(|e| Error::from_reason(e.to_string()))?;
m.set_symbol_program(symbol_id as usize, Program::constant(value.inner.clone()));
Ok(())
}
#[napi]
pub fn setup_classification_programs(&self) -> Result<()> {
let mut m = self.inner.lock().map_err(|e| Error::from_reason(e.to_string()))?;
let vocab_len = m.vocabulary().len();
for symbol_id in 0..vocab_len {
if let Some(name) = m.vocabulary().get_name(symbol_id) {
let program = Program::constant(SemanticValue::Symbol(name.to_string()));
m.set_symbol_program(symbol_id, program);
}
}
Ok(())
}
#[napi]
pub fn get_program(&self, symbol_id: u32) -> Result<Option<JsProgram>> {
let m = self.inner.lock().map_err(|e| Error::from_reason(e.to_string()))?;
Ok(m.get_symbol_program(symbol_id as usize).map(|p| JsProgram {
inner: p.clone(),
}))
}
}
#[napi]
pub fn scan_machine() -> JsNSRMachine {
JsNSRMachine {
inner: Arc::new(Mutex::new(presets::scan_machine())),
}
}
#[napi]
pub fn pcfg_machine() -> JsNSRMachine {
JsNSRMachine {
inner: Arc::new(Mutex::new(presets::pcfg_machine())),
}
}
#[napi]
pub fn hint_machine() -> JsNSRMachine {
JsNSRMachine {
inner: Arc::new(Mutex::new(presets::hint_machine())),
}
}
#[napi]
pub fn cogs_machine() -> JsNSRMachine {
JsNSRMachine {
inner: Arc::new(Mutex::new(presets::cogs_machine())),
}
}
#[napi]
pub fn version() -> String {
env!("CARGO_PKG_VERSION").to_string()
}