adk-graph 1.0.0

Graph-based workflow orchestration for ADK-Rust agents
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
//! TaskContext — runtime context for functional API tasks.
//!
//! Provides access to state, checkpointing, interrupts, and streaming.
//! Passed to `#[entrypoint]` and `#[task]` annotated functions.

use std::collections::HashMap;
use std::sync::Arc;

use serde::de::DeserializeOwned;
use serde_json::Value;
use tokio::sync::RwLock;
use tokio_util::sync::CancellationToken;

use crate::checkpoint::Checkpointer;
use crate::error::Result;
use crate::state::{State, StateSchema};
use crate::stream::StreamEvent;

use super::error::FunctionalError;
use super::execution_log::ExecutionLog;
use super::schema::StateSchemaValidator;

/// Runtime context passed to `#[entrypoint]` and `#[task]` functions.
///
/// Provides access to workflow state, checkpointing, interrupt/resume,
/// and progress streaming. Each task function receives a mutable reference
/// to `TaskContext` enabling state reads, writes, event emission, and
/// interrupt requests.
///
/// # Example
///
/// ```rust,ignore
/// use adk_graph::functional::TaskContext;
///
/// #[task]
/// async fn my_step(ctx: &mut TaskContext) -> Result<Value> {
///     // Read state
///     let count: i64 = ctx.get("counter").unwrap_or(0);
///
///     // Write state
///     ctx.set("counter", serde_json::json!(count + 1));
///
///     // Emit progress
///     ctx.emit(StreamEvent::custom("my_step", "progress", serde_json::json!({"count": count + 1})));
///
///     Ok(serde_json::json!({"new_count": count + 1}))
/// }
/// ```
pub struct TaskContext {
    /// Thread identifier for checkpoint scoping.
    thread_id: String,
    /// Current workflow state.
    state: State,
    /// Checkpointer for persistence.
    checkpointer: Arc<dyn Checkpointer>,
    /// Stream event sender.
    event_tx: tokio::sync::broadcast::Sender<StreamEvent>,
    /// Task execution tracker (for skip-on-resume).
    execution_log: Arc<RwLock<ExecutionLog>>,
    /// Cancellation token.
    cancel_token: CancellationToken,
    /// State schema for validation.
    schema: Option<StateSchema>,
    /// State schema validator for functional API validation.
    schema_validator: Option<StateSchemaValidator>,
    /// Iteration counters for loop checkpoint keying.
    /// Maps task_name -> current iteration index.
    iteration_counters: HashMap<String, usize>,
    /// Pending dynamic route targets set by `route_to()`.
    /// Consumed by the executor to populate `EventActions.route`.
    pending_route: Option<Vec<String>>,
}

impl TaskContext {
    /// Create a new `TaskContext`.
    ///
    /// Typically constructed by the macro-generated entrypoint, not by user code.
    pub fn new(
        thread_id: String,
        state: State,
        checkpointer: Arc<dyn Checkpointer>,
        event_tx: tokio::sync::broadcast::Sender<StreamEvent>,
        execution_log: Arc<RwLock<ExecutionLog>>,
        cancel_token: CancellationToken,
        schema: Option<StateSchema>,
    ) -> Self {
        Self {
            thread_id,
            state,
            checkpointer,
            event_tx,
            execution_log,
            cancel_token,
            schema,
            schema_validator: None,
            iteration_counters: HashMap::new(),
            pending_route: None,
        }
    }

    // ─── Public Methods ──────────────────────────────────────────────────

    /// Get the current workflow state (read-only).
    pub fn state(&self) -> &State {
        &self.state
    }

    /// Get a typed value from state.
    ///
    /// Attempts to deserialize the value stored at `key` into type `T`.
    /// Returns `None` if the key does not exist or deserialization fails.
    ///
    /// # Example
    ///
    /// ```rust,ignore
    /// let count: Option<i64> = ctx.get("counter");
    /// ```
    pub fn get<T: DeserializeOwned>(&self, key: &str) -> Option<T> {
        self.state.get(key).and_then(|v| serde_json::from_value(v.clone()).ok())
    }

