Skip to main content

omni_ffi/
lib.rs

1//! omni-ffi — zero-cost cxx bridge between the Rust orchestrator and the
2//! OmniPulse Module-I WST math engine.
3//!
4//! On the default (CPU) build the bridge calls `cpu_wst_forward()` from
5//! `cpu_wst_engine.h`: a real Radix-2 Cooley-Tukey FFT + analytic Morlet
6//! filter bank + depth-m scattering cascade. No mocks.
7//!
8//! When the `cuda` feature is enabled the bridge links `cudart` and `cufft`
9//! and (in the GPU build) dispatches to the templated
10//! `WSTEngine<HopperTag, J, Q>` defined in `wst_kernel.cuh`.
11//!
12//! ## Memory ownership
13//!
14//! [`run_wst_pipeline`](ffi::run_wst_pipeline) returns a [`WSTResult`] whose
15//! `fingerprint_ptr` owns either:
16//!   * a `new float[]` heap allocation (CPU build), or
17//!   * a `cudaMalloc` device allocation (CUDA build).
18//!
19//! The Rust caller MUST release it by calling
20//! [`free_wst_result`](ffi::free_wst_result) exactly once. Forgetting to do
21//! so leaks heap or VRAM. Calling it twice is undefined behavior.
22
23#[cxx::bridge]
24mod ffi {
25    /// Plain-old-data result struct shared with C++.
26    ///
27    /// Layout is locked by cxx: three `uint64_t` fields, in this order, no
28    /// padding. The matching C++ definition is generated by cxx-build into
29    /// `target/cxxbridge/omni-ffi/src/lib.rs.h` and re-used by
30    /// `cpp/wst_bridge.h` (which `#include`s the generated header).
31    #[derive(Clone, Copy, Debug)]
32    struct WSTResult {
33        /// Opaque pointer to the output scattering tensor.
34        ///
35        /// CPU build: a `float*` from `new float[]` — release with
36        /// [`free_wst_result`] (`delete[]`).
37        ///
38        /// CUDA build: a `CUdeviceptr` from `cudaMalloc` — release with
39        /// [`free_wst_result`] (`cudaFree`).
40        fingerprint_ptr: u64,
41
42        /// Number of `float32` coefficients in the output tensor
43        /// (`signal_len * batch_size`).
44        coeff_count: u64,
45
46        /// Wall-clock execution time of the scattering cascade in
47        /// microseconds. Used by the FinOps autoscaler for cost attribution.
48        exec_time_us: u64,
49    }
50
51    unsafe extern "C++" {
52        include!("wst_bridge.h");
53
54        /// Run one WST/JTFS scattering pass against a Plasma-resident input
55        /// buffer.
56        ///
57        /// # Safety
58        ///
59        /// The caller MUST uphold every one of the following invariants —
60        /// any violation is undefined behavior:
61        ///
62        /// * `input_plasma_ptr` is a valid, host-readable pointer to a
63        ///   contiguous `f32` array of exactly `signal_len * batch_size`
64        ///   elements. In production this is the base address of an Apache
65        ///   Arrow Plasma `mmap` region. On CUDA builds the same address
66        ///   must already be registered with the CUDA driver via
67        ///   `cudaHostRegister` so it is reachable from device kernels via
68        ///   UVA.
69        /// * The Plasma object that backs `input_plasma_ptr` must remain
70        ///   live (sealed and not evicted) for the entire duration of this
71        ///   call.
72        /// * All integer parameters must be strictly positive. The C++ side
73        ///   throws `std::runtime_error` on non-positive values; cxx
74        ///   surfaces that as `Err(cxx::Exception)`.
75        ///
76        /// On success the returned [`WSTResult`] owns a heap (CPU) or
77        /// device (CUDA) allocation that MUST be released with
78        /// [`free_wst_result`] exactly once.
79        unsafe fn run_wst_pipeline(
80            input_plasma_ptr: u64,
81            signal_len: i32,
82            batch_size: i32,
83            J: i32,
84            Q: i32,
85            depth: i32,
86            use_jtfs: bool,
87        ) -> Result<WSTResult>;
88
89        /// Release the tensor allocation backing `result`.
90        ///
91        /// # Safety
92        ///
93        /// `result` must have been returned by a successful call to
94        /// [`run_wst_pipeline`] on the same process and must not have been
95        /// passed to this function previously. Calling this with any other
96        /// value, or twice with the same value, is undefined behavior
97        /// (double-free / use-after-free of heap or VRAM).
98        ///
99        /// Calling with a `WSTResult` whose `fingerprint_ptr == 0` is a
100        /// no-op — that's the only safe sentinel.
101        unsafe fn free_wst_result(result: WSTResult);
102    }
103}
104
105pub use ffi::WSTResult;
106
107/// Convenience wrapper that runs a single-batch, depth-2, plain-WST pass.
108///
109/// # Safety
110///
111/// `plasma_id` must satisfy the same invariants as
112/// [`ffi::run_wst_pipeline::input_plasma_ptr`](ffi::run_wst_pipeline): a
113/// live, contiguous `f32[signal_len]` Apache Arrow Plasma mmap base
114/// address. See the [`ffi::run_wst_pipeline`] docs for the full contract.
115///
116/// The returned [`WSTResult`] owns an allocation that must be released
117/// exactly once via [`free_fingerprint`].
118pub unsafe fn execute_fingerprint_pass(
119    plasma_id: u64,
120    signal_len: i32,
121    j: i32,
122    q: i32,
123) -> Result<WSTResult, cxx::Exception> {
124    // SAFETY: forwarded directly to ffi::run_wst_pipeline; the caller of
125    // this function has already promised the Plasma-pointer invariants.
126    unsafe { ffi::run_wst_pipeline(plasma_id, signal_len, 1, j, q, 2, false) }
127}
128
129/// Release the tensor allocation backing a [`WSTResult`].
130///
131/// # Safety
132///
133/// `result` must have come from [`execute_fingerprint_pass`] (or
134/// [`ffi::run_wst_pipeline`] directly) on this process and must not have
135/// been passed here already. See [`ffi::free_wst_result`] for the full
136/// contract.
137pub unsafe fn free_fingerprint(result: WSTResult) {
138    // SAFETY: contract delegated to the caller above.
139    unsafe { ffi::free_wst_result(result) }
140}