#![allow(non_camel_case_types)]
#![allow(clippy::too_many_arguments)]
#![cfg_attr(all(not(test), not(feature = "std")), no_std)]
pub mod capabilities;
pub mod cast;
pub mod curved;
pub mod each;
pub mod geospatial;
pub mod maxsim;
pub mod mesh;
pub mod probability;
pub mod reduce;
pub mod set;
pub mod sparse;
pub mod spatial;
pub mod matrix;
pub mod tensor;
pub mod types;
pub mod vector;
pub use types::{
bf16, bf16c, e2m3, e3m2, e4m3, e5m2, f16, f16c, f32c, f64c, i4x2, is_close, u1x8, u4x2, DimMut,
DimRef, FloatConvertible, FloatLike, NumberLike, StorageElement,
};
pub use spatial::{Angular, Dot, Euclidean, Roots, SpatialSimilarity, VDot};
pub use set::{BinarySimilarity, Hamming, Jaccard};
pub use probability::{JensenShannon, KullbackLeibler, ProbabilitySimilarity};
pub use each::{EachATan, EachBlend, EachCos, EachFMA, EachScale, EachSin, EachSum, Trigonometry};
pub use reduce::{ReduceMinMax, ReduceMoments, Reductions};
pub use curved::{Bilinear, Mahalanobis};
pub use mesh::{MeshAlignment, MeshAlignmentResult};
pub use geospatial::{Geospatial, Haversine, Vincenty};
pub use sparse::{SparseDot, SparseIntersect};
pub use cast::{cast, CastDtype};
pub use capabilities::cap;
pub use capabilities::{available, configure_thread, uses_dynamic_dispatch};
pub use tensor::{
AllCloseOps, Allocator, AxisIterator, AxisIteratorMut, BlendOps, CastOps, FmaOps, Global,
Matrix, MatrixSpan, MatrixView, MinMaxOps, MinMaxResult, MomentsOps, RangeStep, ScaleOps,
SliceArg, SliceRange, SliceSpec, SumOps, Tensor, TensorDims, TensorError,
TensorIterator, TensorMut, TensorRef, TensorSpan, TensorSpanDims, TensorSpanIterator,
TensorView, TensorViewDims, TensorViewIterator, TrigAtanOps, TrigCosOps, TrigSinOps,
DEFAULT_MAX_RANK, SIMD_ALIGNMENT,
};
pub use matrix::{
Angulars, Dots, Euclideans, Hammings, Jaccards, PackedMatrix, SymmetricAngulars, SymmetricDots,
SymmetricEuclideans, SymmetricHammings, SymmetricJaccards,
};
pub use vector::{
Vector, VectorIndex, VectorIterator, VectorSpan, VectorSpanIterator, VectorView,
VectorViewIterator,
};
pub use maxsim::{MaxSim, MaxSimPackedMatrix};
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn dot_smoke() {
let first = [1.0_f32, 2.0, 3.0];
let second = [4.0_f32, 5.0, 6.0];
assert!((<f32 as Dot>::dot(&first, &second).unwrap() - 32.0).abs() < 0.01);
}
#[test]
fn angular_smoke() {
let first = [1.0_f32, 0.0];
let second = [0.0_f32, 1.0];
assert!((f32::angular(&first, &second).unwrap() - 1.0).abs() < 0.01);
}
#[test]
fn euclidean_smoke() {
let first = [0.0_f32, 0.0, 0.0];
let second = [3.0_f32, 4.0, 0.0];
assert!((f32::euclidean(&first, &second).unwrap() - 5.0).abs() < 0.1);
}
#[test]
fn maxsim_smoke() {
capabilities::configure_thread();
let queries = Tensor::<f32>::try_full(&[4, 16], 1.0).unwrap();
let documents = Tensor::<f32>::try_full(&[8, 16], 1.0).unwrap();
let queries_view = queries.view();
let docs_view = documents.view();
let queries_packed = MaxSimPackedMatrix::try_pack(&queries_view).unwrap();
let docs_packed = MaxSimPackedMatrix::try_pack(&docs_view).unwrap();
assert_eq!(queries_packed.dims(), (4, 16));
assert_eq!(docs_packed.dims(), (8, 16));
let score = queries_packed.score(&docs_packed);
assert!(
score.is_finite(),
"MaxSim score must be finite, got {score}"
);
}
#[test]
fn tensor_dots_smoke() {
capabilities::configure_thread();
let queries = Tensor::<f32>::try_full(&[2, 4], 1.0).unwrap();
let targets = Tensor::<f32>::try_full(&[3, 4], 1.0).unwrap();
let packed_targets = PackedMatrix::try_pack(&targets).unwrap();
let products = queries.dots_packed(&packed_targets);
assert_eq!(products.shape(), &[2, 3]);
assert!((products.as_slice()[0] - 4.0).abs() < 0.01);
}
}
#[cfg(all(test, feature = "wasm-runtime"))]
mod wasm_runtime_tests {
use std::fs;
use std::path::Path;
use wasmtime::{
Config, Engine, Extern, ExternType, Linker, Memory, MemoryType, Module, SharedMemory, Store,
};
use wasmtime_wasi::WasiCtx;
fn resolve_wasi_module() -> Option<String> {
if let Ok(path) = std::env::var("NK_WASI_MODULE") {
if Path::new(&path).exists() {
return Some(path);
}
}
if Path::new("build-wasi/nk_test.wasm").exists() {
Some("build-wasi/nk_test.wasm".to_string())
} else if Path::new("build-wasi/test.wasm").exists() {
Some("build-wasi/test.wasm".to_string())
} else {
None
}
}
#[test]
fn wasi_with_wasmtime() -> wasmtime::Result<()> {
let Some(wasm_path) = resolve_wasi_module() else {
eprintln!("WASI build not found. Run:");
eprintln!(" export WASI_SDK_PATH=~/wasi-sdk");
eprintln!(" cmake -B build-wasi -DCMAKE_TOOLCHAIN_FILE=cmake/toolchain-wasi.cmake -DNK_BUILD_TEST=ON");
eprintln!(" cmake --build build-wasi --target nk_test");
return Ok(()); };
println!("Loading WASI module from {}", wasm_path);
let mut config = Config::new();
config.wasm_simd(true);
config.wasm_relaxed_simd(true);
config.wasm_threads(true);
config.shared_memory(true);
let engine = Engine::new(&config)?;
let mut linker = Linker::new(&engine);
let wasi = WasiCtx::builder().inherit_stdio().inherit_env().build_p1();
let mut store = Store::new(&engine, wasi);
wasmtime_wasi::p1::add_to_linker_sync(&mut linker, |s| s)?;
linker.func_wrap("env", "nk_has_v128", || -> i32 {
println!(" nk_has_v128() called from WASM -> returning 1");
1
})?;
linker.func_wrap("env", "nk_has_relaxed", || -> i32 {
println!(" nk_has_relaxed() called from WASM -> returning 1");
1
})?;
linker.func_wrap("wasi", "thread-spawn", |_start_arg: i32| -> i32 { -1 })?;
let wasm_bytes = fs::read(&wasm_path)?;
let module = Module::new(&engine, wasm_bytes)?;
for import in module.imports() {
if import.module() != "env" || import.name() != "memory" {
continue;
}
let ExternType::Memory(memory_ty) = import.ty() else {
continue;
};
let minimum = u32::try_from(memory_ty.minimum()).map_err(|_| {
wasmtime::Error::msg(format!(
"memory minimum {} does not fit in u32",
memory_ty.minimum()
))
})?;
let maximum = memory_ty
.maximum()
.ok_or_else(|| wasmtime::Error::msg("shared memory import is missing a maximum"))?;
let maximum = u32::try_from(maximum).map_err(|_| {
wasmtime::Error::msg(format!("memory maximum {maximum} does not fit in u32"))
})?;
if memory_ty.is_shared() {
let memory = SharedMemory::new(&engine, MemoryType::shared(minimum, maximum))?;
linker.define(&store, "env", "memory", Extern::from(memory))?;
} else {
let memory = Memory::new(&mut store, MemoryType::new(minimum, Some(maximum)))?;
linker.define(&store, "env", "memory", Extern::from(memory))?;
}
}
println!("Instantiating WASM module...");
let instance = linker.instantiate(&mut store, &module)?;
let start = instance.get_typed_func::<(), ()>(&mut store, "_start")?;
println!("Running WASM tests...");
match start.call(&mut store, ()) {
Ok(()) => {}
Err(e) => {
if let Some(exit) = e.downcast_ref::<wasmtime_wasi::I32Exit>() {
assert_eq!(exit.0, 0, "WASI tests failed with exit code {}", exit.0);
} else {
return Err(e);
}
}
}
println!("WASM tests completed successfully");
Ok(())
}
#[test]
fn capability_imports() -> wasmtime::Result<()> {
println!("Testing capability import mechanism...");
let engine = Engine::default();
let mut linker = Linker::<()>::new(&engine);
linker.func_wrap("env", "nk_has_v128", || -> i32 { 1 })?;
linker.func_wrap("env", "nk_has_relaxed", || -> i32 { 0 })?;
println!(" ✓ Capability imports defined successfully");
Ok(())
}
}