use std::collections::HashMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::time::{Duration, Instant};
use tokio_util::sync::CancellationToken;
use tracing::instrument;
use uuid::Uuid;
use crate::api::UniInner;
use crate::api::hooks::{HookContext, QueryType, SessionHook};
use crate::api::impl_locy::LocyRuleRegistry;
use crate::api::locy_result::LocyResult;
use crate::api::transaction::{IsolationLevel, Transaction};
use uni_common::{Result, UniError, Value};
use uni_query::{ExplainOutput, ProfileOutput, QueryCursor, QueryResult, Row};
pub(crate) struct PlanCacheMetrics {
pub(crate) hits: AtomicU64,
pub(crate) misses: AtomicU64,
}
#[derive(Debug, Clone)]
pub struct SessionCapabilities {
pub can_write: bool,
pub can_pin: bool,
pub isolation: IsolationLevel,
pub has_notifications: bool,
pub write_lease: Option<WriteLeaseSummary>,
}
#[derive(Debug, Clone)]
pub enum WriteLeaseSummary {
Local,
DynamoDB { table: String },
Custom,
}
pub(crate) struct SessionMetricsInner {
pub(crate) queries_executed: AtomicU64,
pub(crate) locy_evaluations: AtomicU64,
pub(crate) total_query_time_us: AtomicU64,
pub(crate) transactions_committed: AtomicU64,
pub(crate) transactions_rolled_back: AtomicU64,
pub(crate) total_rows_returned: AtomicU64,
pub(crate) total_rows_scanned: AtomicU64,
}
impl SessionMetricsInner {
fn new() -> Self {
Self {
queries_executed: AtomicU64::new(0),
locy_evaluations: AtomicU64::new(0),
total_query_time_us: AtomicU64::new(0),
transactions_committed: AtomicU64::new(0),
transactions_rolled_back: AtomicU64::new(0),
total_rows_returned: AtomicU64::new(0),
total_rows_scanned: AtomicU64::new(0),
}
}
}
#[derive(Debug, Clone)]
pub struct SessionMetrics {
pub session_id: String,
pub active_since: Instant,
pub queries_executed: u64,
pub locy_evaluations: u64,
pub total_query_time: Duration,
pub transactions_committed: u64,
pub transactions_rolled_back: u64,
pub total_rows_returned: u64,
pub total_rows_scanned: u64,
pub plan_cache_hits: u64,
pub plan_cache_misses: u64,
pub plan_cache_size: usize,
}
pub struct Session {
pub(crate) db: Arc<UniInner>,
original_db: Option<Arc<UniInner>>,
id: String,
params: Arc<std::sync::RwLock<HashMap<String, Value>>>,
rule_registry: Arc<std::sync::RwLock<LocyRuleRegistry>>,
active_write_guard: Arc<AtomicBool>,
pub(crate) metrics_inner: Arc<SessionMetricsInner>,
created_at: Instant,
cancellation_token: Arc<std::sync::RwLock<CancellationToken>>,
plan_cache: Arc<std::sync::Mutex<PlanCache>>,
plan_cache_metrics: Arc<PlanCacheMetrics>,
pub(crate) hooks: HashMap<String, Arc<dyn SessionHook>>,
pub(crate) query_timeout: Option<Duration>,
pub(crate) transaction_timeout: Option<Duration>,
}
impl Session {
pub(crate) fn new(db: Arc<UniInner>) -> Self {
let global_registry = db.locy_rule_registry.read().unwrap();
let session_registry = global_registry.clone();
drop(global_registry);
db.active_session_count.fetch_add(1, Ordering::Relaxed);
Self {
db,
original_db: None,
id: Uuid::new_v4().to_string(),
params: Arc::new(std::sync::RwLock::new(HashMap::new())),
rule_registry: Arc::new(std::sync::RwLock::new(session_registry)),
active_write_guard: Arc::new(AtomicBool::new(false)),
metrics_inner: Arc::new(SessionMetricsInner::new()),
created_at: Instant::now(),
cancellation_token: Arc::new(std::sync::RwLock::new(CancellationToken::new())),
plan_cache: Arc::new(std::sync::Mutex::new(PlanCache::new(1000))),
plan_cache_metrics: Arc::new(PlanCacheMetrics {
hits: AtomicU64::new(0),
misses: AtomicU64::new(0),
}),
hooks: HashMap::new(),
query_timeout: None,
transaction_timeout: None,
}
}
pub(crate) fn new_from_template(
db: Arc<UniInner>,
params: HashMap<String, Value>,
rule_registry: LocyRuleRegistry,
hooks: HashMap<String, Arc<dyn SessionHook>>,
query_timeout: Option<Duration>,
transaction_timeout: Option<Duration>,
) -> Self {
db.active_session_count.fetch_add(1, Ordering::Relaxed);
Self {
db,
original_db: None,
id: Uuid::new_v4().to_string(),
params: Arc::new(std::sync::RwLock::new(params)),
rule_registry: Arc::new(std::sync::RwLock::new(rule_registry)),
active_write_guard: Arc::new(AtomicBool::new(false)),
metrics_inner: Arc::new(SessionMetricsInner::new()),
created_at: Instant::now(),
cancellation_token: Arc::new(std::sync::RwLock::new(CancellationToken::new())),
plan_cache: Arc::new(std::sync::Mutex::new(PlanCache::new(1000))),
plan_cache_metrics: Arc::new(PlanCacheMetrics {
hits: AtomicU64::new(0),
misses: AtomicU64::new(0),
}),
hooks,
query_timeout,
transaction_timeout,
}
}
pub fn params(&self) -> Params<'_> {
Params {
store: &self.params,
}
}
#[instrument(skip(self), fields(session_id = %self.id))]
pub async fn query(&self, cypher: &str) -> Result<QueryResult> {
let params = self.merge_params(HashMap::new());
self.run_before_query_hooks(cypher, QueryType::Cypher, ¶ms)?;
let start = Instant::now();
let result = self.execute_cached(cypher, params.clone()).await;
self.metrics_inner
.queries_executed
.fetch_add(1, Ordering::Relaxed);
self.db.total_queries.fetch_add(1, Ordering::Relaxed);
self.metrics_inner
.total_query_time_us
.fetch_add(start.elapsed().as_micros() as u64, Ordering::Relaxed);
if let Ok(ref qr) = result {
self.metrics_inner
.total_rows_returned
.fetch_add(qr.len() as u64, Ordering::Relaxed);
self.run_after_query_hooks(cypher, QueryType::Cypher, ¶ms, qr.metrics());
}
result
}
pub fn query_with(&self, cypher: &str) -> QueryBuilder<'_> {
QueryBuilder {
session: self,
cypher: cypher.to_string(),
params: HashMap::new(),
timeout: self.query_timeout,
max_memory: None,
cancellation_token: None,
}
}
#[instrument(skip(self), fields(session_id = %self.id))]
pub async fn locy(&self, program: &str) -> Result<LocyResult> {
self.run_before_query_hooks(program, QueryType::Locy, &HashMap::new())?;
let result = self.locy_with(program).run().await;
self.metrics_inner
.locy_evaluations
.fetch_add(1, Ordering::Relaxed);
result
}
pub fn locy_with(&self, program: &str) -> crate::api::locy_builder::LocyBuilder<'_> {
crate::api::locy_builder::LocyBuilder::new(self, program)
}
pub fn rules(&self) -> super::rule_registry::RuleRegistry<'_> {
super::rule_registry::RuleRegistry::new(&self.rule_registry)
}
#[instrument(skip(self), fields(session_id = %self.id))]
pub fn compile_locy(&self, program: &str) -> Result<uni_locy::CompiledProgram> {
let ast = uni_cypher::parse_locy(program).map_err(|e| UniError::Parse {
message: format!("LocyParseError: {e}"),
position: None,
line: None,
column: None,
context: None,
})?;
let registry = self.rule_registry.read().unwrap();
if registry.rules.is_empty() {
drop(registry);
uni_locy::compile(&ast).map_err(|e| UniError::Query {
message: format!("LocyCompileError: {e}"),
query: None,
})
} else {
let external_names: Vec<String> = registry.rules.keys().cloned().collect();
drop(registry);
uni_locy::compile_with_external_rules(&ast, &external_names).map_err(|e| {
UniError::Query {
message: format!("LocyCompileError: {e}"),
query: None,
}
})
}
}
#[instrument(skip(self), fields(session_id = %self.id))]
pub async fn tx(&self) -> Result<Transaction> {
if self.is_pinned() {
return Err(UniError::ReadOnly {
operation: "start_transaction".to_string(),
});
}
Transaction::new(self).await
}
pub fn tx_with(&self) -> TransactionBuilder<'_> {
TransactionBuilder {
session: self,
timeout: self.transaction_timeout,
isolation: IsolationLevel::default(),
}
}
#[instrument(skip(self), fields(session_id = %self.id))]
pub async fn pin_to_version(&mut self, snapshot_id: &str) -> Result<()> {
let pinned = self.live_db().at_snapshot(snapshot_id).await?;
if self.original_db.is_none() {
self.original_db = Some(self.db.clone());
}
self.db = Arc::new(pinned);
Ok(())
}
#[instrument(skip(self), fields(session_id = %self.id))]
pub async fn pin_to_timestamp(&mut self, ts: chrono::DateTime<chrono::Utc>) -> Result<()> {
let snapshot_id = self.live_db().resolve_time_travel_timestamp(ts).await?;
self.pin_to_version(&snapshot_id).await
}
pub async fn refresh(&mut self) -> Result<()> {
if let Some(original) = self.original_db.take() {
self.db = original;
}
Ok(())
}
pub fn is_pinned(&self) -> bool {
self.original_db.is_some()
}
fn live_db(&self) -> &Arc<UniInner> {
self.original_db.as_ref().unwrap_or(&self.db)
}
#[instrument(skip(self), fields(session_id = %self.id))]
pub fn cancel(&self) {
let mut token = self.cancellation_token.write().unwrap();
token.cancel();
*token = CancellationToken::new();
}
pub fn cancellation_token(&self) -> CancellationToken {
self.cancellation_token.read().unwrap().clone()
}
#[instrument(skip(self), fields(session_id = %self.id))]
pub async fn prepare(&self, cypher: &str) -> Result<crate::api::prepared::PreparedQuery> {
crate::api::prepared::PreparedQuery::new(self.db.clone(), cypher).await
}
#[instrument(skip(self), fields(session_id = %self.id))]
pub async fn prepare_locy(&self, program: &str) -> Result<crate::api::prepared::PreparedLocy> {
crate::api::prepared::PreparedLocy::new(
self.db.clone(),
self.rule_registry.clone(),
program,
)
}
pub fn add_hook(&mut self, name: impl Into<String>, hook: impl SessionHook + 'static) {
self.hooks.insert(name.into(), Arc::new(hook));
}
pub fn remove_hook(&mut self, name: &str) -> bool {
self.hooks.remove(name).is_some()
}
pub fn list_hooks(&self) -> Vec<String> {
self.hooks.keys().cloned().collect()
}
pub fn clear_hooks(&mut self) {
self.hooks.clear();
}
pub(crate) fn run_before_query_hooks(
&self,
query_text: &str,
query_type: QueryType,
params: &HashMap<String, Value>,
) -> Result<()> {
if self.hooks.is_empty() {
return Ok(());
}
let ctx = HookContext {
session_id: self.id.clone(),
query_text: query_text.to_string(),
query_type,
params: params.clone(),
};
for hook in self.hooks.values() {
hook.before_query(&ctx)?;
}
Ok(())
}
pub(crate) fn run_after_query_hooks(
&self,
query_text: &str,
query_type: QueryType,
params: &HashMap<String, Value>,
metrics: &uni_query::QueryMetrics,
) {
if self.hooks.is_empty() {
return;
}
let ctx = HookContext {
session_id: self.id.clone(),
query_text: query_text.to_string(),
query_type,
params: params.clone(),
};
for hook in self.hooks.values() {
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
hook.after_query(&ctx, metrics);
}));
if let Err(e) = result {
tracing::error!("after_query hook panicked: {:?}", e);
}
}
}
pub fn watch(&self) -> crate::api::notifications::CommitStream {
let rx = self.db.commit_tx.subscribe();
crate::api::notifications::WatchBuilder::new(rx).build()
}
pub fn watch_with(&self) -> crate::api::notifications::WatchBuilder {
let rx = self.db.commit_tx.subscribe();
crate::api::notifications::WatchBuilder::new(rx)
}
pub fn id(&self) -> &str {
&self.id
}
pub fn capabilities(&self) -> SessionCapabilities {
use crate::api::multi_agent::WriteLease;
let write_lease = self.db.write_lease.as_ref().map(|wl| match wl {
WriteLease::Local => WriteLeaseSummary::Local,
WriteLease::DynamoDB { table } => WriteLeaseSummary::DynamoDB {
table: table.clone(),
},
WriteLease::Custom(_) => WriteLeaseSummary::Custom,
});
SessionCapabilities {
can_write: self.db.writer.is_some() && !self.is_pinned(),
can_pin: true,
isolation: IsolationLevel::default(),
has_notifications: true,
write_lease,
}
}
pub fn metrics(&self) -> SessionMetrics {
let m = &self.metrics_inner;
SessionMetrics {
session_id: self.id.clone(),
active_since: self.created_at,
queries_executed: m.queries_executed.load(Ordering::Relaxed),
locy_evaluations: m.locy_evaluations.load(Ordering::Relaxed),
total_query_time: Duration::from_micros(m.total_query_time_us.load(Ordering::Relaxed)),
transactions_committed: m.transactions_committed.load(Ordering::Relaxed),
transactions_rolled_back: m.transactions_rolled_back.load(Ordering::Relaxed),
total_rows_returned: m.total_rows_returned.load(Ordering::Relaxed),
total_rows_scanned: m.total_rows_scanned.load(Ordering::Relaxed),
plan_cache_hits: self.plan_cache_metrics.hits.load(Ordering::Relaxed),
plan_cache_misses: self.plan_cache_metrics.misses.load(Ordering::Relaxed),
plan_cache_size: self.plan_cache.lock().map(|c| c.len()).unwrap_or(0),
}
}
pub(crate) async fn execute_cached(
&self,
cypher: &str,
params: HashMap<String, Value>,
) -> Result<QueryResult> {
let schema_version = self.db.schema.schema().schema_version;
let cache_key = plan_cache_key(cypher);
let cached = self.plan_cache.lock().ok().and_then(|mut cache| {
cache
.get(cache_key, schema_version)
.map(|entry| (entry.ast.clone(), entry.plan.clone()))
});
if let Some((_ast, plan)) = cached {
self.plan_cache_metrics.hits.fetch_add(1, Ordering::Relaxed);
return self
.db
.execute_plan_internal(plan, cypher, params, self.db.config.clone(), None)
.await;
}
self.plan_cache_metrics
.misses
.fetch_add(1, Ordering::Relaxed);
let ast = uni_cypher::parse(cypher).map_err(crate::api::impl_query::into_parse_error)?;
uni_query::validate_read_only(&ast).map_err(|_| UniError::Query {
message: "Session.query() is read-only. Mutation clauses (CREATE, MERGE, DELETE, SET, \
REMOVE) require a transaction. Use session.tx() to start one."
.to_string(),
query: Some(cypher.to_string()),
})?;
if matches!(ast, uni_cypher::ast::Query::TimeTravel { .. }) {
return self
.db
.execute_internal_with_config(cypher, params, self.db.config.clone())
.await;
}
let planner = uni_query::QueryPlanner::new(self.db.schema.schema().clone())
.with_params(params.clone());
let plan = planner
.plan(ast.clone())
.map_err(|e| crate::api::impl_query::into_query_error(e, cypher))?;
if let Ok(mut cache) = self.plan_cache.lock() {
cache.insert(
cache_key,
PlanCacheEntry {
ast,
plan: plan.clone(),
schema_version,
hit_count: 0,
},
);
}
self.db
.execute_plan_internal(plan, cypher, params, self.db.config.clone(), None)
.await
}
pub(crate) fn db(&self) -> &Arc<UniInner> {
&self.db
}
pub(crate) fn rule_registry(&self) -> &Arc<std::sync::RwLock<LocyRuleRegistry>> {
&self.rule_registry
}
pub(crate) fn active_write_guard(&self) -> &Arc<AtomicBool> {
&self.active_write_guard
}
pub(crate) fn merge_params(
&self,
mut query_params: HashMap<String, Value>,
) -> HashMap<String, Value> {
let session_params = self.params.read().unwrap();
if !session_params.is_empty() {
let session_map: HashMap<String, Value> = session_params.clone();
if let Some(Value::Map(existing)) = query_params.get_mut("session") {
for (k, v) in session_map {
existing.entry(k).or_insert(v);
}
} else {
query_params.insert("session".to_string(), Value::Map(session_map));
}
}
query_params
}
}
pub struct Params<'a> {
store: &'a Arc<std::sync::RwLock<HashMap<String, Value>>>,
}
impl<'a> Params<'a> {
pub fn set<K: Into<String>, V: Into<Value>>(&self, key: K, value: V) {
self.store.write().unwrap().insert(key.into(), value.into());
}
pub fn get(&self, key: &str) -> Option<Value> {
self.store.read().unwrap().get(key).cloned()
}
pub fn unset(&self, key: &str) -> Option<Value> {
self.store.write().unwrap().remove(key)
}
pub fn get_all(&self) -> HashMap<String, Value> {
self.store.read().unwrap().clone()
}
pub fn set_all<I, K, V>(&self, params: I)
where
I: IntoIterator<Item = (K, V)>,
K: Into<String>,
V: Into<Value>,
{
let mut store = self.store.write().unwrap();
for (k, v) in params {
store.insert(k.into(), v.into());
}
}
pub fn clone_store_arc(&self) -> Arc<std::sync::RwLock<HashMap<String, Value>>> {
self.store.clone()
}
}
pub struct QueryBuilder<'a> {
session: &'a Session,
cypher: String,
params: HashMap<String, Value>,
timeout: Option<std::time::Duration>,
max_memory: Option<usize>,
cancellation_token: Option<CancellationToken>,
}
impl<'a> QueryBuilder<'a> {
pub fn param<K: Into<String>, V: Into<Value>>(mut self, key: K, value: V) -> Self {
self.params.insert(key.into(), value.into());
self
}
pub fn params<'p>(mut self, params: impl IntoIterator<Item = (&'p str, Value)>) -> Self {
for (k, v) in params {
self.params.insert(k.to_string(), v);
}
self
}
pub fn timeout(mut self, duration: std::time::Duration) -> Self {
self.timeout = Some(duration);
self
}
pub fn max_memory(mut self, bytes: usize) -> Self {
self.max_memory = Some(bytes);
self
}
pub fn cancellation_token(mut self, token: CancellationToken) -> Self {
self.cancellation_token = Some(token);
self
}
pub async fn fetch_all(self) -> Result<QueryResult> {
let has_overrides = self.timeout.is_some()
|| self.max_memory.is_some()
|| self.cancellation_token.is_some();
if has_overrides {
let ast = uni_cypher::parse(&self.cypher)
.map_err(crate::api::impl_query::into_parse_error)?;
uni_query::validate_read_only(&ast).map_err(|_| UniError::Query {
message: "Session.query() is read-only. Mutation clauses (CREATE, MERGE, DELETE, \
SET, REMOVE) require a transaction. Use session.tx() to start one."
.to_string(),
query: Some(self.cypher.clone()),
})?;
let mut db_config = self.session.db.config.clone();
if let Some(t) = self.timeout {
db_config.query_timeout = t;
}
if let Some(m) = self.max_memory {
db_config.max_query_memory = m;
}
let params = self.session.merge_params(self.params);
self.session
.db
.execute_internal_with_config_and_token(
&self.cypher,
params,
db_config,
self.cancellation_token,
)
.await
} else {
let params = self.session.merge_params(self.params);
self.session.execute_cached(&self.cypher, params).await
}
}
pub async fn fetch_one(self) -> Result<Option<Row>> {
let result = self.fetch_all().await?;
Ok(result.into_rows().into_iter().next())
}
pub async fn cursor(self) -> Result<QueryCursor> {
let mut db_config = self.session.db.config.clone();
if let Some(t) = self.timeout {
db_config.query_timeout = t;
}
if let Some(m) = self.max_memory {
db_config.max_query_memory = m;
}
let params = self.session.merge_params(self.params);
self.session
.db
.execute_cursor_internal_with_config(&self.cypher, params, db_config)
.await
}
pub async fn explain(self) -> Result<ExplainOutput> {
self.session.db.explain_internal(&self.cypher).await
}
pub async fn profile(self) -> Result<(QueryResult, ProfileOutput)> {
let params = self.session.merge_params(self.params);
self.session.db.profile_internal(&self.cypher, params).await
}
}
pub struct TransactionBuilder<'a> {
session: &'a Session,
timeout: Option<Duration>,
isolation: IsolationLevel,
}
impl<'a> TransactionBuilder<'a> {
pub fn timeout(mut self, d: Duration) -> Self {
self.timeout = Some(d);
self
}
pub fn isolation(mut self, level: IsolationLevel) -> Self {
self.isolation = level;
self
}
pub async fn start(self) -> Result<Transaction> {
if self.session.is_pinned() {
return Err(UniError::ReadOnly {
operation: "start_transaction".to_string(),
});
}
Transaction::new_with_options(self.session, self.timeout, self.isolation).await
}
}
impl Clone for Session {
fn clone(&self) -> Self {
self.db.active_session_count.fetch_add(1, Ordering::Relaxed);
Self {
db: self.db.clone(),
original_db: self.original_db.clone(),
id: Uuid::new_v4().to_string(),
params: Arc::new(std::sync::RwLock::new(self.params.read().unwrap().clone())),
rule_registry: Arc::new(std::sync::RwLock::new(
self.rule_registry.read().unwrap().clone(),
)),
active_write_guard: Arc::new(AtomicBool::new(false)),
metrics_inner: Arc::new(SessionMetricsInner::new()),
created_at: Instant::now(),
cancellation_token: Arc::new(std::sync::RwLock::new(CancellationToken::new())),
plan_cache: self.plan_cache.clone(),
plan_cache_metrics: self.plan_cache_metrics.clone(),
hooks: self.hooks.clone(),
query_timeout: self.query_timeout,
transaction_timeout: self.transaction_timeout,
}
}
}
impl Drop for Session {
fn drop(&mut self) {
self.db.active_session_count.fetch_sub(1, Ordering::Relaxed);
}
}
struct PlanCacheEntry {
ast: uni_query::CypherQuery,
plan: uni_query::LogicalPlan,
schema_version: u32,
hit_count: u64,
}
struct PlanCache {
entries: HashMap<u64, PlanCacheEntry>,
max_entries: usize,
}
impl PlanCache {
fn new(max_entries: usize) -> Self {
Self {
entries: HashMap::new(),
max_entries,
}
}
fn get(&mut self, key: u64, current_schema_version: u32) -> Option<&PlanCacheEntry> {
if let Some(entry) = self.entries.get_mut(&key) {
if entry.schema_version == current_schema_version {
entry.hit_count += 1;
return self.entries.get(&key);
}
self.entries.remove(&key);
}
None
}
fn insert(&mut self, key: u64, entry: PlanCacheEntry) {
if self.entries.len() >= self.max_entries {
if let Some((&evict_key, _)) = self.entries.iter().min_by_key(|(_, e)| e.hit_count) {
self.entries.remove(&evict_key);
}
}
self.entries.insert(key, entry);
}
fn len(&self) -> usize {
self.entries.len()
}
}
fn plan_cache_key(cypher: &str) -> u64 {
use std::hash::{Hash, Hasher};
let mut hasher = std::collections::hash_map::DefaultHasher::new();
cypher.hash(&mut hasher);
hasher.finish()
}