    /// Set a value in state.
    ///
    /// If a [`StateSchema`] is configured, the update is applied using the
    /// appropriate reducer for the key. Otherwise the value is set directly
    /// (overwrite semantics).
    ///
    /// # Example
    ///
    /// ```rust,ignore
    /// ctx.set("counter", serde_json::json!(42));
    /// ```
    pub fn set(&mut self, key: &str, value: impl Into<Value>) {
        let value = value.into();
        if let Some(schema) = &self.schema {
            schema.apply_update(&mut self.state, key, value);
        } else {
            self.state.insert(key.to_string(), value);
        }
    }

    /// Emit a progress event to stream listeners.
    ///
    /// Events are broadcast to all registered receivers. If no listeners
    /// are active the event is silently dropped.
    ///
    /// # Example
    ///
    /// ```rust,ignore
    /// ctx.emit(StreamEvent::custom("my_task", "progress", json!({"pct": 50})));
    /// ```
    pub fn emit(&self, event: StreamEvent) {
        // Ignore send errors — they indicate no active receivers.
        let _ = self.event_tx.send(event);
    }

    /// Interrupt execution and wait for external input.
    ///
    /// Persists the current state as an interrupt checkpoint, emits an
    /// interrupted event, and suspends execution. When the workflow is
    /// resumed with an interrupt value, the value is deserialized into `T`
    /// and returned.
    ///
    /// # Errors
    ///
    /// Returns [`FunctionalError::InterruptTypeMismatch`] if the resume
    /// value cannot be deserialized into `T`.
    ///
    /// Returns [`FunctionalError::CheckpointFailed`] if persisting the
    /// interrupt checkpoint fails.
    ///
    /// # Example
    ///
    /// ```rust,ignore
    /// let approval: bool = ctx.interrupt("Please approve this action").await?;
    /// ```
    pub async fn interrupt<T: DeserializeOwned>(&self, message: &str) -> Result<T> {
        // Emit the interrupt event for stream listeners.
        self.emit(StreamEvent::interrupted("functional_task", message));

        // Persist the interrupt checkpoint.
        let checkpoint = crate::state::Checkpoint::new(
            &self.thread_id,
            self.state.clone(),
            self.current_step().await,
            vec![],
        )
        .with_metadata("interrupt_message", Value::String(message.to_string()));

        self.checkpointer.save(&checkpoint).await.map_err(|e| {
            FunctionalError::CheckpointFailed {
                task: "interrupt".to_string(),
                message: e.to_string(),
            }
        })?;

        // Mark the current task as interrupted in the execution log.
        {
            let mut log = self.execution_log.write().await;
            log.tasks.entry("__interrupt__".to_string()).or_insert(
                super::execution_log::TaskRecord {
                    status: super::execution_log::TaskStatus::Interrupted,
                    result: None,
                    error: None,
                    started_at: chrono::Utc::now().to_rfc3339(),
                    completed_at: None,
                    attempt: 1,
                },
            );
        }

        // In a real runtime the workflow executor would suspend here and
        // later provide the resume value. For now we return an error
        // indicating the interrupt was requested — the macro-generated
        // wrapper handles actual suspension/resumption.
        Err(FunctionalError::InterruptTypeMismatch {
            task: "interrupt".to_string(),
            message: format!("workflow interrupted: {message}"),
        }
        .into())
    }

    /// Get the thread identifier for this context.
    pub fn thread_id(&self) -> &str {
        &self.thread_id
    }

    /// Get a reference to the cancellation token.
    pub fn cancel_token(&self) -> &CancellationToken {
        &self.cancel_token
    }

    /// Check if the workflow has been cancelled.
    pub fn is_cancelled(&self) -> bool {
        self.cancel_token.is_cancelled()
    }

    /// Get the current step number from the execution log.
    pub async fn current_step(&self) -> usize {
        self.execution_log.read().await.current_step()
    }

