use std::collections::HashMap;
use uni_common::{Result, Value};
use uni_query::QueryMetrics;
use crate::commit_result::CommitResult;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum QueryType {
Cypher,
Locy,
Execute,
}
#[derive(Debug, Clone)]
pub struct HookContext {
pub session_id: String,
pub query_text: String,
pub query_type: QueryType,
pub params: HashMap<String, Value>,
}
#[derive(Debug, Clone)]
pub struct CommitHookContext {
pub session_id: String,
pub tx_id: String,
pub mutation_count: usize,
}
pub trait SessionHook: Send + Sync {
fn before_query(&self, _ctx: &HookContext) -> Result<()> {
Ok(())
}
fn after_query(&self, _ctx: &HookContext, _metrics: &QueryMetrics) {}
fn before_commit(&self, _ctx: &CommitHookContext) -> Result<()> {
Ok(())
}
fn after_commit(&self, _ctx: &CommitHookContext, _result: &CommitResult) {}
}
use std::sync::Arc;
use datafusion::scalar::ScalarValue;
use uni_plugin::errors::HookOutcome;
use uni_plugin::traits::hook::{
AbortContext, AnalyzeContext, CommitContext as PluginCommitContext, ExecuteContext,
ParseContext, PlanContext, QueryMetrics as PluginQueryMetrics, QueryType as PluginQueryType,
SessionHook as PluginSessionHook,
};
pub struct LegacyHookAdapter {
name: String,
inner: Arc<dyn SessionHook>,
}
impl LegacyHookAdapter {
#[must_use]
pub fn new(name: impl Into<String>, inner: Arc<dyn SessionHook>) -> Self {
Self {
name: name.into(),
inner,
}
}
#[must_use]
pub fn name(&self) -> &str {
&self.name
}
}
impl std::fmt::Debug for LegacyHookAdapter {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("LegacyHookAdapter")
.field("name", &self.name)
.finish_non_exhaustive()
}
}
impl PluginSessionHook for LegacyHookAdapter {
fn on_parse(&self, ctx: &ParseContext<'_>) -> HookOutcome {
let legacy_ctx = HookContext {
session_id: ctx.session_id.to_owned(),
query_text: ctx.source.to_owned(),
query_type: plugin_query_type_to_legacy(ctx.query_type),
params: params_to_legacy(ctx.params),
};
match self.inner.before_query(&legacy_ctx) {
Ok(()) => HookOutcome::Continue,
Err(e) => HookOutcome::Reject {
reason: e.to_string(),
},
}
}
fn on_execute_end(&self, ctx: &ExecuteContext<'_>, metrics: &PluginQueryMetrics) {
let legacy_ctx = HookContext {
session_id: ctx.session_id.to_owned(),
query_text: String::new(),
query_type: QueryType::Cypher,
params: HashMap::new(),
};
let legacy_metrics = uni_query::QueryMetrics {
total_time: metrics.elapsed,
rows_returned: metrics.rows_out as usize,
bytes_read: metrics.bytes_read as usize,
..Default::default()
};
self.inner.after_query(&legacy_ctx, &legacy_metrics);
}
fn before_commit(&self, ctx: &PluginCommitContext<'_>) -> HookOutcome {
let legacy_ctx = CommitHookContext {
session_id: ctx.session_id.to_owned(),
tx_id: String::new(),
mutation_count: 0,
};
match self.inner.before_commit(&legacy_ctx) {
Ok(()) => HookOutcome::Continue,
Err(e) => HookOutcome::Reject {
reason: e.to_string(),
},
}
}
fn after_commit(&self, ctx: &PluginCommitContext<'_>) {
let legacy_ctx = CommitHookContext {
session_id: ctx.session_id.to_owned(),
tx_id: String::new(),
mutation_count: ctx.commit_result.map(|r| r.mutations as usize).unwrap_or(0),
};
let result = ctx
.commit_result
.map(|r| CommitResult {
mutations_committed: r.mutations as usize,
version: r.version,
wal_lsn: r.wal_lsn,
duration: r.duration,
..CommitResult::default()
})
.unwrap_or_default();
self.inner.after_commit(&legacy_ctx, &result);
}
fn on_analyze(&self, _ctx: &AnalyzeContext<'_>) -> HookOutcome {
HookOutcome::Continue
}
fn on_plan(&self, _ctx: &PlanContext<'_>) -> HookOutcome {
HookOutcome::Continue
}
fn on_execute_start(&self, _ctx: &ExecuteContext<'_>) -> HookOutcome {
HookOutcome::Continue
}
fn on_abort(&self, _ctx: &AbortContext<'_>) {}
}
fn plugin_query_type_to_legacy(t: PluginQueryType) -> QueryType {
match t {
PluginQueryType::Cypher => QueryType::Cypher,
PluginQueryType::Locy => QueryType::Locy,
PluginQueryType::Execute => QueryType::Execute,
}
}
fn params_to_legacy<S: AsRef<str>>(
params: &[(S, ScalarValue)],
) -> HashMap<String, uni_common::Value> {
params
.iter()
.map(|(k, v)| (k.as_ref().to_owned(), scalar_to_value(v)))
.collect()
}
use std::sync::atomic::{AtomicU64, Ordering};
use uni_plugin::{
AbiRange, Capability, CapabilitySet, Determinism, Plugin, PluginError, PluginId,
PluginManifest, PluginRegistrar, ProvidedSurfaces, Scope, SideEffects as PluginSideEffects,
};
pub struct BuiltinHookPlugin {
manifest: PluginManifest,
adapter: Arc<LegacyHookAdapter>,
}
static BUILTIN_HOOK_PLUGIN_SEQ: AtomicU64 = AtomicU64::new(0);
impl BuiltinHookPlugin {
#[must_use]
pub fn new(name: impl Into<String>, hook: Arc<dyn SessionHook>) -> Self {
let name = name.into();
let seq = BUILTIN_HOOK_PLUGIN_SEQ.fetch_add(1, Ordering::Relaxed);
let id = PluginId::new(format!("builtin.hook.{seq}"));
let manifest = PluginManifest {
id,
version: "1.0.0".parse().expect("static version parses"),
abi: AbiRange::parse("^1").expect("static ABI range parses"),
depends_on: vec![],
capabilities: CapabilitySet::from_iter_of([Capability::Hook]),
determinism: Determinism::Nondeterministic,
side_effects: PluginSideEffects::ReadOnly,
scope: Scope::Instance,
hash: None,
signature: None,
provides: ProvidedSurfaces::default(),
docs: "BuiltinHookPlugin — legacy SessionHook adapter".to_owned(),
metadata: std::collections::BTreeMap::new(),
};
Self {
manifest,
adapter: Arc::new(LegacyHookAdapter::new(name, hook)),
}
}
}
impl Plugin for BuiltinHookPlugin {
fn manifest(&self) -> &PluginManifest {
&self.manifest
}
fn register(&self, r: &mut PluginRegistrar<'_>) -> std::result::Result<(), PluginError> {
r.hook(Arc::clone(&self.adapter) as Arc<dyn PluginSessionHook>)?;
Ok(())
}
}
fn scalar_to_value(v: &ScalarValue) -> uni_common::Value {
use uni_common::Value;
match v {
ScalarValue::Null => Value::Null,
ScalarValue::Boolean(Some(b)) => Value::Bool(*b),
ScalarValue::Int8(Some(i)) => Value::Int(i64::from(*i)),
ScalarValue::Int16(Some(i)) => Value::Int(i64::from(*i)),
ScalarValue::Int32(Some(i)) => Value::Int(i64::from(*i)),
ScalarValue::Int64(Some(i)) => Value::Int(*i),
ScalarValue::UInt8(Some(i)) => Value::Int(i64::from(*i)),
ScalarValue::UInt16(Some(i)) => Value::Int(i64::from(*i)),
ScalarValue::UInt32(Some(i)) => Value::Int(i64::from(*i)),
ScalarValue::UInt64(Some(i)) => {
Value::Int(i64::try_from(*i).unwrap_or(i64::MAX))
}
ScalarValue::Float32(Some(f)) => Value::Float(f64::from(*f)),
ScalarValue::Float64(Some(f)) => Value::Float(*f),
ScalarValue::Utf8(Some(s))
| ScalarValue::LargeUtf8(Some(s))
| ScalarValue::Utf8View(Some(s)) => Value::String(s.clone()),
ScalarValue::Binary(Some(b))
| ScalarValue::LargeBinary(Some(b))
| ScalarValue::BinaryView(Some(b)) => Value::Bytes(b.clone()),
other => {
tracing::warn!(
"LegacyHookAdapter::params_to_legacy: unsupported ScalarValue \
variant {other:?}; surfacing as Value::Null. Hooks needing \
typed access should register against the phased trait."
);
Value::Null
}
}
}