use std::cell::RefCell;
use std::sync::{Arc, RwLock};
use uuid::Uuid;
use crate::api::runtime::callbacks::EventSubscriberFn;
use crate::api::scope::{ScopeHandle, ScopeType};
use crate::context::registries::ScopeLocalRegistries;
use crate::error::{FlowError, Result};
use crate::registry::SortedRegistry;
pub struct ScopeStack {
stack: Vec<ScopeHandle>,
scope_registries: std::collections::HashMap<Uuid, ScopeLocalRegistries>,
}
impl ScopeStack {
pub fn new() -> Self {
let root = ScopeHandle::builder()
.name("root")
.scope_type(ScopeType::Agent)
.build();
Self {
stack: vec![root],
scope_registries: std::collections::HashMap::new(),
}
}
pub fn push(&mut self, handle: ScopeHandle) {
self.stack.push(handle);
}
pub fn top(&self) -> &ScopeHandle {
self.stack
.last()
.expect("scope stack should never be empty")
}
pub fn top_mut(&mut self) -> &mut ScopeHandle {
self.stack
.last_mut()
.expect("scope stack should never be empty")
}
pub fn root_uuid(&self) -> Uuid {
self.stack
.first()
.expect("scope stack should never be empty")
.uuid
}
pub fn scopes(&self) -> &[ScopeHandle] {
&self.stack
}
pub fn find(&self, uuid: &Uuid) -> Option<&ScopeHandle> {
self.stack.iter().find(|handle| handle.uuid == *uuid)
}
pub fn remove(&mut self, uuid: &Uuid) -> Result<ScopeHandle> {
let top = self
.stack
.last()
.expect("scope stack should never be empty");
if top.uuid == *uuid {
if self.stack.len() == 1 {
return Err(FlowError::InvalidArgument(
"root scope cannot be removed".into(),
));
}
self.scope_registries.remove(uuid);
return Ok(self
.stack
.pop()
.expect("scope stack should contain a removable top scope"));
}
if self.stack.iter().any(|handle| handle.uuid == *uuid) {
return Err(FlowError::InvalidArgument(
"scope handle is not at the top of the stack".into(),
));
}
Err(FlowError::NotFound("scope handle not found".into()))
}
pub fn local_registries_mut(&mut self, uuid: &Uuid) -> Option<&mut ScopeLocalRegistries> {
if !self.stack.iter().any(|handle| handle.uuid == *uuid) {
return None;
}
Some(self.scope_registries.entry(*uuid).or_default())
}
pub fn collect_scope_local_registries<'a, T>(
&'a self,
field: impl Fn(&'a ScopeLocalRegistries) -> &'a SortedRegistry<T>,
) -> Vec<&'a SortedRegistry<T>> {
self.stack
.iter()
.filter_map(|handle| self.scope_registries.get(&handle.uuid))
.map(field)
.collect()
}
pub fn collect_scope_local_subscribers(&self) -> Vec<EventSubscriberFn> {
self.stack
.iter()
.filter_map(|handle| self.scope_registries.get(&handle.uuid))
.flat_map(|registries| registries.event_subscribers.values().cloned())
.collect()
}
pub fn scope_registries_get(&self, uuid: &Uuid) -> Option<&ScopeLocalRegistries> {
self.scope_registries.get(uuid)
}
}
impl std::fmt::Debug for ScopeStack {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ScopeStack")
.field("stack", &self.stack)
.field("scope_registries_count", &self.scope_registries.len())
.finish()
}
}
impl Default for ScopeStack {
fn default() -> Self {
Self::new()
}
}
pub type ScopeStackHandle = Arc<RwLock<ScopeStack>>;
#[derive(Clone)]
pub struct ThreadScopeStackBinding {
stack: ScopeStackHandle,
explicit: bool,
}
pub fn create_scope_stack() -> ScopeStackHandle {
Arc::new(RwLock::new(ScopeStack::new()))
}
tokio::task_local! {
pub static TASK_SCOPE_STACK: ScopeStackHandle;
}
thread_local! {
static THREAD_SCOPE_STACK: RefCell<ScopeStackHandle> = RefCell::new(create_scope_stack());
static THREAD_SCOPE_STACK_EXPLICIT: std::cell::Cell<bool> = const { std::cell::Cell::new(false) };
}
pub fn current_scope_stack() -> ScopeStackHandle {
TASK_SCOPE_STACK
.try_with(|stack| stack.clone())
.unwrap_or_else(|_| THREAD_SCOPE_STACK.with(|stack| stack.borrow().clone()))
}
pub fn set_thread_scope_stack(handle: ScopeStackHandle) {
THREAD_SCOPE_STACK.with(|stack| *stack.borrow_mut() = handle);
THREAD_SCOPE_STACK_EXPLICIT.with(|flag| flag.set(true));
}
pub fn capture_thread_scope_stack() -> ThreadScopeStackBinding {
let stack = THREAD_SCOPE_STACK.with(|stack| stack.borrow().clone());
let explicit = THREAD_SCOPE_STACK_EXPLICIT.with(|flag| flag.get());
ThreadScopeStackBinding { stack, explicit }
}
pub fn restore_thread_scope_stack(binding: ThreadScopeStackBinding) {
THREAD_SCOPE_STACK.with(|stack| *stack.borrow_mut() = binding.stack);
THREAD_SCOPE_STACK_EXPLICIT.with(|flag| flag.set(binding.explicit));
}
pub fn sync_thread_scope_stack(handle: ScopeStackHandle) {
THREAD_SCOPE_STACK.with(|stack| *stack.borrow_mut() = handle);
}
pub fn scope_stack_active() -> bool {
TASK_SCOPE_STACK
.try_with(|_| true)
.unwrap_or_else(|_| THREAD_SCOPE_STACK_EXPLICIT.with(|flag| flag.get()))
}
pub fn propagate_scope_to_thread() -> Result<ScopeStackHandle> {
if !scope_stack_active() {
return Err(FlowError::Internal(
"no active scope stack in current context; call create_scope_stack() and set_thread_scope_stack() first"
.into(),
));
}
Ok(current_scope_stack())
}
pub fn task_scope_top() -> ScopeHandle {
let stack = current_scope_stack();
let guard = stack.read().expect("scope stack lock poisoned");
guard.top().clone()
}
pub fn task_scope_push(handle: ScopeHandle) {
let stack = current_scope_stack();
let mut guard = stack.write().expect("scope stack lock poisoned");
guard.push(handle);
}
pub fn task_scope_remove(uuid: &Uuid) -> Result<ScopeHandle> {
let stack = current_scope_stack();
let mut guard = stack.write().expect("scope stack lock poisoned");
guard.remove(uuid)
}