omni_ffi/lib.rs
1//! omni-ffi — zero-cost cxx bridge between the Rust orchestrator and the
2//! OmniPulse Module-I WST math engine.
3//!
4//! On the default (CPU) build the bridge calls `cpu_wst_forward()` from
5//! `cpu_wst_engine.h`: a real Radix-2 Cooley-Tukey FFT + analytic Morlet
6//! filter bank + depth-m scattering cascade. No mocks.
7//!
8//! When the `cuda` feature is enabled the bridge links `cudart` and `cufft`
9//! and (in the GPU build) dispatches to the templated
10//! `WSTEngine<HopperTag, J, Q>` defined in `wst_kernel.cuh`.
11//!
12//! ## Memory ownership
13//!
14//! [`run_wst_pipeline`](ffi::run_wst_pipeline) returns a [`WSTResult`] whose
15//! `fingerprint_ptr` owns either:
16//! * a `new float[]` heap allocation (CPU build), or
17//! * a `cudaMalloc` device allocation (CUDA build).
18//!
19//! The Rust caller MUST release it by calling
20//! [`free_wst_result`](ffi::free_wst_result) exactly once. Forgetting to do
21//! so leaks heap or VRAM. Calling it twice is undefined behavior.
22
23#[cxx::bridge]
24mod ffi {
25 /// Plain-old-data result struct shared with C++.
26 ///
27 /// Layout is locked by cxx: three `uint64_t` fields, in this order, no
28 /// padding. The matching C++ definition is generated by cxx-build into
29 /// `target/cxxbridge/omni-ffi/src/lib.rs.h` and re-used by
30 /// `cpp/wst_bridge.h` (which `#include`s the generated header).
31 #[derive(Clone, Copy, Debug)]
32 struct WSTResult {
33 /// Opaque pointer to the output scattering tensor.
34 ///
35 /// CPU build: a `float*` from `new float[]` — release with
36 /// [`free_wst_result`] (`delete[]`).
37 ///
38 /// CUDA build: a `CUdeviceptr` from `cudaMalloc` — release with
39 /// [`free_wst_result`] (`cudaFree`).
40 fingerprint_ptr: u64,
41
42 /// Number of `float32` coefficients in the output tensor
43 /// (`signal_len * batch_size`).
44 coeff_count: u64,
45
46 /// Wall-clock execution time of the scattering cascade in
47 /// microseconds. Used by the FinOps autoscaler for cost attribution.
48 exec_time_us: u64,
49 }
50
51 unsafe extern "C++" {
52 include!("wst_bridge.h");
53
54 /// Run one WST/JTFS scattering pass against a Plasma-resident input
55 /// buffer.
56 ///
57 /// # Safety
58 ///
59 /// The caller MUST uphold every one of the following invariants —
60 /// any violation is undefined behavior:
61 ///
62 /// * `input_plasma_ptr` is a valid, host-readable pointer to a
63 /// contiguous `f32` array of exactly `signal_len * batch_size`
64 /// elements. In production this is the base address of an Apache
65 /// Arrow Plasma `mmap` region. On CUDA builds the same address
66 /// must already be registered with the CUDA driver via
67 /// `cudaHostRegister` so it is reachable from device kernels via
68 /// UVA.
69 /// * The Plasma object that backs `input_plasma_ptr` must remain
70 /// live (sealed and not evicted) for the entire duration of this
71 /// call.
72 /// * All integer parameters must be strictly positive. The C++ side
73 /// throws `std::runtime_error` on non-positive values; cxx
74 /// surfaces that as `Err(cxx::Exception)`.
75 ///
76 /// On success the returned [`WSTResult`] owns a heap (CPU) or
77 /// device (CUDA) allocation that MUST be released with
78 /// [`free_wst_result`] exactly once.
79 unsafe fn run_wst_pipeline(
80 input_plasma_ptr: u64,
81 signal_len: i32,
82 batch_size: i32,
83 J: i32,
84 Q: i32,
85 depth: i32,
86 use_jtfs: bool,
87 ) -> Result<WSTResult>;
88
89 /// Release the tensor allocation backing `result`.
90 ///
91 /// # Safety
92 ///
93 /// `result` must have been returned by a successful call to
94 /// [`run_wst_pipeline`] on the same process and must not have been
95 /// passed to this function previously. Calling this with any other
96 /// value, or twice with the same value, is undefined behavior
97 /// (double-free / use-after-free of heap or VRAM).
98 ///
99 /// Calling with a `WSTResult` whose `fingerprint_ptr == 0` is a
100 /// no-op — that's the only safe sentinel.
101 unsafe fn free_wst_result(result: WSTResult);
102 }
103}
104
105pub use ffi::WSTResult;
106
107/// Convenience wrapper that runs a single-batch, depth-2, plain-WST pass.
108///
109/// # Safety
110///
111/// `plasma_id` must satisfy the same invariants as
112/// [`ffi::run_wst_pipeline::input_plasma_ptr`](ffi::run_wst_pipeline): a
113/// live, contiguous `f32[signal_len]` Apache Arrow Plasma mmap base
114/// address. See the [`ffi::run_wst_pipeline`] docs for the full contract.
115///
116/// The returned [`WSTResult`] owns an allocation that must be released
117/// exactly once via [`free_fingerprint`].
118pub unsafe fn execute_fingerprint_pass(
119 plasma_id: u64,
120 signal_len: i32,
121 j: i32,
122 q: i32,
123) -> Result<WSTResult, cxx::Exception> {
124 // SAFETY: forwarded directly to ffi::run_wst_pipeline; the caller of
125 // this function has already promised the Plasma-pointer invariants.
126 unsafe { ffi::run_wst_pipeline(plasma_id, signal_len, 1, j, q, 2, false) }
127}
128
129/// Release the tensor allocation backing a [`WSTResult`].
130///
131/// # Safety
132///
133/// `result` must have come from [`execute_fingerprint_pass`] (or
134/// [`ffi::run_wst_pipeline`] directly) on this process and must not have
135/// been passed here already. See [`ffi::free_wst_result`] for the full
136/// contract.
137pub unsafe fn free_fingerprint(result: WSTResult) {
138 // SAFETY: contract delegated to the caller above.
139 unsafe { ffi::free_wst_result(result) }
140}