Skip to main content

polymarket_kernel/
lib.rs

1use core::pin::Pin;
2
3#[repr(C)]
4#[derive(Clone, Copy, Debug, Default)]
5pub struct GreekOut {
6    pub delta_x: f64,
7    pub gamma_x: f64,
8}
9
10unsafe extern "C" {
11    #[link_name = "kernel_sigmoid"]
12    fn ffi_kernel_sigmoid(x: f64) -> f64;
13    #[link_name = "kernel_logit"]
14    fn ffi_kernel_logit(p: f64) -> f64;
15    #[link_name = "kernel_sigmoid_batch"]
16    fn ffi_kernel_sigmoid_batch(x: *const f64, out_p: *mut f64, n: usize);
17    #[link_name = "kernel_logit_batch"]
18    fn ffi_kernel_logit_batch(p: *const f64, out_x: *mut f64, n: usize);
19    #[link_name = "kernel_greeks_from_logit"]
20    fn ffi_kernel_greeks_from_logit(x: f64, delta_x: *mut f64, gamma_x: *mut f64);
21    #[link_name = "kernel_greeks_batch"]
22    fn ffi_kernel_greeks_batch(x: *const f64, out: *mut GreekOut, n: usize);
23    #[link_name = "calculate_quotes_logit"]
24    fn ffi_calculate_quotes_logit(
25        x_t: *const f64,
26        q_t: *const f64,
27        sigma_b: *const f64,
28        gamma: *const f64,
29        tau: *const f64,
30        k: *const f64,
31        bid_p: *mut f64,
32        ask_p: *mut f64,
33        n: usize,
34    );
35}
36
37#[inline]
38pub fn sigmoid(x: f64) -> f64 {
39    unsafe { ffi_kernel_sigmoid(x) }
40}
41
42#[inline]
43pub fn logit(p: f64) -> f64 {
44    unsafe { ffi_kernel_logit(p) }
45}
46
47#[inline]
48pub fn greeks_from_logit(x: f64) -> GreekOut {
49    let mut out = GreekOut::default();
50    unsafe {
51        ffi_kernel_greeks_from_logit(x, &mut out.delta_x, &mut out.gamma_x);
52    }
53    out
54}
55
56pub fn sigmoid_batch(x: &[f64], out_p: &mut [f64]) {
57    sigmoid_batch_pinned(Pin::new(x), Pin::new(out_p));
58}
59
60pub fn logit_batch(p: &[f64], out_x: &mut [f64]) {
61    logit_batch_pinned(Pin::new(p), Pin::new(out_x));
62}
63
64pub fn greeks_batch(x: &[f64], out: &mut [GreekOut]) {
65    greeks_batch_pinned(Pin::new(x), Pin::new(out));
66}
67
68#[allow(clippy::too_many_arguments)]
69pub fn calculate_quotes_logit(
70    x_t: &[f64],
71    q_t: &[f64],
72    sigma_b: &[f64],
73    gamma: &[f64],
74    tau: &[f64],
75    k: &[f64],
76    bid_p: &mut [f64],
77    ask_p: &mut [f64],
78) {
79    calculate_quotes_logit_pinned(
80        Pin::new(x_t),
81        Pin::new(q_t),
82        Pin::new(sigma_b),
83        Pin::new(gamma),
84        Pin::new(tau),
85        Pin::new(k),
86        Pin::new(bid_p),
87        Pin::new(ask_p),
88    );
89}
90
91pub fn sigmoid_batch_pinned(x: Pin<&[f64]>, mut out_p: Pin<&mut [f64]>) {
92    let x_ref = x.get_ref();
93    let out_ref = out_p.as_mut().get_mut();
94    assert_eq!(
95        x_ref.len(),
96        out_ref.len(),
97        "sigmoid_batch: input and output lengths must match"
98    );
99
100    unsafe {
101        ffi_kernel_sigmoid_batch(x_ref.as_ptr(), out_ref.as_mut_ptr(), x_ref.len());
102    }
103}
104
105pub fn logit_batch_pinned(p: Pin<&[f64]>, mut out_x: Pin<&mut [f64]>) {
106    let p_ref = p.get_ref();
107    let out_ref = out_x.as_mut().get_mut();
108    assert_eq!(
109        p_ref.len(),
110        out_ref.len(),
111        "logit_batch: input and output lengths must match"
112    );
113
114    unsafe {
115        ffi_kernel_logit_batch(p_ref.as_ptr(), out_ref.as_mut_ptr(), p_ref.len());
116    }
117}
118
119pub fn greeks_batch_pinned(x: Pin<&[f64]>, mut out: Pin<&mut [GreekOut]>) {
120    let x_ref = x.get_ref();
121    let out_ref = out.as_mut().get_mut();
122    assert_eq!(
123        x_ref.len(),
124        out_ref.len(),
125        "greeks_batch: input and output lengths must match"
126    );
127
128    unsafe {
129        ffi_kernel_greeks_batch(x_ref.as_ptr(), out_ref.as_mut_ptr(), x_ref.len());
130    }
131}
132
133#[allow(clippy::too_many_arguments)]
134pub fn calculate_quotes_logit_pinned(
135    x_t: Pin<&[f64]>,
136    q_t: Pin<&[f64]>,
137    sigma_b: Pin<&[f64]>,
138    gamma: Pin<&[f64]>,
139    tau: Pin<&[f64]>,
140    k: Pin<&[f64]>,
141    mut bid_p: Pin<&mut [f64]>,
142    mut ask_p: Pin<&mut [f64]>,
143) {
144    let x_ref = x_t.get_ref();
145    let q_ref = q_t.get_ref();
146    let sigma_ref = sigma_b.get_ref();
147    let gamma_ref = gamma.get_ref();
148    let tau_ref = tau.get_ref();
149    let k_ref = k.get_ref();
150    let bid_ref = bid_p.as_mut().get_mut();
151    let ask_ref = ask_p.as_mut().get_mut();
152
153    let n = x_ref.len();
154    assert_eq!(q_ref.len(), n, "calculate_quotes_logit: q_t length mismatch");
155    assert_eq!(
156        sigma_ref.len(),
157        n,
158        "calculate_quotes_logit: sigma_b length mismatch"
159    );
160    assert_eq!(
161        gamma_ref.len(),
162        n,
163        "calculate_quotes_logit: gamma length mismatch"
164    );
165    assert_eq!(tau_ref.len(), n, "calculate_quotes_logit: tau length mismatch");
166    assert_eq!(k_ref.len(), n, "calculate_quotes_logit: k length mismatch");
167    assert_eq!(
168        bid_ref.len(),
169        n,
170        "calculate_quotes_logit: bid_p length mismatch"
171    );
172    assert_eq!(
173        ask_ref.len(),
174        n,
175        "calculate_quotes_logit: ask_p length mismatch"
176    );
177
178    unsafe {
179        ffi_calculate_quotes_logit(
180            x_ref.as_ptr(),
181            q_ref.as_ptr(),
182            sigma_ref.as_ptr(),
183            gamma_ref.as_ptr(),
184            tau_ref.as_ptr(),
185            k_ref.as_ptr(),
186            bid_ref.as_mut_ptr(),
187            ask_ref.as_mut_ptr(),
188            n,
189        );
190    }
191}
192
193#[cfg(test)]
194mod tests {
195    use super::*;
196    use core::mem::{align_of, size_of};
197
198    #[test]
199    fn greek_layout_is_compact() {
200        assert_eq!(size_of::<GreekOut>(), 16);
201        assert_eq!(align_of::<GreekOut>(), 8);
202    }
203}