use crate::graph::{GraphState, KgliteGraph};
use crate::result::{KgliteCypherResult, ResultState};
use crate::status::KgliteStatusCode;
use crate::strings::alloc_c_string;
use kglite::api::param::json_value_to_kglite_value;
use kglite::api::session::{execute_mut, execute_read, ExecuteOptions, Session};
use kglite::api::{Embedder, Value};
use std::collections::HashMap;
use std::ffi::{c_char, CStr};
use std::sync::Arc;
#[repr(C)]
pub struct KgliteSession {
_opaque: [u8; 0],
_marker: core::marker::PhantomData<(*mut u8, core::marker::PhantomPinned)>,
}
pub(crate) struct SessionState {
pub(crate) inner: Session,
pub(crate) embedder: Option<Arc<dyn Embedder>>,
}
impl SessionState {
fn into_handle(session: Session) -> *mut KgliteSession {
let boxed = Box::new(SessionState {
inner: session,
embedder: None,
});
Box::into_raw(boxed).cast::<KgliteSession>()
}
pub(crate) unsafe fn from_handle<'a>(handle: *const KgliteSession) -> &'a SessionState {
unsafe { &*handle.cast::<SessionState>() }
}
pub(crate) unsafe fn from_handle_mut<'a>(handle: *mut KgliteSession) -> &'a mut SessionState {
unsafe { &mut *handle.cast::<SessionState>() }
}
unsafe fn free_handle(handle: *mut KgliteSession) {
if handle.is_null() {
return;
}
let _ = unsafe { Box::from_raw(handle.cast::<SessionState>()) };
}
}
#[no_mangle]
pub unsafe extern "C" fn kglite_session_new(
graph: *mut KgliteGraph,
out_session: *mut *mut KgliteSession,
) -> KgliteStatusCode {
if graph.is_null() || out_session.is_null() {
return KgliteStatusCode::NullPointer;
}
let graph_state = unsafe { Box::from_raw(graph.cast::<GraphState>()) };
let session = Session::from_arc(graph_state.inner);
unsafe {
*out_session = SessionState::into_handle(session);
}
KgliteStatusCode::Ok
}
#[no_mangle]
pub unsafe extern "C" fn kglite_session_execute_read(
session: *const KgliteSession,
query: *const c_char,
params_json: *const c_char,
out_result: *mut *mut KgliteCypherResult,
out_error_msg: *mut *const c_char,
) -> KgliteStatusCode {
if session.is_null() || query.is_null() || out_result.is_null() {
return KgliteStatusCode::NullPointer;
}
let query_str = match unsafe { CStr::from_ptr(query) }.to_str() {
Ok(s) => s,
Err(_) => return KgliteStatusCode::InvalidUtf8,
};
let params = match parse_params_json(params_json) {
Ok(p) => p,
Err(rc) => return rc,
};
let session_state = unsafe { SessionState::from_handle(session) };
let snapshot = session_state.inner.snapshot();
let mut opts = ExecuteOptions::eager(¶ms);
opts.embedder = session_state.embedder.clone();
match execute_read(&snapshot, query_str, &opts) {
Ok(outcome) => {
unsafe {
*out_result = ResultState::into_handle(outcome.result);
}
if !out_error_msg.is_null() {
unsafe {
*out_error_msg = std::ptr::null();
}
}
KgliteStatusCode::Ok
}
Err(err) => {
unsafe {
*out_result = std::ptr::null_mut();
}
let code = KgliteStatusCode::from_kg_error_code(err.code());
if !out_error_msg.is_null() {
unsafe {
*out_error_msg = alloc_c_string(&err.to_string());
}
}
code
}
}
}
#[no_mangle]
pub unsafe extern "C" fn kglite_session_execute_mut(
session: *mut KgliteSession,
query: *const c_char,
params_json: *const c_char,
out_result: *mut *mut KgliteCypherResult,
out_error_msg: *mut *const c_char,
) -> KgliteStatusCode {
if session.is_null() || query.is_null() || out_result.is_null() {
return KgliteStatusCode::NullPointer;
}
let query_str = match unsafe { CStr::from_ptr(query) }.to_str() {
Ok(s) => s,
Err(_) => return KgliteStatusCode::InvalidUtf8,
};
let params = match parse_params_json(params_json) {
Ok(p) => p,
Err(rc) => return rc,
};
let session_state = unsafe { SessionState::from_handle(session) };
let mut opts = ExecuteOptions::eager(¶ms);
opts.embedder = session_state.embedder.clone();
let mut tx = session_state.inner.begin();
let exec_result = {
let working = match tx.working_mut() {
Ok(w) => w,
Err(err) => {
let code = KgliteStatusCode::from_kg_error_code(err.code());
if !out_error_msg.is_null() {
unsafe {
*out_error_msg = alloc_c_string(&err.to_string());
}
}
unsafe {
*out_result = std::ptr::null_mut();
}
return code;
}
};
execute_mut(working, query_str, &opts)
};
match exec_result {
Ok(outcome) => {
let _ = session_state.inner.commit(tx, false);
unsafe {
*out_result = ResultState::into_handle(outcome.result);
}
if !out_error_msg.is_null() {
unsafe {
*out_error_msg = std::ptr::null();
}
}
KgliteStatusCode::Ok
}
Err(err) => {
unsafe {
*out_result = std::ptr::null_mut();
}
let code = KgliteStatusCode::from_kg_error_code(err.code());
if !out_error_msg.is_null() {
unsafe {
*out_error_msg = alloc_c_string(&err.to_string());
}
}
code
}
}
}
#[no_mangle]
pub unsafe extern "C" fn kglite_session_free(session: *mut KgliteSession) {
unsafe { SessionState::free_handle(session) };
}
fn parse_params_json(
params_json: *const c_char,
) -> Result<HashMap<String, Value>, KgliteStatusCode> {
if params_json.is_null() {
return Ok(HashMap::new());
}
let s = match unsafe { CStr::from_ptr(params_json) }.to_str() {
Ok(s) => s,
Err(_) => return Err(KgliteStatusCode::InvalidUtf8),
};
if s.is_empty() {
return Ok(HashMap::new());
}
let parsed: serde_json::Value = match serde_json::from_str(s) {
Ok(v) => v,
Err(_) => return Err(KgliteStatusCode::InvalidArgument),
};
match parsed {
serde_json::Value::Object(obj) => Ok(obj
.into_iter()
.map(|(k, v)| (k, json_value_to_kglite_value(&v)))
.collect()),
serde_json::Value::Null => Ok(HashMap::new()),
_ => Err(KgliteStatusCode::InvalidArgument),
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::ffi::CString;
#[test]
fn parse_params_empty_string_is_empty_map() {
let s = CString::new("").unwrap();
let m = parse_params_json(s.as_ptr()).unwrap();
assert!(m.is_empty());
}
#[test]
fn parse_params_object_round_trips() {
let s = CString::new(r#"{"x": 42, "y": "hello"}"#).unwrap();
let m = parse_params_json(s.as_ptr()).unwrap();
assert_eq!(m.get("x"), Some(&Value::Int64(42)));
assert_eq!(m.get("y"), Some(&Value::String("hello".to_string())));
}
#[test]
fn parse_params_null_pointer_is_empty_map() {
let m = parse_params_json(std::ptr::null()).unwrap();
assert!(m.is_empty());
}
#[test]
fn parse_params_array_is_invalid_argument() {
let s = CString::new("[1, 2, 3]").unwrap();
let err = parse_params_json(s.as_ptr()).unwrap_err();
assert_eq!(err, KgliteStatusCode::InvalidArgument);
}
}