1use core::pin::Pin;
2
3#[repr(C)]
4#[derive(Clone, Copy, Debug, Default)]
5pub struct GreekOut {
6 pub delta_x: f64,
7 pub gamma_x: f64,
8}
9
10unsafe extern "C" {
11 #[link_name = "kernel_sigmoid"]
12 fn ffi_kernel_sigmoid(x: f64) -> f64;
13 #[link_name = "kernel_logit"]
14 fn ffi_kernel_logit(p: f64) -> f64;
15 #[link_name = "kernel_sigmoid_batch"]
16 fn ffi_kernel_sigmoid_batch(x: *const f64, out_p: *mut f64, n: usize);
17 #[link_name = "kernel_logit_batch"]
18 fn ffi_kernel_logit_batch(p: *const f64, out_x: *mut f64, n: usize);
19 #[link_name = "kernel_greeks_from_logit"]
20 fn ffi_kernel_greeks_from_logit(x: f64, delta_x: *mut f64, gamma_x: *mut f64);
21 #[link_name = "kernel_greeks_batch"]
22 fn ffi_kernel_greeks_batch(x: *const f64, out: *mut GreekOut, n: usize);
23 #[link_name = "calculate_quotes_logit"]
24 fn ffi_calculate_quotes_logit(
25 x_t: *const f64,
26 q_t: *const f64,
27 sigma_b: *const f64,
28 gamma: *const f64,
29 tau: *const f64,
30 k: *const f64,
31 bid_p: *mut f64,
32 ask_p: *mut f64,
33 n: usize,
34 );
35}
36
37#[inline]
38pub fn sigmoid(x: f64) -> f64 {
39 unsafe { ffi_kernel_sigmoid(x) }
40}
41
42#[inline]
43pub fn logit(p: f64) -> f64 {
44 unsafe { ffi_kernel_logit(p) }
45}
46
47#[inline]
48pub fn greeks_from_logit(x: f64) -> GreekOut {
49 let mut out = GreekOut::default();
50 unsafe {
51 ffi_kernel_greeks_from_logit(x, &mut out.delta_x, &mut out.gamma_x);
52 }
53 out
54}
55
56pub fn sigmoid_batch(x: &[f64], out_p: &mut [f64]) {
57 sigmoid_batch_pinned(Pin::new(x), Pin::new(out_p));
58}
59
60pub fn logit_batch(p: &[f64], out_x: &mut [f64]) {
61 logit_batch_pinned(Pin::new(p), Pin::new(out_x));
62}
63
64pub fn greeks_batch(x: &[f64], out: &mut [GreekOut]) {
65 greeks_batch_pinned(Pin::new(x), Pin::new(out));
66}
67
68#[allow(clippy::too_many_arguments)]
69pub fn calculate_quotes_logit(
70 x_t: &[f64],
71 q_t: &[f64],
72 sigma_b: &[f64],
73 gamma: &[f64],
74 tau: &[f64],
75 k: &[f64],
76 bid_p: &mut [f64],
77 ask_p: &mut [f64],
78) {
79 calculate_quotes_logit_pinned(
80 Pin::new(x_t),
81 Pin::new(q_t),
82 Pin::new(sigma_b),
83 Pin::new(gamma),
84 Pin::new(tau),
85 Pin::new(k),
86 Pin::new(bid_p),
87 Pin::new(ask_p),
88 );
89}
90
91pub fn sigmoid_batch_pinned(x: Pin<&[f64]>, mut out_p: Pin<&mut [f64]>) {
92 let x_ref = x.get_ref();
93 let out_ref = out_p.as_mut().get_mut();
94 assert_eq!(
95 x_ref.len(),
96 out_ref.len(),
97 "sigmoid_batch: input and output lengths must match"
98 );
99
100 unsafe {
101 ffi_kernel_sigmoid_batch(x_ref.as_ptr(), out_ref.as_mut_ptr(), x_ref.len());
102 }
103}
104
105pub fn logit_batch_pinned(p: Pin<&[f64]>, mut out_x: Pin<&mut [f64]>) {
106 let p_ref = p.get_ref();
107 let out_ref = out_x.as_mut().get_mut();
108 assert_eq!(
109 p_ref.len(),
110 out_ref.len(),
111 "logit_batch: input and output lengths must match"
112 );
113
114 unsafe {
115 ffi_kernel_logit_batch(p_ref.as_ptr(), out_ref.as_mut_ptr(), p_ref.len());
116 }
117}
118
119pub fn greeks_batch_pinned(x: Pin<&[f64]>, mut out: Pin<&mut [GreekOut]>) {
120 let x_ref = x.get_ref();
121 let out_ref = out.as_mut().get_mut();
122 assert_eq!(
123 x_ref.len(),
124 out_ref.len(),
125 "greeks_batch: input and output lengths must match"
126 );
127
128 unsafe {
129 ffi_kernel_greeks_batch(x_ref.as_ptr(), out_ref.as_mut_ptr(), x_ref.len());
130 }
131}
132
133#[allow(clippy::too_many_arguments)]
134pub fn calculate_quotes_logit_pinned(
135 x_t: Pin<&[f64]>,
136 q_t: Pin<&[f64]>,
137 sigma_b: Pin<&[f64]>,
138 gamma: Pin<&[f64]>,
139 tau: Pin<&[f64]>,
140 k: Pin<&[f64]>,
141 mut bid_p: Pin<&mut [f64]>,
142 mut ask_p: Pin<&mut [f64]>,
143) {
144 let x_ref = x_t.get_ref();
145 let q_ref = q_t.get_ref();
146 let sigma_ref = sigma_b.get_ref();
147 let gamma_ref = gamma.get_ref();
148 let tau_ref = tau.get_ref();
149 let k_ref = k.get_ref();
150 let bid_ref = bid_p.as_mut().get_mut();
151 let ask_ref = ask_p.as_mut().get_mut();
152
153 let n = x_ref.len();
154 assert_eq!(q_ref.len(), n, "calculate_quotes_logit: q_t length mismatch");
155 assert_eq!(
156 sigma_ref.len(),
157 n,
158 "calculate_quotes_logit: sigma_b length mismatch"
159 );
160 assert_eq!(
161 gamma_ref.len(),
162 n,
163 "calculate_quotes_logit: gamma length mismatch"
164 );
165 assert_eq!(tau_ref.len(), n, "calculate_quotes_logit: tau length mismatch");
166 assert_eq!(k_ref.len(), n, "calculate_quotes_logit: k length mismatch");
167 assert_eq!(
168 bid_ref.len(),
169 n,
170 "calculate_quotes_logit: bid_p length mismatch"
171 );
172 assert_eq!(
173 ask_ref.len(),
174 n,
175 "calculate_quotes_logit: ask_p length mismatch"
176 );
177
178 unsafe {
179 ffi_calculate_quotes_logit(
180 x_ref.as_ptr(),
181 q_ref.as_ptr(),
182 sigma_ref.as_ptr(),
183 gamma_ref.as_ptr(),
184 tau_ref.as_ptr(),
185 k_ref.as_ptr(),
186 bid_ref.as_mut_ptr(),
187 ask_ref.as_mut_ptr(),
188 n,
189 );
190 }
191}
192
193#[cfg(test)]
194mod tests {
195 use super::*;
196 use core::mem::{align_of, size_of};
197
198 #[test]
199 fn greek_layout_is_compact() {
200 assert_eq!(size_of::<GreekOut>(), 16);
201 assert_eq!(align_of::<GreekOut>(), 8);
202 }
203}