use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use std::time::Duration;
use chrono::{DateTime, Utc};
use uuid::Uuid;
use crate::session::JammiSession;
use crate::store::mutable::definition::MutableTableId;
use crate::tenant::TenantId;
use super::event::{SessionLifecycleEvent, SessionLifecycleRecord};
use super::topic;
pub const DEFAULT_SCAN_INTERVAL: Duration = Duration::from_secs(60);
#[derive(Clone)]
struct SessionSnapshot {
tenant: TenantId,
deadline: DateTime<Utc>,
tables: Vec<MutableTableId>,
}
#[derive(Clone, Default)]
pub struct ActiveSessions {
inner: Arc<Mutex<HashMap<Uuid, SessionSnapshot>>>,
}
impl ActiveSessions {
pub fn new() -> Self {
Self::default()
}
pub(super) fn upsert(
&self,
session_id: Uuid,
tenant: TenantId,
deadline: DateTime<Utc>,
tables: Vec<MutableTableId>,
) {
if let Ok(mut map) = self.inner.lock() {
map.insert(
session_id,
SessionSnapshot {
tenant,
deadline,
tables,
},
);
}
}
pub(super) fn remove(&self, session_id: &Uuid) -> bool {
self.inner
.lock()
.map(|mut map| map.remove(session_id).is_some())
.unwrap_or(false)
}
fn take_expired(&self, now: DateTime<Utc>) -> Vec<(Uuid, SessionSnapshot)> {
let Ok(mut map) = self.inner.lock() else {
return Vec::new();
};
let expired: Vec<Uuid> = map
.iter()
.filter(|(_, snap)| now >= snap.deadline)
.map(|(id, _)| *id)
.collect();
expired
.into_iter()
.filter_map(|id| map.remove(&id).map(|snap| (id, snap)))
.collect()
}
pub fn len(&self) -> usize {
self.inner.lock().map(|m| m.len()).unwrap_or(0)
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
pub async fn scan(parent: &Arc<JammiSession>, active: &ActiveSessions) {
let now = Utc::now();
for (session_id, snap) in active.take_expired(now) {
if let Err(e) = delete_tables_and_emit(
parent,
session_id,
snap.tenant,
&snap.tables,
SessionLifecycleEvent::TimedOut,
)
.await
{
tracing::warn!(
session_id = %session_id,
error = %e,
"ephemeral timeout force-close completed with errors",
);
}
}
}
pub(super) struct DeletionOutcome {
pub failures: Vec<String>,
}
pub(super) async fn delete_tables_and_emit(
parent: &Arc<JammiSession>,
session_id: Uuid,
tenant: TenantId,
tables: &[MutableTableId],
event: SessionLifecycleEvent,
) -> Result<DeletionOutcome, super::error::EphemeralError> {
let mut deleted_rows: u64 = 0;
let mut failures: Vec<String> = Vec::new();
for id in tables {
match count_rows(parent, tenant, id).await {
Ok(n) => deleted_rows += n,
Err(e) => tracing::warn!(
session_id = %session_id,
table = %id.as_str(),
error = %e,
"ephemeral row count failed before drop",
),
}
if let Err(e) = parent.drop_mutable_table(id).await {
tracing::warn!(
session_id = %session_id,
table = %id.as_str(),
error = %e,
"ephemeral table drop failed",
);
failures.push(id.as_str().to_string());
}
}
let (emit_event, details) = if failures.is_empty() {
(event, None)
} else {
(
SessionLifecycleEvent::PartialDeletionFailure,
Some(serde_json::json!({
"surviving_tables": failures,
"attempted_event": event,
})),
)
};
let record = SessionLifecycleRecord {
session_id,
tenant_id: tenant.to_string(),
event: emit_event,
occurred_at: Utc::now(),
ephemeral_table_count: tables.len(),
deleted_row_count: deleted_rows,
details,
};
topic::publish_lifecycle(parent, tenant, &record).await?;
Ok(DeletionOutcome { failures })
}
async fn count_rows(
parent: &Arc<JammiSession>,
tenant: TenantId,
id: &MutableTableId,
) -> Result<u64, super::error::EphemeralError> {
let sql = format!(
"SELECT COUNT(*) AS n FROM mutable.public.\"{}\"",
id.as_str()
);
let batches = parent
.with_tenant_scoped(tenant, move |scope| async move { scope.sql(&sql).await })
.await
.map_err(|e| super::error::EphemeralError::Storage(e.to_string()))?;
let count = batches
.first()
.and_then(|b| b.column_by_name("n"))
.and_then(|c| c.as_any().downcast_ref::<arrow::array::Int64Array>())
.filter(|a| !a.is_empty())
.map(|a| a.value(0))
.unwrap_or(0);
Ok(count.max(0) as u64)
}
pub fn spawn(
parent: Arc<JammiSession>,
active: ActiveSessions,
interval: Duration,
) -> tokio::task::JoinHandle<()> {
tokio::spawn(async move {
let mut ticker = tokio::time::interval(interval);
ticker.tick().await;
loop {
ticker.tick().await;
scan(&parent, &active).await;
}
})
}