#![cfg(feature = "ffi")]
use std::ffi::c_char;
use std::os::raw::c_int;
use crate::index::{BonsaiConfig, BonsaiIndex};
use crate::types::{BBox, BackendKind, EntryId, Point};
pub struct BonsaiHandle(BonsaiIndex<OpaquePayload, f64, 2>);
unsafe impl Send for BonsaiHandle {}
unsafe impl Sync for BonsaiHandle {}
#[derive(Clone, Copy)]
pub struct OpaquePayload(#[allow(dead_code)] *mut std::ffi::c_void);
unsafe impl Send for OpaquePayload {}
unsafe impl Sync for OpaquePayload {}
#[repr(C)]
pub struct BonsaiStats {
pub point_count: u64,
pub query_count: u64,
pub migration_count: u64,
pub migrating: c_int,
pub backend_kind: c_int,
}
#[no_mangle]
pub unsafe extern "C" fn bonsai_new() -> *mut BonsaiHandle {
let index: BonsaiIndex<OpaquePayload, f64, 2> =
BonsaiIndex::from_config(BonsaiConfig::default());
Box::into_raw(Box::new(BonsaiHandle(index)))
}
#[no_mangle]
pub unsafe extern "C" fn bonsai_free(handle: *mut BonsaiHandle) {
if handle.is_null() {
return;
}
drop(Box::from_raw(handle));
}
#[no_mangle]
pub unsafe extern "C" fn bonsai_insert_2d(
handle: *mut BonsaiHandle,
x: f64,
y: f64,
payload: *mut std::ffi::c_void,
) -> u64 {
if handle.is_null() {
return u64::MAX;
}
let index = &mut (*handle).0;
let id = index.insert(Point::new([x, y]), OpaquePayload(payload));
id.0
}
#[no_mangle]
pub unsafe extern "C" fn bonsai_remove(handle: *mut BonsaiHandle, entry_id: u64) -> c_int {
if handle.is_null() {
return 0;
}
let index = &mut (*handle).0;
match index.remove(EntryId(entry_id)) {
Some(_) => 1,
None => 0,
}
}
#[no_mangle]
pub unsafe extern "C" fn bonsai_range_query_2d(
handle: *mut BonsaiHandle,
min_x: f64,
min_y: f64,
max_x: f64,
max_y: f64,
out_ids: *mut u64,
capacity: usize,
) -> usize {
if handle.is_null() || out_ids.is_null() {
return 0;
}
let index = &mut (*handle).0;
let bbox = BBox::new(Point::new([min_x, min_y]), Point::new([max_x, max_y]));
let results = index.range_query(&bbox);
let count = results.len().min(capacity);
for (i, (id, _payload)) in results.iter().take(count).enumerate() {
*out_ids.add(i) = id.0;
}
count
}
#[no_mangle]
pub unsafe extern "C" fn bonsai_knn_query_2d(
handle: *const BonsaiHandle,
qx: f64,
qy: f64,
k: usize,
out_ids: *mut u64,
out_dist: *mut f64,
) -> usize {
if handle.is_null() || out_ids.is_null() || out_dist.is_null() {
return 0;
}
let index = &(*handle).0;
let results = index.knn_query(&Point::new([qx, qy]), k);
let count = results.len().min(k);
for (i, (dist, id, _payload)) in results.iter().take(count).enumerate() {
*out_ids.add(i) = id.0;
*out_dist.add(i) = *dist;
}
count
}
#[no_mangle]
pub unsafe extern "C" fn bonsai_stats(handle: *const BonsaiHandle, out: *mut BonsaiStats) -> c_int {
if handle.is_null() || out.is_null() {
return 0;
}
let index = &(*handle).0;
let s = index.stats();
let backend_kind = match s.backend {
BackendKind::RTree => 0,
BackendKind::KDTree => 1,
BackendKind::Quadtree => 2,
BackendKind::Grid => 3,
};
*out = BonsaiStats {
point_count: s.point_count as u64,
query_count: s.query_count,
migration_count: s.migrations,
migrating: s.migrating as c_int,
backend_kind,
};
1
}
#[no_mangle]
pub unsafe extern "C" fn bonsai_force_backend(handle: *mut BonsaiHandle, backend: c_int) -> c_int {
if handle.is_null() {
return 0;
}
let kind = match backend {
0 => BackendKind::RTree,
1 => BackendKind::KDTree,
2 => BackendKind::Quadtree,
3 => BackendKind::Grid,
_ => return 0,
};
let index = &mut (*handle).0;
match index.force_backend(kind) {
Ok(()) => 1,
Err(_) => 0,
}
}
#[no_mangle]
pub unsafe extern "C" fn bonsai_backend_name(
handle: *const BonsaiHandle,
out_buf: *mut c_char,
buf_len: usize,
) -> usize {
if handle.is_null() || out_buf.is_null() || buf_len == 0 {
return usize::MAX;
}
let index = &(*handle).0;
let name = match index.stats().backend {
BackendKind::RTree => "rtree",
BackendKind::KDTree => "kdtree",
BackendKind::Quadtree => "quadtree",
BackendKind::Grid => "grid",
};
let bytes = name.as_bytes();
let copy_len = bytes.len().min(buf_len - 1);
std::ptr::copy_nonoverlapping(bytes.as_ptr() as *const c_char, out_buf, copy_len);
*out_buf.add(copy_len) = 0;
copy_len
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn bonsai_free_null_is_noop() {
unsafe { bonsai_free(std::ptr::null_mut()) };
}
#[test]
fn round_trip_insert_range_query() {
unsafe {
let h = bonsai_new();
assert!(!h.is_null());
let id0 = bonsai_insert_2d(h, 1.0, 1.0, std::ptr::null_mut());
let id1 = bonsai_insert_2d(h, 5.0, 5.0, std::ptr::null_mut());
let _id2 = bonsai_insert_2d(h, 9.0, 9.0, std::ptr::null_mut());
let mut out_ids = [0u64; 8];
let count =
bonsai_range_query_2d(h, 0.0, 0.0, 6.0, 6.0, out_ids.as_mut_ptr(), out_ids.len());
assert_eq!(count, 2);
let found: std::collections::HashSet<u64> = out_ids[..count].iter().copied().collect();
assert!(found.contains(&id0));
assert!(found.contains(&id1));
bonsai_free(h);
}
}
#[test]
fn round_trip_knn_query() {
unsafe {
let h = bonsai_new();
let id0 = bonsai_insert_2d(h, 0.0, 0.0, std::ptr::null_mut());
let _id1 = bonsai_insert_2d(h, 10.0, 10.0, std::ptr::null_mut());
let mut out_ids = [0u64; 2];
let mut out_dist = [0.0f64; 2];
let count =
bonsai_knn_query_2d(h, 0.0, 0.0, 1, out_ids.as_mut_ptr(), out_dist.as_mut_ptr());
assert_eq!(count, 1);
assert_eq!(out_ids[0], id0);
assert!(out_dist[0].abs() < 1e-9);
bonsai_free(h);
}
}
#[test]
fn remove_entry() {
unsafe {
let h = bonsai_new();
let id = bonsai_insert_2d(h, 3.0, 3.0, std::ptr::null_mut());
assert_eq!(bonsai_remove(h, id), 1);
let mut out_ids = [0u64; 4];
let count =
bonsai_range_query_2d(h, 0.0, 0.0, 10.0, 10.0, out_ids.as_mut_ptr(), out_ids.len());
assert_eq!(count, 0);
bonsai_free(h);
}
}
#[test]
fn stats_round_trip() {
unsafe {
let h = bonsai_new();
bonsai_insert_2d(h, 1.0, 2.0, std::ptr::null_mut());
bonsai_insert_2d(h, 3.0, 4.0, std::ptr::null_mut());
let mut s = BonsaiStats {
point_count: 0,
query_count: 0,
migration_count: 0,
migrating: 0,
backend_kind: 0,
};
let ok = bonsai_stats(h, &mut s as *mut BonsaiStats);
assert_eq!(ok, 1);
assert_eq!(s.point_count, 2);
bonsai_free(h);
}
}
#[test]
fn force_backend_ok() {
unsafe {
let h = bonsai_new();
assert_eq!(bonsai_force_backend(h, 0), 1);
bonsai_free(h);
}
}
#[test]
fn null_handle_returns_safe_defaults() {
unsafe {
assert_eq!(
bonsai_insert_2d(std::ptr::null_mut(), 0.0, 0.0, std::ptr::null_mut()),
u64::MAX
);
assert_eq!(bonsai_remove(std::ptr::null_mut(), 0), 0);
let mut out_ids = [0u64; 4];
assert_eq!(
bonsai_range_query_2d(
std::ptr::null_mut(),
0.0,
0.0,
1.0,
1.0,
out_ids.as_mut_ptr(),
4,
),
0
);
let mut out_ids2 = [0u64; 1];
let mut out_dist = [0.0f64; 1];
assert_eq!(
bonsai_knn_query_2d(
std::ptr::null(),
0.0,
0.0,
1,
out_ids2.as_mut_ptr(),
out_dist.as_mut_ptr(),
),
0
);
assert_eq!(bonsai_force_backend(std::ptr::null_mut(), 0), 0);
}
}
}