Skip to main content

polymarket_kernel/
lib.rs

1use core::pin::Pin;
2
3pub mod analytics;
4pub mod ring_buffer;
5
6#[repr(C)]
7#[derive(Clone, Copy, Debug, Default)]
8pub struct GreekOut {
9    pub delta_x: f64,
10    pub gamma_x: f64,
11}
12
13unsafe extern "C" {
14    #[link_name = "kernel_sigmoid"]
15    fn ffi_kernel_sigmoid(x: f64) -> f64;
16    #[link_name = "kernel_logit"]
17    fn ffi_kernel_logit(p: f64) -> f64;
18    #[link_name = "kernel_sigmoid_batch"]
19    fn ffi_kernel_sigmoid_batch(x: *const f64, out_p: *mut f64, n: usize);
20    #[link_name = "kernel_logit_batch"]
21    fn ffi_kernel_logit_batch(p: *const f64, out_x: *mut f64, n: usize);
22    #[link_name = "kernel_greeks_from_logit"]
23    fn ffi_kernel_greeks_from_logit(x: f64, delta_x: *mut f64, gamma_x: *mut f64);
24    #[link_name = "kernel_greeks_batch"]
25    fn ffi_kernel_greeks_batch(x: *const f64, out: *mut GreekOut, n: usize);
26    #[link_name = "calculate_quotes_logit"]
27    fn ffi_calculate_quotes_logit(
28        x_t: *const f64,
29        q_t: *const f64,
30        sigma_b: *const f64,
31        gamma: *const f64,
32        tau: *const f64,
33        k: *const f64,
34        bid_p: *mut f64,
35        ask_p: *mut f64,
36        n: usize,
37    );
38}
39
40#[inline]
41pub fn sigmoid(x: f64) -> f64 {
42    unsafe { ffi_kernel_sigmoid(x) }
43}
44
45#[inline]
46pub fn logit(p: f64) -> f64 {
47    unsafe { ffi_kernel_logit(p) }
48}
49
50#[inline]
51pub fn greeks_from_logit(x: f64) -> GreekOut {
52    let mut out = GreekOut::default();
53    unsafe {
54        ffi_kernel_greeks_from_logit(x, &mut out.delta_x, &mut out.gamma_x);
55    }
56    out
57}
58
59pub fn sigmoid_batch(x: &[f64], out_p: &mut [f64]) {
60    sigmoid_batch_pinned(Pin::new(x), Pin::new(out_p));
61}
62
63pub fn logit_batch(p: &[f64], out_x: &mut [f64]) {
64    logit_batch_pinned(Pin::new(p), Pin::new(out_x));
65}
66
67pub fn greeks_batch(x: &[f64], out: &mut [GreekOut]) {
68    greeks_batch_pinned(Pin::new(x), Pin::new(out));
69}
70
71#[allow(clippy::too_many_arguments)]
72pub fn calculate_quotes_logit(
73    x_t: &[f64],
74    q_t: &[f64],
75    sigma_b: &[f64],
76    gamma: &[f64],
77    tau: &[f64],
78    k: &[f64],
79    bid_p: &mut [f64],
80    ask_p: &mut [f64],
81) {
82    calculate_quotes_logit_pinned(
83        Pin::new(x_t),
84        Pin::new(q_t),
85        Pin::new(sigma_b),
86        Pin::new(gamma),
87        Pin::new(tau),
88        Pin::new(k),
89        Pin::new(bid_p),
90        Pin::new(ask_p),
91    );
92}
93
94pub fn sigmoid_batch_pinned(x: Pin<&[f64]>, mut out_p: Pin<&mut [f64]>) {
95    let x_ref = x.get_ref();
96    let out_ref = out_p.as_mut().get_mut();
97    assert_eq!(
98        x_ref.len(),
99        out_ref.len(),
100        "sigmoid_batch: input and output lengths must match"
101    );
102
103    unsafe {
104        ffi_kernel_sigmoid_batch(x_ref.as_ptr(), out_ref.as_mut_ptr(), x_ref.len());
105    }
106}
107
108pub fn logit_batch_pinned(p: Pin<&[f64]>, mut out_x: Pin<&mut [f64]>) {
109    let p_ref = p.get_ref();
110    let out_ref = out_x.as_mut().get_mut();
111    assert_eq!(
112        p_ref.len(),
113        out_ref.len(),
114        "logit_batch: input and output lengths must match"
115    );
116
117    unsafe {
118        ffi_kernel_logit_batch(p_ref.as_ptr(), out_ref.as_mut_ptr(), p_ref.len());
119    }
120}
121
122pub fn greeks_batch_pinned(x: Pin<&[f64]>, mut out: Pin<&mut [GreekOut]>) {
123    let x_ref = x.get_ref();
124    let out_ref = out.as_mut().get_mut();
125    assert_eq!(
126        x_ref.len(),
127        out_ref.len(),
128        "greeks_batch: input and output lengths must match"
129    );
130
131    unsafe {
132        ffi_kernel_greeks_batch(x_ref.as_ptr(), out_ref.as_mut_ptr(), x_ref.len());
133    }
134}
135
136#[allow(clippy::too_many_arguments)]
137pub fn calculate_quotes_logit_pinned(
138    x_t: Pin<&[f64]>,
139    q_t: Pin<&[f64]>,
140    sigma_b: Pin<&[f64]>,
141    gamma: Pin<&[f64]>,
142    tau: Pin<&[f64]>,
143    k: Pin<&[f64]>,
144    mut bid_p: Pin<&mut [f64]>,
145    mut ask_p: Pin<&mut [f64]>,
146) {
147    let x_ref = x_t.get_ref();
148    let q_ref = q_t.get_ref();
149    let sigma_ref = sigma_b.get_ref();
150    let gamma_ref = gamma.get_ref();
151    let tau_ref = tau.get_ref();
152    let k_ref = k.get_ref();
153    let bid_ref = bid_p.as_mut().get_mut();
154    let ask_ref = ask_p.as_mut().get_mut();
155
156    let n = x_ref.len();
157    assert_eq!(
158        q_ref.len(),
159        n,
160        "calculate_quotes_logit: q_t length mismatch"
161    );
162    assert_eq!(
163        sigma_ref.len(),
164        n,
165        "calculate_quotes_logit: sigma_b length mismatch"
166    );
167    assert_eq!(
168        gamma_ref.len(),
169        n,
170        "calculate_quotes_logit: gamma length mismatch"
171    );
172    assert_eq!(
173        tau_ref.len(),
174        n,
175        "calculate_quotes_logit: tau length mismatch"
176    );
177    assert_eq!(k_ref.len(), n, "calculate_quotes_logit: k length mismatch");
178    assert_eq!(
179        bid_ref.len(),
180        n,
181        "calculate_quotes_logit: bid_p length mismatch"
182    );
183    assert_eq!(
184        ask_ref.len(),
185        n,
186        "calculate_quotes_logit: ask_p length mismatch"
187    );
188
189    unsafe {
190        ffi_calculate_quotes_logit(
191            x_ref.as_ptr(),
192            q_ref.as_ptr(),
193            sigma_ref.as_ptr(),
194            gamma_ref.as_ptr(),
195            tau_ref.as_ptr(),
196            k_ref.as_ptr(),
197            bid_ref.as_mut_ptr(),
198            ask_ref.as_mut_ptr(),
199            n,
200        );
201    }
202}
203
204#[cfg(test)]
205mod tests {
206    use super::*;
207    use core::mem::{align_of, size_of};
208
209    unsafe extern "C" {
210        #[link_name = "pm_internal_calculate_quotes_logit_portable"]
211        fn ffi_internal_calculate_quotes_logit_portable(
212            x_t: *const f64,
213            q_t: *const f64,
214            sigma_b: *const f64,
215            gamma: *const f64,
216            tau: *const f64,
217            k: *const f64,
218            bid_p: *mut f64,
219            ask_p: *mut f64,
220            n: usize,
221        );
222    }
223
224    fn assert_close(actual: f64, expected: f64, tol: f64) {
225        let diff = (actual - expected).abs();
226        assert!(
227            diff <= tol,
228            "expected {expected:.15}, got {actual:.15}, diff {diff:.15} > {tol:.15}"
229        );
230    }
231
232    fn has_avx512f() -> bool {
233        #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
234        {
235            std::arch::is_x86_feature_detected!("avx512f")
236        }
237
238        #[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))]
239        {
240            false
241        }
242    }
243
244    fn run_quote_portable(
245        x_t: &[f64],
246        q_t: &[f64],
247        sigma_b: &[f64],
248        gamma: &[f64],
249        tau: &[f64],
250        k: &[f64],
251    ) -> (Vec<f64>, Vec<f64>) {
252        let n = x_t.len();
253        let mut bid_p = vec![0.0; n];
254        let mut ask_p = vec![0.0; n];
255
256        unsafe {
257            ffi_internal_calculate_quotes_logit_portable(
258                x_t.as_ptr(),
259                q_t.as_ptr(),
260                sigma_b.as_ptr(),
261                gamma.as_ptr(),
262                tau.as_ptr(),
263                k.as_ptr(),
264                bid_p.as_mut_ptr(),
265                ask_p.as_mut_ptr(),
266                n,
267            );
268        }
269
270        (bid_p, ask_p)
271    }
272
273    #[test]
274    fn greek_layout_is_compact() {
275        assert_eq!(size_of::<GreekOut>(), 16);
276        assert_eq!(align_of::<GreekOut>(), 8);
277    }
278
279    #[test]
280    fn logit_sigmoid_roundtrip_is_stable() {
281        for x in [-20.0, -10.0, -2.0, 0.0, 2.0, 10.0, 20.0] {
282            let p = sigmoid(x);
283            let roundtrip = logit(p);
284            assert_close(roundtrip, x, 1e-7);
285        }
286    }
287
288    #[test]
289    fn sigmoid_batch_matches_scalar() {
290        let x = [-20.0, -3.0, -0.5, 0.0, 0.5, 4.0, 12.0, 20.0, -15.0];
291        let mut out = vec![0.0; x.len()];
292        sigmoid_batch(&x, &mut out);
293
294        for (actual, input) in out.iter().zip(x) {
295            assert_close(*actual, sigmoid(input), 1e-15);
296        }
297    }
298
299    #[test]
300    fn greeks_batch_matches_scalar_reference() {
301        let x = [-12.0, -1.5, 0.0, 1.5, 12.0, 20.0, -20.0];
302        let mut out = vec![GreekOut::default(); x.len()];
303        greeks_batch(&x, &mut out);
304
305        for (actual, input) in out.iter().zip(x) {
306            let expected = greeks_from_logit(input);
307            assert_close(actual.delta_x, expected.delta_x, 1e-15);
308            assert_close(actual.gamma_x, expected.gamma_x, 1e-15);
309        }
310    }
311
312    #[test]
313    fn calculate_quotes_dispatch_matches_portable_reference() {
314        if !has_avx512f() {
315            return;
316        }
317
318        for len in [3usize, 8, 17] {
319            let x_t: Vec<_> = (0..len)
320                .map(|i| -2.5 + (i as f64) * 0.4 + if i % 4 == 0 { 6.0 } else { 0.0 })
321                .collect();
322            let q_t: Vec<_> = (0..len).map(|i| (i as f64) - 3.0).collect();
323            let sigma_b: Vec<_> = (0..len).map(|i| 0.1 + (i as f64) * 0.02).collect();
324            let gamma: Vec<_> = (0..len)
325                .map(|i| {
326                    if i % 5 == 0 {
327                        -0.05
328                    } else {
329                        0.03 + (i as f64) * 0.01
330                    }
331                })
332                .collect();
333            let tau: Vec<_> = (0..len)
334                .map(|i| {
335                    if i % 6 == 0 {
336                        -0.1
337                    } else {
338                        0.2 + (i as f64) * 0.03
339                    }
340                })
341                .collect();
342            let k: Vec<_> = (0..len)
343                .map(|i| {
344                    if i % 7 == 0 {
345                        0.0
346                    } else {
347                        1.0 + (i as f64) * 0.1
348                    }
349                })
350                .collect();
351
352            let mut bid_p = vec![0.0; len];
353            let mut ask_p = vec![0.0; len];
354            calculate_quotes_logit(
355                &x_t, &q_t, &sigma_b, &gamma, &tau, &k, &mut bid_p, &mut ask_p,
356            );
357
358            let (expected_bid, expected_ask) =
359                run_quote_portable(&x_t, &q_t, &sigma_b, &gamma, &tau, &k);
360
361            for i in 0..len {
362                assert_close(bid_p[i], expected_bid[i], 1e-12);
363                assert_close(ask_p[i], expected_ask[i], 1e-12);
364            }
365        }
366    }
367}