use std::sync::Arc;
use crate::task::TaskMetadata;
#[derive(Clone, Debug)]
pub struct TaskExecutionContext {
pub workflow_id: Arc<str>,
pub instance_id: Arc<str>,
pub task_id: Arc<str>,
pub metadata: TaskMetadata,
pub workflow_metadata_json: Option<Arc<str>>,
}
pub struct WorkflowContext<C, M> {
pub workflow_id: crate::WorkflowId,
pub workflow_name: Arc<str>,
pub codec: Arc<C>,
pub metadata: Arc<M>,
pub metadata_json: Option<Arc<str>>,
}
impl<C, M> Clone for WorkflowContext<C, M> {
fn clone(&self) -> Self {
Self {
workflow_id: self.workflow_id,
workflow_name: Arc::clone(&self.workflow_name),
codec: Arc::clone(&self.codec),
metadata: Arc::clone(&self.metadata),
metadata_json: self.metadata_json.clone(),
}
}
}
impl<C, M> WorkflowContext<C, M> {
pub fn new(workflow_name: impl Into<Arc<str>>, codec: Arc<C>, metadata: Arc<M>) -> Self {
let workflow_name: Arc<str> = workflow_name.into();
let workflow_id = crate::WorkflowId::from(workflow_name.as_ref());
Self {
workflow_id,
workflow_name,
codec,
metadata,
metadata_json: None,
}
}
#[must_use]
pub fn workflow_id(&self) -> &str {
&self.workflow_name
}
#[must_use]
pub fn workflow_id_hash(&self) -> crate::WorkflowId {
self.workflow_id
}
#[must_use]
pub fn codec(&self) -> Arc<C> {
self.codec.clone()
}
#[must_use]
pub fn metadata(&self) -> Arc<M> {
self.metadata.clone()
}
}
use std::cell::RefCell;
std::thread_local! {
static THREAD_LOCAL_TASK_CTX: RefCell<Option<TaskExecutionContext>> = const { RefCell::new(None) };
}
pub fn with_thread_local_task_context<R>(ctx: TaskExecutionContext, f: impl FnOnce() -> R) -> R {
THREAD_LOCAL_TASK_CTX.with(|cell| {
let prev = cell.borrow_mut().replace(ctx);
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(f));
*cell.borrow_mut() = prev;
match result {
Ok(r) => r,
Err(e) => std::panic::resume_unwind(e),
}
})
}
#[must_use]
pub fn get_thread_local_task_context() -> Option<TaskExecutionContext> {
THREAD_LOCAL_TASK_CTX.with(|cell| cell.borrow().clone())
}
#[cfg(feature = "tokio")]
mod task_local_ctx {
use super::TaskExecutionContext;
tokio::task_local! {
static TASK_EXEC_CTX: Option<TaskExecutionContext>;
}
pub async fn with_task_context<F: std::future::Future>(
ctx: TaskExecutionContext,
fut: F,
) -> F::Output {
TASK_EXEC_CTX.scope(Some(ctx), fut).await
}
#[must_use]
pub fn get_task_context() -> Option<TaskExecutionContext> {
TASK_EXEC_CTX
.try_with(std::clone::Clone::clone)
.ok()
.flatten()
.or_else(super::get_thread_local_task_context)
}
}
#[cfg(feature = "tokio")]
pub use task_local_ctx::{get_task_context, with_task_context};
#[cfg(not(feature = "tokio"))]
#[must_use]
pub fn get_task_context() -> Option<TaskExecutionContext> {
get_thread_local_task_context()
}
#[macro_export]
macro_rules! task_context {
() => {
$crate::context::get_task_context()
};
}
#[cfg(all(test, feature = "tokio"))]
#[allow(clippy::unwrap_used, clippy::panic)]
mod tests {
use super::*;
use crate::task::TaskMetadata;
fn make_task_ctx() -> TaskExecutionContext {
TaskExecutionContext {
workflow_id: Arc::from("wf-1"),
instance_id: Arc::from("inst-1"),
task_id: Arc::from("task-a"),
metadata: TaskMetadata::default(),
workflow_metadata_json: None,
}
}
#[test]
fn thread_local_roundtrip() {
assert!(get_thread_local_task_context().is_none());
let ctx = make_task_ctx();
let result = with_thread_local_task_context(ctx.clone(), || {
let inner = get_thread_local_task_context().unwrap();
assert_eq!(&*inner.workflow_id, "wf-1");
assert_eq!(&*inner.instance_id, "inst-1");
assert_eq!(&*inner.task_id, "task-a");
42
});
assert_eq!(result, 42);
assert!(get_thread_local_task_context().is_none());
}
#[test]
fn thread_local_restores_on_panic() {
let ctx = make_task_ctx();
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
with_thread_local_task_context(ctx, || {
panic!("boom");
})
}));
assert!(result.is_err());
assert!(get_thread_local_task_context().is_none());
}
#[test]
fn task_local_roundtrip() {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
rt.block_on(async {
assert!(get_task_context().is_none());
let ctx = make_task_ctx();
let inner = with_task_context(ctx, async {
let c = get_task_context().unwrap();
assert_eq!(&*c.task_id, "task-a");
c
})
.await;
assert_eq!(&*inner.workflow_id, "wf-1");
});
}
#[test]
fn task_local_falls_back_to_thread_local() {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
rt.block_on(async {
let ctx = make_task_ctx();
let result = with_thread_local_task_context(ctx, get_task_context);
assert!(result.is_some());
assert_eq!(&*result.unwrap().instance_id, "inst-1");
});
}
#[test]
fn macro_works() {
let ctx = make_task_ctx();
with_thread_local_task_context(ctx, || {
let c = task_context!().unwrap();
assert_eq!(&*c.task_id, "task-a");
});
}
}