use std::{
ffi::{c_char, CStr, CString},
sync::Arc,
};
use ailake_catalog::{HadoopCatalog, TableIdent};
use ailake_core::VectorMetric;
use ailake_query::{
search as rs_search, Chunk, ContextAssembler, ContextAssemblerConfig, SearchConfig,
SearchResult,
};
use ailake_store::LocalStore;
use serde::Serialize;
use tracing::{debug, info, warn};
#[derive(Clone)]
pub struct RowResult {
pub row_id: u64,
pub distance: f32,
pub file_path: String,
}
#[derive(Serialize)]
struct RowResultJson {
row_id: u64,
distance: f32,
file_path: String,
}
impl From<SearchResult> for RowResultJson {
fn from(r: SearchResult) -> Self {
Self {
row_id: r.row_id.as_u64(),
distance: r.distance,
file_path: r.file_path,
}
}
}
fn rt() -> &'static tokio::runtime::Runtime {
use std::sync::OnceLock;
static RT: OnceLock<tokio::runtime::Runtime> = OnceLock::new();
RT.get_or_init(|| {
match tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
{
Ok(rt) => {
info!("ailake-jni: Tokio multi-thread runtime initialised");
rt
}
Err(e) => {
warn!(
"ailake-jni: multi-thread Tokio runtime failed ({}); \
falling back to single-threaded runtime to avoid JVM signal handler conflicts",
e
);
tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.expect("ailake-jni: tokio runtime unavailable")
}
}
})
}
fn parse_metric(s: &str) -> VectorMetric {
match s {
"euclidean" => VectorMetric::Euclidean,
"dot_product" | "dotproduct" => VectorMetric::DotProduct,
_ => VectorMetric::Cosine,
}
}
#[allow(clippy::too_many_arguments)]
fn do_search(
warehouse: String,
namespace: &str,
table_name: &str,
vec_col: &str,
dim: u32,
query: Vec<f32>,
top_k: u32,
ef_search: u32,
) -> ailake_core::AilakeResult<Vec<SearchResult>> {
let store: Arc<dyn ailake_store::Store> = Arc::new(LocalStore::new(&warehouse));
let catalog = Arc::new(HadoopCatalog::new(store.clone(), &warehouse));
let table = TableIdent::new(namespace, table_name);
let config = SearchConfig {
top_k: top_k as usize,
ef_search: ef_search as usize,
pruning_threshold: f32::INFINITY,
rerank_factor: None,
};
rt().block_on(rs_search(
&table, &query, config, vec_col, dim, catalog, store,
))
}
#[allow(dead_code)]
fn assemble_context(chunk_jsons: Vec<String>, max_tokens: u64) -> String {
let config = ContextAssemblerConfig {
max_tokens: max_tokens as usize,
..Default::default()
};
let ca = ContextAssembler::new(config);
let chunks: Vec<Chunk> = chunk_jsons
.iter()
.filter_map(|json| {
let v: serde_json::Value = serde_json::from_str(json).ok()?;
let get_str = |key: &str| -> String {
v.get(key)
.and_then(|x| x.as_str())
.unwrap_or("")
.to_string()
};
let get_opt = |key: &str| -> Option<String> {
v.get(key).and_then(|x| x.as_str()).map(|s| s.to_string())
};
Some(Chunk {
document_id: get_str("document_id"),
chunk_index: v.get("chunk_index").and_then(|x| x.as_u64()).unwrap_or(0) as u32,
chunk_text: get_str("chunk_text"),
document_title: get_opt("document_title"),
section_path: get_opt("section_path"),
source_uri: get_opt("source_uri"),
distance: v.get("distance").and_then(|x| x.as_f64()).unwrap_or(0.0) as f32,
embedding: None,
})
})
.collect();
ca.assemble_chunks(chunks).text
}
fn cstr_empty_json() -> *mut c_char {
CString::new("[]").unwrap().into_raw()
}
fn cstr_err_json(msg: impl std::fmt::Display) -> *mut c_char {
let s = format!("{{\"ok\":false,\"error\":\"{msg}\"}}");
CString::new(s).unwrap_or_default().into_raw()
}
#[no_mangle]
pub extern "C" fn ailake_version() -> *const c_char {
static V: &str = concat!(env!("CARGO_PKG_VERSION"), "\0");
V.as_ptr() as *const c_char
}
#[no_mangle]
pub unsafe extern "C" fn ailake_vector_search_json(
table_uri: *const c_char,
query_ptr: *const f32,
query_len: u32,
top_k: u32,
) -> *mut c_char {
if table_uri.is_null() || query_ptr.is_null() {
return cstr_empty_json();
}
let uri = match CStr::from_ptr(table_uri).to_str() {
Ok(s) => s.to_string(),
Err(_) => return cstr_empty_json(),
};
let query = std::slice::from_raw_parts(query_ptr, query_len as usize).to_vec();
let dim = query.len() as u32;
let results: Vec<RowResultJson> =
match do_search(uri, "default", "table", "embedding", dim, query, top_k, 50) {
Ok(v) => v.into_iter().map(RowResultJson::from).collect(),
Err(e) => return cstr_err_json(e),
};
let json = serde_json::to_string(&results).unwrap_or_else(|_| "[]".to_string());
CString::new(json)
.unwrap_or_else(|_| CString::new("[]").unwrap())
.into_raw()
}
#[no_mangle]
pub unsafe extern "C" fn ailake_search_json(request_json: *const c_char) -> *mut c_char {
#[derive(serde::Deserialize)]
struct Req {
warehouse: String,
#[serde(default = "default_ns")]
namespace: String,
table: String,
#[serde(default = "default_col")]
vec_col: String,
dim: u32,
query: Vec<f32>,
#[serde(default = "default_topk")]
top_k: u32,
#[serde(default = "default_ef")]
ef_search: u32,
}
fn default_ns() -> String {
"default".into()
}
fn default_col() -> String {
"embedding".into()
}
fn default_topk() -> u32 {
10
}
fn default_ef() -> u32 {
50
}
if request_json.is_null() {
return cstr_err_json("null request_json");
}
let json_str = match CStr::from_ptr(request_json).to_str() {
Ok(s) => s,
Err(e) => {
warn!("ailake_search_json: invalid UTF-8 in request_json: {}", e);
return cstr_err_json(e);
}
};
let req: Req = match serde_json::from_str(json_str) {
Ok(r) => r,
Err(e) => {
warn!("ailake_search_json: JSON parse error: {}", e);
return cstr_err_json(e);
}
};
debug!(
"ailake_search_json: warehouse={} table={}.{} dim={} top_k={}",
req.warehouse, req.namespace, req.table, req.dim, req.top_k
);
let results = match do_search(
req.warehouse,
&req.namespace,
&req.table,
&req.vec_col,
req.dim,
req.query,
req.top_k,
req.ef_search,
) {
Ok(v) => v,
Err(e) => {
warn!("ailake_search_json: search failed: {}", e);
return cstr_err_json(e);
}
};
#[derive(serde::Serialize)]
struct Resp {
ok: bool,
results: Vec<RowResultJson>,
}
let body = Resp {
ok: true,
results: results.into_iter().map(RowResultJson::from).collect(),
};
let json = serde_json::to_string(&body)
.unwrap_or_else(|_| "{\"ok\":false,\"error\":\"serialize\"}".into());
CString::new(json).unwrap_or_default().into_raw()
}
#[no_mangle]
pub unsafe extern "C" fn ailake_write_batch_json(request_json: *const c_char) -> *mut c_char {
use ailake_core::{VectorPrecision, VectorStoragePolicy};
use ailake_query::TableWriter;
use arrow_array::{Int64Array, RecordBatch};
use arrow_schema::{DataType, Field, Schema};
#[derive(serde::Deserialize)]
struct Req {
warehouse: String,
#[serde(default = "default_ns")]
namespace: String,
table: String,
#[serde(default = "default_col")]
vec_col: String,
dim: u32,
#[serde(default)]
metric: Option<String>,
#[serde(default)]
precision: Option<String>,
ids: Vec<i64>,
embeddings: Vec<Vec<f32>>,
}
fn default_ns() -> String {
"default".into()
}
fn default_col() -> String {
"embedding".into()
}
if request_json.is_null() {
return cstr_err_json("null request_json");
}
let json_str = match CStr::from_ptr(request_json).to_str() {
Ok(s) => s,
Err(e) => {
warn!(
"ailake_write_batch_json: invalid UTF-8 in request_json: {}",
e
);
return cstr_err_json(e);
}
};
let req: Req = match serde_json::from_str(json_str) {
Ok(r) => r,
Err(e) => {
warn!("ailake_write_batch_json: JSON parse error: {}", e);
return cstr_err_json(e);
}
};
if req.ids.len() != req.embeddings.len() {
warn!(
"ailake_write_batch_json: ids.len()={} != embeddings.len()={}",
req.ids.len(),
req.embeddings.len()
);
return cstr_err_json("ids.len() != embeddings.len()");
}
debug!(
"ailake_write_batch_json: warehouse={} table={}.{} rows={}",
req.warehouse,
req.namespace,
req.table,
req.ids.len()
);
let metric = parse_metric(req.metric.as_deref().unwrap_or("euclidean"));
let precision = match req.precision.as_deref().unwrap_or("f16") {
"f32" => VectorPrecision::F32,
"i8" => VectorPrecision::I8,
_ => VectorPrecision::F16,
};
let policy = VectorStoragePolicy {
column_name: req.vec_col.clone(),
dim: req.dim,
metric,
precision,
pq: None,
keep_raw_for_reranking: false,
pre_normalize: false,
hnsw_m: None,
hnsw_ef_construction: None,
};
let table = ailake_catalog::TableIdent::new(&req.namespace, &req.table);
let store: std::sync::Arc<dyn ailake_store::Store> =
std::sync::Arc::new(LocalStore::new(&req.warehouse));
let catalog = std::sync::Arc::new(HadoopCatalog::new(store.clone(), &req.warehouse));
let schema = std::sync::Arc::new(Schema::new(vec![Field::new("id", DataType::Int64, false)]));
let batch =
match RecordBatch::try_new(schema, vec![std::sync::Arc::new(Int64Array::from(req.ids))]) {
Ok(b) => b,
Err(e) => return cstr_err_json(e),
};
let result = rt().block_on(async {
let mut writer = TableWriter::create_or_open(catalog, store, policy, table).await?;
writer.write_batch(&batch, &req.embeddings).await?;
writer.commit().await
});
#[derive(serde::Serialize)]
struct Resp {
ok: bool,
snapshot_id: i64,
}
match result {
Ok(snap) => {
info!(
"ailake_write_batch_json: committed snapshot_id={} table={}.{}",
snap, req.namespace, req.table
);
let json = serde_json::to_string(&Resp {
ok: true,
snapshot_id: snap,
})
.unwrap_or_default();
CString::new(json).unwrap_or_default().into_raw()
}
Err(e) => {
warn!("ailake_write_batch_json: write failed: {}", e);
cstr_err_json(e)
}
}
}
#[no_mangle]
pub unsafe extern "C" fn ailake_free_string(ptr: *mut c_char) {
if !ptr.is_null() {
drop(CString::from_raw(ptr));
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn query_bytes_decode() {
let v = vec![1.0f32, 2.0, 3.0];
let bytes: Vec<u8> = v.iter().flat_map(|f| f.to_le_bytes()).collect();
let decoded: Vec<f32> = bytes
.chunks_exact(4)
.map(|b| f32::from_le_bytes(b.try_into().unwrap()))
.collect();
assert_eq!(decoded, v);
}
#[test]
fn assemble_context_empty() {
let result = assemble_context(vec![], 1024);
assert!(result.contains("<context") || result.is_empty());
}
#[test]
fn assemble_context_one_chunk() {
let chunk = serde_json::json!({
"document_id": "doc-1",
"chunk_index": 0,
"chunk_text": "Hello world",
"document_title": "Test",
})
.to_string();
let result = assemble_context(vec![chunk], 4096);
assert!(result.contains("Hello world"));
}
#[test]
fn cabi_null_guard() {
let ptr = unsafe { ailake_vector_search_json(std::ptr::null(), std::ptr::null(), 0, 10) };
assert!(!ptr.is_null());
let json = unsafe { CStr::from_ptr(ptr).to_str().unwrap().to_string() };
assert_eq!(json, "[]");
unsafe { ailake_free_string(ptr) };
}
}