#![allow(clippy::not_unsafe_ptr_arg_deref)]
use crate::sparse_vector::SparseVector;
use crate::sparse_vector_store::SparseVectorStore;
use once_cell::sync::Lazy;
use parking_lot::Mutex;
use std::collections::HashMap;
use std::ffi::{c_char, c_float, c_int, c_uint, CStr};
use std::ptr;
pub const SVS_SUCCESS: c_int = 1;
pub const SVS_ERR_GENERIC: c_int = 0;
pub const SVS_ERR_NULL_PTR: c_int = -1;
pub const SVS_ERR_INVALID_UTF8: c_int = -2;
pub const SVS_ERR_NOT_FOUND: c_int = -3;
pub const SVS_ERR_ALREADY_EXISTS: c_int = -4;
pub const SVS_ERR_INTERNAL: c_int = -100;
static SVS_REGISTRY: Lazy<Mutex<HashMap<String, SparseVectorStore>>> =
Lazy::new(|| Mutex::new(HashMap::new()));
unsafe fn cstr_to_str(ptr: *const c_char) -> Option<&'static str> {
if ptr.is_null() {
return None;
}
CStr::from_ptr(ptr).to_str().ok()
}
#[no_mangle]
pub extern "C" fn svs_new(path: *const c_char) -> c_int {
std::panic::catch_unwind(|| {
let path = match unsafe { cstr_to_str(path) } {
Some(p) => p,
None => return SVS_ERR_NULL_PTR,
};
let mut registry = SVS_REGISTRY.lock();
if registry.contains_key(path) {
return SVS_ERR_ALREADY_EXISTS;
}
registry.insert(path.to_string(), SparseVectorStore::new());
SVS_SUCCESS
})
.unwrap_or(SVS_ERR_INTERNAL)
}
#[no_mangle]
pub extern "C" fn svs_close(path: *const c_char) -> c_int {
std::panic::catch_unwind(|| {
let path = match unsafe { cstr_to_str(path) } {
Some(p) => p,
None => return SVS_ERR_NULL_PTR,
};
let mut registry = SVS_REGISTRY.lock();
if registry.remove(path).is_some() {
SVS_SUCCESS
} else {
SVS_ERR_NOT_FOUND
}
})
.unwrap_or(SVS_ERR_INTERNAL)
}
#[no_mangle]
pub extern "C" fn svs_index(
path: *const c_char,
key: *const c_char,
term_ids: *const c_uint,
weights: *const c_float,
count: c_uint,
) -> i64 {
std::panic::catch_unwind(|| {
let path = match unsafe { cstr_to_str(path) } {
Some(p) => p,
None => return SVS_ERR_NULL_PTR as i64,
};
let key = match unsafe { cstr_to_str(key) } {
Some(k) => k,
None => return SVS_ERR_NULL_PTR as i64,
};
if term_ids.is_null() || weights.is_null() {
return SVS_ERR_NULL_PTR as i64;
}
let mut registry = SVS_REGISTRY.lock();
let store = match registry.get_mut(path) {
Some(s) => s,
None => return SVS_ERR_NOT_FOUND as i64,
};
let mut vec = SparseVector::new();
for i in 0..count as usize {
let term_id = unsafe { *term_ids.add(i) };
let weight = unsafe { *weights.add(i) };
vec.add(term_id, weight);
}
let doc_id = store.index_with_key(key, vec);
doc_id as i64
})
.unwrap_or(SVS_ERR_INTERNAL as i64)
}
#[no_mangle]
pub extern "C" fn svs_search(
path: *const c_char,
term_ids: *const c_uint,
weights: *const c_float,
count: c_uint,
k: c_uint,
out_keys: *mut *mut c_char,
out_scores: *mut c_float,
out_count: *mut c_uint,
) -> c_int {
std::panic::catch_unwind(|| {
let path = match unsafe { cstr_to_str(path) } {
Some(p) => p,
None => return SVS_ERR_NULL_PTR,
};
if term_ids.is_null()
|| weights.is_null()
|| out_keys.is_null()
|| out_scores.is_null()
|| out_count.is_null()
{
return SVS_ERR_NULL_PTR;
}
let registry = SVS_REGISTRY.lock();
let store = match registry.get(path) {
Some(s) => s,
None => return SVS_ERR_NOT_FOUND,
};
let mut query = SparseVector::new();
for i in 0..count as usize {
let term_id = unsafe { *term_ids.add(i) };
let weight = unsafe { *weights.add(i) };
query.add(term_id, weight);
}
let results = store.search(&query, k as usize);
unsafe {
*out_count = results.len() as c_uint;
for (i, result) in results.iter().enumerate() {
let key_cstr = std::ffi::CString::new(result.key.clone()).unwrap();
let key_ptr = libc::malloc(key_cstr.as_bytes_with_nul().len()) as *mut c_char;
if !key_ptr.is_null() {
ptr::copy_nonoverlapping(
key_cstr.as_ptr(),
key_ptr,
key_cstr.as_bytes_with_nul().len(),
);
}
*out_keys.add(i) = key_ptr;
*out_scores.add(i) = result.score;
}
}
SVS_SUCCESS
})
.unwrap_or(SVS_ERR_INTERNAL)
}
#[no_mangle]
pub extern "C" fn svs_free_key(key: *mut c_char) {
if !key.is_null() {
unsafe {
libc::free(key as *mut libc::c_void);
}
}
}
#[no_mangle]
pub extern "C" fn svs_len(path: *const c_char) -> i64 {
std::panic::catch_unwind(|| {
let path = match unsafe { cstr_to_str(path) } {
Some(p) => p,
None => return SVS_ERR_NULL_PTR as i64,
};
let registry = SVS_REGISTRY.lock();
match registry.get(path) {
Some(store) => store.len() as i64,
None => SVS_ERR_NOT_FOUND as i64,
}
})
.unwrap_or(SVS_ERR_INTERNAL as i64)
}
#[no_mangle]
pub extern "C" fn svs_delete(path: *const c_char, key: *const c_char) -> c_int {
std::panic::catch_unwind(|| {
let path = match unsafe { cstr_to_str(path) } {
Some(p) => p,
None => return SVS_ERR_NULL_PTR,
};
let key = match unsafe { cstr_to_str(key) } {
Some(k) => k,
None => return SVS_ERR_NULL_PTR,
};
let mut registry = SVS_REGISTRY.lock();
let store = match registry.get_mut(path) {
Some(s) => s,
None => return SVS_ERR_NOT_FOUND,
};
if store.delete(key) {
SVS_SUCCESS
} else {
SVS_ERR_NOT_FOUND
}
})
.unwrap_or(SVS_ERR_INTERNAL)
}
#[no_mangle]
pub extern "C" fn svs_stats(
path: *const c_char,
out_num_docs: *mut c_uint,
out_num_terms: *mut c_uint,
out_num_postings: *mut c_uint,
out_avg_doc_len: *mut c_float,
) -> c_int {
std::panic::catch_unwind(|| {
let path = match unsafe { cstr_to_str(path) } {
Some(p) => p,
None => return SVS_ERR_NULL_PTR,
};
if out_num_docs.is_null()
|| out_num_terms.is_null()
|| out_num_postings.is_null()
|| out_avg_doc_len.is_null()
{
return SVS_ERR_NULL_PTR;
}
let registry = SVS_REGISTRY.lock();
let store = match registry.get(path) {
Some(s) => s,
None => return SVS_ERR_NOT_FOUND,
};
let stats = store.stats();
unsafe {
*out_num_docs = stats.num_documents as c_uint;
*out_num_terms = stats.num_terms as c_uint;
*out_num_postings = stats.num_postings as c_uint;
*out_avg_doc_len = stats.avg_doc_length;
}
SVS_SUCCESS
})
.unwrap_or(SVS_ERR_INTERNAL)
}
#[no_mangle]
pub extern "C" fn svs_exists(path: *const c_char) -> c_int {
std::panic::catch_unwind(|| {
let path = match unsafe { cstr_to_str(path) } {
Some(p) => p,
None => return SVS_ERR_NULL_PTR,
};
let registry = SVS_REGISTRY.lock();
if registry.contains_key(path) {
1
} else {
0
}
})
.unwrap_or(SVS_ERR_INTERNAL)
}
#[no_mangle]
pub extern "C" fn svs_save(path: *const c_char, file_path: *const c_char) -> c_int {
std::panic::catch_unwind(|| {
let path = match unsafe { cstr_to_str(path) } {
Some(p) => p,
None => return SVS_ERR_NULL_PTR,
};
let file_path = match unsafe { cstr_to_str(file_path) } {
Some(p) => p,
None => return SVS_ERR_NULL_PTR,
};
let registry = SVS_REGISTRY.lock();
let store = match registry.get(path) {
Some(s) => s,
None => return SVS_ERR_NOT_FOUND,
};
match store.save(file_path) {
Ok(()) => SVS_SUCCESS,
Err(_) => SVS_ERR_INTERNAL,
}
})
.unwrap_or(SVS_ERR_INTERNAL)
}
#[no_mangle]
pub extern "C" fn svs_open(path: *const c_char, file_path: *const c_char) -> c_int {
std::panic::catch_unwind(|| {
let path = match unsafe { cstr_to_str(path) } {
Some(p) => p,
None => return SVS_ERR_NULL_PTR,
};
let file_path = match unsafe { cstr_to_str(file_path) } {
Some(p) => p,
None => return SVS_ERR_NULL_PTR,
};
let mut registry = SVS_REGISTRY.lock();
if registry.contains_key(path) {
return SVS_ERR_ALREADY_EXISTS;
}
match SparseVectorStore::load(file_path) {
Ok(store) => {
registry.insert(path.to_string(), store);
SVS_SUCCESS
}
Err(_) => SVS_ERR_INTERNAL,
}
})
.unwrap_or(SVS_ERR_INTERNAL)
}