use std::collections::HashMap;
use std::sync::Arc;
use serde::de::DeserializeOwned;
use serde_json::Value;
use tokio::sync::RwLock;
use tokio_util::sync::CancellationToken;
use crate::checkpoint::Checkpointer;
use crate::error::Result;
use crate::state::{State, StateSchema};
use crate::stream::StreamEvent;
use super::error::FunctionalError;
use super::execution_log::ExecutionLog;
use super::schema::StateSchemaValidator;
pub struct TaskContext {
thread_id: String,
state: State,
checkpointer: Arc<dyn Checkpointer>,
event_tx: tokio::sync::broadcast::Sender<StreamEvent>,
execution_log: Arc<RwLock<ExecutionLog>>,
cancel_token: CancellationToken,
schema: Option<StateSchema>,
schema_validator: Option<StateSchemaValidator>,
iteration_counters: HashMap<String, usize>,
pending_route: Option<Vec<String>>,
}
impl TaskContext {
pub fn new(
thread_id: String,
state: State,
checkpointer: Arc<dyn Checkpointer>,
event_tx: tokio::sync::broadcast::Sender<StreamEvent>,
execution_log: Arc<RwLock<ExecutionLog>>,
cancel_token: CancellationToken,
schema: Option<StateSchema>,
) -> Self {
Self {
thread_id,
state,
checkpointer,
event_tx,
execution_log,
cancel_token,
schema,
schema_validator: None,
iteration_counters: HashMap::new(),
pending_route: None,
}
}
pub fn state(&self) -> &State {
&self.state
}
pub fn get<T: DeserializeOwned>(&self, key: &str) -> Option<T> {
self.state.get(key).and_then(|v| serde_json::from_value(v.clone()).ok())
}
pub fn set(&mut self, key: &str, value: impl Into<Value>) {
let value = value.into();
if let Some(schema) = &self.schema {
schema.apply_update(&mut self.state, key, value);
} else {
self.state.insert(key.to_string(), value);
}
}
pub fn emit(&self, event: StreamEvent) {
let _ = self.event_tx.send(event);
}
pub async fn interrupt<T: DeserializeOwned>(&self, message: &str) -> Result<T> {
self.emit(StreamEvent::interrupted("functional_task", message));
let checkpoint = crate::state::Checkpoint::new(
&self.thread_id,
self.state.clone(),
self.current_step().await,
vec![],
)
.with_metadata("interrupt_message", Value::String(message.to_string()));
self.checkpointer.save(&checkpoint).await.map_err(|e| {
FunctionalError::CheckpointFailed {
task: "interrupt".to_string(),
message: e.to_string(),
}
})?;
{
let mut log = self.execution_log.write().await;
log.tasks.entry("__interrupt__".to_string()).or_insert(
super::execution_log::TaskRecord {
status: super::execution_log::TaskStatus::Interrupted,
result: None,
error: None,
started_at: chrono::Utc::now().to_rfc3339(),
completed_at: None,
attempt: 1,
},
);
}
Err(FunctionalError::InterruptTypeMismatch {
task: "interrupt".to_string(),
message: format!("workflow interrupted: {message}"),
}
.into())
}
pub fn thread_id(&self) -> &str {
&self.thread_id
}
pub fn cancel_token(&self) -> &CancellationToken {
&self.cancel_token
}
pub fn is_cancelled(&self) -> bool {
self.cancel_token.is_cancelled()
}
pub async fn current_step(&self) -> usize {
self.execution_log.read().await.current_step()
}
pub fn with_schema_validator(mut self, validator: StateSchemaValidator) -> Self {
self.schema_validator = Some(validator);
self
}
pub fn schema_validator(&self) -> Option<&StateSchemaValidator> {
self.schema_validator.as_ref()
}
pub fn validate_state(&self) -> std::result::Result<(), FunctionalError> {
if let Some(validator) = &self.schema_validator {
validator.validate_state(&self.state)?;
}
Ok(())
}
pub fn validate_task_output(&self, output: &State) -> std::result::Result<(), FunctionalError> {
if let Some(validator) = &self.schema_validator {
validator.validate_task_output(output)?;
}
Ok(())
}
pub fn iteration_key(&mut self, task_name: &str) -> String {
let counter = self.iteration_counters.entry(task_name.to_string()).or_insert(0);
let key = format!("{task_name}::iter_{counter}");
*counter += 1;
key
}
pub fn current_iteration(&self, task_name: &str) -> Option<usize> {
self.iteration_counters.get(task_name).copied()
}
pub fn reset_iteration(&mut self, task_name: &str) {
self.iteration_counters.remove(task_name);
}
pub fn reset_all_iterations(&mut self) {
self.iteration_counters.clear();
}
pub fn route_to(&mut self, targets: &[&str]) {
self.pending_route = Some(targets.iter().map(|s| s.to_string()).collect());
}
pub fn take_pending_route(&mut self) -> Option<Vec<String>> {
self.pending_route.take()
}
#[allow(dead_code)]
#[doc(hidden)]
pub fn is_completed(&self, task_id: &str) -> bool {
match self.execution_log.try_read() {
Ok(log) => log.is_completed(task_id),
Err(_) => false,
}
}
#[allow(dead_code)]
#[doc(hidden)]
pub async fn is_completed_async(&self, task_id: &str) -> bool {
self.execution_log.read().await.is_completed(task_id)
}
#[allow(dead_code)]
#[doc(hidden)]
pub async fn get_cached_result(&self, task_id: &str) -> Option<Value> {
self.execution_log.read().await.get_result(task_id).cloned()
}
#[allow(dead_code)]
#[doc(hidden)]
pub async fn record_completion(&self, task_id: &str, result: &Value) -> Result<()> {
{
let mut log = self.execution_log.write().await;
log.record_completion(task_id, result.clone());
log.advance_step();
}
let step = self.execution_log.read().await.current_step();
let checkpoint =
crate::state::Checkpoint::new(&self.thread_id, self.state.clone(), step, vec![])
.with_metadata("completed_task", Value::String(task_id.to_string()))
.with_metadata(
"execution_log",
serde_json::to_value(&*self.execution_log.read().await).unwrap_or(Value::Null),
);
self.checkpointer.save(&checkpoint).await.map_err(|e| {
FunctionalError::CheckpointFailed { task: task_id.to_string(), message: e.to_string() }
})?;
Ok(())
}
#[allow(dead_code)]
#[doc(hidden)]
pub async fn record_failure(&self, task_id: &str, error: &str) -> Result<()> {
{
let mut log = self.execution_log.write().await;
log.record_failure(task_id, error);
}
let step = self.execution_log.read().await.current_step();
let checkpoint =
crate::state::Checkpoint::new(&self.thread_id, self.state.clone(), step, vec![])
.with_metadata("failed_task", Value::String(task_id.to_string()))
.with_metadata("error", Value::String(error.to_string()))
.with_metadata(
"execution_log",
serde_json::to_value(&*self.execution_log.read().await).unwrap_or(Value::Null),
);
self.checkpointer.save(&checkpoint).await.map_err(|e| {
FunctionalError::CheckpointFailed { task: task_id.to_string(), message: e.to_string() }
})?;
Ok(())
}
#[allow(dead_code)]
#[doc(hidden)]
pub async fn record_start(&self, task_id: &str) {
let mut log = self.execution_log.write().await;
log.record_start(task_id);
}
}