use serde::{Deserialize, Serialize};
use wasm_bindgen::prelude::*;
const MAX_API_LEVEL: u32 = 5;
const MAX_API_MATCHES: usize = 10_000;
#[wasm_bindgen]
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct WasmMatch {
#[wasm_bindgen(getter_with_clone)]
pub lhs: String,
#[wasm_bindgen(getter_with_clone)]
pub rhs: String,
#[wasm_bindgen(getter_with_clone)]
pub lhs_postfix: String,
#[wasm_bindgen(getter_with_clone)]
pub rhs_postfix: String,
#[wasm_bindgen(getter_with_clone)]
pub solve_for_x: Option<String>,
#[wasm_bindgen(getter_with_clone)]
pub solve_for_x_postfix: Option<String>,
#[wasm_bindgen(getter_with_clone)]
pub canonical_key: String,
pub x_value: f64,
pub error: f64,
pub complexity: u32,
pub operator_count: usize,
pub tree_depth: usize,
pub is_exact: bool,
}
impl From<crate::search::Match> for WasmMatch {
fn from(m: crate::search::Match) -> Self {
let lhs_infix = m.lhs.expr.to_infix();
let rhs_infix = m.rhs.expr.to_infix();
let solved = crate::solver::solve_for_x_rhs_expression(&m.lhs.expr, &m.rhs.expr);
let solve_for_x = solved.as_ref().map(|e| format!("x = {}", e.to_infix()));
let solve_for_x_postfix = solved.as_ref().map(|e| e.to_postfix());
let canonical_key = crate::solver::canonical_expression_key(&m.lhs.expr)
.zip(crate::solver::canonical_expression_key(&m.rhs.expr))
.map(|(l, r)| format!("{}={}", l, r))
.unwrap_or_else(|| format!("{}={}", m.lhs.expr.to_postfix(), m.rhs.expr.to_postfix()));
Self {
lhs: lhs_infix,
rhs: rhs_infix,
lhs_postfix: m.lhs.expr.to_postfix(),
rhs_postfix: m.rhs.expr.to_postfix(),
solve_for_x,
solve_for_x_postfix,
canonical_key,
x_value: m.x_value,
error: m.error,
complexity: m.complexity,
operator_count: m.lhs.expr.operator_count() + m.rhs.expr.operator_count(),
tree_depth: m.lhs.expr.tree_depth().max(m.rhs.expr.tree_depth()),
is_exact: m.error.abs() < crate::thresholds::EXACT_MATCH_TOLERANCE,
}
}
}
#[wasm_bindgen]
impl WasmMatch {
#[allow(clippy::inherent_to_string)]
pub fn to_string(&self) -> String {
format!(
"{} = {} [error: {:.2e}] {{{}}}",
self.lhs, self.rhs, self.error, self.complexity
)
}
pub fn to_json(&self) -> Result<JsValue, JsValue> {
serde_wasm_bindgen::to_value(self).map_err(|e| JsValue::from_str(&e.to_string()))
}
}
#[wasm_bindgen]
#[derive(Clone, Debug, Serialize)]
pub struct SearchOptions {
pub level: u32,
pub max_matches: usize,
#[wasm_bindgen(getter_with_clone)]
pub preset: Option<String>,
}
#[wasm_bindgen]
impl SearchOptions {
#[wasm_bindgen(constructor)]
pub fn new() -> Self {
Self {
level: 2,
max_matches: 16,
preset: None,
}
}
pub fn level(mut self, level: u32) -> Self {
self.level = level;
self
}
pub fn max_matches(mut self, max_matches: usize) -> Self {
self.max_matches = max_matches;
self
}
pub fn preset(mut self, preset: String) -> Self {
self.preset = Some(preset);
self
}
pub fn to_json(&self) -> Result<JsValue, JsValue> {
serde_wasm_bindgen::to_value(self).map_err(|e| JsValue::from_str(&e.to_string()))
}
}
impl Default for SearchOptions {
fn default() -> Self {
Self::new()
}
}
#[derive(Clone, Debug, Deserialize)]
#[serde(default)]
struct SearchOptionsInput {
level: u32,
#[serde(rename = "maxMatches", alias = "max_matches")]
max_matches: usize,
preset: Option<String>,
#[serde(rename = "rankingMode", alias = "ranking_mode")]
ranking_mode: Option<String>,
#[serde(rename = "matchAllDigits", alias = "match_all_digits")]
match_all_digits: bool,
#[serde(rename = "usePslq", alias = "use_pslq")]
use_pslq: bool,
}
impl Default for SearchOptionsInput {
fn default() -> Self {
Self {
level: 2,
max_matches: 16,
preset: None,
ranking_mode: None,
match_all_digits: false,
use_pslq: false,
}
}
}
fn parse_search_options(options: Option<JsValue>) -> Result<SearchOptionsInput, JsValue> {
match options {
None => Ok(SearchOptionsInput::default()),
Some(value) if value.is_null() || value.is_undefined() => Ok(SearchOptionsInput::default()),
Some(value) => serde_wasm_bindgen::from_value(value)
.map_err(|e| JsValue::from_str(&format!("Invalid search options: {}", e))),
}
}
fn build_symbol_table(profile: &crate::profile::Profile) -> crate::symbol_table::SymbolTable {
crate::symbol_table::SymbolTable::from_profile(profile)
}
fn parse_ranking_mode(value: Option<&str>) -> Result<crate::pool::RankingMode, JsValue> {
match value.unwrap_or("complexity") {
"complexity" => Ok(crate::pool::RankingMode::Complexity),
"parity" => Ok(crate::pool::RankingMode::Parity),
other => Err(JsValue::from_str(&format!(
"Unknown rankingMode '{}'. Supported values: 'complexity', 'parity'.",
other
))),
}
}
fn compute_significant_digits_tolerance(target: f64) -> f64 {
if target == 0.0 {
return 1e-15;
}
let target_str = format!("{:.15}", target);
let trimmed = target_str.trim_end_matches('0');
let digits_after_decimal = trimmed
.find('.')
.map(|pos| trimmed.len().saturating_sub(pos + 1))
.unwrap_or(0);
(0.5 * 10_f64.powi(-(digits_after_decimal as i32))).max(1e-15)
}
fn build_gen_config(
max_lhs_complexity: u32,
max_rhs_complexity: u32,
profile: &crate::profile::Profile,
) -> crate::gen::GenConfig {
use crate::symbol::{NumType, Symbol};
use std::collections::HashMap;
use std::sync::Arc;
let mut constants: Vec<Symbol> = Symbol::constants().to_vec();
let mut unary_ops: Vec<Symbol> = Symbol::unary_ops().to_vec();
let binary_ops: Vec<Symbol> = Symbol::binary_ops().to_vec();
for idx in 0..profile.constants.len().min(16) {
if let Some(sym) = Symbol::from_byte(128 + idx as u8) {
constants.push(sym);
}
}
for idx in 0..profile.functions.len().min(16) {
if let Some(sym) = Symbol::from_byte(144 + idx as u8) {
unary_ops.push(sym);
}
}
let symbol_table = build_symbol_table(profile);
crate::gen::GenConfig {
max_lhs_complexity,
max_rhs_complexity,
max_length: 21,
constants,
unary_ops,
binary_ops,
rhs_constants: None,
rhs_unary_ops: None,
rhs_binary_ops: None,
symbol_max_counts: HashMap::new(),
rhs_symbol_max_counts: None,
min_num_type: NumType::Transcendental,
generate_lhs: true,
generate_rhs: true,
user_constants: profile.constants.clone(),
user_functions: profile.functions.clone(),
show_pruned_arith: false,
symbol_table: Arc::new(symbol_table),
}
}
#[wasm_bindgen]
pub fn search(target: f64, options: Option<JsValue>) -> Result<Vec<WasmMatch>, JsValue> {
let opts = parse_search_options(options)?;
if opts.use_pslq {
return Err(JsValue::from_str(
"usePslq is not supported in the WebAssembly build yet.",
));
}
if opts.level > MAX_API_LEVEL {
return Err(JsValue::from_str(&format!(
"Invalid level {}. Supported range is 0..={}.",
opts.level, MAX_API_LEVEL
)));
}
if opts.max_matches > MAX_API_MATCHES {
return Err(JsValue::from_str(&format!(
"maxMatches {} is too large. Maximum supported value is {}.",
opts.max_matches, MAX_API_MATCHES
)));
}
let internal_max_matches = opts
.max_matches
.checked_mul(2)
.ok_or_else(|| JsValue::from_str("maxMatches is too large"))?;
let ranking_mode = parse_ranking_mode(opts.ranking_mode.as_deref())?;
let (max_lhs_complexity, max_rhs_complexity) = crate::search::level_to_complexity(opts.level);
let mut profile = crate::profile::Profile::new();
if let Some(preset_name) = opts.preset.as_deref() {
let parsed = crate::presets::Preset::from_str(preset_name).ok_or_else(|| {
JsValue::from_str(&format!(
"Unknown preset '{}'. Use listPresets() for available options.",
preset_name
))
})?;
profile = profile.merge(parsed.to_profile());
}
let gen_config = build_gen_config(max_lhs_complexity, max_rhs_complexity, &profile);
let max_error = if opts.match_all_digits {
compute_significant_digits_tolerance(target)
} else {
(target.abs() * 0.01).max(1e-12)
};
let search_config = crate::search::SearchConfig {
target,
max_matches: internal_max_matches,
max_error,
stop_at_exact: false,
stop_below: None,
zero_value_threshold: 1e-4,
newton_iterations: 15,
user_constants: gen_config.user_constants.clone(),
user_functions: gen_config.user_functions.clone(),
trig_argument_scale: crate::eval::DEFAULT_TRIG_ARGUMENT_SCALE,
refine_with_newton: true,
rhs_allowed_symbols: None,
rhs_excluded_symbols: None,
show_newton: false,
show_match_checks: false,
show_pruned_arith: false,
show_pruned_range: false,
show_db_adds: false,
match_all_digits: opts.match_all_digits,
derivative_margin: crate::thresholds::DEGENERATE_DERIVATIVE,
ranking_mode,
};
let (matches, _stats) = {
#[cfg(feature = "wasm-threads")]
{
crate::search::search_parallel_with_stats_and_config(&gen_config, &search_config)
}
#[cfg(not(feature = "wasm-threads"))]
{
crate::search::search_with_stats_and_config(&gen_config, &search_config)
}
};
Ok(matches
.into_iter()
.take(opts.max_matches)
.map(WasmMatch::from)
.collect())
}
#[wasm_bindgen(js_name = listPresets)]
pub fn list_presets() -> Result<JsValue, JsValue> {
let presets: std::collections::BTreeMap<String, String> = crate::presets::Preset::all()
.iter()
.map(|p| (p.name().to_string(), p.description().to_string()))
.collect();
serde_wasm_bindgen::to_value(&presets).map_err(|e| JsValue::from_str(&e.to_string()))
}
#[wasm_bindgen(js_name = list_presets)]
pub fn list_presets_compat() -> Result<JsValue, JsValue> {
list_presets()
}
#[wasm_bindgen]
pub fn version() -> String {
env!("CARGO_PKG_VERSION").to_string()
}
#[wasm_bindgen]
pub fn init() {
console_error_panic_hook::set_once();
}
#[cfg(feature = "wasm-threads")]
pub use wasm_bindgen_rayon::init_thread_pool;