#![allow(unsafe_code)]
use std::panic::catch_unwind;
use crate::{TurboCode, TurboQuantizer, TurboQuantError};
fn err_code(e: &TurboQuantError) -> i32 {
match e {
TurboQuantError::DimensionMismatch { .. } => -1,
TurboQuantError::ZeroDimension
| TurboQuantError::OddDimension(_)
| TurboQuantError::InvalidBitWidth(_)
| TurboQuantError::ZeroProjections
| TurboQuantError::DimensionTooLarge(_, _)
| TurboQuantError::EmptyInput(_)
| TurboQuantError::IndexOutOfBounds { .. } => -2,
TurboQuantError::NonFiniteInput { .. } => -3,
TurboQuantError::DeserializationError { .. } => -2,
}
}
#[inline]
unsafe fn write_err(err_out: *mut i32, code: i32) {
if !err_out.is_null() {
unsafe { *err_out = code };
}
}
#[no_mangle]
pub unsafe extern "C" fn bitpolar_new(
dim: u32,
bits: u8,
projections: u32,
seed: u64,
err_out: *mut i32,
) -> *mut TurboQuantizer {
let result = catch_unwind(|| {
TurboQuantizer::new(dim as usize, bits, projections as usize, seed)
});
match result {
Ok(Ok(q)) => {
unsafe { write_err(err_out, 0) };
Box::into_raw(Box::new(q))
}
Ok(Err(e)) => {
unsafe { write_err(err_out, err_code(&e)) };
std::ptr::null_mut()
}
Err(_) => {
unsafe { write_err(err_out, -99) };
std::ptr::null_mut()
}
}
}
#[no_mangle]
pub unsafe extern "C" fn bitpolar_free(ptr: *mut TurboQuantizer) {
if !ptr.is_null() {
drop(unsafe { Box::from_raw(ptr) });
}
}
#[no_mangle]
pub unsafe extern "C" fn bitpolar_encode(
q: *const TurboQuantizer,
vector: *const f32,
dim: u32,
err_out: *mut i32,
) -> *mut TurboCode {
if q.is_null() || vector.is_null() {
unsafe { write_err(err_out, -4) };
return std::ptr::null_mut();
}
if dim == 0 {
unsafe { write_err(err_out, -2) };
return std::ptr::null_mut();
}
let expected_dim = unsafe { (&*q).dim() };
if dim as usize != expected_dim {
unsafe { write_err(err_out, -1) };
return std::ptr::null_mut();
}
let result = catch_unwind(|| {
let quantizer = unsafe { &*q };
let slice = unsafe { std::slice::from_raw_parts(vector, dim as usize) };
quantizer.encode(slice)
});
match result {
Ok(Ok(code)) => {
unsafe { write_err(err_out, 0) };
Box::into_raw(Box::new(code))
}
Ok(Err(e)) => {
unsafe { write_err(err_out, err_code(&e)) };
std::ptr::null_mut()
}
Err(_) => {
unsafe { write_err(err_out, -99) };
std::ptr::null_mut()
}
}
}
#[no_mangle]
pub unsafe extern "C" fn bitpolar_code_free(ptr: *mut TurboCode) {
if !ptr.is_null() {
drop(unsafe { Box::from_raw(ptr) });
}
}
#[no_mangle]
pub unsafe extern "C" fn bitpolar_inner_product(
q: *const TurboQuantizer,
code: *const TurboCode,
query: *const f32,
dim: u32,
result_out: *mut f32,
) -> i32 {
if q.is_null() || code.is_null() || query.is_null() || result_out.is_null() {
return -4;
}
if dim == 0 {
return -2;
}
let expected_dim = unsafe { (&*q).dim() };
if dim as usize != expected_dim {
return -1;
}
let result = catch_unwind(|| {
let quantizer = unsafe { &*q };
let tc = unsafe { &*code };
let slice = unsafe { std::slice::from_raw_parts(query, dim as usize) };
quantizer.inner_product_estimate(tc, slice)
});
match result {
Ok(Ok(v)) => {
unsafe { *result_out = v };
0
}
Ok(Err(e)) => err_code(&e),
Err(_) => -99,
}
}
#[no_mangle]
pub unsafe extern "C" fn bitpolar_decode(
q: *const TurboQuantizer,
code: *const TurboCode,
out: *mut f32,
dim: u32,
) -> i32 {
if q.is_null() || code.is_null() || out.is_null() {
return -4;
}
let result = catch_unwind(|| {
let quantizer = unsafe { &*q };
let tc = unsafe { &*code };
quantizer.decode(tc)
});
match result {
Ok(decoded) => {
let n = dim as usize;
if decoded.len() != n {
return -1;
}
unsafe {
let dst = std::slice::from_raw_parts_mut(out, n);
dst.copy_from_slice(&decoded);
}
0
}
Err(_) => -99,
}
}
#[no_mangle]
pub unsafe extern "C" fn bitpolar_code_to_bytes(
code: *const TurboCode,
buf: *mut u8,
buf_len: u32,
) -> i32 {
if code.is_null() {
return -1;
}
let result = catch_unwind(|| {
let tc = unsafe { &*code };
tc.to_compact_bytes()
});
let bytes = match result {
Ok(b) => b,
Err(_) => return -1,
};
let needed = bytes.len();
if buf.is_null() {
return needed as i32;
}
if (buf_len as usize) < needed {
return -1;
}
unsafe {
let dst = std::slice::from_raw_parts_mut(buf, needed);
dst.copy_from_slice(&bytes);
}
needed as i32
}
#[no_mangle]
pub unsafe extern "C" fn bitpolar_code_from_bytes(
buf: *const u8,
len: u32,
err_out: *mut i32,
) -> *mut TurboCode {
if buf.is_null() {
unsafe { write_err(err_out, -4) };
return std::ptr::null_mut();
}
let result = catch_unwind(|| {
let slice = unsafe { std::slice::from_raw_parts(buf, len as usize) };
TurboCode::from_compact_bytes(slice)
});
match result {
Ok(Ok(code)) => {
unsafe { write_err(err_out, 0) };
Box::into_raw(Box::new(code))
}
Ok(Err(e)) => {
unsafe { write_err(err_out, err_code(&e)) };
std::ptr::null_mut()
}
Err(_) => {
unsafe { write_err(err_out, -99) };
std::ptr::null_mut()
}
}
}
#[no_mangle]
pub unsafe extern "C" fn bitpolar_dim(q: *const TurboQuantizer) -> u32 {
if q.is_null() {
return 0;
}
let quantizer = unsafe { &*q };
quantizer.dim() as u32
}
#[no_mangle]
pub unsafe extern "C" fn bitpolar_code_size(code: *const TurboCode) -> u32 {
if code.is_null() {
return 0;
}
let tc = unsafe { &*code };
tc.size_bytes() as u32
}
#[no_mangle]
pub unsafe extern "C" fn bitpolar_batch_inner_product(
q: *const TurboQuantizer,
codes: *const *const TurboCode,
n_codes: u32,
query: *const f32,
dim: u32,
scores_out: *mut f32,
) -> i32 {
if q.is_null() || query.is_null() || scores_out.is_null() {
return -4;
}
if n_codes > 0 && codes.is_null() {
return -4;
}
if dim == 0 && n_codes > 0 {
return -2;
}
if n_codes > 0 {
let expected_dim = unsafe { (&*q).dim() };
if dim as usize != expected_dim {
return -1;
}
}
let result = catch_unwind(|| {
let quantizer = unsafe { &*q };
let q_slice = unsafe { std::slice::from_raw_parts(query, dim as usize) };
let n = n_codes as usize;
crate::error::validate_finite(q_slice)?;
let scores_slice = unsafe { std::slice::from_raw_parts_mut(scores_out, n) };
for (i, score) in scores_slice.iter_mut().enumerate() {
let code_ptr = unsafe { *codes.add(i) };
if code_ptr.is_null() {
return Err(TurboQuantError::EmptyInput("null code pointer in batch"));
}
let tc = unsafe { &*code_ptr };
*score = quantizer.inner_product_estimate(tc, q_slice)?;
}
Ok(())
});
match result {
Ok(Ok(())) => 0,
Ok(Err(e)) => err_code(&e),
Err(_) => -99,
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::ptr;
#[test]
fn ffi_new_and_free() {
let mut err: i32 = -99;
let q = unsafe { bitpolar_new(8, 4, 8, 42, &mut err) };
assert_eq!(err, 0);
assert!(!q.is_null());
unsafe { bitpolar_free(q) };
}
#[test]
fn ffi_free_null_is_noop() {
unsafe { bitpolar_free(ptr::null_mut()) };
}
#[test]
fn ffi_new_invalid_dim_zero() {
let mut err: i32 = 0;
let q = unsafe { bitpolar_new(0, 4, 8, 42, &mut err) };
assert!(q.is_null());
assert_eq!(err, -2);
}
#[test]
fn ffi_new_odd_dimension() {
let mut err: i32 = 0;
let q = unsafe { bitpolar_new(3, 4, 8, 42, &mut err) };
assert!(q.is_null());
assert_eq!(err, -2);
}
#[test]
fn ffi_new_zero_projections() {
let mut err: i32 = 0;
let q = unsafe { bitpolar_new(8, 4, 0, 42, &mut err) };
assert!(q.is_null());
assert_eq!(err, -2);
}
#[test]
fn ffi_encode_decode_roundtrip() {
let mut err: i32 = -99;
let q = unsafe { bitpolar_new(8, 4, 8, 42, &mut err) };
assert_eq!(err, 0);
assert!(!q.is_null());
let vector = [0.1_f32, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8];
let code = unsafe { bitpolar_encode(q, vector.as_ptr(), 8, &mut err) };
assert_eq!(err, 0);
assert!(!code.is_null());
let mut decoded = [0.0_f32; 8];
let rc = unsafe { bitpolar_decode(q, code, decoded.as_mut_ptr(), 8) };
assert_eq!(rc, 0);
assert!(decoded.iter().all(|v| v.is_finite()));
unsafe { bitpolar_code_free(code) };
unsafe { bitpolar_free(q) };
}
#[test]
fn ffi_inner_product() {
let mut err: i32 = -99;
let q = unsafe { bitpolar_new(8, 4, 8, 42, &mut err) };
assert_eq!(err, 0);
let vector = [1.0_f32; 8];
let code = unsafe { bitpolar_encode(q, vector.as_ptr(), 8, &mut err) };
assert_eq!(err, 0);
assert!(!code.is_null());
let query = [0.5_f32; 8];
let mut result = 0.0_f32;
let rc = unsafe { bitpolar_inner_product(q, code, query.as_ptr(), 8, &mut result) };
assert_eq!(rc, 0);
assert!(result.is_finite());
unsafe { bitpolar_code_free(code) };
unsafe { bitpolar_free(q) };
}
#[test]
fn ffi_inner_product_null_returns_error() {
let mut err: i32 = -99;
let q = unsafe { bitpolar_new(8, 4, 8, 42, &mut err) };
assert!(!q.is_null());
let query = [0.5_f32; 8];
let mut result = 0.0_f32;
let rc = unsafe {
bitpolar_inner_product(q, ptr::null(), query.as_ptr(), 8, &mut result)
};
assert_eq!(rc, -4);
unsafe { bitpolar_free(q) };
}
#[test]
fn ffi_encode_null_returns_null() {
let mut err: i32 = 0;
let code = unsafe {
bitpolar_encode(ptr::null(), [0.0_f32; 8].as_ptr(), 8, &mut err)
};
assert!(code.is_null());
assert_eq!(err, -4);
}
#[test]
fn ffi_encode_dim_mismatch() {
let mut err: i32 = -99;
let q = unsafe { bitpolar_new(8, 4, 8, 42, &mut err) };
assert_eq!(err, 0);
let short = [0.0_f32; 4];
let code = unsafe { bitpolar_encode(q, short.as_ptr(), 4, &mut err) };
assert!(code.is_null());
assert_eq!(err, -1);
unsafe { bitpolar_free(q) };
}
#[test]
fn ffi_encode_nan_rejected() {
let mut err: i32 = -99;
let q = unsafe { bitpolar_new(8, 4, 8, 42, &mut err) };
assert_eq!(err, 0);
let mut v = [0.1_f32; 8];
v[3] = f32::NAN;
let code = unsafe { bitpolar_encode(q, v.as_ptr(), 8, &mut err) };
assert!(code.is_null());
assert_eq!(err, -3);
unsafe { bitpolar_free(q) };
}
#[test]
fn ffi_code_to_bytes_and_from_bytes() {
let mut err: i32 = -99;
let q = unsafe { bitpolar_new(8, 4, 8, 42, &mut err) };
assert_eq!(err, 0);
let vector = [0.1_f32, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8];
let code = unsafe { bitpolar_encode(q, vector.as_ptr(), 8, &mut err) };
assert_eq!(err, 0);
assert!(!code.is_null());
let needed = unsafe { bitpolar_code_to_bytes(code, ptr::null_mut(), 0) };
assert!(needed > 0, "required size must be positive");
let mut buf = vec![0u8; needed as usize];
let written = unsafe {
bitpolar_code_to_bytes(code, buf.as_mut_ptr(), buf.len() as u32)
};
assert_eq!(written, needed);
let code2 =
unsafe { bitpolar_code_from_bytes(buf.as_ptr(), buf.len() as u32, &mut err) };
assert_eq!(err, 0);
assert!(!code2.is_null());
let needed2 = unsafe { bitpolar_code_to_bytes(code2, ptr::null_mut(), 0) };
assert_eq!(needed, needed2);
let mut buf2 = vec![0u8; needed2 as usize];
unsafe { bitpolar_code_to_bytes(code2, buf2.as_mut_ptr(), buf2.len() as u32) };
assert_eq!(buf, buf2, "deserialized code must re-serialize identically");
unsafe { bitpolar_code_free(code) };
unsafe { bitpolar_code_free(code2) };
unsafe { bitpolar_free(q) };
}
#[test]
fn ffi_code_to_bytes_buf_too_small() {
let mut err: i32 = -99;
let q = unsafe { bitpolar_new(8, 4, 8, 42, &mut err) };
let vector = [0.1_f32; 8];
let code = unsafe { bitpolar_encode(q, vector.as_ptr(), 8, &mut err) };
assert!(!code.is_null());
let mut tiny = [0u8; 1];
let rc = unsafe { bitpolar_code_to_bytes(code, tiny.as_mut_ptr(), 1) };
assert_eq!(rc, -1, "should return -1 when buffer is too small");
unsafe { bitpolar_code_free(code) };
unsafe { bitpolar_free(q) };
}
#[test]
fn ffi_code_from_bytes_null_buf() {
let mut err: i32 = 0;
let code = unsafe { bitpolar_code_from_bytes(ptr::null(), 0, &mut err) };
assert!(code.is_null());
assert_eq!(err, -4);
}
#[test]
fn ffi_code_from_bytes_corrupted() {
let mut err: i32 = 0;
let garbage = [0xFFu8; 16];
let code =
unsafe { bitpolar_code_from_bytes(garbage.as_ptr(), garbage.len() as u32, &mut err) };
assert!(code.is_null());
assert_ne!(err, 0);
}
#[test]
fn ffi_dim() {
let mut err: i32 = -99;
let q = unsafe { bitpolar_new(8, 4, 8, 42, &mut err) };
assert_eq!(err, 0);
assert_eq!(unsafe { bitpolar_dim(q) }, 8);
unsafe { bitpolar_free(q) };
}
#[test]
fn ffi_dim_null() {
assert_eq!(unsafe { bitpolar_dim(ptr::null()) }, 0);
}
#[test]
fn ffi_code_size() {
let mut err: i32 = -99;
let q = unsafe { bitpolar_new(8, 4, 8, 42, &mut err) };
let vector = [0.1_f32; 8];
let code = unsafe { bitpolar_encode(q, vector.as_ptr(), 8, &mut err) };
assert!(!code.is_null());
let sz = unsafe { bitpolar_code_size(code) };
assert!(sz > 0);
unsafe { bitpolar_code_free(code) };
unsafe { bitpolar_free(q) };
}
#[test]
fn ffi_code_size_null() {
assert_eq!(unsafe { bitpolar_code_size(ptr::null()) }, 0);
}
#[test]
fn ffi_batch_inner_product() {
let mut err: i32 = -99;
let q = unsafe { bitpolar_new(8, 4, 8, 42, &mut err) };
assert_eq!(err, 0);
let vecs = [
[0.1_f32, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8],
[0.8_f32, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1],
[0.5_f32; 8],
];
let codes: Vec<*const TurboCode> = vecs
.iter()
.map(|v| {
let c = unsafe { bitpolar_encode(q, v.as_ptr(), 8, &mut err) };
assert_eq!(err, 0);
c as *const TurboCode
})
.collect();
let query = [0.3_f32; 8];
let mut scores = [0.0_f32; 3];
let rc = unsafe {
bitpolar_batch_inner_product(
q,
codes.as_ptr(),
codes.len() as u32,
query.as_ptr(),
8,
scores.as_mut_ptr(),
)
};
assert_eq!(rc, 0);
assert!(scores.iter().all(|s| s.is_finite()));
for (i, &code_ptr) in codes.iter().enumerate() {
let mut single = 0.0_f32;
let rc2 = unsafe {
bitpolar_inner_product(q, code_ptr, query.as_ptr(), 8, &mut single)
};
assert_eq!(rc2, 0);
assert!(
(scores[i] - single).abs() < 1e-5,
"batch score[{i}]={} != sequential={single}",
scores[i]
);
}
for &c in &codes {
unsafe { bitpolar_code_free(c as *mut TurboCode) };
}
unsafe { bitpolar_free(q) };
}
#[test]
fn ffi_batch_inner_product_empty() {
let mut err: i32 = -99;
let q = unsafe { bitpolar_new(8, 4, 8, 42, &mut err) };
assert_eq!(err, 0);
let query = [0.3_f32; 8];
let mut scores: [f32; 0] = [];
let rc = unsafe {
bitpolar_batch_inner_product(
q,
ptr::null(),
0,
query.as_ptr(),
8,
scores.as_mut_ptr(),
)
};
assert_eq!(rc, 0);
unsafe { bitpolar_free(q) };
}
#[test]
fn ffi_batch_inner_product_null_query() {
let mut err: i32 = -99;
let q = unsafe { bitpolar_new(8, 4, 8, 42, &mut err) };
let vector = [0.1_f32; 8];
let code = unsafe { bitpolar_encode(q, vector.as_ptr(), 8, &mut err) } as *const TurboCode;
let codes = [code];
let mut scores = [0.0_f32; 1];
let rc = unsafe {
bitpolar_batch_inner_product(q, codes.as_ptr(), 1, ptr::null(), 8, scores.as_mut_ptr())
};
assert_eq!(rc, -4);
unsafe { bitpolar_code_free(code as *mut TurboCode) };
unsafe { bitpolar_free(q) };
}
#[test]
fn ffi_err_out_null_is_safe() {
let q = unsafe { bitpolar_new(8, 4, 8, 42, ptr::null_mut()) };
assert!(!q.is_null());
unsafe { bitpolar_free(q) };
}
#[test]
fn ffi_new_null_errout() {
let q = unsafe { bitpolar_new(0, 4, 8, 42, ptr::null_mut()) };
assert!(q.is_null()); }
}