use std::any::Any;
use std::collections::HashMap;
use std::sync::Arc;
use serde::{Deserialize, Serialize};
use tokio::sync::RwLock;
use uuid::Uuid;
pub const SESSION_REF_TAG: &str = "__blazen_session_ref__";
pub const MAX_SESSION_REFS_PER_RUN: usize = 10_000;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(transparent)]
pub struct RegistryKey(pub Uuid);
impl RegistryKey {
#[must_use]
pub fn new() -> Self {
Self(Uuid::new_v4())
}
pub fn parse(s: &str) -> Result<Self, uuid::Error> {
Uuid::parse_str(s).map(Self)
}
}
impl Default for RegistryKey {
fn default() -> Self {
Self::new()
}
}
impl std::fmt::Display for RegistryKey {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.0.fmt(f)
}
}
type AnyArc = Arc<dyn Any + Send + Sync>;
#[derive(Default)]
pub struct SessionRefRegistry {
inner: RwLock<HashMap<RegistryKey, AnyArc>>,
}
impl std::fmt::Debug for SessionRefRegistry {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SessionRefRegistry").finish_non_exhaustive()
}
}
impl SessionRefRegistry {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub async fn insert_arc(&self, value: AnyArc) -> Result<RegistryKey, SessionRefError> {
let mut g = self.inner.write().await;
if g.len() >= MAX_SESSION_REFS_PER_RUN {
return Err(SessionRefError::CapacityExceeded {
cap: MAX_SESSION_REFS_PER_RUN,
});
}
let key = RegistryKey::new();
g.insert(key, value);
Ok(key)
}
pub async fn insert<T: Any + Send + Sync + 'static>(
&self,
value: T,
) -> Result<RegistryKey, SessionRefError> {
self.insert_arc(Arc::new(value)).await
}
pub async fn get_any(&self, key: RegistryKey) -> Option<AnyArc> {
self.inner.read().await.get(&key).cloned()
}
pub async fn get<T: Any + Send + Sync + 'static>(&self, key: RegistryKey) -> Option<Arc<T>> {
let any = self.inner.read().await.get(&key).cloned()?;
Arc::downcast::<T>(any).ok()
}
pub async fn remove(&self, key: RegistryKey) -> Option<AnyArc> {
self.inner.write().await.remove(&key)
}
pub async fn drain(&self) -> usize {
let mut g = self.inner.write().await;
let n = g.len();
g.clear();
n
}
pub async fn len(&self) -> usize {
self.inner.read().await.len()
}
pub async fn is_empty(&self) -> bool {
self.inner.read().await.is_empty()
}
pub async fn keys(&self) -> Vec<RegistryKey> {
self.inner.read().await.keys().copied().collect()
}
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum SessionPausePolicy {
#[default]
PickleOrError,
WarnDrop,
HardError,
}
#[derive(Debug, thiserror::Error)]
pub enum SessionRefError {
#[error(
"session ref registry capacity exceeded ({cap} entries) — \
too many live references in this workflow run"
)]
CapacityExceeded {
cap: usize,
},
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn insert_and_get_roundtrip() {
let reg = SessionRefRegistry::new();
let key = reg.insert(42_i32).await.unwrap();
let got = reg.get::<i32>(key).await.unwrap();
assert_eq!(*got, 42);
}
#[tokio::test]
async fn get_wrong_type_returns_none() {
let reg = SessionRefRegistry::new();
let key = reg.insert(42_i32).await.unwrap();
assert!(reg.get::<String>(key).await.is_none());
}
#[tokio::test]
async fn remove_returns_value_and_clears() {
let reg = SessionRefRegistry::new();
let key = reg.insert("hello".to_owned()).await.unwrap();
assert_eq!(reg.len().await, 1);
let removed = reg.remove(key).await;
assert!(removed.is_some());
assert_eq!(reg.len().await, 0);
}
#[tokio::test]
async fn drain_clears_everything() {
let reg = SessionRefRegistry::new();
let _ = reg.insert(1_i32).await.unwrap();
let _ = reg.insert(2_i32).await.unwrap();
let _ = reg.insert(3_i32).await.unwrap();
assert_eq!(reg.drain().await, 3);
assert!(reg.is_empty().await);
}
#[tokio::test]
async fn capacity_cap_enforced() {
let reg = SessionRefRegistry::new();
for i in 0..100_i32 {
assert!(reg.insert(i).await.is_ok());
}
assert_eq!(reg.len().await, 100);
}
#[test]
fn registry_key_parse_roundtrip() {
let k = RegistryKey::new();
let s = k.to_string();
let parsed = RegistryKey::parse(&s).unwrap();
assert_eq!(k, parsed);
}
#[test]
fn session_pause_policy_default_is_pickle_or_error() {
assert_eq!(
SessionPausePolicy::default(),
SessionPausePolicy::PickleOrError
);
}
#[test]
fn session_pause_policy_serde_roundtrip() {
let p = SessionPausePolicy::WarnDrop;
let json = serde_json::to_string(&p).unwrap();
assert_eq!(json, "\"warn_drop\"");
let back: SessionPausePolicy = serde_json::from_str(&json).unwrap();
assert_eq!(back, p);
}
}