Skip to main content

akribes_sdk/
suspend.rs

1//! SDK-facing mirror of `akribes_types::event::SuspendTrigger` and friends.
2//!
3//! `akribes-core` is the source of truth for the `Suspended` event wire shape.
4//! It already defines [`akribes_types::event::SuspendTrigger`] with the three
5//! variants the server emits (`DagPosition`, `ValidationExhausted`,
6//! `AgentUnable`). We re-mirror the shape at the SDK layer for two reasons:
7//!
8//! 1. **Forward-compat.** The SDK mirror carries an [`Unknown`] catch-all
9//!    via `#[serde(other)]` so a newer server emitting a future variant
10//!    never crashes the SDK — the suspension surfaces as `Unknown` and the
11//!    raw payload is still available on the wire for consumers to inspect
12//!    via [`crate::WorkflowEvent::Other`] / raw [`akribes_types::event::EngineEvent`]
13//!    access. The core enum is deliberately not marked `#[serde(other)]`
14//!    because akribes-core tests exhaustiveness of its own variants; the
15//!    forward-compat contract lives at the SDK boundary per the Wave-4
16//!    tracker decisions.
17//! 2. **Stable public surface.** SDK consumers don't have to reach into
18//!    `akribes_types::*` for common wire types — [`SuspendTrigger`],
19//!    [`UnableRecord`], and [`ValidationErrorWire`] are re-exported at the
20//!    SDK crate root.
21//!
22//! Conversions from the core shape are provided so the SDK's
23//! [`crate::WorkflowEvent::Checkpoint`] can carry a typed trigger without
24//! leaking `akribes_types::event::*` imports to consumers.
25//!
26//! [`Unknown`]: SuspendTrigger::Unknown
27
28use akribes_types::event as core_event;
29use serde::{Deserialize, Serialize};
30
31/// Wire-format twin of [`akribes_types::validation::ValidationError`].
32///
33/// Owned + serializable; the `stage` discriminator is a string (`"parse"`,
34/// `"schema"`, `"custom:<rule>"`) so SDK consumers don't need to round-trip
35/// through the internal enum.
36#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
37pub struct ValidationErrorWire {
38    pub stage: String,
39    pub message: String,
40    pub path: Option<String>,
41}
42
43impl From<core_event::ValidationErrorWire> for ValidationErrorWire {
44    fn from(v: core_event::ValidationErrorWire) -> Self {
45        Self {
46            stage: v.stage,
47            message: v.message,
48            path: v.path,
49        }
50    }
51}
52
53/// Structured "I can't" payload from an agent with a `T | Unable` return
54/// type. Canonical wire envelope is `{ "unable": { reason, missing, category } }`;
55/// this record is the payload after the envelope key is stripped.
56///
57/// `category` is kept as a free-form `String` at the SDK layer so forward
58/// compat on category names doesn't require an SDK release. The core enum
59/// [`akribes_types::value::UnableCategory`] lists the current five canonical
60/// buckets (`input_missing`, `input_ambiguous`, `input_conflicts`,
61/// `capability`, `other`); consumers can parse into that enum if they want
62/// strict typing.
63#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
64pub struct UnableRecord {
65    pub reason: String,
66    #[serde(default)]
67    pub missing: Vec<String>,
68    pub category: String,
69}
70
71/// Why the engine suspended execution at a checkpoint.
72///
73/// Serde-tagged with an internal `"kind"` discriminator matching
74/// [`akribes_types::event::SuspendTrigger`]. Unknown discriminants deserialize
75/// to [`SuspendTrigger::Unknown`] (via `#[serde(other)]`) so the SDK is
76/// forward-compatible with future akribes-core / server additions without a
77/// new SDK release.
78#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
79#[serde(tag = "kind")]
80#[derive(Default)]
81pub enum SuspendTrigger {
82    /// The DAG reached an explicit `checkpoint cp(...)` call site.
83    #[default]
84    DagPosition,
85    /// `on_validation_exhausted:` fired — retries consumed without
86    /// producing a payload that passes parse → schema → custom validation.
87    ValidationExhausted {
88        task_name: String,
89        retry_count: u32,
90        last_attempt: String,
91        validation_errors: Vec<ValidationErrorWire>,
92    },
93    /// A task with a `T | Unable` return type produced an `Unable` value
94    /// and the flow routed it to a checkpoint via `on unable <cp>`.
95    AgentUnable {
96        task_name: String,
97        unable: UnableRecord,
98    },
99    /// A task with a discriminated-union return type
100    /// (`A | B | ... | Unable`) produced a non-Unable variant and the flow
101    /// routed it to a checkpoint via `on <Variant> <cp>`. `variant` is
102    /// the record name as declared in source (PascalCase); `payload` is
103    /// the parsed record (with `kind` stripped).
104    AgentVariant {
105        task_name: String,
106        variant: String,
107        payload: serde_json::Value,
108    },
109    /// Catch-all for discriminants the SDK doesn't recognize (e.g. a
110    /// variant added by a newer akribes-core). The raw `Suspended.trigger`
111    /// payload is not preserved here — consumers that need full-fidelity
112    /// unknown handling can read it off the raw
113    /// [`akribes_types::event::EngineEvent::Suspended`] instead of the
114    /// normalized [`crate::WorkflowEvent`].
115    #[serde(other)]
116    Unknown,
117}
118
119impl From<core_event::SuspendTrigger> for SuspendTrigger {
120    fn from(t: core_event::SuspendTrigger) -> Self {
121        match t {
122            core_event::SuspendTrigger::DagPosition => SuspendTrigger::DagPosition,
123            core_event::SuspendTrigger::ValidationExhausted {
124                task_name,
125                retry_count,
126                last_attempt,
127                validation_errors,
128            } => SuspendTrigger::ValidationExhausted {
129                task_name,
130                retry_count,
131                last_attempt,
132                validation_errors: validation_errors.into_iter().map(Into::into).collect(),
133            },
134            core_event::SuspendTrigger::AgentUnable { task_name, unable } => {
135                SuspendTrigger::AgentUnable {
136                    task_name,
137                    unable: UnableRecord {
138                        reason: unable.reason,
139                        missing: unable.missing,
140                        category: unable.category.as_wire_str().to_string(),
141                    },
142                }
143            }
144            core_event::SuspendTrigger::AgentVariant {
145                task_name,
146                variant,
147                payload,
148            } => SuspendTrigger::AgentVariant {
149                task_name,
150                variant,
151                payload,
152            },
153        }
154    }
155}
156
157#[cfg(test)]
158mod tests {
159    use super::*;
160    use serde_json::json;
161
162    // ── DagPosition round-trip ───────────────────────────────────────────────
163
164    #[test]
165    fn dag_position_roundtrips_byte_identical() {
166        let wire = r#"{"kind":"DagPosition"}"#;
167        let parsed: SuspendTrigger = serde_json::from_str(wire).unwrap();
168        assert!(matches!(parsed, SuspendTrigger::DagPosition));
169        let reserialized = serde_json::to_string(&parsed).unwrap();
170        assert_eq!(reserialized, wire);
171    }
172
173    // ── ValidationExhausted round-trip ───────────────────────────────────────
174
175    #[test]
176    fn validation_exhausted_roundtrips_byte_identical() {
177        // Field order mirrors the SDK struct declaration order — serde_json
178        // serializes in declaration order for derived Serialize, so a wire
179        // sample built in the same order should be byte-identical.
180        let wire = r#"{"kind":"ValidationExhausted","task_name":"decompose_claims","retry_count":3,"last_attempt":"{\"bad\":true}","validation_errors":[{"stage":"schema","message":"required property \"number\" missing","path":"/0"}]}"#;
181        let parsed: SuspendTrigger = serde_json::from_str(wire).unwrap();
182        match &parsed {
183            SuspendTrigger::ValidationExhausted {
184                task_name,
185                retry_count,
186                last_attempt,
187                validation_errors,
188            } => {
189                assert_eq!(task_name, "decompose_claims");
190                assert_eq!(*retry_count, 3);
191                assert_eq!(last_attempt, r#"{"bad":true}"#);
192                assert_eq!(validation_errors.len(), 1);
193                assert_eq!(validation_errors[0].stage, "schema");
194                assert_eq!(validation_errors[0].path.as_deref(), Some("/0"));
195            }
196            other => panic!("expected ValidationExhausted, got {other:?}"),
197        }
198        let reserialized = serde_json::to_string(&parsed).unwrap();
199        assert_eq!(reserialized, wire);
200    }
201
202    // ── AgentUnable round-trip ───────────────────────────────────────────────
203
204    #[test]
205    fn agent_unable_roundtrips_byte_identical() {
206        let wire = r#"{"kind":"AgentUnable","task_name":"escalate","unable":{"reason":"image too blurry to OCR","missing":["claim_text"],"category":"input_ambiguous"}}"#;
207        let parsed: SuspendTrigger = serde_json::from_str(wire).unwrap();
208        match &parsed {
209            SuspendTrigger::AgentUnable { task_name, unable } => {
210                assert_eq!(task_name, "escalate");
211                assert_eq!(unable.reason, "image too blurry to OCR");
212                assert_eq!(unable.missing, vec!["claim_text".to_string()]);
213                assert_eq!(unable.category, "input_ambiguous");
214            }
215            other => panic!("expected AgentUnable, got {other:?}"),
216        }
217        let reserialized = serde_json::to_string(&parsed).unwrap();
218        assert_eq!(reserialized, wire);
219    }
220
221    #[test]
222    fn agent_unable_accepts_missing_field_default() {
223        // `missing` defaults to `[]` so older/minimal payloads still parse.
224        let wire = json!({
225            "kind": "AgentUnable",
226            "task_name": "t",
227            "unable": { "reason": "x", "category": "other" },
228        });
229        let parsed: SuspendTrigger = serde_json::from_value(wire).unwrap();
230        match parsed {
231            SuspendTrigger::AgentUnable { unable, .. } => {
232                assert!(unable.missing.is_empty());
233            }
234            other => panic!("expected AgentUnable, got {other:?}"),
235        }
236    }
237
238    // ── Unknown variant passthrough ──────────────────────────────────────────
239
240    #[test]
241    fn unknown_kind_deserializes_to_unknown_variant() {
242        // A future akribes-core release might add a new discriminant. The SDK
243        // must not crash — it forwards as `Unknown`.
244        let wire = json!({
245            "kind": "SomeFutureVariant",
246            "extra_field": 42,
247        });
248        let parsed: SuspendTrigger = serde_json::from_value(wire).unwrap();
249        assert!(matches!(parsed, SuspendTrigger::Unknown));
250    }
251
252    #[test]
253    fn unknown_kind_with_no_extra_fields_still_parses() {
254        let parsed: SuspendTrigger = serde_json::from_str(r#"{"kind":"Nope"}"#).unwrap();
255        assert!(matches!(parsed, SuspendTrigger::Unknown));
256    }
257
258    // ── Interop with akribes-core ───────────────────────────────────────────────
259
260    #[test]
261    fn converts_from_core_dag_position() {
262        let core = core_event::SuspendTrigger::DagPosition;
263        let sdk: SuspendTrigger = core.into();
264        assert!(matches!(sdk, SuspendTrigger::DagPosition));
265    }
266
267    #[test]
268    fn converts_from_core_validation_exhausted() {
269        let core = core_event::SuspendTrigger::ValidationExhausted {
270            task_name: "t".into(),
271            retry_count: 2,
272            last_attempt: "{}".into(),
273            validation_errors: vec![core_event::ValidationErrorWire {
274                stage: "parse".into(),
275                message: "bad json".into(),
276                path: None,
277            }],
278        };
279        let sdk: SuspendTrigger = core.into();
280        match sdk {
281            SuspendTrigger::ValidationExhausted {
282                task_name,
283                retry_count,
284                validation_errors,
285                ..
286            } => {
287                assert_eq!(task_name, "t");
288                assert_eq!(retry_count, 2);
289                assert_eq!(validation_errors[0].stage, "parse");
290            }
291            other => panic!("expected ValidationExhausted, got {other:?}"),
292        }
293    }
294
295    #[test]
296    fn converts_from_core_agent_unable() {
297        let core = core_event::SuspendTrigger::AgentUnable {
298            task_name: "escalate".into(),
299            unable: akribes_types::value::UnableRecord {
300                reason: "blurry".into(),
301                missing: vec!["claim_text".into()],
302                category: akribes_types::value::UnableCategory::InputAmbiguous,
303            },
304        };
305        let sdk: SuspendTrigger = core.into();
306        match sdk {
307            SuspendTrigger::AgentUnable { task_name, unable } => {
308                assert_eq!(task_name, "escalate");
309                assert_eq!(unable.reason, "blurry");
310                assert_eq!(unable.category, "input_ambiguous");
311                assert_eq!(unable.missing, vec!["claim_text".to_string()]);
312            }
313            other => panic!("expected AgentUnable, got {other:?}"),
314        }
315    }
316
317    #[test]
318    fn default_is_dag_position() {
319        assert!(matches!(
320            SuspendTrigger::default(),
321            SuspendTrigger::DagPosition
322        ));
323    }
324}