use std::sync::Arc;
use tokio::time::Instant;
use crate::audit::AuditSinkHandle;
use crate::cancellation::CancellationToken;
use crate::extensions::Extensions;
use crate::tenant_id::TenantId;
use crate::tools::{
CurrentToolInvocation, ToolProgress, ToolProgressSinkHandle, ToolProgressStatus,
};
#[derive(Clone, Debug)]
#[non_exhaustive]
pub struct ExecutionContext {
cancellation: CancellationToken,
deadline: Option<Instant>,
thread_id: Option<String>,
tenant_id: TenantId,
run_id: Option<String>,
parent_run_id: Option<String>,
idempotency_key: Option<Arc<str>>,
extensions: Extensions,
}
impl Default for ExecutionContext {
fn default() -> Self {
Self::new()
}
}
impl ExecutionContext {
pub fn new() -> Self {
Self {
cancellation: CancellationToken::new(),
deadline: None,
thread_id: None,
tenant_id: TenantId::default(),
run_id: None,
parent_run_id: None,
idempotency_key: None,
extensions: Extensions::new(),
}
}
pub fn with_cancellation(cancellation: CancellationToken) -> Self {
Self {
cancellation,
deadline: None,
thread_id: None,
tenant_id: TenantId::default(),
run_id: None,
parent_run_id: None,
idempotency_key: None,
extensions: Extensions::new(),
}
}
#[must_use]
pub const fn with_deadline(mut self, deadline: Instant) -> Self {
self.deadline = Some(deadline);
self
}
#[must_use]
pub fn with_thread_id(mut self, thread_id: impl Into<String>) -> Self {
self.thread_id = Some(thread_id.into());
self
}
#[must_use]
pub fn with_tenant_id(mut self, tenant_id: TenantId) -> Self {
self.tenant_id = tenant_id;
self
}
#[must_use]
pub fn with_run_id(mut self, run_id: impl Into<String>) -> Self {
self.run_id = Some(run_id.into());
self
}
#[must_use]
pub fn with_parent_run_id(mut self, parent_run_id: impl Into<String>) -> Self {
self.parent_run_id = Some(parent_run_id.into());
self
}
#[must_use]
pub fn with_run_budget(self, budget: crate::RunBudget) -> Self {
self.add_extension(budget)
}
#[must_use]
pub fn run_budget(&self) -> Option<std::sync::Arc<crate::RunBudget>> {
self.extension::<crate::RunBudget>()
}
#[must_use]
pub fn with_idempotency_key(mut self, key: impl Into<String>) -> Self {
self.idempotency_key = Some(Arc::from(key.into()));
self
}
pub fn ensure_idempotency_key<F>(&mut self, generate: F) -> &str
where
F: FnOnce() -> String,
{
if self.idempotency_key.is_none() {
self.idempotency_key = Some(Arc::from(generate()));
}
self.idempotency_key.as_deref().unwrap_or("")
}
#[must_use]
pub fn with_audit_sink(self, handle: AuditSinkHandle) -> Self {
self.add_extension(handle)
}
#[must_use]
pub fn audit_sink(&self) -> Option<Arc<AuditSinkHandle>> {
self.extension::<AuditSinkHandle>()
}
#[must_use]
pub fn with_tool_progress_sink(self, handle: ToolProgressSinkHandle) -> Self {
self.add_extension(handle)
}
#[must_use]
pub fn tool_progress_sink(&self) -> Option<Arc<ToolProgressSinkHandle>> {
self.extension::<ToolProgressSinkHandle>()
}
pub async fn record_phase(&self, phase: impl Into<String> + Send, status: ToolProgressStatus) {
self.record_phase_with(phase, status, serde_json::Value::Null)
.await;
}
pub async fn record_phase_with(
&self,
phase: impl Into<String> + Send,
status: ToolProgressStatus,
metadata: serde_json::Value,
) {
let Some(sink) = self.tool_progress_sink() else {
return;
};
let Some(current) = self.extension::<CurrentToolInvocation>() else {
return;
};
let progress = ToolProgress {
run_id: self.run_id().map(str::to_owned).unwrap_or_default(),
tool_use_id: current.tool_use_id().to_owned(),
tool_name: current.tool_name().to_owned(),
phase: phase.into(),
status,
dispatch_elapsed_ms: current.dispatch_elapsed_ms(),
metadata,
};
sink.inner().record_progress(progress).await;
}
#[must_use]
pub fn add_extension<T>(mut self, value: T) -> Self
where
T: Send + Sync + 'static,
{
self.extensions = self.extensions.inserted(value);
self
}
pub const fn cancellation(&self) -> &CancellationToken {
&self.cancellation
}
pub const fn deadline(&self) -> Option<Instant> {
self.deadline
}
pub fn thread_id(&self) -> Option<&str> {
self.thread_id.as_deref()
}
pub const fn tenant_id(&self) -> &TenantId {
&self.tenant_id
}
pub fn run_id(&self) -> Option<&str> {
self.run_id.as_deref()
}
pub fn parent_run_id(&self) -> Option<&str> {
self.parent_run_id.as_deref()
}
pub fn idempotency_key(&self) -> Option<&str> {
self.idempotency_key.as_deref()
}
pub const fn extensions(&self) -> &Extensions {
&self.extensions
}
#[must_use]
pub fn extension<T>(&self) -> Option<Arc<T>>
where
T: Send + Sync + 'static,
{
self.extensions.get::<T>()
}
pub fn is_cancelled(&self) -> bool {
self.cancellation.is_cancelled()
}
#[must_use]
pub fn child(&self) -> Self {
Self {
cancellation: self.cancellation.child_token(),
deadline: self.deadline,
thread_id: self.thread_id.clone(),
tenant_id: self.tenant_id.clone(),
run_id: self.run_id.clone(),
parent_run_id: self.parent_run_id.clone(),
idempotency_key: self.idempotency_key.clone(),
extensions: self.extensions.clone(),
}
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod extension_tests {
use super::*;
#[derive(Debug, PartialEq, Eq)]
struct WorkspaceCtx {
repo: &'static str,
}
#[test]
fn fresh_context_has_no_extensions() {
let ctx = ExecutionContext::new();
assert!(ctx.extensions().is_empty());
assert!(ctx.extension::<WorkspaceCtx>().is_none());
}
#[test]
fn add_extension_threads_typed_value() {
let ctx = ExecutionContext::new().add_extension(WorkspaceCtx { repo: "entelix" });
let got = ctx.extension::<WorkspaceCtx>().unwrap();
assert_eq!(*got, WorkspaceCtx { repo: "entelix" });
assert_eq!(ctx.extensions().len(), 1);
}
#[test]
fn add_extension_is_copy_on_write() {
let original = ExecutionContext::new();
let extended = original
.clone()
.add_extension(WorkspaceCtx { repo: "entelix" });
assert!(original.extension::<WorkspaceCtx>().is_none());
assert!(extended.extension::<WorkspaceCtx>().is_some());
}
#[test]
fn child_inherits_extensions() {
let parent = ExecutionContext::new().add_extension(WorkspaceCtx { repo: "entelix" });
let child = parent.child();
let got = child.extension::<WorkspaceCtx>().unwrap();
assert_eq!(*got, WorkspaceCtx { repo: "entelix" });
}
#[test]
fn extension_arc_outlives_dropped_context() {
let ctx = ExecutionContext::new().add_extension(WorkspaceCtx { repo: "entelix" });
let arc = ctx.extension::<WorkspaceCtx>().unwrap();
drop(ctx);
assert_eq!(*arc, WorkspaceCtx { repo: "entelix" });
}
}