    /// Set a [`StateSchemaValidator`] for this context.
    ///
    /// When set, the validator is used to validate initial state at
    /// workflow start and task output before applying reducers.
    pub fn with_schema_validator(mut self, validator: StateSchemaValidator) -> Self {
        self.schema_validator = Some(validator);
        self
    }

    /// Get the schema validator, if configured.
    pub fn schema_validator(&self) -> Option<&StateSchemaValidator> {
        self.schema_validator.as_ref()
    }

    /// Validate the current state against the schema validator.
    ///
    /// Called at workflow start to validate initial state.
    ///
    /// # Errors
    ///
    /// Returns [`FunctionalError::SchemaValidation`] if validation fails.
    pub fn validate_state(&self) -> std::result::Result<(), FunctionalError> {
        if let Some(validator) = &self.schema_validator {
            validator.validate_state(&self.state)?;
        }
        Ok(())
    }

    /// Validate task output against the schema validator.
    ///
    /// Called after a task produces output, before applying reducers.
    ///
    /// # Errors
    ///
    /// Returns [`FunctionalError::SchemaValidation`] if validation fails.
    pub fn validate_task_output(&self, output: &State) -> std::result::Result<(), FunctionalError> {
        if let Some(validator) = &self.schema_validator {
            validator.validate_task_output(output)?;
        }
        Ok(())
    }

    // ─── Loop Iteration Checkpoint Keying ────────────────────────────────

    /// Generate a unique checkpoint key for a task inside a loop.
    ///
    /// Each call to this method for the same `task_name` increments the
    /// iteration counter, producing keys like `"step_a::iter_0"`,
    /// `"step_a::iter_1"`, etc. Keys are deterministic from task name
    /// and iteration index.
    ///
    /// # Example
    ///
    /// ```rust,ignore
    /// for item in items {
    ///     let key = ctx.iteration_key("process_item");
    ///     // key = "process_item::iter_0", "process_item::iter_1", ...
    /// }
    /// ```
    pub fn iteration_key(&mut self, task_name: &str) -> String {
        let counter = self.iteration_counters.entry(task_name.to_string()).or_insert(0);
        let key = format!("{task_name}::iter_{counter}");
        *counter += 1;
        key
    }

    /// Get the current iteration index for a task without incrementing.
    ///
    /// Returns `None` if the task has not been called in a loop yet.
    pub fn current_iteration(&self, task_name: &str) -> Option<usize> {
        self.iteration_counters.get(task_name).copied()
    }

    /// Reset the iteration counter for a task.
    ///
    /// Useful when re-entering a loop (e.g., nested loops or retry).
    pub fn reset_iteration(&mut self, task_name: &str) {
        self.iteration_counters.remove(task_name);
    }

    /// Reset all iteration counters.
    pub fn reset_all_iterations(&mut self) {
        self.iteration_counters.clear();
    }

    // ─── Route Dispatch ──────────────────────────────────────────────────

    /// Set dynamic route targets for this task's output.
    ///
    /// The execution framework will dispatch to these named tasks
    /// instead of following the declared order. The pending route is
    /// consumed by the executor after the task completes and used to
    /// populate `EventActions.route`.
    ///
    /// # Example
    ///
    /// ```rust,ignore
    /// ctx.route_to(&["process_a", "process_b"]);
    /// ```
    pub fn route_to(&mut self, targets: &[&str]) {
        self.pending_route = Some(targets.iter().map(|s| s.to_string()).collect());
    }

    /// Consume and return the pending route targets, if any.
    ///
    /// Used by the executor to retrieve the route set by `route_to()`
    /// and clear the pending state.
    pub fn take_pending_route(&mut self) -> Option<Vec<String>> {
        self.pending_route.take()
    }

    // ─── Internal Methods (pub(crate)) ───────────────────────────────────
    // These methods are used by the macro-generated task wrappers
    // (`#[entrypoint]` and `#[task]`), not directly by user code.

