use std::sync::Arc;
use rustvello_proto::identifiers::{InvocationId, RunnerId, TaskId};
use rustvello_proto::invocation::WorkflowIdentity;
use serde::{Deserialize, Serialize};
fn current_thread_id() -> u64 {
let id = std::thread::current().id();
let debug = format!("{id:?}");
debug
.trim_start_matches("ThreadId(")
.trim_end_matches(')')
.parse()
.unwrap_or(0)
}
#[derive(Debug, Clone)]
pub struct InvocationContext {
pub invocation_id: InvocationId,
pub task_id: TaskId,
pub workflow: WorkflowIdentity,
pub parent_invocation_id: Option<InvocationId>,
pub num_retries: u32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RunnerContext {
pub runner_id: RunnerId,
pub runner_cls: Arc<str>,
#[serde(
serialize_with = "serialize_arc_str",
deserialize_with = "deserialize_arc_str"
)]
pub app_id: Arc<str>,
pub pid: u32,
pub hostname: String,
pub thread_id: u64,
pub parent_ctx: Option<Box<RunnerContext>>,
}
fn serialize_arc_str<S: serde::Serializer>(v: &Arc<str>, s: S) -> Result<S::Ok, S::Error> {
s.serialize_str(v)
}
fn deserialize_arc_str<'de, D: serde::Deserializer<'de>>(d: D) -> Result<Arc<str>, D::Error> {
let s = String::deserialize(d)?;
Ok(Arc::from(s.as_str()))
}
impl RunnerContext {
pub fn new(runner_id: RunnerId, app_id: Arc<str>, runner_cls: impl Into<Arc<str>>) -> Self {
Self {
runner_id,
runner_cls: runner_cls.into(),
app_id,
pid: std::process::id(),
hostname: Self::get_hostname(),
thread_id: current_thread_id(),
parent_ctx: None,
}
}
pub fn new_child(&self, runner_id: RunnerId) -> Self {
Self {
runner_id,
runner_cls: Arc::clone(&self.runner_cls),
app_id: Arc::clone(&self.app_id),
pid: std::process::id(),
hostname: self.hostname.clone(),
thread_id: current_thread_id(),
parent_ctx: Some(Box::new(self.clone())),
}
}
pub fn root_runner_id(&self) -> &RunnerId {
match &self.parent_ctx {
Some(parent) => parent.root_runner_id(),
None => &self.runner_id,
}
}
pub fn external() -> Self {
let hostname = Self::get_hostname();
let pid = std::process::id();
let runner_id = RunnerId::from_string(format!("{hostname}-{pid}"));
Self {
runner_id,
runner_cls: Arc::from("ExternalRunner"),
app_id: Arc::from("external"),
pid,
hostname,
thread_id: current_thread_id(),
parent_ctx: None,
}
}
pub(crate) fn get_hostname() -> String {
hostname::get().map_or_else(
|_| "unknown".to_string(),
|h| h.to_string_lossy().into_owned(),
)
}
}
tokio::task_local! {
pub static INVOCATION_CTX: InvocationContext;
pub static RUNNER_CTX: RunnerContext;
}
std::thread_local! {
static THREAD_RUNNER_CTX: std::cell::RefCell<Option<RunnerContext>> =
const { std::cell::RefCell::new(None) };
static THREAD_INVOCATION_CTX: std::cell::RefCell<Option<InvocationContext>> =
const { std::cell::RefCell::new(None) };
}
pub fn set_thread_runner_context(ctx: RunnerContext) {
THREAD_RUNNER_CTX.with(|cell| {
*cell.borrow_mut() = Some(ctx);
});
}
pub fn clear_thread_runner_context() {
THREAD_RUNNER_CTX.with(|cell| {
*cell.borrow_mut() = None;
});
}
pub fn set_thread_invocation_context(ctx: InvocationContext) {
THREAD_INVOCATION_CTX.with(|cell| {
*cell.borrow_mut() = Some(ctx);
});
}
pub fn clear_thread_invocation_context() {
THREAD_INVOCATION_CTX.with(|cell| {
*cell.borrow_mut() = None;
});
}
pub fn get_invocation_context() -> Option<InvocationContext> {
if let Ok(ctx) = INVOCATION_CTX.try_with(Clone::clone) {
return Some(ctx);
}
THREAD_INVOCATION_CTX.with(|cell| cell.borrow().clone())
}
pub fn with_invocation_context<F, R>(f: F) -> Option<R>
where
F: FnOnce(&InvocationContext) -> R,
{
get_invocation_context().as_ref().map(f)
}
pub fn get_runner_context() -> Option<RunnerContext> {
RUNNER_CTX.try_with(Clone::clone).ok()
}
pub fn with_runner_context<F, R>(f: F) -> Option<R>
where
F: FnOnce(&RunnerContext) -> R,
{
RUNNER_CTX.try_with(f).ok()
}
pub fn get_or_create_runner_id() -> RunnerId {
if let Some(rid) = with_runner_context(|ctx| ctx.runner_id.clone()) {
return rid;
}
if let Some(rid) =
THREAD_RUNNER_CTX.with(|cell| cell.borrow().as_ref().map(|ctx| ctx.runner_id.clone()))
{
return rid;
}
external_runner_id()
}
pub fn get_or_create_runner_context() -> RunnerContext {
if let Ok(ctx) = RUNNER_CTX.try_with(Clone::clone) {
return ctx;
}
if let Some(ctx) = THREAD_RUNNER_CTX.with(|cell| cell.borrow().clone()) {
return ctx;
}
RunnerContext::external()
}
fn external_runner_id() -> RunnerId {
let hostname = RunnerContext::get_hostname();
let pid = std::process::id();
RunnerId::from_string(format!("{hostname}-{pid}"))
}
#[cfg(test)]
mod tests {
use super::*;
fn sample_invocation_ctx() -> InvocationContext {
let inv_id = InvocationId::from_string("inv-1");
let task_id = TaskId::new("mod", "my_task");
InvocationContext {
invocation_id: inv_id.clone(),
task_id: task_id.clone(),
workflow: WorkflowIdentity::root(inv_id, task_id),
parent_invocation_id: None,
num_retries: 0,
}
}
fn sample_runner_ctx() -> RunnerContext {
RunnerContext::new(
RunnerId::from_string("runner-1"),
Arc::from("test-app"),
"TestRunner",
)
}
#[tokio::test]
async fn context_not_set_outside_scope() {
assert!(get_invocation_context().is_none());
assert!(get_runner_context().is_none());
}
#[tokio::test]
async fn invocation_context_set_get() {
let ctx = sample_invocation_ctx();
INVOCATION_CTX
.scope(ctx.clone(), async {
let got = get_invocation_context().unwrap();
assert_eq!(got.invocation_id, ctx.invocation_id);
assert_eq!(got.task_id, ctx.task_id);
assert!(got.parent_invocation_id.is_none());
})
.await;
}
#[tokio::test]
async fn runner_context_set_get() {
let ctx = sample_runner_ctx();
RUNNER_CTX
.scope(ctx, async {
let got = get_runner_context().unwrap();
assert_eq!(got.runner_id, RunnerId::from_string("runner-1"));
assert_eq!(&*got.app_id, "test-app");
})
.await;
}
#[tokio::test]
async fn nested_invocation_scopes() {
let outer = sample_invocation_ctx();
let inner = InvocationContext {
invocation_id: InvocationId::from_string("inv-inner"),
task_id: TaskId::new("mod", "inner_task"),
workflow: outer.workflow.clone(),
parent_invocation_id: Some(outer.invocation_id.clone()),
num_retries: 0,
};
INVOCATION_CTX
.scope(outer.clone(), async {
assert_eq!(
get_invocation_context().unwrap().invocation_id.as_str(),
"inv-1"
);
INVOCATION_CTX
.scope(inner, async {
let ctx = get_invocation_context().unwrap();
assert_eq!(ctx.invocation_id.as_str(), "inv-inner");
assert_eq!(ctx.parent_invocation_id.as_ref().unwrap().as_str(), "inv-1");
})
.await;
assert_eq!(
get_invocation_context().unwrap().invocation_id.as_str(),
"inv-1"
);
})
.await;
}
#[tokio::test]
async fn both_contexts_together() {
let inv_ctx = sample_invocation_ctx();
let run_ctx = sample_runner_ctx();
INVOCATION_CTX
.scope(
inv_ctx,
RUNNER_CTX.scope(run_ctx, async {
assert!(get_invocation_context().is_some());
assert!(get_runner_context().is_some());
}),
)
.await;
assert!(get_invocation_context().is_none());
assert!(get_runner_context().is_none());
}
}