use wasm_bindgen::prelude::*;
use ruvector_math::{
optimal_transport::{SlicedWasserstein, SinkhornSolver, GromovWasserstein},
information_geometry::{FisherInformation, NaturalGradient},
spherical::SphericalSpace,
product_manifold::ProductManifold,
};
#[wasm_bindgen(start)]
pub fn start() {
#[cfg(feature = "console_error_panic_hook")]
console_error_panic_hook::set_once();
}
#[wasm_bindgen]
pub struct WasmSlicedWasserstein {
inner: SlicedWasserstein,
}
#[wasm_bindgen]
impl WasmSlicedWasserstein {
#[wasm_bindgen(constructor)]
pub fn new(num_projections: usize) -> Self {
Self {
inner: SlicedWasserstein::new(num_projections),
}
}
#[wasm_bindgen(js_name = withPower)]
pub fn with_power(self, p: f64) -> Self {
Self {
inner: self.inner.with_power(p),
}
}
#[wasm_bindgen(js_name = withSeed)]
pub fn with_seed(self, seed: u64) -> Self {
Self {
inner: self.inner.with_seed(seed),
}
}
#[wasm_bindgen]
pub fn distance(&self, source: &[f64], target: &[f64], dim: usize) -> f64 {
use ruvector_math::optimal_transport::OptimalTransport;
let source_points = to_points(source, dim);
let target_points = to_points(target, dim);
self.inner.distance(&source_points, &target_points)
}
#[wasm_bindgen(js_name = weightedDistance)]
pub fn weighted_distance(
&self,
source: &[f64],
source_weights: &[f64],
target: &[f64],
target_weights: &[f64],
dim: usize,
) -> f64 {
use ruvector_math::optimal_transport::OptimalTransport;
let source_points = to_points(source, dim);
let target_points = to_points(target, dim);
self.inner.weighted_distance(
&source_points,
source_weights,
&target_points,
target_weights,
)
}
}
#[wasm_bindgen]
pub struct WasmSinkhorn {
inner: SinkhornSolver,
}
#[wasm_bindgen]
impl WasmSinkhorn {
#[wasm_bindgen(constructor)]
pub fn new(regularization: f64, max_iterations: usize) -> Self {
Self {
inner: SinkhornSolver::new(regularization, max_iterations),
}
}
#[wasm_bindgen]
pub fn distance(&self, source: &[f64], target: &[f64], dim: usize) -> Result<f64, JsError> {
let source_points = to_points(source, dim);
let target_points = to_points(target, dim);
self.inner
.distance(&source_points, &target_points)
.map_err(|e| JsError::new(&e.to_string()))
}
#[wasm_bindgen(js_name = solveTransport)]
pub fn solve_transport(
&self,
cost_matrix: &[f64],
source_weights: &[f64],
target_weights: &[f64],
n: usize,
m: usize,
) -> Result<TransportResult, JsError> {
let cost = to_matrix(cost_matrix, n, m);
let result = self
.inner
.solve(&cost, source_weights, target_weights)
.map_err(|e| JsError::new(&e.to_string()))?;
Ok(TransportResult {
plan: result.plan.into_iter().flatten().collect(),
cost: result.cost,
iterations: result.iterations,
converged: result.converged,
})
}
}
#[wasm_bindgen]
pub struct TransportResult {
plan: Vec<f64>,
cost: f64,
iterations: usize,
converged: bool,
}
#[wasm_bindgen]
impl TransportResult {
#[wasm_bindgen(getter)]
pub fn plan(&self) -> Vec<f64> {
self.plan.clone()
}
#[wasm_bindgen(getter)]
pub fn cost(&self) -> f64 {
self.cost
}
#[wasm_bindgen(getter)]
pub fn iterations(&self) -> usize {
self.iterations
}
#[wasm_bindgen(getter)]
pub fn converged(&self) -> bool {
self.converged
}
}
#[wasm_bindgen]
pub struct WasmGromovWasserstein {
inner: GromovWasserstein,
}
#[wasm_bindgen]
impl WasmGromovWasserstein {
#[wasm_bindgen(constructor)]
pub fn new(regularization: f64) -> Self {
Self {
inner: GromovWasserstein::new(regularization),
}
}
#[wasm_bindgen]
pub fn distance(&self, source: &[f64], target: &[f64], dim: usize) -> Result<f64, JsError> {
let source_points = to_points(source, dim);
let target_points = to_points(target, dim);
self.inner
.distance(&source_points, &target_points)
.map_err(|e| JsError::new(&e.to_string()))
}
}
#[wasm_bindgen]
pub struct WasmFisherInformation {
inner: FisherInformation,
}
#[wasm_bindgen]
impl WasmFisherInformation {
#[wasm_bindgen(constructor)]
pub fn new() -> Self {
Self {
inner: FisherInformation::new(),
}
}
#[wasm_bindgen(js_name = withDamping)]
pub fn with_damping(self, damping: f64) -> Self {
Self {
inner: self.inner.with_damping(damping),
}
}
#[wasm_bindgen(js_name = diagonalFim)]
pub fn diagonal_fim(&self, gradients: &[f64], _num_samples: usize, dim: usize) -> Result<Vec<f64>, JsError> {
let grads = to_points(gradients, dim);
self.inner
.diagonal_fim(&grads)
.map_err(|e| JsError::new(&e.to_string()))
}
#[wasm_bindgen(js_name = naturalGradient)]
pub fn natural_gradient(
&self,
fim_diag: &[f64],
gradient: &[f64],
damping: f64,
) -> Vec<f64> {
gradient
.iter()
.zip(fim_diag.iter())
.map(|(&g, &f)| g / (f + damping))
.collect()
}
}
#[wasm_bindgen]
pub struct WasmNaturalGradient {
inner: NaturalGradient,
}
#[wasm_bindgen]
impl WasmNaturalGradient {
#[wasm_bindgen(constructor)]
pub fn new(learning_rate: f64) -> Self {
Self {
inner: NaturalGradient::new(learning_rate),
}
}
#[wasm_bindgen(js_name = withDamping)]
pub fn with_damping(self, damping: f64) -> Self {
Self {
inner: self.inner.with_damping(damping),
}
}
#[wasm_bindgen(js_name = withDiagonal)]
pub fn with_diagonal(self, use_diagonal: bool) -> Self {
Self {
inner: self.inner.with_diagonal(use_diagonal),
}
}
#[wasm_bindgen]
pub fn step(
&mut self,
gradient: &[f64],
gradient_samples: Option<Vec<f64>>,
dim: usize,
) -> Result<Vec<f64>, JsError> {
let samples = gradient_samples.map(|s| to_points(&s, dim));
self.inner
.step(gradient, samples.as_deref())
.map_err(|e| JsError::new(&e.to_string()))
}
#[wasm_bindgen]
pub fn reset(&mut self) {
self.inner.reset();
}
}
#[wasm_bindgen]
pub struct WasmSphericalSpace {
inner: SphericalSpace,
}
#[wasm_bindgen]
impl WasmSphericalSpace {
#[wasm_bindgen(constructor)]
pub fn new(ambient_dim: usize) -> Self {
Self {
inner: SphericalSpace::new(ambient_dim),
}
}
#[wasm_bindgen(getter, js_name = ambientDim)]
pub fn ambient_dim(&self) -> usize {
self.inner.ambient_dim()
}
#[wasm_bindgen]
pub fn project(&self, point: &[f64]) -> Result<Vec<f64>, JsError> {
self.inner
.project(point)
.map_err(|e| JsError::new(&e.to_string()))
}
#[wasm_bindgen]
pub fn distance(&self, x: &[f64], y: &[f64]) -> Result<f64, JsError> {
self.inner
.distance(x, y)
.map_err(|e| JsError::new(&e.to_string()))
}
#[wasm_bindgen(js_name = expMap)]
pub fn exp_map(&self, x: &[f64], v: &[f64]) -> Result<Vec<f64>, JsError> {
self.inner
.exp_map(x, v)
.map_err(|e| JsError::new(&e.to_string()))
}
#[wasm_bindgen(js_name = logMap)]
pub fn log_map(&self, x: &[f64], y: &[f64]) -> Result<Vec<f64>, JsError> {
self.inner
.log_map(x, y)
.map_err(|e| JsError::new(&e.to_string()))
}
#[wasm_bindgen]
pub fn geodesic(&self, x: &[f64], y: &[f64], t: f64) -> Result<Vec<f64>, JsError> {
self.inner
.geodesic(x, y, t)
.map_err(|e| JsError::new(&e.to_string()))
}
#[wasm_bindgen(js_name = frechetMean)]
pub fn frechet_mean(&self, points: &[f64], dim: usize) -> Result<Vec<f64>, JsError> {
let pts = to_points(points, dim);
self.inner
.frechet_mean(&pts, None)
.map_err(|e| JsError::new(&e.to_string()))
}
}
#[wasm_bindgen]
pub struct WasmProductManifold {
inner: ProductManifold,
}
#[wasm_bindgen]
impl WasmProductManifold {
#[wasm_bindgen(constructor)]
pub fn new(euclidean_dim: usize, hyperbolic_dim: usize, spherical_dim: usize) -> Self {
Self {
inner: ProductManifold::new(euclidean_dim, hyperbolic_dim, spherical_dim),
}
}
#[wasm_bindgen(getter)]
pub fn dim(&self) -> usize {
self.inner.dim()
}
#[wasm_bindgen]
pub fn project(&self, point: &[f64]) -> Result<Vec<f64>, JsError> {
self.inner
.project(point)
.map_err(|e| JsError::new(&e.to_string()))
}
#[wasm_bindgen]
pub fn distance(&self, x: &[f64], y: &[f64]) -> Result<f64, JsError> {
self.inner
.distance(x, y)
.map_err(|e| JsError::new(&e.to_string()))
}
#[wasm_bindgen(js_name = expMap)]
pub fn exp_map(&self, x: &[f64], v: &[f64]) -> Result<Vec<f64>, JsError> {
self.inner
.exp_map(x, v)
.map_err(|e| JsError::new(&e.to_string()))
}
#[wasm_bindgen(js_name = logMap)]
pub fn log_map(&self, x: &[f64], y: &[f64]) -> Result<Vec<f64>, JsError> {
self.inner
.log_map(x, y)
.map_err(|e| JsError::new(&e.to_string()))
}
#[wasm_bindgen]
pub fn geodesic(&self, x: &[f64], y: &[f64], t: f64) -> Result<Vec<f64>, JsError> {
self.inner
.geodesic(x, y, t)
.map_err(|e| JsError::new(&e.to_string()))
}
#[wasm_bindgen(js_name = frechetMean)]
pub fn frechet_mean(&self, points: &[f64], _num_points: usize) -> Result<Vec<f64>, JsError> {
let dim = self.inner.dim();
let pts = to_points(points, dim);
self.inner
.frechet_mean(&pts, None)
.map_err(|e| JsError::new(&e.to_string()))
}
#[wasm_bindgen]
pub fn knn(&self, query: &[f64], points: &[f64], k: usize) -> Result<Vec<u32>, JsError> {
let dim = self.inner.dim();
let pts = to_points(points, dim);
let neighbors = self
.inner
.knn(query, &pts, k)
.map_err(|e| JsError::new(&e.to_string()))?;
Ok(neighbors.into_iter().map(|(idx, _)| idx as u32).collect())
}
#[wasm_bindgen(js_name = pairwiseDistances)]
pub fn pairwise_distances(&self, points: &[f64]) -> Result<Vec<f64>, JsError> {
let dim = self.inner.dim();
let pts = to_points(points, dim);
let dists = self
.inner
.pairwise_distances(&pts)
.map_err(|e| JsError::new(&e.to_string()))?;
Ok(dists.into_iter().flatten().collect())
}
}
fn to_points(flat: &[f64], dim: usize) -> Vec<Vec<f64>> {
flat.chunks(dim).map(|c| c.to_vec()).collect()
}
fn to_matrix(flat: &[f64], rows: usize, cols: usize) -> Vec<Vec<f64>> {
flat.chunks(cols).take(rows).map(|c| c.to_vec()).collect()
}
#[wasm_bindgen(typescript_custom_section)]
const TS_TYPES: &'static str = r#"
/** Sliced Wasserstein distance for comparing point cloud distributions */
export interface SlicedWassersteinOptions {
numProjections?: number;
power?: number;
seed?: number;
}
/** Sinkhorn optimal transport options */
export interface SinkhornOptions {
regularization?: number;
maxIterations?: number;
threshold?: number;
}
/** Product manifold configuration */
export interface ProductManifoldConfig {
euclideanDim: number;
hyperbolicDim: number;
sphericalDim: number;
hyperbolicCurvature?: number;
sphericalCurvature?: number;
}
"#;