    /// Check if a task was already completed in a prior run (for resume-skip).
    ///
    /// Uses `try_read()` for synchronous non-blocking access. If the lock
    /// is held, conservatively returns `false` (the task will re-execute).
    /// For reliable resume-skip in async contexts, prefer [`Self::is_completed_async`].
    #[allow(dead_code)]
    #[doc(hidden)]
    pub fn is_completed(&self, task_id: &str) -> bool {
        // We need synchronous access here; use try_read to avoid blocking.
        // If the lock is held, conservatively return false (task will re-execute).
        match self.execution_log.try_read() {
            Ok(log) => log.is_completed(task_id),
            Err(_) => false,
        }
    }

    /// Async version of [`Self::is_completed`] for reliable resume-skip behavior.
    ///
    /// Awaits the read lock on the execution log to guarantee accurate
    /// completion status. Use this in async task wrappers where blocking
    /// is acceptable and correctness is required.
    #[allow(dead_code)]
    #[doc(hidden)]
    pub async fn is_completed_async(&self, task_id: &str) -> bool {
        self.execution_log.read().await.is_completed(task_id)
    }

    /// Get a cached result for a completed task (async).
    ///
    /// If the task is recorded as completed, returns a clone of its
    /// result value. Used by the resume-skip logic to return cached
    /// results without re-executing the task.
    #[allow(dead_code)]
    #[doc(hidden)]
    pub async fn get_cached_result(&self, task_id: &str) -> Option<Value> {
        self.execution_log.read().await.get_result(task_id).cloned()
    }

    /// Record task completion for checkpoint tracking.
    ///
    /// Marks the task as completed in the execution log, persists the
    /// current state to the checkpointer, and advances the step counter.
    /// Each task gets its own checkpoint record regardless of sibling
    /// task status (parallel task independence).
    #[allow(dead_code)]
    #[doc(hidden)]
    pub async fn record_completion(&self, task_id: &str, result: &Value) -> Result<()> {
        // Update execution log — each task is recorded independently.
        {
            let mut log = self.execution_log.write().await;
            log.record_completion(task_id, result.clone());
            log.advance_step();
        }

        // Persist checkpoint with current state and full execution log.
        let step = self.execution_log.read().await.current_step();
        let checkpoint =
            crate::state::Checkpoint::new(&self.thread_id, self.state.clone(), step, vec![])
                .with_metadata("completed_task", Value::String(task_id.to_string()))
                .with_metadata(
                    "execution_log",
                    serde_json::to_value(&*self.execution_log.read().await).unwrap_or(Value::Null),
                );

        self.checkpointer.save(&checkpoint).await.map_err(|e| {
            FunctionalError::CheckpointFailed { task: task_id.to_string(), message: e.to_string() }
        })?;

        Ok(())
    }

    /// Record task failure.
    ///
    /// Marks the task as failed in the execution log and persists a
    /// failure checkpoint containing the error details. Each task failure
    /// is recorded independently (parallel task independence).
    #[allow(dead_code)]
    #[doc(hidden)]
    pub async fn record_failure(&self, task_id: &str, error: &str) -> Result<()> {
        // Update execution log — each task failure is independent.
        {
            let mut log = self.execution_log.write().await;
            log.record_failure(task_id, error);
        }

        // Persist failure checkpoint with error context.
        let step = self.execution_log.read().await.current_step();
        let checkpoint =
            crate::state::Checkpoint::new(&self.thread_id, self.state.clone(), step, vec![])
                .with_metadata("failed_task", Value::String(task_id.to_string()))
                .with_metadata("error", Value::String(error.to_string()))
                .with_metadata(
                    "execution_log",
                    serde_json::to_value(&*self.execution_log.read().await).unwrap_or(Value::Null),
                );

        self.checkpointer.save(&checkpoint).await.map_err(|e| {
            FunctionalError::CheckpointFailed { task: task_id.to_string(), message: e.to_string() }
        })?;

        Ok(())
    }

    /// Record that a task has started executing.
    ///
    /// Marks the task as running in the execution log. Used by the
    /// macro-generated task wrapper before executing the task body.
    #[allow(dead_code)]
    #[doc(hidden)]
    pub async fn record_start(&self, task_id: &str) {
        let mut log = self.execution_log.write().await;
        log.record_start(task_id);
    }
}