Skip to main content

agy_bridge/
types.rs

1//! Core types for the agent SDK bridge.
2//!
3//! This module defines the data structures that model an agent's execution
4//! trajectory: individual [`Step`](crate::types::Step)s,
5//! [`ToolCallInfo`](crate::types::ToolCallInfo) requests,
6//! [`ToolResult`](crate::types::ToolResult) responses, and
7//! [`UsageMetadata`](crate::types::UsageMetadata) for token accounting. All types derive
8//! `Serialize`/`Deserialize` for JSON interchange with the Python SDK.
9
10use std::{fmt, str::FromStr};
11
12use serde::{Deserialize, Serialize};
13
14// =============================================================================
15// Step / ToolCall / ToolResult types (§1.6)
16// =============================================================================
17
18/// Define an SDK enum with `SCREAMING_SNAKE_CASE` serde rename and auto-generated
19/// `Display` and `FromStr` impls.
20///
21/// Each variant maps to a wire-format string. Unrecognized strings parse as `Err`
22/// via `FromStr` — they never panic.
23///
24/// # Syntax
25///
26/// ```text
27/// define_sdk_enum! {
28///     /// Doc comment for the enum.
29///     EnumName {
30///         Variant1 => "WIRE_STRING_1",
31///         Variant2 => "WIRE_STRING_2",
32///         #[default]
33///         Unknown => "UNKNOWN",
34///     }
35/// }
36/// ```
37macro_rules! define_sdk_enum {
38    (
39        $(#[$meta:meta])*
40        $name:ident {
41            $(
42                $(#[$vmeta:meta])*
43                $variant:ident => $wire:literal
44            ),+ $(,)?
45        }
46    ) => {
47        $(#[$meta])*
48        #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
49        #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
50        pub enum $name {
51            $(
52                $(#[$vmeta])*
53                $variant,
54            )+
55        }
56
57        impl fmt::Display for $name {
58            fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
59                let s = match self {
60                    $( Self::$variant => $wire, )+
61                };
62                f.write_str(s)
63            }
64        }
65
66        impl FromStr for $name {
67            type Err = String;
68
69            fn from_str(s: &str) -> Result<Self, Self::Err> {
70                match s {
71                    $( $wire => Ok(Self::$variant), )+
72                    other => Err(format!(concat!("Unrecognized ", stringify!($name), ": {:?}"), other)),
73                }
74            }
75        }
76    };
77}
78
79/// Like [`define_sdk_enum!`] but for enums where the serde wire format uses
80/// per-variant `#[serde(rename = "...")]` instead of `rename_all`.
81///
82/// This is needed for [`StepTarget`] whose SDK strings have a `TARGET_` prefix
83/// that doesn't match the `SCREAMING_SNAKE_CASE` of the enum name.
84macro_rules! define_sdk_enum_custom_serde {
85    (
86        $(#[$meta:meta])*
87        $name:ident {
88            $(
89                $(#[$vmeta:meta])*
90                $variant:ident => $wire:literal
91            ),+ $(,)?
92        }
93    ) => {
94        $(#[$meta])*
95        #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
96        pub enum $name {
97            $(
98                $(#[$vmeta])*
99                #[serde(rename = $wire)]
100                $variant,
101            )+
102        }
103
104        impl fmt::Display for $name {
105            fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
106                let s = match self {
107                    $( Self::$variant => $wire, )+
108                };
109                f.write_str(s)
110            }
111        }
112
113        impl FromStr for $name {
114            type Err = String;
115
116            fn from_str(s: &str) -> Result<Self, Self::Err> {
117                match s {
118                    $( $wire => Ok(Self::$variant), )+
119                    other => Err(format!(concat!("Unrecognized ", stringify!($name), ": {:?}"), other)),
120                }
121            }
122        }
123    };
124}
125
126define_sdk_enum! {
127    /// The high-level type of a step in the agent trajectory.
128    StepType {
129        TextResponse => "TEXT_RESPONSE",
130        ToolCall => "TOOL_CALL",
131        SystemMessage => "SYSTEM_MESSAGE",
132        Compaction => "COMPACTION",
133        Finish => "FINISH",
134        #[default]
135        Unknown => "UNKNOWN",
136    }
137}
138
139define_sdk_enum! {
140    /// The source that generated a step.
141    StepSource {
142        System => "SYSTEM",
143        User => "USER",
144        Model => "MODEL",
145        #[default]
146        Unknown => "UNKNOWN",
147    }
148}
149
150define_sdk_enum! {
151    /// The execution status of a step.
152    StepStatus {
153        Active => "ACTIVE",
154        Done => "DONE",
155        WaitingForUser => "WAITING_FOR_USER",
156        Error => "ERROR",
157        Canceled => "CANCELED",
158        #[default]
159        Unknown => "UNKNOWN",
160    }
161}
162
163define_sdk_enum_custom_serde! {
164    /// Target of a step interaction, mirroring the Python SDK's `StepTarget`.
165    ///
166    /// The Python SDK uses `TARGET_` prefixed strings (e.g. `TARGET_USER`).
167    /// Uses per-variant `#[serde(rename)]` because the SDK's wire format has a
168    /// `TARGET_` prefix that doesn't follow `SCREAMING_SNAKE_CASE` of the enum name.
169    StepTarget {
170        /// Step is directed at the model.
171        Model => "TARGET_MODEL",
172        /// Step is directed at the user.
173        User => "TARGET_USER",
174        /// Step is directed at the environment (tool execution).
175        Environment => "TARGET_ENVIRONMENT",
176        /// Target is unspecified.
177        Unspecified => "TARGET_UNSPECIFIED",
178        /// Unknown target (fallback).
179        #[default]
180        Unknown => "UNKNOWN",
181    }
182}
183
184/// A tool call from the model, mirroring the Python SDK's `ToolCall`.
185#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
186pub struct ToolCallInfo {
187    /// Tool name — either a `BuiltinTools` string or a custom tool name.
188    pub name: String,
189    /// Arguments as a JSON value (typically an object/dict).
190    #[serde(default)]
191    pub args: serde_json::Value,
192    /// Optional unique identifier for the call.
193    #[serde(default)]
194    pub id: Option<String>,
195    /// Optional normalized filesystem path for file-related tools.
196    #[serde(default)]
197    pub canonical_path: Option<String>,
198}
199
200/// Result of a single tool execution, mirroring the Python SDK's `ToolResult`.
201#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
202pub struct ToolResult {
203    /// The name of the tool that was executed.
204    pub name: String,
205    /// Optional identifier correlating this result with a `ToolCallInfo.id`.
206    #[serde(default)]
207    pub id: Option<String>,
208    /// The tool's return value (any JSON-serializable value).
209    #[serde(default)]
210    pub result: serde_json::Value,
211    /// An error message if execution failed, or `None` on success.
212    #[serde(default)]
213    pub error: Option<String>,
214}
215
216/// Token usage metadata from the model API, mirroring the SDK's `UsageMetadata`.
217#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
218pub struct UsageMetadata {
219    /// Number of tokens in the prompt.
220    #[serde(default)]
221    pub prompt_token_count: Option<u64>,
222    /// Number of tokens from cached content (subset of prompt tokens).
223    #[serde(default)]
224    pub cached_content_token_count: Option<u64>,
225    /// Number of tokens in the generated candidates (excluding thinking).
226    #[serde(default)]
227    pub candidates_token_count: Option<u64>,
228    /// Number of tokens used for thinking/reasoning.
229    #[serde(default)]
230    pub thoughts_token_count: Option<u64>,
231    /// Sum of prompt + candidates + thinking tokens.
232    #[serde(default)]
233    pub total_token_count: Option<u64>,
234}
235
236/// The role of a message author in the conversation.
237#[non_exhaustive]
238#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
239#[serde(rename_all = "lowercase")]
240#[derive(Default)]
241pub enum MessageRole {
242    /// A user-authored message.
243    #[default]
244    User,
245    /// A model-generated message.
246    Model,
247    /// A system-level message.
248    System,
249    /// An unrecognized role — preserves the original string for forward
250    /// compatibility with new SDK roles.
251    #[serde(untagged)]
252    Unknown(String),
253}
254
255impl std::fmt::Display for MessageRole {
256    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
257        match self {
258            Self::User => f.write_str("user"),
259            Self::Model => f.write_str("model"),
260            Self::System => f.write_str("system"),
261            Self::Unknown(s) => f.write_str(s),
262        }
263    }
264}
265
266/// A single message in the conversation history, mirroring the Python SDK's
267/// `ConversationMessage`.
268///
269/// Each message has a [`MessageRole`] and textual `content`.
270#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
271pub struct ConversationMessage {
272    /// The role of the message author.
273    #[serde(default)]
274    pub role: MessageRole,
275    /// The textual content of the message.
276    #[serde(default)]
277    pub content: String,
278}
279
280/// A single step in the agent trajectory, mirroring the SDK's `Step`.
281#[non_exhaustive]
282#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
283pub struct Step {
284    /// Unique string identifier for the step.
285    #[serde(default)]
286    pub id: String,
287    /// Integer index of the step in the trajectory.
288    #[serde(default)]
289    pub step_index: u32,
290    /// The high-level type of the step.
291    #[serde(default, rename = "type")]
292    pub step_type: StepType,
293    /// The source that generated the step.
294    #[serde(default)]
295    pub source: StepSource,
296    /// The target of the step interaction.
297    #[serde(default)]
298    pub target: StepTarget,
299    /// The status of the step.
300    #[serde(default)]
301    pub status: StepStatus,
302    /// The text content/output of the step.
303    #[serde(default)]
304    pub content: String,
305    /// Incremental text content added since the last update for this step.
306    #[serde(default)]
307    pub content_delta: String,
308    /// Full model reasoning/thinking text for planner responses.
309    #[serde(default)]
310    pub thinking: String,
311    /// Incremental thinking text added since the last update for this step.
312    #[serde(default)]
313    pub thinking_delta: String,
314    /// List of tool calls associated with the step.
315    #[serde(default)]
316    pub tool_calls: Vec<ToolCallInfo>,
317    /// Short error message if the step failed.
318    #[serde(default)]
319    pub error: String,
320    /// Whether this step is a completed model response directed at the user.
321    ///
322    /// Multiple steps per turn may have this flag set; consumers wanting only
323    /// the last response should iterate fully.
324    #[serde(default)]
325    pub is_complete_response: Option<bool>,
326    /// Structured output payload extracted from the FINISH step.
327    ///
328    /// This is `serde_json::Value` because it contains user-defined schema data
329    /// whose shape is not known at compile time.
330    #[serde(default)]
331    pub structured_output: Option<serde_json::Value>,
332    /// Token usage for this step's model invocation.
333    #[serde(default)]
334    pub usage_metadata: Option<UsageMetadata>,
335}
336
337macro_rules! impl_from_py_object {
338    ($($t:ty),+) => {
339        $(
340            impl<'py> pyo3::FromPyObject<'py> for $t {
341                fn extract_bound(ob: &pyo3::Bound<'py, pyo3::PyAny>) -> pyo3::PyResult<Self> {
342                    pythonize::depythonize(ob).map_err(|e| {
343                        pyo3::exceptions::PyValueError::new_err(format!(
344                            "Failed to deserialize {} from Python dict: {}",
345                            stringify!($t),
346                            e
347                        ))
348                    })
349                }
350            }
351        )+
352    };
353}
354
355impl_from_py_object!(
356    StepType,
357    StepSource,
358    StepStatus,
359    StepTarget,
360    ToolCallInfo,
361    ToolResult,
362    UsageMetadata,
363    MessageRole,
364    ConversationMessage,
365    Step
366);
367
368#[cfg(test)]
369mod tests {
370    use super::*;
371
372    // =========================================================================
373    // Step / ToolCall / ToolResult tests
374    // =========================================================================
375
376    #[test]
377    fn test_step_type_roundtrip() {
378        for (variant, expected_str) in [
379            (StepType::TextResponse, "\"TEXT_RESPONSE\""),
380            (StepType::ToolCall, "\"TOOL_CALL\""),
381            (StepType::SystemMessage, "\"SYSTEM_MESSAGE\""),
382            (StepType::Compaction, "\"COMPACTION\""),
383            (StepType::Finish, "\"FINISH\""),
384            (StepType::Unknown, "\"UNKNOWN\""),
385        ] {
386            let json = serde_json::to_string(&variant).unwrap();
387            assert_eq!(
388                json, expected_str,
389                "StepType serialization mismatch for {variant:?}"
390            );
391            let parsed: StepType = serde_json::from_str(&json).unwrap();
392            assert_eq!(parsed, variant);
393        }
394    }
395
396    #[test]
397    fn test_step_type_parse() {
398        assert_eq!(
399            "TEXT_RESPONSE".parse::<StepType>().unwrap(),
400            StepType::TextResponse
401        );
402        assert_eq!("TOOL_CALL".parse::<StepType>().unwrap(), StepType::ToolCall);
403        assert_eq!(
404            "SYSTEM_MESSAGE".parse::<StepType>().unwrap(),
405            StepType::SystemMessage
406        );
407        assert_eq!(
408            "COMPACTION".parse::<StepType>().unwrap(),
409            StepType::Compaction
410        );
411        assert_eq!("FINISH".parse::<StepType>().unwrap(), StepType::Finish);
412    }
413
414    #[test]
415    fn test_step_source_roundtrip() {
416        for (variant, expected_str) in [
417            (StepSource::System, "\"SYSTEM\""),
418            (StepSource::User, "\"USER\""),
419            (StepSource::Model, "\"MODEL\""),
420            (StepSource::Unknown, "\"UNKNOWN\""),
421        ] {
422            let json = serde_json::to_string(&variant).unwrap();
423            assert_eq!(json, expected_str);
424            let parsed: StepSource = serde_json::from_str(&json).unwrap();
425            assert_eq!(parsed, variant);
426        }
427    }
428
429    #[test]
430    fn test_step_source_parse() {
431        assert_eq!("SYSTEM".parse::<StepSource>().unwrap(), StepSource::System);
432        assert_eq!("USER".parse::<StepSource>().unwrap(), StepSource::User);
433        assert_eq!("MODEL".parse::<StepSource>().unwrap(), StepSource::Model);
434    }
435
436    #[test]
437    fn test_step_status_roundtrip() {
438        for (variant, expected_str) in [
439            (StepStatus::Active, "\"ACTIVE\""),
440            (StepStatus::Done, "\"DONE\""),
441            (StepStatus::WaitingForUser, "\"WAITING_FOR_USER\""),
442            (StepStatus::Error, "\"ERROR\""),
443            (StepStatus::Canceled, "\"CANCELED\""),
444            (StepStatus::Unknown, "\"UNKNOWN\""),
445        ] {
446            let json = serde_json::to_string(&variant).unwrap();
447            assert_eq!(json, expected_str);
448            let parsed: StepStatus = serde_json::from_str(&json).unwrap();
449            assert_eq!(parsed, variant);
450        }
451    }
452
453    #[test]
454    fn test_step_status_parse() {
455        assert_eq!("ACTIVE".parse::<StepStatus>().unwrap(), StepStatus::Active);
456        assert_eq!("DONE".parse::<StepStatus>().unwrap(), StepStatus::Done);
457        assert_eq!(
458            "WAITING_FOR_USER".parse::<StepStatus>().unwrap(),
459            StepStatus::WaitingForUser
460        );
461        assert_eq!("ERROR".parse::<StepStatus>().unwrap(), StepStatus::Error);
462        assert_eq!(
463            "CANCELED".parse::<StepStatus>().unwrap(),
464            StepStatus::Canceled
465        );
466    }
467
468    #[test]
469    fn test_step_type_parse_returns_err_for_unrecognized() {
470        assert!("NONEXISTENT".parse::<StepType>().is_err());
471    }
472
473    #[test]
474    fn test_step_source_parse_returns_err_for_unrecognized() {
475        assert!("???".parse::<StepSource>().is_err());
476    }
477
478    #[test]
479    fn test_step_status_parse_returns_err_for_unrecognized() {
480        assert!("nope".parse::<StepStatus>().is_err());
481    }
482
483    #[test]
484    fn test_tool_call_info_roundtrip() {
485        let tc = ToolCallInfo {
486            name: "view_file".to_string(),
487            args: serde_json::json!({"path": "/tmp/foo.rs", "line": 42}),
488            id: Some("call_123".to_string()),
489            canonical_path: Some("/tmp/foo.rs".to_string()),
490        };
491        let json = serde_json::to_string(&tc).unwrap();
492        let parsed: ToolCallInfo = serde_json::from_str(&json).unwrap();
493        assert_eq!(parsed, tc);
494    }
495
496    #[test]
497    fn test_tool_call_info_minimal() {
498        let json = r#"{"name":"custom_tool"}"#;
499        let parsed: ToolCallInfo = serde_json::from_str(json).unwrap();
500        assert_eq!(parsed.name, "custom_tool");
501        assert_eq!(parsed.args, serde_json::Value::Null);
502        assert!(parsed.id.is_none());
503        assert!(parsed.canonical_path.is_none());
504    }
505
506    #[test]
507    fn test_tool_result_roundtrip() {
508        let tr = ToolResult {
509            name: "run_command".to_string(),
510            id: Some("result_456".to_string()),
511            result: serde_json::json!({"output": "hello world"}),
512            error: None,
513        };
514        let json = serde_json::to_string(&tr).unwrap();
515        let parsed: ToolResult = serde_json::from_str(&json).unwrap();
516        assert_eq!(parsed, tr);
517    }
518
519    #[test]
520    fn test_tool_result_with_error() {
521        let tr = ToolResult {
522            name: "create_file".to_string(),
523            id: None,
524            result: serde_json::Value::Null,
525            error: Some("permission denied".to_string()),
526        };
527        let json = serde_json::to_string(&tr).unwrap();
528        let parsed: ToolResult = serde_json::from_str(&json).unwrap();
529        assert_eq!(parsed.error.as_deref(), Some("permission denied"));
530    }
531
532    #[test]
533    fn test_usage_metadata_roundtrip() {
534        let um = UsageMetadata {
535            prompt_token_count: Some(100),
536            cached_content_token_count: Some(20),
537            candidates_token_count: Some(50),
538            thoughts_token_count: Some(30),
539            total_token_count: Some(180),
540        };
541        let json = serde_json::to_string(&um).unwrap();
542        let parsed: UsageMetadata = serde_json::from_str(&json).unwrap();
543        assert_eq!(parsed, um);
544    }
545
546    #[test]
547    fn test_usage_metadata_defaults() {
548        let um: UsageMetadata = serde_json::from_str("{}").unwrap();
549        assert!(um.prompt_token_count.is_none());
550        assert!(um.total_token_count.is_none());
551    }
552
553    #[test]
554    fn test_step_full_roundtrip() {
555        let step = Step {
556            id: "traj:0".to_string(),
557            step_index: 3,
558            step_type: StepType::ToolCall,
559            source: StepSource::Model,
560            target: StepTarget::Environment,
561            status: StepStatus::Done,
562            content: "Running command...".to_string(),
563            content_delta: "Running".to_string(),
564            thinking: "I should run the command".to_string(),
565            thinking_delta: "I should".to_string(),
566            tool_calls: vec![ToolCallInfo {
567                name: "run_command".to_string(),
568                args: serde_json::json!({"command": "ls -la"}),
569                id: Some("call_1".to_string()),
570                canonical_path: None,
571            }],
572            error: String::new(),
573            is_complete_response: Some(false),
574            structured_output: None,
575            usage_metadata: Some(UsageMetadata {
576                prompt_token_count: Some(500),
577                cached_content_token_count: None,
578                candidates_token_count: Some(100),
579                thoughts_token_count: Some(50),
580                total_token_count: Some(650),
581            }),
582        };
583
584        let json = serde_json::to_string_pretty(&step).unwrap();
585        let parsed: Step = serde_json::from_str(&json).unwrap();
586        assert_eq!(parsed, step);
587        assert_eq!(parsed.tool_calls.len(), 1);
588        assert_eq!(parsed.tool_calls[0].name, "run_command");
589    }
590
591    #[test]
592    fn test_step_minimal_deserialization() {
593        // Should deserialize with all defaults.
594        let json = r#"{"id":"s1"}"#;
595        let step: Step = serde_json::from_str(json).unwrap();
596        assert_eq!(step.id, "s1");
597        assert_eq!(step.step_index, 0);
598        assert_eq!(step.step_type, StepType::Unknown);
599        assert_eq!(step.source, StepSource::Unknown);
600        assert_eq!(step.target, StepTarget::Unknown);
601        assert_eq!(step.status, StepStatus::Unknown);
602        assert!(step.content.is_empty());
603        assert!(step.content_delta.is_empty());
604        assert!(step.thinking.is_empty());
605        assert!(step.thinking_delta.is_empty());
606        assert!(step.tool_calls.is_empty());
607        assert!(step.error.is_empty());
608        assert!(step.is_complete_response.is_none());
609        assert!(step.structured_output.is_none());
610        assert!(step.usage_metadata.is_none());
611    }
612
613    // =========================================================================
614    // Step with multiple tool calls
615    // =========================================================================
616
617    #[test]
618    fn step_with_multiple_tool_calls() {
619        let step = Step {
620            id: "multi-tc".to_string(),
621            step_index: 7,
622            step_type: StepType::ToolCall,
623            source: StepSource::Model,
624            target: StepTarget::Environment,
625            status: StepStatus::Done,
626            content: String::new(),
627            content_delta: String::new(),
628            thinking: String::new(),
629            thinking_delta: String::new(),
630            tool_calls: vec![
631                ToolCallInfo {
632                    name: "view_file".to_string(),
633                    args: serde_json::json!({"path": "/a.rs"}),
634                    id: Some("tc1".to_string()),
635                    canonical_path: Some("/a.rs".to_string()),
636                },
637                ToolCallInfo {
638                    name: "run_command".to_string(),
639                    args: serde_json::json!({"command": "cargo test"}),
640                    id: Some("tc2".to_string()),
641                    canonical_path: None,
642                },
643            ],
644            error: String::new(),
645            is_complete_response: None,
646            structured_output: None,
647            usage_metadata: None,
648        };
649        let json = serde_json::to_string(&step).unwrap();
650        let parsed: Step = serde_json::from_str(&json).unwrap();
651        assert_eq!(parsed.tool_calls.len(), 2);
652        assert_eq!(parsed.tool_calls[0].name, "view_file");
653        assert_eq!(parsed.tool_calls[1].name, "run_command");
654        assert_eq!(
655            parsed.tool_calls[0].canonical_path.as_deref(),
656            Some("/a.rs")
657        );
658        assert!(parsed.tool_calls[1].canonical_path.is_none());
659    }
660
661    // =========================================================================
662    // ToolCallInfo / ToolResult edge cases
663    // =========================================================================
664
665    #[test]
666    fn tool_call_info_with_complex_args() {
667        let tc = ToolCallInfo {
668            name: "run_command".to_string(),
669            args: serde_json::json!({
670                "command": "cargo test",
671                "env": {"RUST_LOG": "debug"},
672                "timeout": 300,
673                "nested": [1, 2, {"deep": true}]
674            }),
675            id: None,
676            canonical_path: None,
677        };
678        let json = serde_json::to_string(&tc).unwrap();
679        let parsed: ToolCallInfo = serde_json::from_str(&json).unwrap();
680        assert_eq!(parsed.args["env"]["RUST_LOG"], "debug");
681        assert_eq!(parsed.args["nested"][2]["deep"], true);
682    }
683
684    #[test]
685    fn tool_result_with_complex_result() {
686        let tr = ToolResult {
687            name: "search_dir".to_string(),
688            id: Some("r1".to_string()),
689            result: serde_json::json!({
690                "matches": [
691                    {"file": "/src/main.rs", "line": 42},
692                    {"file": "/src/lib.rs", "line": 10},
693                ],
694                "total": 2
695            }),
696            error: None,
697        };
698        let json = serde_json::to_string(&tr).unwrap();
699        let parsed: ToolResult = serde_json::from_str(&json).unwrap();
700        assert_eq!(parsed.result["total"], 2);
701        assert_eq!(parsed.result["matches"][0]["line"], 42);
702    }
703
704    // =========================================================================
705    // UsageMetadata partial fields
706    // =========================================================================
707
708    #[test]
709    fn usage_metadata_partial_fields() {
710        let json = r#"{"prompt_token_count":100,"total_token_count":200}"#;
711        let um: UsageMetadata = serde_json::from_str(json).unwrap();
712        assert_eq!(um.prompt_token_count, Some(100));
713        assert!(um.cached_content_token_count.is_none());
714        assert!(um.candidates_token_count.is_none());
715        assert!(um.thoughts_token_count.is_none());
716        assert_eq!(um.total_token_count, Some(200));
717    }
718
719    // =========================================================================
720    // StepTarget tests
721    // =========================================================================
722
723    #[test]
724    fn test_step_target_roundtrip() {
725        for (variant, expected_str) in [
726            (StepTarget::User, "\"TARGET_USER\""),
727            (StepTarget::Environment, "\"TARGET_ENVIRONMENT\""),
728            (StepTarget::Unspecified, "\"TARGET_UNSPECIFIED\""),
729            (StepTarget::Unknown, "\"UNKNOWN\""),
730        ] {
731            let json = serde_json::to_string(&variant).unwrap();
732            assert_eq!(
733                json, expected_str,
734                "StepTarget serialization mismatch for {variant:?}"
735            );
736            let parsed: StepTarget = serde_json::from_str(&json).unwrap();
737            assert_eq!(parsed, variant);
738        }
739    }
740
741    #[test]
742    fn test_step_target_parse() {
743        assert_eq!(
744            "TARGET_MODEL".parse::<StepTarget>().unwrap(),
745            StepTarget::Model
746        );
747        assert_eq!(
748            "TARGET_USER".parse::<StepTarget>().unwrap(),
749            StepTarget::User
750        );
751        assert_eq!(
752            "TARGET_ENVIRONMENT".parse::<StepTarget>().unwrap(),
753            StepTarget::Environment
754        );
755        assert_eq!(
756            "TARGET_UNSPECIFIED".parse::<StepTarget>().unwrap(),
757            StepTarget::Unspecified
758        );
759        assert_eq!(
760            "UNKNOWN".parse::<StepTarget>().unwrap(),
761            StepTarget::Unknown
762        );
763    }
764
765    #[test]
766    fn test_step_target_parse_returns_err_for_unrecognized() {
767        assert!("INVALID_TARGET".parse::<StepTarget>().is_err());
768    }
769
770    // =========================================================================
771    // Display trait tests
772    // =========================================================================
773
774    #[test]
775    fn test_step_type_display() {
776        assert_eq!(StepType::TextResponse.to_string(), "TEXT_RESPONSE");
777        assert_eq!(StepType::ToolCall.to_string(), "TOOL_CALL");
778        assert_eq!(StepType::SystemMessage.to_string(), "SYSTEM_MESSAGE");
779        assert_eq!(StepType::Compaction.to_string(), "COMPACTION");
780        assert_eq!(StepType::Finish.to_string(), "FINISH");
781        assert_eq!(StepType::Unknown.to_string(), "UNKNOWN");
782    }
783
784    #[test]
785    fn test_step_source_display() {
786        assert_eq!(StepSource::System.to_string(), "SYSTEM");
787        assert_eq!(StepSource::User.to_string(), "USER");
788        assert_eq!(StepSource::Model.to_string(), "MODEL");
789        assert_eq!(StepSource::Unknown.to_string(), "UNKNOWN");
790    }
791
792    #[test]
793    fn test_step_status_display() {
794        assert_eq!(StepStatus::Active.to_string(), "ACTIVE");
795        assert_eq!(StepStatus::Done.to_string(), "DONE");
796        assert_eq!(StepStatus::WaitingForUser.to_string(), "WAITING_FOR_USER");
797        assert_eq!(StepStatus::Error.to_string(), "ERROR");
798        assert_eq!(StepStatus::Canceled.to_string(), "CANCELED");
799        assert_eq!(StepStatus::Unknown.to_string(), "UNKNOWN");
800    }
801
802    #[test]
803    fn test_step_target_display() {
804        assert_eq!(StepTarget::User.to_string(), "TARGET_USER");
805        assert_eq!(StepTarget::Environment.to_string(), "TARGET_ENVIRONMENT");
806        assert_eq!(StepTarget::Unspecified.to_string(), "TARGET_UNSPECIFIED");
807        assert_eq!(StepTarget::Unknown.to_string(), "UNKNOWN");
808    }
809
810    // =========================================================================
811    // Display → FromStr roundtrip tests
812    // =========================================================================
813
814    #[test]
815    fn test_step_type_display_from_str_roundtrip() {
816        for variant in [
817            StepType::TextResponse,
818            StepType::ToolCall,
819            StepType::SystemMessage,
820            StepType::Compaction,
821            StepType::Finish,
822            StepType::Unknown,
823        ] {
824            let s = variant.to_string();
825            let parsed: StepType = s.parse().unwrap();
826            assert_eq!(parsed, variant, "roundtrip failed for {variant:?}");
827        }
828    }
829
830    #[test]
831    fn test_step_source_display_from_str_roundtrip() {
832        for variant in [
833            StepSource::System,
834            StepSource::User,
835            StepSource::Model,
836            StepSource::Unknown,
837        ] {
838            let s = variant.to_string();
839            let parsed: StepSource = s.parse().unwrap();
840            assert_eq!(parsed, variant, "roundtrip failed for {variant:?}");
841        }
842    }
843
844    #[test]
845    fn test_step_status_display_from_str_roundtrip() {
846        for variant in [
847            StepStatus::Active,
848            StepStatus::Done,
849            StepStatus::WaitingForUser,
850            StepStatus::Error,
851            StepStatus::Canceled,
852            StepStatus::Unknown,
853        ] {
854            let s = variant.to_string();
855            let parsed: StepStatus = s.parse().unwrap();
856            assert_eq!(parsed, variant, "roundtrip failed for {variant:?}");
857        }
858    }
859
860    #[test]
861    fn test_step_target_display_from_str_roundtrip() {
862        for variant in [
863            StepTarget::Model,
864            StepTarget::User,
865            StepTarget::Environment,
866            StepTarget::Unspecified,
867            StepTarget::Unknown,
868        ] {
869            let s = variant.to_string();
870            let parsed: StepTarget = s.parse().unwrap();
871            assert_eq!(parsed, variant, "roundtrip failed for {variant:?}");
872        }
873    }
874
875    // =========================================================================
876    // FromStr with garbage input tests
877    // =========================================================================
878
879    #[test]
880    fn test_from_str_garbage_returns_err() {
881        assert!("xyzzy".parse::<StepType>().is_err());
882        assert!("xyzzy".parse::<StepSource>().is_err());
883        assert!("xyzzy".parse::<StepStatus>().is_err());
884        assert!("xyzzy".parse::<StepTarget>().is_err());
885    }
886
887    #[test]
888    fn test_from_str_empty_returns_err() {
889        assert!("".parse::<StepType>().is_err());
890        assert!("".parse::<StepSource>().is_err());
891        assert!("".parse::<StepStatus>().is_err());
892        assert!("".parse::<StepTarget>().is_err());
893    }
894
895    #[test]
896    fn test_from_str_case_sensitive() {
897        // SDK strings are case-sensitive — lowercase should return Err.
898        assert!("text_response".parse::<StepType>().is_err());
899        assert!("system".parse::<StepSource>().is_err());
900        assert!("active".parse::<StepStatus>().is_err());
901        assert!("target_user".parse::<StepTarget>().is_err());
902    }
903
904    // =========================================================================
905    // MessageRole / ConversationMessage Tests
906    // =========================================================================
907
908    #[test]
909    fn test_message_role_roundtrip() {
910        for (variant, expected_str) in [
911            (MessageRole::User, "\"user\""),
912            (MessageRole::Model, "\"model\""),
913            (MessageRole::System, "\"system\""),
914            (MessageRole::Unknown("custom".to_string()), "\"custom\""),
915        ] {
916            let json = serde_json::to_string(&variant).unwrap();
917            assert_eq!(json, expected_str);
918            let parsed: MessageRole = serde_json::from_str(&json).unwrap();
919            assert_eq!(parsed, variant);
920        }
921    }
922
923    #[test]
924    fn test_conversation_message_roundtrip() {
925        let msg = ConversationMessage {
926            role: MessageRole::Model,
927            content: "Hello!".to_string(),
928        };
929        let json = serde_json::to_string(&msg).unwrap();
930        let parsed: ConversationMessage = serde_json::from_str(&json).unwrap();
931        assert_eq!(parsed, msg);
932    }
933
934    // =========================================================================
935    // PyO3 Extract Tests
936    // =========================================================================
937
938    #[test]
939    fn test_pyo3_extract_roundtrip() {
940        use pyo3::{prelude::*, types::PyDictMethods};
941        pyo3::prepare_freethreaded_python();
942        pyo3::Python::with_gil(|py| {
943            let dict = pyo3::types::PyDict::new_bound(py);
944            dict.set_item("id", "step-1").unwrap();
945            dict.set_item("step_index", 42).unwrap();
946            dict.set_item("type", "TEXT_RESPONSE").unwrap();
947            dict.set_item("source", "MODEL").unwrap();
948            dict.set_item("target", "TARGET_USER").unwrap();
949            dict.set_item("status", "DONE").unwrap();
950
951            let step: Step = dict.extract().expect("failed to extract Step");
952            assert_eq!(step.id, "step-1");
953            assert_eq!(step.step_index, 42);
954            assert_eq!(step.step_type, StepType::TextResponse);
955            assert_eq!(step.source, StepSource::Model);
956            assert_eq!(step.target, StepTarget::User);
957            assert_eq!(step.status, StepStatus::Done);
958
959            // Now test an enum
960            let s = pyo3::types::PyString::new_bound(py, "SYSTEM_MESSAGE");
961            let st: StepType = s.extract().unwrap();
962            assert_eq!(st, StepType::SystemMessage);
963        });
964    }
965
966    // =========================================================================
967    // Step new fields tests
968    // =========================================================================
969
970    #[test]
971    fn step_with_deltas_and_thinking() {
972        let step = Step {
973            id: "s2".to_string(),
974            step_index: 1,
975            step_type: StepType::TextResponse,
976            source: StepSource::Model,
977            target: StepTarget::User,
978            status: StepStatus::Active,
979            content: "Hello world".to_string(),
980            content_delta: "world".to_string(),
981            thinking: "The user said hi".to_string(),
982            thinking_delta: "said hi".to_string(),
983            tool_calls: vec![],
984            error: String::new(),
985            is_complete_response: Some(true),
986            structured_output: None,
987            usage_metadata: None,
988        };
989        let json = serde_json::to_string(&step).unwrap();
990        let parsed: Step = serde_json::from_str(&json).unwrap();
991        assert_eq!(parsed.content_delta, "world");
992        assert_eq!(parsed.thinking, "The user said hi");
993        assert_eq!(parsed.thinking_delta, "said hi");
994        assert_eq!(parsed.is_complete_response, Some(true));
995        assert_eq!(parsed.target, StepTarget::User);
996    }
997
998    #[test]
999    fn step_with_structured_output() {
1000        let payload = serde_json::json!({"answer": 42, "valid": true});
1001        let step = Step {
1002            id: "finish-1".to_string(),
1003            step_index: 5,
1004            step_type: StepType::Finish,
1005            source: StepSource::Model,
1006            target: StepTarget::User,
1007            status: StepStatus::Done,
1008            content: String::new(),
1009            content_delta: String::new(),
1010            thinking: String::new(),
1011            thinking_delta: String::new(),
1012            tool_calls: vec![],
1013            error: String::new(),
1014            is_complete_response: Some(true),
1015            structured_output: Some(payload.clone()),
1016            usage_metadata: None,
1017        };
1018        let json = serde_json::to_string(&step).unwrap();
1019        let parsed: Step = serde_json::from_str(&json).unwrap();
1020        assert_eq!(parsed.structured_output, Some(payload));
1021        assert_eq!(parsed.step_type, StepType::Finish);
1022    }
1023}