Skip to main content

agy_bridge/
types.rs

1//! Core types for the agent SDK bridge.
2//!
3//! This module defines the data structures that model an agent's execution
4//! trajectory: individual [`Step`](crate::types::Step)s,
5//! [`ToolCallInfo`](crate::types::ToolCallInfo) requests,
6//! [`ToolResult`](crate::types::ToolResult) responses, and
7//! [`UsageMetadata`](crate::types::UsageMetadata) for token accounting. All types derive
8//! `Serialize`/`Deserialize` for JSON interchange with the Python SDK.
9
10use std::{fmt, str::FromStr};
11
12use serde::{Deserialize, Serialize};
13
14// =============================================================================
15// Step / ToolCall / ToolResult types (§1.6)
16// =============================================================================
17
18/// Define an SDK enum with `SCREAMING_SNAKE_CASE` serde rename and auto-generated
19/// `Display` and `FromStr` impls.
20///
21/// Each variant maps to a wire-format string. Unrecognized strings parse as `Err`
22/// via `FromStr` — they never panic.
23///
24/// # Syntax
25///
26/// ```text
27/// define_sdk_enum! {
28///     /// Doc comment for the enum.
29///     EnumName {
30///         Variant1 => "WIRE_STRING_1",
31///         Variant2 => "WIRE_STRING_2",
32///         #[default]
33///         Unknown => "UNKNOWN",
34///     }
35/// }
36/// ```
37macro_rules! define_sdk_enum {
38    (
39        $(#[$meta:meta])*
40        $name:ident {
41            $(
42                $(#[$vmeta:meta])*
43                $variant:ident => $wire:literal
44            ),+ $(,)?
45        }
46    ) => {
47        $(#[$meta])*
48        #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
49        #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
50        pub enum $name {
51            $(
52                $(#[$vmeta])*
53                $variant,
54            )+
55        }
56
57        impl fmt::Display for $name {
58            fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
59                let s = match self {
60                    $( Self::$variant => $wire, )+
61                };
62                f.write_str(s)
63            }
64        }
65
66        impl FromStr for $name {
67            type Err = String;
68
69            fn from_str(s: &str) -> Result<Self, Self::Err> {
70                match s {
71                    $( $wire => Ok(Self::$variant), )+
72                    other => Err(format!(concat!("Unrecognized ", stringify!($name), ": {:?}"), other)),
73                }
74            }
75        }
76    };
77}
78
79/// Like [`define_sdk_enum!`] but for enums where the serde wire format uses
80/// per-variant `#[serde(rename = "...")]` instead of `rename_all`.
81///
82/// This is needed for [`StepTarget`] whose SDK strings have a `TARGET_` prefix
83/// that doesn't match the `SCREAMING_SNAKE_CASE` of the enum name.
84macro_rules! define_sdk_enum_custom_serde {
85    (
86        $(#[$meta:meta])*
87        $name:ident {
88            $(
89                $(#[$vmeta:meta])*
90                $variant:ident => $wire:literal
91            ),+ $(,)?
92        }
93    ) => {
94        $(#[$meta])*
95        #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
96        pub enum $name {
97            $(
98                $(#[$vmeta])*
99                #[serde(rename = $wire)]
100                $variant,
101            )+
102        }
103
104        impl fmt::Display for $name {
105            fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
106                let s = match self {
107                    $( Self::$variant => $wire, )+
108                };
109                f.write_str(s)
110            }
111        }
112
113        impl FromStr for $name {
114            type Err = String;
115
116            fn from_str(s: &str) -> Result<Self, Self::Err> {
117                match s {
118                    $( $wire => Ok(Self::$variant), )+
119                    other => Err(format!(concat!("Unrecognized ", stringify!($name), ": {:?}"), other)),
120                }
121            }
122        }
123    };
124}
125
126define_sdk_enum! {
127    /// The high-level type of a step in the agent trajectory.
128    StepType {
129        /// A textual response from the model.
130        TextResponse => "TEXT_RESPONSE",
131        /// A tool invocation requested by the model.
132        ToolCall => "TOOL_CALL",
133        /// A system-generated message (e.g. context injection).
134        SystemMessage => "SYSTEM_MESSAGE",
135        /// A context-window compaction event.
136        Compaction => "COMPACTION",
137        /// The agent has signaled task completion.
138        Finish => "FINISH",
139        /// Unrecognized step type (forward-compatibility fallback).
140        #[default]
141        Unknown => "UNKNOWN",
142    }
143}
144
145define_sdk_enum! {
146    /// The source that generated a step.
147    StepSource {
148        /// Generated by the system runtime.
149        System => "SYSTEM",
150        /// Provided by the user.
151        User => "USER",
152        /// Generated by the model.
153        Model => "MODEL",
154        /// Unrecognized source (forward-compatibility fallback).
155        #[default]
156        Unknown => "UNKNOWN",
157    }
158}
159
160define_sdk_enum! {
161    /// The execution status of a step.
162    StepStatus {
163        /// Step is currently executing.
164        Active => "ACTIVE",
165        /// Step completed successfully.
166        Done => "DONE",
167        /// Step is blocked waiting for user input.
168        WaitingForUser => "WAITING_FOR_USER",
169        /// Step failed with an error.
170        Error => "ERROR",
171        /// Step was canceled before completion.
172        Canceled => "CANCELED",
173        /// Unrecognized status (forward-compatibility fallback).
174        #[default]
175        Unknown => "UNKNOWN",
176    }
177}
178
179define_sdk_enum_custom_serde! {
180    /// Target of a step interaction, mirroring the Python SDK's `StepTarget`.
181    ///
182    /// The Python SDK uses `TARGET_` prefixed strings (e.g. `TARGET_USER`).
183    /// Uses per-variant `#[serde(rename)]` because the SDK's wire format has a
184    /// `TARGET_` prefix that doesn't follow `SCREAMING_SNAKE_CASE` of the enum name.
185    StepTarget {
186        /// Step is directed at the model.
187        Model => "TARGET_MODEL",
188        /// Step is directed at the user.
189        User => "TARGET_USER",
190        /// Step is directed at the environment (tool execution).
191        Environment => "TARGET_ENVIRONMENT",
192        /// Target is unspecified.
193        Unspecified => "TARGET_UNSPECIFIED",
194        /// Unknown target (fallback).
195        #[default]
196        Unknown => "UNKNOWN",
197    }
198}
199
200/// A tool call from the model, mirroring the Python SDK's `ToolCall`.
201#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
202pub struct ToolCallInfo {
203    /// Tool name — either a `BuiltinTools` string or a custom tool name.
204    pub name: String,
205    /// Arguments as a JSON value (typically an object/dict).
206    #[serde(default)]
207    pub args: serde_json::Value,
208    /// Optional unique identifier for the call.
209    #[serde(default)]
210    pub id: Option<String>,
211    /// Optional normalized filesystem path for file-related tools.
212    #[serde(default)]
213    pub canonical_path: Option<String>,
214}
215
216/// Result of a single tool execution, mirroring the Python SDK's `ToolResult`.
217#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
218pub struct ToolResult {
219    /// The name of the tool that was executed.
220    pub name: String,
221    /// Optional identifier correlating this result with a `ToolCallInfo.id`.
222    #[serde(default)]
223    pub id: Option<String>,
224    /// The tool's return value (any JSON-serializable value).
225    #[serde(default)]
226    pub result: serde_json::Value,
227    /// An error message if execution failed, or `None` on success.
228    #[serde(default)]
229    pub error: Option<String>,
230}
231
232/// Token usage metadata from the model API, mirroring the SDK's `UsageMetadata`.
233#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
234pub struct UsageMetadata {
235    /// Number of tokens in the prompt.
236    #[serde(default)]
237    pub prompt_token_count: Option<u64>,
238    /// Number of tokens from cached content (subset of prompt tokens).
239    #[serde(default)]
240    pub cached_content_token_count: Option<u64>,
241    /// Number of tokens in the generated candidates (excluding thinking).
242    #[serde(default)]
243    pub candidates_token_count: Option<u64>,
244    /// Number of tokens used for thinking/reasoning.
245    #[serde(default)]
246    pub thoughts_token_count: Option<u64>,
247    /// Sum of prompt + candidates + thinking tokens.
248    #[serde(default)]
249    pub total_token_count: Option<u64>,
250}
251
252/// The role of a message author in the conversation.
253#[non_exhaustive]
254#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
255#[serde(rename_all = "lowercase")]
256#[derive(Default)]
257pub enum MessageRole {
258    /// A user-authored message.
259    #[default]
260    User,
261    /// A model-generated message.
262    Model,
263    /// A system-level message.
264    System,
265    /// An unrecognized role — preserves the original string for forward
266    /// compatibility with new SDK roles.
267    #[serde(untagged)]
268    Unknown(String),
269}
270
271impl std::fmt::Display for MessageRole {
272    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
273        match self {
274            Self::User => f.write_str("user"),
275            Self::Model => f.write_str("model"),
276            Self::System => f.write_str("system"),
277            Self::Unknown(s) => f.write_str(s),
278        }
279    }
280}
281
282/// A single message in the conversation history, mirroring the Python SDK's
283/// `ConversationMessage`.
284///
285/// Each message has a [`MessageRole`] and textual `content`.
286#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
287pub struct ConversationMessage {
288    /// The role of the message author.
289    #[serde(default)]
290    pub role: MessageRole,
291    /// The textual content of the message.
292    #[serde(default)]
293    pub content: String,
294}
295
296/// A single step in the agent trajectory, mirroring the SDK's `Step`.
297#[non_exhaustive]
298#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
299pub struct Step {
300    /// Unique string identifier for the step.
301    #[serde(default)]
302    pub id: String,
303    /// Integer index of the step in the trajectory.
304    #[serde(default)]
305    pub step_index: u32,
306    /// The high-level type of the step.
307    #[serde(default, rename = "type")]
308    pub step_type: StepType,
309    /// The source that generated the step.
310    #[serde(default)]
311    pub source: StepSource,
312    /// The target of the step interaction.
313    #[serde(default)]
314    pub target: StepTarget,
315    /// The status of the step.
316    #[serde(default)]
317    pub status: StepStatus,
318    /// The text content/output of the step.
319    #[serde(default)]
320    pub content: String,
321    /// Incremental text content added since the last update for this step.
322    #[serde(default)]
323    pub content_delta: String,
324    /// Full model reasoning/thinking text for planner responses.
325    #[serde(default)]
326    pub thinking: String,
327    /// Incremental thinking text added since the last update for this step.
328    #[serde(default)]
329    pub thinking_delta: String,
330    /// List of tool calls associated with the step.
331    #[serde(default)]
332    pub tool_calls: Vec<ToolCallInfo>,
333    /// Short error message if the step failed.
334    #[serde(default)]
335    pub error: String,
336    /// Whether this step is a completed model response directed at the user.
337    ///
338    /// Multiple steps per turn may have this flag set; consumers wanting only
339    /// the last response should iterate fully.
340    #[serde(default)]
341    pub is_complete_response: Option<bool>,
342    /// Structured output payload extracted from the FINISH step.
343    ///
344    /// This is `serde_json::Value` because it contains user-defined schema data
345    /// whose shape is not known at compile time.
346    #[serde(default)]
347    pub structured_output: Option<serde_json::Value>,
348    /// Token usage for this step's model invocation.
349    #[serde(default)]
350    pub usage_metadata: Option<UsageMetadata>,
351}
352
353macro_rules! impl_from_py_object {
354    ($($t:ty),+) => {
355        $(
356            impl<'a, 'py> pyo3::FromPyObject<'a, 'py> for $t {
357                type Error = pyo3::PyErr;
358
359                fn extract(ob: pyo3::Borrowed<'a, 'py, pyo3::PyAny>) -> pyo3::PyResult<Self> {
360                    pythonize::depythonize(&*ob).map_err(|e| {
361                        pyo3::exceptions::PyValueError::new_err(format!(
362                            "Failed to deserialize {} from Python dict: {}",
363                            stringify!($t),
364                            e
365                        ))
366                    })
367                }
368            }
369        )+
370    };
371}
372
373impl_from_py_object!(
374    StepType,
375    StepSource,
376    StepStatus,
377    StepTarget,
378    ToolCallInfo,
379    ToolResult,
380    UsageMetadata,
381    MessageRole,
382    ConversationMessage,
383    Step
384);
385
386#[cfg(test)]
387mod tests {
388    use super::*;
389
390    // =========================================================================
391    // Step / ToolCall / ToolResult tests
392    // =========================================================================
393
394    #[test]
395    fn test_step_type_roundtrip() {
396        for (variant, expected_str) in [
397            (StepType::TextResponse, "\"TEXT_RESPONSE\""),
398            (StepType::ToolCall, "\"TOOL_CALL\""),
399            (StepType::SystemMessage, "\"SYSTEM_MESSAGE\""),
400            (StepType::Compaction, "\"COMPACTION\""),
401            (StepType::Finish, "\"FINISH\""),
402            (StepType::Unknown, "\"UNKNOWN\""),
403        ] {
404            let json = serde_json::to_string(&variant).unwrap();
405            assert_eq!(
406                json, expected_str,
407                "StepType serialization mismatch for {variant:?}"
408            );
409            let parsed: StepType = serde_json::from_str(&json).unwrap();
410            assert_eq!(parsed, variant);
411        }
412    }
413
414    #[test]
415    fn test_step_type_parse() {
416        assert_eq!(
417            "TEXT_RESPONSE".parse::<StepType>().unwrap(),
418            StepType::TextResponse
419        );
420        assert_eq!("TOOL_CALL".parse::<StepType>().unwrap(), StepType::ToolCall);
421        assert_eq!(
422            "SYSTEM_MESSAGE".parse::<StepType>().unwrap(),
423            StepType::SystemMessage
424        );
425        assert_eq!(
426            "COMPACTION".parse::<StepType>().unwrap(),
427            StepType::Compaction
428        );
429        assert_eq!("FINISH".parse::<StepType>().unwrap(), StepType::Finish);
430    }
431
432    #[test]
433    fn test_step_source_roundtrip() {
434        for (variant, expected_str) in [
435            (StepSource::System, "\"SYSTEM\""),
436            (StepSource::User, "\"USER\""),
437            (StepSource::Model, "\"MODEL\""),
438            (StepSource::Unknown, "\"UNKNOWN\""),
439        ] {
440            let json = serde_json::to_string(&variant).unwrap();
441            assert_eq!(json, expected_str);
442            let parsed: StepSource = serde_json::from_str(&json).unwrap();
443            assert_eq!(parsed, variant);
444        }
445    }
446
447    #[test]
448    fn test_step_source_parse() {
449        assert_eq!("SYSTEM".parse::<StepSource>().unwrap(), StepSource::System);
450        assert_eq!("USER".parse::<StepSource>().unwrap(), StepSource::User);
451        assert_eq!("MODEL".parse::<StepSource>().unwrap(), StepSource::Model);
452    }
453
454    #[test]
455    fn test_step_status_roundtrip() {
456        for (variant, expected_str) in [
457            (StepStatus::Active, "\"ACTIVE\""),
458            (StepStatus::Done, "\"DONE\""),
459            (StepStatus::WaitingForUser, "\"WAITING_FOR_USER\""),
460            (StepStatus::Error, "\"ERROR\""),
461            (StepStatus::Canceled, "\"CANCELED\""),
462            (StepStatus::Unknown, "\"UNKNOWN\""),
463        ] {
464            let json = serde_json::to_string(&variant).unwrap();
465            assert_eq!(json, expected_str);
466            let parsed: StepStatus = serde_json::from_str(&json).unwrap();
467            assert_eq!(parsed, variant);
468        }
469    }
470
471    #[test]
472    fn test_step_status_parse() {
473        assert_eq!("ACTIVE".parse::<StepStatus>().unwrap(), StepStatus::Active);
474        assert_eq!("DONE".parse::<StepStatus>().unwrap(), StepStatus::Done);
475        assert_eq!(
476            "WAITING_FOR_USER".parse::<StepStatus>().unwrap(),
477            StepStatus::WaitingForUser
478        );
479        assert_eq!("ERROR".parse::<StepStatus>().unwrap(), StepStatus::Error);
480        assert_eq!(
481            "CANCELED".parse::<StepStatus>().unwrap(),
482            StepStatus::Canceled
483        );
484    }
485
486    #[test]
487    fn test_step_type_parse_returns_err_for_unrecognized() {
488        assert!("NONEXISTENT".parse::<StepType>().is_err());
489    }
490
491    #[test]
492    fn test_step_source_parse_returns_err_for_unrecognized() {
493        assert!("???".parse::<StepSource>().is_err());
494    }
495
496    #[test]
497    fn test_step_status_parse_returns_err_for_unrecognized() {
498        assert!("nope".parse::<StepStatus>().is_err());
499    }
500
501    #[test]
502    fn test_tool_call_info_roundtrip() {
503        let tc = ToolCallInfo {
504            name: "view_file".to_string(),
505            args: serde_json::json!({"path": "/tmp/foo.rs", "line": 42}),
506            id: Some("call_123".to_string()),
507            canonical_path: Some("/tmp/foo.rs".to_string()),
508        };
509        let json = serde_json::to_string(&tc).unwrap();
510        let parsed: ToolCallInfo = serde_json::from_str(&json).unwrap();
511        assert_eq!(parsed, tc);
512    }
513
514    #[test]
515    fn test_tool_call_info_minimal() {
516        let json = r#"{"name":"custom_tool"}"#;
517        let parsed: ToolCallInfo = serde_json::from_str(json).unwrap();
518        assert_eq!(parsed.name, "custom_tool");
519        assert_eq!(parsed.args, serde_json::Value::Null);
520        assert!(parsed.id.is_none());
521        assert!(parsed.canonical_path.is_none());
522    }
523
524    #[test]
525    fn test_tool_result_roundtrip() {
526        let tr = ToolResult {
527            name: "run_command".to_string(),
528            id: Some("result_456".to_string()),
529            result: serde_json::json!({"output": "hello world"}),
530            error: None,
531        };
532        let json = serde_json::to_string(&tr).unwrap();
533        let parsed: ToolResult = serde_json::from_str(&json).unwrap();
534        assert_eq!(parsed, tr);
535    }
536
537    #[test]
538    fn test_tool_result_with_error() {
539        let tr = ToolResult {
540            name: "create_file".to_string(),
541            id: None,
542            result: serde_json::Value::Null,
543            error: Some("permission denied".to_string()),
544        };
545        let json = serde_json::to_string(&tr).unwrap();
546        let parsed: ToolResult = serde_json::from_str(&json).unwrap();
547        assert_eq!(parsed.error.as_deref(), Some("permission denied"));
548    }
549
550    #[test]
551    fn test_usage_metadata_roundtrip() {
552        let um = UsageMetadata {
553            prompt_token_count: Some(100),
554            cached_content_token_count: Some(20),
555            candidates_token_count: Some(50),
556            thoughts_token_count: Some(30),
557            total_token_count: Some(180),
558        };
559        let json = serde_json::to_string(&um).unwrap();
560        let parsed: UsageMetadata = serde_json::from_str(&json).unwrap();
561        assert_eq!(parsed, um);
562    }
563
564    #[test]
565    fn test_usage_metadata_defaults() {
566        let um: UsageMetadata = serde_json::from_str("{}").unwrap();
567        assert!(um.prompt_token_count.is_none());
568        assert!(um.total_token_count.is_none());
569    }
570
571    #[test]
572    fn test_step_full_roundtrip() {
573        let step = Step {
574            id: "traj:0".to_string(),
575            step_index: 3,
576            step_type: StepType::ToolCall,
577            source: StepSource::Model,
578            target: StepTarget::Environment,
579            status: StepStatus::Done,
580            content: "Running command...".to_string(),
581            content_delta: "Running".to_string(),
582            thinking: "I should run the command".to_string(),
583            thinking_delta: "I should".to_string(),
584            tool_calls: vec![ToolCallInfo {
585                name: "run_command".to_string(),
586                args: serde_json::json!({"command": "ls -la"}),
587                id: Some("call_1".to_string()),
588                canonical_path: None,
589            }],
590            error: String::new(),
591            is_complete_response: Some(false),
592            structured_output: None,
593            usage_metadata: Some(UsageMetadata {
594                prompt_token_count: Some(500),
595                cached_content_token_count: None,
596                candidates_token_count: Some(100),
597                thoughts_token_count: Some(50),
598                total_token_count: Some(650),
599            }),
600        };
601
602        let json = serde_json::to_string_pretty(&step).unwrap();
603        let parsed: Step = serde_json::from_str(&json).unwrap();
604        assert_eq!(parsed, step);
605        assert_eq!(parsed.tool_calls.len(), 1);
606        assert_eq!(parsed.tool_calls[0].name, "run_command");
607    }
608
609    #[test]
610    fn test_step_minimal_deserialization() {
611        // Should deserialize with all defaults.
612        let json = r#"{"id":"s1"}"#;
613        let step: Step = serde_json::from_str(json).unwrap();
614        assert_eq!(step.id, "s1");
615        assert_eq!(step.step_index, 0);
616        assert_eq!(step.step_type, StepType::Unknown);
617        assert_eq!(step.source, StepSource::Unknown);
618        assert_eq!(step.target, StepTarget::Unknown);
619        assert_eq!(step.status, StepStatus::Unknown);
620        assert!(step.content.is_empty());
621        assert!(step.content_delta.is_empty());
622        assert!(step.thinking.is_empty());
623        assert!(step.thinking_delta.is_empty());
624        assert!(step.tool_calls.is_empty());
625        assert!(step.error.is_empty());
626        assert!(step.is_complete_response.is_none());
627        assert!(step.structured_output.is_none());
628        assert!(step.usage_metadata.is_none());
629    }
630
631    // =========================================================================
632    // Step with multiple tool calls
633    // =========================================================================
634
635    #[test]
636    fn step_with_multiple_tool_calls() {
637        let step = Step {
638            id: "multi-tc".to_string(),
639            step_index: 7,
640            step_type: StepType::ToolCall,
641            source: StepSource::Model,
642            target: StepTarget::Environment,
643            status: StepStatus::Done,
644            content: String::new(),
645            content_delta: String::new(),
646            thinking: String::new(),
647            thinking_delta: String::new(),
648            tool_calls: vec![
649                ToolCallInfo {
650                    name: "view_file".to_string(),
651                    args: serde_json::json!({"path": "/a.rs"}),
652                    id: Some("tc1".to_string()),
653                    canonical_path: Some("/a.rs".to_string()),
654                },
655                ToolCallInfo {
656                    name: "run_command".to_string(),
657                    args: serde_json::json!({"command": "cargo test"}),
658                    id: Some("tc2".to_string()),
659                    canonical_path: None,
660                },
661            ],
662            error: String::new(),
663            is_complete_response: None,
664            structured_output: None,
665            usage_metadata: None,
666        };
667        let json = serde_json::to_string(&step).unwrap();
668        let parsed: Step = serde_json::from_str(&json).unwrap();
669        assert_eq!(parsed.tool_calls.len(), 2);
670        assert_eq!(parsed.tool_calls[0].name, "view_file");
671        assert_eq!(parsed.tool_calls[1].name, "run_command");
672        assert_eq!(
673            parsed.tool_calls[0].canonical_path.as_deref(),
674            Some("/a.rs")
675        );
676        assert!(parsed.tool_calls[1].canonical_path.is_none());
677    }
678
679    // =========================================================================
680    // ToolCallInfo / ToolResult edge cases
681    // =========================================================================
682
683    #[test]
684    fn tool_call_info_with_complex_args() {
685        let tc = ToolCallInfo {
686            name: "run_command".to_string(),
687            args: serde_json::json!({
688                "command": "cargo test",
689                "env": {"RUST_LOG": "debug"},
690                "timeout": 300,
691                "nested": [1, 2, {"deep": true}]
692            }),
693            id: None,
694            canonical_path: None,
695        };
696        let json = serde_json::to_string(&tc).unwrap();
697        let parsed: ToolCallInfo = serde_json::from_str(&json).unwrap();
698        assert_eq!(parsed.args["env"]["RUST_LOG"], "debug");
699        assert_eq!(parsed.args["nested"][2]["deep"], true);
700    }
701
702    #[test]
703    fn tool_result_with_complex_result() {
704        let tr = ToolResult {
705            name: "search_dir".to_string(),
706            id: Some("r1".to_string()),
707            result: serde_json::json!({
708                "matches": [
709                    {"file": "/src/main.rs", "line": 42},
710                    {"file": "/src/lib.rs", "line": 10},
711                ],
712                "total": 2
713            }),
714            error: None,
715        };
716        let json = serde_json::to_string(&tr).unwrap();
717        let parsed: ToolResult = serde_json::from_str(&json).unwrap();
718        assert_eq!(parsed.result["total"], 2);
719        assert_eq!(parsed.result["matches"][0]["line"], 42);
720    }
721
722    // =========================================================================
723    // UsageMetadata partial fields
724    // =========================================================================
725
726    #[test]
727    fn usage_metadata_partial_fields() {
728        let json = r#"{"prompt_token_count":100,"total_token_count":200}"#;
729        let um: UsageMetadata = serde_json::from_str(json).unwrap();
730        assert_eq!(um.prompt_token_count, Some(100));
731        assert!(um.cached_content_token_count.is_none());
732        assert!(um.candidates_token_count.is_none());
733        assert!(um.thoughts_token_count.is_none());
734        assert_eq!(um.total_token_count, Some(200));
735    }
736
737    // =========================================================================
738    // StepTarget tests
739    // =========================================================================
740
741    #[test]
742    fn test_step_target_roundtrip() {
743        for (variant, expected_str) in [
744            (StepTarget::User, "\"TARGET_USER\""),
745            (StepTarget::Environment, "\"TARGET_ENVIRONMENT\""),
746            (StepTarget::Unspecified, "\"TARGET_UNSPECIFIED\""),
747            (StepTarget::Unknown, "\"UNKNOWN\""),
748        ] {
749            let json = serde_json::to_string(&variant).unwrap();
750            assert_eq!(
751                json, expected_str,
752                "StepTarget serialization mismatch for {variant:?}"
753            );
754            let parsed: StepTarget = serde_json::from_str(&json).unwrap();
755            assert_eq!(parsed, variant);
756        }
757    }
758
759    #[test]
760    fn test_step_target_parse() {
761        assert_eq!(
762            "TARGET_MODEL".parse::<StepTarget>().unwrap(),
763            StepTarget::Model
764        );
765        assert_eq!(
766            "TARGET_USER".parse::<StepTarget>().unwrap(),
767            StepTarget::User
768        );
769        assert_eq!(
770            "TARGET_ENVIRONMENT".parse::<StepTarget>().unwrap(),
771            StepTarget::Environment
772        );
773        assert_eq!(
774            "TARGET_UNSPECIFIED".parse::<StepTarget>().unwrap(),
775            StepTarget::Unspecified
776        );
777        assert_eq!(
778            "UNKNOWN".parse::<StepTarget>().unwrap(),
779            StepTarget::Unknown
780        );
781    }
782
783    #[test]
784    fn test_step_target_parse_returns_err_for_unrecognized() {
785        assert!("INVALID_TARGET".parse::<StepTarget>().is_err());
786    }
787
788    // =========================================================================
789    // Display trait tests
790    // =========================================================================
791
792    #[test]
793    fn test_step_type_display() {
794        assert_eq!(StepType::TextResponse.to_string(), "TEXT_RESPONSE");
795        assert_eq!(StepType::ToolCall.to_string(), "TOOL_CALL");
796        assert_eq!(StepType::SystemMessage.to_string(), "SYSTEM_MESSAGE");
797        assert_eq!(StepType::Compaction.to_string(), "COMPACTION");
798        assert_eq!(StepType::Finish.to_string(), "FINISH");
799        assert_eq!(StepType::Unknown.to_string(), "UNKNOWN");
800    }
801
802    #[test]
803    fn test_step_source_display() {
804        assert_eq!(StepSource::System.to_string(), "SYSTEM");
805        assert_eq!(StepSource::User.to_string(), "USER");
806        assert_eq!(StepSource::Model.to_string(), "MODEL");
807        assert_eq!(StepSource::Unknown.to_string(), "UNKNOWN");
808    }
809
810    #[test]
811    fn test_step_status_display() {
812        assert_eq!(StepStatus::Active.to_string(), "ACTIVE");
813        assert_eq!(StepStatus::Done.to_string(), "DONE");
814        assert_eq!(StepStatus::WaitingForUser.to_string(), "WAITING_FOR_USER");
815        assert_eq!(StepStatus::Error.to_string(), "ERROR");
816        assert_eq!(StepStatus::Canceled.to_string(), "CANCELED");
817        assert_eq!(StepStatus::Unknown.to_string(), "UNKNOWN");
818    }
819
820    #[test]
821    fn test_step_target_display() {
822        assert_eq!(StepTarget::User.to_string(), "TARGET_USER");
823        assert_eq!(StepTarget::Environment.to_string(), "TARGET_ENVIRONMENT");
824        assert_eq!(StepTarget::Unspecified.to_string(), "TARGET_UNSPECIFIED");
825        assert_eq!(StepTarget::Unknown.to_string(), "UNKNOWN");
826    }
827
828    // =========================================================================
829    // Display → FromStr roundtrip tests
830    // =========================================================================
831
832    #[test]
833    fn test_step_type_display_from_str_roundtrip() {
834        for variant in [
835            StepType::TextResponse,
836            StepType::ToolCall,
837            StepType::SystemMessage,
838            StepType::Compaction,
839            StepType::Finish,
840            StepType::Unknown,
841        ] {
842            let s = variant.to_string();
843            let parsed: StepType = s.parse().unwrap();
844            assert_eq!(parsed, variant, "roundtrip failed for {variant:?}");
845        }
846    }
847
848    #[test]
849    fn test_step_source_display_from_str_roundtrip() {
850        for variant in [
851            StepSource::System,
852            StepSource::User,
853            StepSource::Model,
854            StepSource::Unknown,
855        ] {
856            let s = variant.to_string();
857            let parsed: StepSource = s.parse().unwrap();
858            assert_eq!(parsed, variant, "roundtrip failed for {variant:?}");
859        }
860    }
861
862    #[test]
863    fn test_step_status_display_from_str_roundtrip() {
864        for variant in [
865            StepStatus::Active,
866            StepStatus::Done,
867            StepStatus::WaitingForUser,
868            StepStatus::Error,
869            StepStatus::Canceled,
870            StepStatus::Unknown,
871        ] {
872            let s = variant.to_string();
873            let parsed: StepStatus = s.parse().unwrap();
874            assert_eq!(parsed, variant, "roundtrip failed for {variant:?}");
875        }
876    }
877
878    #[test]
879    fn test_step_target_display_from_str_roundtrip() {
880        for variant in [
881            StepTarget::Model,
882            StepTarget::User,
883            StepTarget::Environment,
884            StepTarget::Unspecified,
885            StepTarget::Unknown,
886        ] {
887            let s = variant.to_string();
888            let parsed: StepTarget = s.parse().unwrap();
889            assert_eq!(parsed, variant, "roundtrip failed for {variant:?}");
890        }
891    }
892
893    // =========================================================================
894    // FromStr with garbage input tests
895    // =========================================================================
896
897    #[test]
898    fn test_from_str_garbage_returns_err() {
899        assert!("xyzzy".parse::<StepType>().is_err());
900        assert!("xyzzy".parse::<StepSource>().is_err());
901        assert!("xyzzy".parse::<StepStatus>().is_err());
902        assert!("xyzzy".parse::<StepTarget>().is_err());
903    }
904
905    #[test]
906    fn test_from_str_empty_returns_err() {
907        assert!("".parse::<StepType>().is_err());
908        assert!("".parse::<StepSource>().is_err());
909        assert!("".parse::<StepStatus>().is_err());
910        assert!("".parse::<StepTarget>().is_err());
911    }
912
913    #[test]
914    fn test_from_str_case_sensitive() {
915        // SDK strings are case-sensitive — lowercase should return Err.
916        assert!("text_response".parse::<StepType>().is_err());
917        assert!("system".parse::<StepSource>().is_err());
918        assert!("active".parse::<StepStatus>().is_err());
919        assert!("target_user".parse::<StepTarget>().is_err());
920    }
921
922    // =========================================================================
923    // MessageRole / ConversationMessage Tests
924    // =========================================================================
925
926    #[test]
927    fn test_message_role_roundtrip() {
928        for (variant, expected_str) in [
929            (MessageRole::User, "\"user\""),
930            (MessageRole::Model, "\"model\""),
931            (MessageRole::System, "\"system\""),
932            (MessageRole::Unknown("custom".to_string()), "\"custom\""),
933        ] {
934            let json = serde_json::to_string(&variant).unwrap();
935            assert_eq!(json, expected_str);
936            let parsed: MessageRole = serde_json::from_str(&json).unwrap();
937            assert_eq!(parsed, variant);
938        }
939    }
940
941    #[test]
942    fn test_conversation_message_roundtrip() {
943        let msg = ConversationMessage {
944            role: MessageRole::Model,
945            content: "Hello!".to_string(),
946        };
947        let json = serde_json::to_string(&msg).unwrap();
948        let parsed: ConversationMessage = serde_json::from_str(&json).unwrap();
949        assert_eq!(parsed, msg);
950    }
951
952    // =========================================================================
953    // PyO3 Extract Tests
954    // =========================================================================
955
956    #[test]
957    fn test_pyo3_extract_roundtrip() {
958        use pyo3::{prelude::*, types::PyDictMethods};
959        pyo3::Python::initialize();
960        pyo3::Python::attach(|py| {
961            let dict = pyo3::types::PyDict::new(py);
962            dict.set_item("id", "step-1").unwrap();
963            dict.set_item("step_index", 42).unwrap();
964            dict.set_item("type", "TEXT_RESPONSE").unwrap();
965            dict.set_item("source", "MODEL").unwrap();
966            dict.set_item("target", "TARGET_USER").unwrap();
967            dict.set_item("status", "DONE").unwrap();
968
969            let step: Step = dict.extract().expect("failed to extract Step");
970            assert_eq!(step.id, "step-1");
971            assert_eq!(step.step_index, 42);
972            assert_eq!(step.step_type, StepType::TextResponse);
973            assert_eq!(step.source, StepSource::Model);
974            assert_eq!(step.target, StepTarget::User);
975            assert_eq!(step.status, StepStatus::Done);
976
977            // Now test an enum
978            let s = pyo3::types::PyString::new(py, "SYSTEM_MESSAGE");
979            let st: StepType = s.extract().unwrap();
980            assert_eq!(st, StepType::SystemMessage);
981        });
982    }
983
984    // =========================================================================
985    // Step new fields tests
986    // =========================================================================
987
988    #[test]
989    fn step_with_deltas_and_thinking() {
990        let step = Step {
991            id: "s2".to_string(),
992            step_index: 1,
993            step_type: StepType::TextResponse,
994            source: StepSource::Model,
995            target: StepTarget::User,
996            status: StepStatus::Active,
997            content: "Hello world".to_string(),
998            content_delta: "world".to_string(),
999            thinking: "The user said hi".to_string(),
1000            thinking_delta: "said hi".to_string(),
1001            tool_calls: vec![],
1002            error: String::new(),
1003            is_complete_response: Some(true),
1004            structured_output: None,
1005            usage_metadata: None,
1006        };
1007        let json = serde_json::to_string(&step).unwrap();
1008        let parsed: Step = serde_json::from_str(&json).unwrap();
1009        assert_eq!(parsed.content_delta, "world");
1010        assert_eq!(parsed.thinking, "The user said hi");
1011        assert_eq!(parsed.thinking_delta, "said hi");
1012        assert_eq!(parsed.is_complete_response, Some(true));
1013        assert_eq!(parsed.target, StepTarget::User);
1014    }
1015
1016    #[test]
1017    fn step_with_structured_output() {
1018        let payload = serde_json::json!({"answer": 42, "valid": true});
1019        let step = Step {
1020            id: "finish-1".to_string(),
1021            step_index: 5,
1022            step_type: StepType::Finish,
1023            source: StepSource::Model,
1024            target: StepTarget::User,
1025            status: StepStatus::Done,
1026            content: String::new(),
1027            content_delta: String::new(),
1028            thinking: String::new(),
1029            thinking_delta: String::new(),
1030            tool_calls: vec![],
1031            error: String::new(),
1032            is_complete_response: Some(true),
1033            structured_output: Some(payload.clone()),
1034            usage_metadata: None,
1035        };
1036        let json = serde_json::to_string(&step).unwrap();
1037        let parsed: Step = serde_json::from_str(&json).unwrap();
1038        assert_eq!(parsed.structured_output, Some(payload));
1039        assert_eq!(parsed.step_type, StepType::Finish);
1040    }
1041}