use std::time::Duration;
use datafusion::scalar::ScalarValue;
use smol_str::SmolStr;
use crate::errors::HookOutcome;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum QueryType {
#[default]
Cypher,
Locy,
Execute,
}
#[derive(Debug, Clone, Default)]
pub struct PluginCommitResult {
pub mutations: u64,
pub version: u64,
pub wal_lsn: u64,
pub duration: Duration,
}
pub trait SessionHook: Send + Sync {
fn on_parse(&self, _ctx: &ParseContext<'_>) -> HookOutcome {
HookOutcome::Continue
}
fn on_analyze(&self, _ctx: &AnalyzeContext<'_>) -> HookOutcome {
HookOutcome::Continue
}
fn on_plan(&self, _ctx: &PlanContext<'_>) -> HookOutcome {
HookOutcome::Continue
}
fn on_execute_start(&self, _ctx: &ExecuteContext<'_>) -> HookOutcome {
HookOutcome::Continue
}
fn on_execute_end(&self, _ctx: &ExecuteContext<'_>, _metrics: &QueryMetrics) {}
fn before_commit(&self, _ctx: &CommitContext<'_>) -> HookOutcome {
HookOutcome::Continue
}
fn after_commit(&self, _ctx: &CommitContext<'_>) {}
fn on_abort(&self, _ctx: &AbortContext<'_>) {}
}
#[derive(Debug)]
#[non_exhaustive]
pub struct ParseContext<'a> {
pub source: &'a str,
pub session_id: &'a str,
pub query_type: QueryType,
pub params: &'a [(SmolStr, ScalarValue)],
}
impl<'a> ParseContext<'a> {
#[must_use]
pub fn new(source: &'a str, session_id: &'a str) -> Self {
Self {
source,
session_id,
query_type: QueryType::default(),
params: &[],
}
}
#[must_use]
pub fn with_query_type(mut self, query_type: QueryType) -> Self {
self.query_type = query_type;
self
}
#[must_use]
pub fn with_params(mut self, params: &'a [(SmolStr, ScalarValue)]) -> Self {
self.params = params;
self
}
}
#[derive(Debug)]
#[non_exhaustive]
pub struct AnalyzeContext<'a> {
pub session_id: &'a str,
pub _marker: std::marker::PhantomData<&'a ()>,
}
impl<'a> AnalyzeContext<'a> {
#[must_use]
pub fn new(session_id: &'a str) -> Self {
Self {
session_id,
_marker: std::marker::PhantomData,
}
}
}
#[derive(Debug)]
#[non_exhaustive]
pub struct PlanContext<'a> {
pub session_id: &'a str,
pub _marker: std::marker::PhantomData<&'a ()>,
}
impl<'a> PlanContext<'a> {
#[must_use]
pub fn new(session_id: &'a str) -> Self {
Self {
session_id,
_marker: std::marker::PhantomData,
}
}
}
#[derive(Debug)]
#[non_exhaustive]
pub struct ExecuteContext<'a> {
pub session_id: &'a str,
pub _marker: std::marker::PhantomData<&'a ()>,
}
impl<'a> ExecuteContext<'a> {
#[must_use]
pub fn new(session_id: &'a str) -> Self {
Self {
session_id,
_marker: std::marker::PhantomData,
}
}
}
#[derive(Debug)]
#[non_exhaustive]
pub struct CommitContext<'a> {
pub session_id: &'a str,
pub commit_result: Option<&'a PluginCommitResult>,
}
impl<'a> CommitContext<'a> {
#[must_use]
pub fn new(session_id: &'a str) -> Self {
Self {
session_id,
commit_result: None,
}
}
#[must_use]
pub fn with_commit_result(mut self, result: &'a PluginCommitResult) -> Self {
self.commit_result = Some(result);
self
}
}
#[derive(Debug)]
#[non_exhaustive]
pub struct AbortContext<'a> {
pub session_id: &'a str,
pub reason: &'a str,
}
impl<'a> AbortContext<'a> {
#[must_use]
pub fn new(session_id: &'a str, reason: &'a str) -> Self {
Self { session_id, reason }
}
}
#[derive(Clone, Debug, Default)]
pub struct QueryMetrics {
pub elapsed: Duration,
pub rows_out: u64,
pub bytes_read: u64,
}