use base64::Engine as _;
use base64::engine::general_purpose::STANDARD as BASE64_STANDARD;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::sync::{Arc, Mutex, OnceLock};
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Default)]
#[serde(rename_all = "snake_case")]
pub enum LuaRuntimeDatabaseProviderMode {
#[default]
DynamicLibrary,
HostCallback,
SpaceController,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Default)]
#[serde(rename_all = "snake_case")]
pub enum LuaRuntimeDatabaseCallbackMode {
#[default]
Standard,
Json,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum RuntimeDatabaseKind {
Sqlite,
LanceDb,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct RuntimeDatabaseBindingContext {
pub space_label: String,
pub skill_id: String,
pub binding_tag: String,
pub root_name: String,
pub space_root: String,
pub skill_dir: String,
pub skill_dir_name: String,
pub database_kind: RuntimeDatabaseKind,
pub default_database_path: String,
}
impl RuntimeDatabaseBindingContext {
pub fn new(
space_label: impl Into<String>,
skill_id: impl Into<String>,
root_name: impl Into<String>,
space_root: impl Into<String>,
skill_dir: impl Into<String>,
skill_dir_name: impl Into<String>,
database_kind: RuntimeDatabaseKind,
default_database_path: impl Into<String>,
) -> Self {
let space_label = space_label.into();
let skill_id = skill_id.into();
Self {
binding_tag: format!("{}-{}", space_label, skill_id),
space_label,
skill_id,
root_name: root_name.into(),
space_root: space_root.into(),
skill_dir: skill_dir.into(),
skill_dir_name: skill_dir_name.into(),
database_kind,
default_database_path: default_database_path.into(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum RuntimeSqliteProviderAction {
ExecuteScript,
ExecuteBatch,
QueryJson,
QueryStream,
QueryStreamWaitMetrics,
QueryStreamChunk,
QueryStreamClose,
TokenizeText,
UpsertCustomWord,
RemoveCustomWord,
ListCustomWords,
EnsureFtsIndex,
RebuildFtsIndex,
UpsertFtsDocument,
DeleteFtsDocument,
SearchFts,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum RuntimeLanceDbProviderAction {
CreateTable,
VectorUpsert,
VectorSearch,
Delete,
DropTable,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct RuntimeSqliteProviderRequest {
pub action: RuntimeSqliteProviderAction,
pub binding: RuntimeDatabaseBindingContext,
pub input: Value,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct RuntimeLanceDbProviderRequest {
pub action: RuntimeLanceDbProviderAction,
pub binding: RuntimeDatabaseBindingContext,
pub input: Value,
}
pub type RuntimeSqliteProviderCallback =
Arc<dyn Fn(&RuntimeSqliteProviderRequest) -> Result<Value, String> + Send + Sync>;
pub type RuntimeLanceDbProviderCallback = Arc<
dyn Fn(&RuntimeLanceDbProviderRequest) -> Result<RuntimeLanceDbProviderResult, String>
+ Send
+ Sync,
>;
pub type RuntimeSqliteProviderJsonCallback =
Arc<dyn Fn(&str) -> Result<String, String> + Send + Sync>;
pub type RuntimeLanceDbProviderJsonCallback =
Arc<dyn Fn(&str) -> Result<String, String> + Send + Sync>;
#[derive(Clone, Default)]
pub(crate) struct RuntimeDatabaseProviderCallbacks {
sqlite_standard: Option<RuntimeSqliteProviderCallback>,
lancedb_standard: Option<RuntimeLanceDbProviderCallback>,
sqlite_json: Option<RuntimeSqliteProviderJsonCallback>,
lancedb_json: Option<RuntimeLanceDbProviderJsonCallback>,
}
impl RuntimeDatabaseProviderCallbacks {
pub(crate) fn capture_process_defaults() -> Result<Self, String> {
Ok(Self {
sqlite_standard: take_optional_callback(sqlite_provider_callback_registry())?,
lancedb_standard: take_optional_callback(lancedb_provider_callback_registry())?,
sqlite_json: take_optional_callback(sqlite_provider_json_callback_registry())?,
lancedb_json: take_optional_callback(lancedb_provider_json_callback_registry())?,
})
}
pub(crate) fn has_sqlite_provider_callback_for_mode(
&self,
callback_mode: LuaRuntimeDatabaseCallbackMode,
) -> bool {
match callback_mode {
LuaRuntimeDatabaseCallbackMode::Standard => self.sqlite_standard.is_some(),
LuaRuntimeDatabaseCallbackMode::Json => self.sqlite_json.is_some(),
}
}
pub(crate) fn has_lancedb_provider_callback_for_mode(
&self,
callback_mode: LuaRuntimeDatabaseCallbackMode,
) -> bool {
match callback_mode {
LuaRuntimeDatabaseCallbackMode::Standard => self.lancedb_standard.is_some(),
LuaRuntimeDatabaseCallbackMode::Json => self.lancedb_json.is_some(),
}
}
pub(crate) fn dispatch_sqlite_provider_request(
&self,
request: &RuntimeSqliteProviderRequest,
callback_mode: LuaRuntimeDatabaseCallbackMode,
) -> Result<Value, String> {
match callback_mode {
LuaRuntimeDatabaseCallbackMode::Standard => {
let callback = self.sqlite_standard.clone().ok_or_else(|| {
"SQLite host-callback mode requires one registered standard callback"
.to_string()
})?;
callback(request)
}
LuaRuntimeDatabaseCallbackMode::Json => {
let callback = self.sqlite_json.clone().ok_or_else(|| {
"SQLite host-callback JSON mode requires one registered JSON callback"
.to_string()
})?;
let request_json = serde_json::to_string(request).map_err(|error| {
format!("failed to encode sqlite provider request: {}", error)
})?;
let response_json = callback(&request_json)?;
serde_json::from_str::<Value>(&response_json).map_err(|error| {
format!("failed to parse sqlite provider response json: {}", error)
})
}
}
}
pub(crate) fn dispatch_lancedb_provider_request(
&self,
request: &RuntimeLanceDbProviderRequest,
callback_mode: LuaRuntimeDatabaseCallbackMode,
) -> Result<RuntimeLanceDbProviderResult, String> {
match callback_mode {
LuaRuntimeDatabaseCallbackMode::Standard => {
let callback = self.lancedb_standard.clone().ok_or_else(|| {
"LanceDB host-callback mode requires one registered standard callback"
.to_string()
})?;
callback(request)
}
LuaRuntimeDatabaseCallbackMode::Json => {
let callback = self.lancedb_json.clone().ok_or_else(|| {
"LanceDB host-callback JSON mode requires one registered JSON callback"
.to_string()
})?;
let request_json = serde_json::to_string(request).map_err(|error| {
format!("failed to encode lancedb provider request: {}", error)
})?;
let response_json = callback(&request_json)?;
let value: Value = serde_json::from_str(&response_json).map_err(|error| {
format!("failed to parse lancedb provider response json: {}", error)
})?;
let meta = value
.get("meta")
.cloned()
.unwrap_or_else(|| Value::Object(Default::default()));
let bytes = value
.get("data_base64")
.and_then(Value::as_str)
.map(|text| {
BASE64_STANDARD.decode(text.as_bytes()).map_err(|error| {
format!("failed to decode lancedb provider data_base64: {}", error)
})
})
.transpose()?
.unwrap_or_default();
Ok(RuntimeLanceDbProviderResult::binary(meta, bytes))
}
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct RuntimeLanceDbProviderResult {
pub meta: Value,
pub bytes: Vec<u8>,
}
impl RuntimeLanceDbProviderResult {
pub fn json(meta: Value) -> Self {
Self {
meta,
bytes: Vec::new(),
}
}
pub fn binary(meta: Value, bytes: Vec<u8>) -> Self {
Self { meta, bytes }
}
}
pub fn set_sqlite_provider_callback(callback: Option<RuntimeSqliteProviderCallback>) {
let registry = sqlite_provider_callback_registry();
let mut guard = registry.lock().unwrap();
*guard = callback;
}
pub fn set_lancedb_provider_callback(callback: Option<RuntimeLanceDbProviderCallback>) {
let registry = lancedb_provider_callback_registry();
let mut guard = registry.lock().unwrap();
*guard = callback;
}
pub fn set_sqlite_provider_json_callback(callback: Option<RuntimeSqliteProviderJsonCallback>) {
let registry = sqlite_provider_json_callback_registry();
let mut guard = registry.lock().unwrap();
*guard = callback;
}
pub fn set_lancedb_provider_json_callback(callback: Option<RuntimeLanceDbProviderJsonCallback>) {
let registry = lancedb_provider_json_callback_registry();
let mut guard = registry.lock().unwrap();
*guard = callback;
}
fn take_optional_callback<T: Clone>(
registry: &'static Mutex<Option<T>>,
) -> Result<Option<T>, String> {
let guard = registry
.lock()
.map_err(|_| "Database provider callback registry lock poisoned".to_string())?;
Ok(guard.clone())
}
fn sqlite_provider_callback_registry() -> &'static Mutex<Option<RuntimeSqliteProviderCallback>> {
static REGISTRY: OnceLock<Mutex<Option<RuntimeSqliteProviderCallback>>> = OnceLock::new();
REGISTRY.get_or_init(|| Mutex::new(None))
}
fn lancedb_provider_callback_registry() -> &'static Mutex<Option<RuntimeLanceDbProviderCallback>> {
static REGISTRY: OnceLock<Mutex<Option<RuntimeLanceDbProviderCallback>>> = OnceLock::new();
REGISTRY.get_or_init(|| Mutex::new(None))
}
fn sqlite_provider_json_callback_registry()
-> &'static Mutex<Option<RuntimeSqliteProviderJsonCallback>> {
static REGISTRY: OnceLock<Mutex<Option<RuntimeSqliteProviderJsonCallback>>> = OnceLock::new();
REGISTRY.get_or_init(|| Mutex::new(None))
}
fn lancedb_provider_json_callback_registry()
-> &'static Mutex<Option<RuntimeLanceDbProviderJsonCallback>> {
static REGISTRY: OnceLock<Mutex<Option<RuntimeLanceDbProviderJsonCallback>>> = OnceLock::new();
REGISTRY.get_or_init(|| Mutex::new(None))
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
use std::sync::{Mutex, OnceLock};
fn database_callback_test_lock() -> &'static Mutex<()> {
static LOCK: OnceLock<Mutex<()>> = OnceLock::new();
LOCK.get_or_init(|| Mutex::new(()))
}
struct ProcessCallbackRestoreGuard {
snapshot: RuntimeDatabaseProviderCallbacks,
}
impl ProcessCallbackRestoreGuard {
fn capture() -> Self {
Self {
snapshot: RuntimeDatabaseProviderCallbacks::capture_process_defaults()
.expect("capture callback snapshot"),
}
}
}
impl Drop for ProcessCallbackRestoreGuard {
fn drop(&mut self) {
set_sqlite_provider_callback(self.snapshot.sqlite_standard.clone());
set_lancedb_provider_callback(self.snapshot.lancedb_standard.clone());
set_sqlite_provider_json_callback(self.snapshot.sqlite_json.clone());
set_lancedb_provider_json_callback(self.snapshot.lancedb_json.clone());
}
}
fn sample_binding_context(database_kind: RuntimeDatabaseKind) -> RuntimeDatabaseBindingContext {
RuntimeDatabaseBindingContext::new(
"ROOT",
"test-skill",
"ROOT",
"D:/runtime-test-root/__database",
"D:/runtime-test-root/skills/test-skill",
"test-skill",
database_kind,
"D:/runtime-test-root/__database/default.db",
)
}
#[test]
fn captured_callback_snapshots_stay_engine_scoped() {
let _serial_guard = database_callback_test_lock()
.lock()
.expect("lock callback test guard");
let _restore_guard = ProcessCallbackRestoreGuard::capture();
set_sqlite_provider_callback(Some(Arc::new(|_| {
Ok(json!({ "source": "sqlite-standard-a" }))
})));
set_sqlite_provider_json_callback(Some(Arc::new(|_| {
Ok("{\"source\":\"sqlite-json-a\"}".to_string())
})));
set_lancedb_provider_callback(Some(Arc::new(|_| {
Ok(RuntimeLanceDbProviderResult::json(
json!({ "source": "lancedb-standard-a" }),
))
})));
set_lancedb_provider_json_callback(Some(Arc::new(|_| {
Ok("{\"meta\":{\"source\":\"lancedb-json-a\"}}".to_string())
})));
let snapshot_a = RuntimeDatabaseProviderCallbacks::capture_process_defaults()
.expect("capture callback snapshot A");
set_sqlite_provider_callback(Some(Arc::new(|_| {
Ok(json!({ "source": "sqlite-standard-b" }))
})));
set_sqlite_provider_json_callback(Some(Arc::new(|_| {
Ok("{\"source\":\"sqlite-json-b\"}".to_string())
})));
set_lancedb_provider_callback(Some(Arc::new(|_| {
Ok(RuntimeLanceDbProviderResult::json(
json!({ "source": "lancedb-standard-b" }),
))
})));
set_lancedb_provider_json_callback(Some(Arc::new(|_| {
Ok("{\"meta\":{\"source\":\"lancedb-json-b\"}}".to_string())
})));
let snapshot_b = RuntimeDatabaseProviderCallbacks::capture_process_defaults()
.expect("capture callback snapshot B");
let sqlite_request = RuntimeSqliteProviderRequest {
action: RuntimeSqliteProviderAction::QueryJson,
binding: sample_binding_context(RuntimeDatabaseKind::Sqlite),
input: json!({ "sql": "select 1" }),
};
let lancedb_request = RuntimeLanceDbProviderRequest {
action: RuntimeLanceDbProviderAction::VectorSearch,
binding: sample_binding_context(RuntimeDatabaseKind::LanceDb),
input: json!({ "table": "demo" }),
};
assert_eq!(
snapshot_a
.dispatch_sqlite_provider_request(
&sqlite_request,
LuaRuntimeDatabaseCallbackMode::Standard,
)
.expect("dispatch sqlite standard A"),
json!({ "source": "sqlite-standard-a" })
);
assert_eq!(
snapshot_a
.dispatch_sqlite_provider_request(
&sqlite_request,
LuaRuntimeDatabaseCallbackMode::Json,
)
.expect("dispatch sqlite json A"),
json!({ "source": "sqlite-json-a" })
);
assert_eq!(
snapshot_b
.dispatch_sqlite_provider_request(
&sqlite_request,
LuaRuntimeDatabaseCallbackMode::Standard,
)
.expect("dispatch sqlite standard B"),
json!({ "source": "sqlite-standard-b" })
);
assert_eq!(
snapshot_b
.dispatch_sqlite_provider_request(
&sqlite_request,
LuaRuntimeDatabaseCallbackMode::Json,
)
.expect("dispatch sqlite json B"),
json!({ "source": "sqlite-json-b" })
);
assert_eq!(
snapshot_a
.dispatch_lancedb_provider_request(
&lancedb_request,
LuaRuntimeDatabaseCallbackMode::Standard,
)
.expect("dispatch lancedb standard A")
.meta,
json!({ "source": "lancedb-standard-a" })
);
assert_eq!(
snapshot_a
.dispatch_lancedb_provider_request(
&lancedb_request,
LuaRuntimeDatabaseCallbackMode::Json,
)
.expect("dispatch lancedb json A")
.meta,
json!({ "source": "lancedb-json-a" })
);
assert_eq!(
snapshot_b
.dispatch_lancedb_provider_request(
&lancedb_request,
LuaRuntimeDatabaseCallbackMode::Standard,
)
.expect("dispatch lancedb standard B")
.meta,
json!({ "source": "lancedb-standard-b" })
);
assert_eq!(
snapshot_b
.dispatch_lancedb_provider_request(
&lancedb_request,
LuaRuntimeDatabaseCallbackMode::Json,
)
.expect("dispatch lancedb json B")
.meta,
json!({ "source": "lancedb-json-b" })
);
}
}