1use core::pin::Pin;
2
3pub mod analytics;
4pub mod ring_buffer;
5
6#[repr(C)]
7#[derive(Clone, Copy, Debug, Default)]
8pub struct GreekOut {
9 pub delta_x: f64,
10 pub gamma_x: f64,
11}
12
13unsafe extern "C" {
14 #[link_name = "kernel_sigmoid"]
15 fn ffi_kernel_sigmoid(x: f64) -> f64;
16 #[link_name = "kernel_logit"]
17 fn ffi_kernel_logit(p: f64) -> f64;
18 #[link_name = "kernel_sigmoid_batch"]
19 fn ffi_kernel_sigmoid_batch(x: *const f64, out_p: *mut f64, n: usize);
20 #[link_name = "kernel_logit_batch"]
21 fn ffi_kernel_logit_batch(p: *const f64, out_x: *mut f64, n: usize);
22 #[link_name = "kernel_greeks_from_logit"]
23 fn ffi_kernel_greeks_from_logit(x: f64, delta_x: *mut f64, gamma_x: *mut f64);
24 #[link_name = "kernel_greeks_batch"]
25 fn ffi_kernel_greeks_batch(x: *const f64, out: *mut GreekOut, n: usize);
26 #[link_name = "calculate_quotes_logit"]
27 fn ffi_calculate_quotes_logit(
28 x_t: *const f64,
29 q_t: *const f64,
30 sigma_b: *const f64,
31 gamma: *const f64,
32 tau: *const f64,
33 k: *const f64,
34 bid_p: *mut f64,
35 ask_p: *mut f64,
36 n: usize,
37 );
38}
39
40#[inline]
41pub fn sigmoid(x: f64) -> f64 {
42 unsafe { ffi_kernel_sigmoid(x) }
43}
44
45#[inline]
46pub fn logit(p: f64) -> f64 {
47 unsafe { ffi_kernel_logit(p) }
48}
49
50#[inline]
51pub fn greeks_from_logit(x: f64) -> GreekOut {
52 let mut out = GreekOut::default();
53 unsafe {
54 ffi_kernel_greeks_from_logit(x, &mut out.delta_x, &mut out.gamma_x);
55 }
56 out
57}
58
59pub fn sigmoid_batch(x: &[f64], out_p: &mut [f64]) {
60 sigmoid_batch_pinned(Pin::new(x), Pin::new(out_p));
61}
62
63pub fn logit_batch(p: &[f64], out_x: &mut [f64]) {
64 logit_batch_pinned(Pin::new(p), Pin::new(out_x));
65}
66
67pub fn greeks_batch(x: &[f64], out: &mut [GreekOut]) {
68 greeks_batch_pinned(Pin::new(x), Pin::new(out));
69}
70
71#[allow(clippy::too_many_arguments)]
72pub fn calculate_quotes_logit(
73 x_t: &[f64],
74 q_t: &[f64],
75 sigma_b: &[f64],
76 gamma: &[f64],
77 tau: &[f64],
78 k: &[f64],
79 bid_p: &mut [f64],
80 ask_p: &mut [f64],
81) {
82 calculate_quotes_logit_pinned(
83 Pin::new(x_t),
84 Pin::new(q_t),
85 Pin::new(sigma_b),
86 Pin::new(gamma),
87 Pin::new(tau),
88 Pin::new(k),
89 Pin::new(bid_p),
90 Pin::new(ask_p),
91 );
92}
93
94pub fn sigmoid_batch_pinned(x: Pin<&[f64]>, mut out_p: Pin<&mut [f64]>) {
95 let x_ref = x.get_ref();
96 let out_ref = out_p.as_mut().get_mut();
97 assert_eq!(
98 x_ref.len(),
99 out_ref.len(),
100 "sigmoid_batch: input and output lengths must match"
101 );
102
103 unsafe {
104 ffi_kernel_sigmoid_batch(x_ref.as_ptr(), out_ref.as_mut_ptr(), x_ref.len());
105 }
106}
107
108pub fn logit_batch_pinned(p: Pin<&[f64]>, mut out_x: Pin<&mut [f64]>) {
109 let p_ref = p.get_ref();
110 let out_ref = out_x.as_mut().get_mut();
111 assert_eq!(
112 p_ref.len(),
113 out_ref.len(),
114 "logit_batch: input and output lengths must match"
115 );
116
117 unsafe {
118 ffi_kernel_logit_batch(p_ref.as_ptr(), out_ref.as_mut_ptr(), p_ref.len());
119 }
120}
121
122pub fn greeks_batch_pinned(x: Pin<&[f64]>, mut out: Pin<&mut [GreekOut]>) {
123 let x_ref = x.get_ref();
124 let out_ref = out.as_mut().get_mut();
125 assert_eq!(
126 x_ref.len(),
127 out_ref.len(),
128 "greeks_batch: input and output lengths must match"
129 );
130
131 unsafe {
132 ffi_kernel_greeks_batch(x_ref.as_ptr(), out_ref.as_mut_ptr(), x_ref.len());
133 }
134}
135
136#[allow(clippy::too_many_arguments)]
137pub fn calculate_quotes_logit_pinned(
138 x_t: Pin<&[f64]>,
139 q_t: Pin<&[f64]>,
140 sigma_b: Pin<&[f64]>,
141 gamma: Pin<&[f64]>,
142 tau: Pin<&[f64]>,
143 k: Pin<&[f64]>,
144 mut bid_p: Pin<&mut [f64]>,
145 mut ask_p: Pin<&mut [f64]>,
146) {
147 let x_ref = x_t.get_ref();
148 let q_ref = q_t.get_ref();
149 let sigma_ref = sigma_b.get_ref();
150 let gamma_ref = gamma.get_ref();
151 let tau_ref = tau.get_ref();
152 let k_ref = k.get_ref();
153 let bid_ref = bid_p.as_mut().get_mut();
154 let ask_ref = ask_p.as_mut().get_mut();
155
156 let n = x_ref.len();
157 assert_eq!(
158 q_ref.len(),
159 n,
160 "calculate_quotes_logit: q_t length mismatch"
161 );
162 assert_eq!(
163 sigma_ref.len(),
164 n,
165 "calculate_quotes_logit: sigma_b length mismatch"
166 );
167 assert_eq!(
168 gamma_ref.len(),
169 n,
170 "calculate_quotes_logit: gamma length mismatch"
171 );
172 assert_eq!(
173 tau_ref.len(),
174 n,
175 "calculate_quotes_logit: tau length mismatch"
176 );
177 assert_eq!(k_ref.len(), n, "calculate_quotes_logit: k length mismatch");
178 assert_eq!(
179 bid_ref.len(),
180 n,
181 "calculate_quotes_logit: bid_p length mismatch"
182 );
183 assert_eq!(
184 ask_ref.len(),
185 n,
186 "calculate_quotes_logit: ask_p length mismatch"
187 );
188
189 unsafe {
190 ffi_calculate_quotes_logit(
191 x_ref.as_ptr(),
192 q_ref.as_ptr(),
193 sigma_ref.as_ptr(),
194 gamma_ref.as_ptr(),
195 tau_ref.as_ptr(),
196 k_ref.as_ptr(),
197 bid_ref.as_mut_ptr(),
198 ask_ref.as_mut_ptr(),
199 n,
200 );
201 }
202}
203
204#[cfg(test)]
205mod tests {
206 use super::*;
207 use core::mem::{align_of, size_of};
208
209 unsafe extern "C" {
210 #[link_name = "pm_internal_calculate_quotes_logit_portable"]
211 fn ffi_internal_calculate_quotes_logit_portable(
212 x_t: *const f64,
213 q_t: *const f64,
214 sigma_b: *const f64,
215 gamma: *const f64,
216 tau: *const f64,
217 k: *const f64,
218 bid_p: *mut f64,
219 ask_p: *mut f64,
220 n: usize,
221 );
222 }
223
224 fn assert_close(actual: f64, expected: f64, tol: f64) {
225 let diff = (actual - expected).abs();
226 assert!(
227 diff <= tol,
228 "expected {expected:.15}, got {actual:.15}, diff {diff:.15} > {tol:.15}"
229 );
230 }
231
232 fn has_avx512f() -> bool {
233 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
234 {
235 std::arch::is_x86_feature_detected!("avx512f")
236 }
237
238 #[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))]
239 {
240 false
241 }
242 }
243
244 fn run_quote_portable(
245 x_t: &[f64],
246 q_t: &[f64],
247 sigma_b: &[f64],
248 gamma: &[f64],
249 tau: &[f64],
250 k: &[f64],
251 ) -> (Vec<f64>, Vec<f64>) {
252 let n = x_t.len();
253 let mut bid_p = vec![0.0; n];
254 let mut ask_p = vec![0.0; n];
255
256 unsafe {
257 ffi_internal_calculate_quotes_logit_portable(
258 x_t.as_ptr(),
259 q_t.as_ptr(),
260 sigma_b.as_ptr(),
261 gamma.as_ptr(),
262 tau.as_ptr(),
263 k.as_ptr(),
264 bid_p.as_mut_ptr(),
265 ask_p.as_mut_ptr(),
266 n,
267 );
268 }
269
270 (bid_p, ask_p)
271 }
272
273 #[test]
274 fn greek_layout_is_compact() {
275 assert_eq!(size_of::<GreekOut>(), 16);
276 assert_eq!(align_of::<GreekOut>(), 8);
277 }
278
279 #[test]
280 fn logit_sigmoid_roundtrip_is_stable() {
281 for x in [-20.0, -10.0, -2.0, 0.0, 2.0, 10.0, 20.0] {
282 let p = sigmoid(x);
283 let roundtrip = logit(p);
284 assert_close(roundtrip, x, 1e-7);
285 }
286 }
287
288 #[test]
289 fn sigmoid_batch_matches_scalar() {
290 let x = [-20.0, -3.0, -0.5, 0.0, 0.5, 4.0, 12.0, 20.0, -15.0];
291 let mut out = vec![0.0; x.len()];
292 sigmoid_batch(&x, &mut out);
293
294 for (actual, input) in out.iter().zip(x) {
295 assert_close(*actual, sigmoid(input), 1e-15);
296 }
297 }
298
299 #[test]
300 fn greeks_batch_matches_scalar_reference() {
301 let x = [-12.0, -1.5, 0.0, 1.5, 12.0, 20.0, -20.0];
302 let mut out = vec![GreekOut::default(); x.len()];
303 greeks_batch(&x, &mut out);
304
305 for (actual, input) in out.iter().zip(x) {
306 let expected = greeks_from_logit(input);
307 assert_close(actual.delta_x, expected.delta_x, 1e-15);
308 assert_close(actual.gamma_x, expected.gamma_x, 1e-15);
309 }
310 }
311
312 #[test]
313 fn calculate_quotes_dispatch_matches_portable_reference() {
314 if !has_avx512f() {
315 return;
316 }
317
318 for len in [3usize, 8, 17] {
319 let x_t: Vec<_> = (0..len)
320 .map(|i| -2.5 + (i as f64) * 0.4 + if i % 4 == 0 { 6.0 } else { 0.0 })
321 .collect();
322 let q_t: Vec<_> = (0..len).map(|i| (i as f64) - 3.0).collect();
323 let sigma_b: Vec<_> = (0..len).map(|i| 0.1 + (i as f64) * 0.02).collect();
324 let gamma: Vec<_> = (0..len)
325 .map(|i| {
326 if i % 5 == 0 {
327 -0.05
328 } else {
329 0.03 + (i as f64) * 0.01
330 }
331 })
332 .collect();
333 let tau: Vec<_> = (0..len)
334 .map(|i| {
335 if i % 6 == 0 {
336 -0.1
337 } else {
338 0.2 + (i as f64) * 0.03
339 }
340 })
341 .collect();
342 let k: Vec<_> = (0..len)
343 .map(|i| {
344 if i % 7 == 0 {
345 0.0
346 } else {
347 1.0 + (i as f64) * 0.1
348 }
349 })
350 .collect();
351
352 let mut bid_p = vec![0.0; len];
353 let mut ask_p = vec![0.0; len];
354 calculate_quotes_logit(
355 &x_t, &q_t, &sigma_b, &gamma, &tau, &k, &mut bid_p, &mut ask_p,
356 );
357
358 let (expected_bid, expected_ask) =
359 run_quote_portable(&x_t, &q_t, &sigma_b, &gamma, &tau, &k);
360
361 for i in 0..len {
362 assert_close(bid_p[i], expected_bid[i], 1e-12);
363 assert_close(ask_p[i], expected_ask[i], 1e-12);
364 }
365 }
366 }
367}