1use std::{fmt, str::FromStr};
11
12use serde::{Deserialize, Serialize};
13
14macro_rules! define_sdk_enum {
38 (
39 $(#[$meta:meta])*
40 $name:ident {
41 $(
42 $(#[$vmeta:meta])*
43 $variant:ident => $wire:literal
44 ),+ $(,)?
45 }
46 ) => {
47 $(#[$meta])*
48 #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
49 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
50 pub enum $name {
51 $(
52 $(#[$vmeta])*
53 $variant,
54 )+
55 }
56
57 impl fmt::Display for $name {
58 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
59 let s = match self {
60 $( Self::$variant => $wire, )+
61 };
62 f.write_str(s)
63 }
64 }
65
66 impl FromStr for $name {
67 type Err = String;
68
69 fn from_str(s: &str) -> Result<Self, Self::Err> {
70 match s {
71 $( $wire => Ok(Self::$variant), )+
72 other => Err(format!(concat!("Unrecognized ", stringify!($name), ": {:?}"), other)),
73 }
74 }
75 }
76 };
77}
78
79macro_rules! define_sdk_enum_custom_serde {
85 (
86 $(#[$meta:meta])*
87 $name:ident {
88 $(
89 $(#[$vmeta:meta])*
90 $variant:ident => $wire:literal
91 ),+ $(,)?
92 }
93 ) => {
94 $(#[$meta])*
95 #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
96 pub enum $name {
97 $(
98 $(#[$vmeta])*
99 #[serde(rename = $wire)]
100 $variant,
101 )+
102 }
103
104 impl fmt::Display for $name {
105 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
106 let s = match self {
107 $( Self::$variant => $wire, )+
108 };
109 f.write_str(s)
110 }
111 }
112
113 impl FromStr for $name {
114 type Err = String;
115
116 fn from_str(s: &str) -> Result<Self, Self::Err> {
117 match s {
118 $( $wire => Ok(Self::$variant), )+
119 other => Err(format!(concat!("Unrecognized ", stringify!($name), ": {:?}"), other)),
120 }
121 }
122 }
123 };
124}
125
126define_sdk_enum! {
127 StepType {
129 TextResponse => "TEXT_RESPONSE",
130 ToolCall => "TOOL_CALL",
131 SystemMessage => "SYSTEM_MESSAGE",
132 Compaction => "COMPACTION",
133 Finish => "FINISH",
134 #[default]
135 Unknown => "UNKNOWN",
136 }
137}
138
139define_sdk_enum! {
140 StepSource {
142 System => "SYSTEM",
143 User => "USER",
144 Model => "MODEL",
145 #[default]
146 Unknown => "UNKNOWN",
147 }
148}
149
150define_sdk_enum! {
151 StepStatus {
153 Active => "ACTIVE",
154 Done => "DONE",
155 WaitingForUser => "WAITING_FOR_USER",
156 Error => "ERROR",
157 Canceled => "CANCELED",
158 #[default]
159 Unknown => "UNKNOWN",
160 }
161}
162
163define_sdk_enum_custom_serde! {
164 StepTarget {
170 Model => "TARGET_MODEL",
172 User => "TARGET_USER",
174 Environment => "TARGET_ENVIRONMENT",
176 Unspecified => "TARGET_UNSPECIFIED",
178 #[default]
180 Unknown => "UNKNOWN",
181 }
182}
183
184#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
186pub struct ToolCallInfo {
187 pub name: String,
189 #[serde(default)]
191 pub args: serde_json::Value,
192 #[serde(default)]
194 pub id: Option<String>,
195 #[serde(default)]
197 pub canonical_path: Option<String>,
198}
199
200#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
202pub struct ToolResult {
203 pub name: String,
205 #[serde(default)]
207 pub id: Option<String>,
208 #[serde(default)]
210 pub result: serde_json::Value,
211 #[serde(default)]
213 pub error: Option<String>,
214}
215
216#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
218pub struct UsageMetadata {
219 #[serde(default)]
221 pub prompt_token_count: Option<u64>,
222 #[serde(default)]
224 pub cached_content_token_count: Option<u64>,
225 #[serde(default)]
227 pub candidates_token_count: Option<u64>,
228 #[serde(default)]
230 pub thoughts_token_count: Option<u64>,
231 #[serde(default)]
233 pub total_token_count: Option<u64>,
234}
235
236#[non_exhaustive]
238#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
239#[serde(rename_all = "lowercase")]
240#[derive(Default)]
241pub enum MessageRole {
242 #[default]
244 User,
245 Model,
247 System,
249 #[serde(untagged)]
252 Unknown(String),
253}
254
255impl std::fmt::Display for MessageRole {
256 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
257 match self {
258 Self::User => f.write_str("user"),
259 Self::Model => f.write_str("model"),
260 Self::System => f.write_str("system"),
261 Self::Unknown(s) => f.write_str(s),
262 }
263 }
264}
265
266#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
271pub struct ConversationMessage {
272 #[serde(default)]
274 pub role: MessageRole,
275 #[serde(default)]
277 pub content: String,
278}
279
280#[non_exhaustive]
282#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
283pub struct Step {
284 #[serde(default)]
286 pub id: String,
287 #[serde(default)]
289 pub step_index: u32,
290 #[serde(default, rename = "type")]
292 pub step_type: StepType,
293 #[serde(default)]
295 pub source: StepSource,
296 #[serde(default)]
298 pub target: StepTarget,
299 #[serde(default)]
301 pub status: StepStatus,
302 #[serde(default)]
304 pub content: String,
305 #[serde(default)]
307 pub content_delta: String,
308 #[serde(default)]
310 pub thinking: String,
311 #[serde(default)]
313 pub thinking_delta: String,
314 #[serde(default)]
316 pub tool_calls: Vec<ToolCallInfo>,
317 #[serde(default)]
319 pub error: String,
320 #[serde(default)]
325 pub is_complete_response: Option<bool>,
326 #[serde(default)]
331 pub structured_output: Option<serde_json::Value>,
332 #[serde(default)]
334 pub usage_metadata: Option<UsageMetadata>,
335}
336
337macro_rules! impl_from_py_object {
338 ($($t:ty),+) => {
339 $(
340 impl<'py> pyo3::FromPyObject<'py> for $t {
341 fn extract_bound(ob: &pyo3::Bound<'py, pyo3::PyAny>) -> pyo3::PyResult<Self> {
342 pythonize::depythonize(ob).map_err(|e| {
343 pyo3::exceptions::PyValueError::new_err(format!(
344 "Failed to deserialize {} from Python dict: {}",
345 stringify!($t),
346 e
347 ))
348 })
349 }
350 }
351 )+
352 };
353}
354
355impl_from_py_object!(
356 StepType,
357 StepSource,
358 StepStatus,
359 StepTarget,
360 ToolCallInfo,
361 ToolResult,
362 UsageMetadata,
363 MessageRole,
364 ConversationMessage,
365 Step
366);
367
368#[cfg(test)]
369mod tests {
370 use super::*;
371
372 #[test]
377 fn test_step_type_roundtrip() {
378 for (variant, expected_str) in [
379 (StepType::TextResponse, "\"TEXT_RESPONSE\""),
380 (StepType::ToolCall, "\"TOOL_CALL\""),
381 (StepType::SystemMessage, "\"SYSTEM_MESSAGE\""),
382 (StepType::Compaction, "\"COMPACTION\""),
383 (StepType::Finish, "\"FINISH\""),
384 (StepType::Unknown, "\"UNKNOWN\""),
385 ] {
386 let json = serde_json::to_string(&variant).unwrap();
387 assert_eq!(
388 json, expected_str,
389 "StepType serialization mismatch for {variant:?}"
390 );
391 let parsed: StepType = serde_json::from_str(&json).unwrap();
392 assert_eq!(parsed, variant);
393 }
394 }
395
396 #[test]
397 fn test_step_type_parse() {
398 assert_eq!(
399 "TEXT_RESPONSE".parse::<StepType>().unwrap(),
400 StepType::TextResponse
401 );
402 assert_eq!("TOOL_CALL".parse::<StepType>().unwrap(), StepType::ToolCall);
403 assert_eq!(
404 "SYSTEM_MESSAGE".parse::<StepType>().unwrap(),
405 StepType::SystemMessage
406 );
407 assert_eq!(
408 "COMPACTION".parse::<StepType>().unwrap(),
409 StepType::Compaction
410 );
411 assert_eq!("FINISH".parse::<StepType>().unwrap(), StepType::Finish);
412 }
413
414 #[test]
415 fn test_step_source_roundtrip() {
416 for (variant, expected_str) in [
417 (StepSource::System, "\"SYSTEM\""),
418 (StepSource::User, "\"USER\""),
419 (StepSource::Model, "\"MODEL\""),
420 (StepSource::Unknown, "\"UNKNOWN\""),
421 ] {
422 let json = serde_json::to_string(&variant).unwrap();
423 assert_eq!(json, expected_str);
424 let parsed: StepSource = serde_json::from_str(&json).unwrap();
425 assert_eq!(parsed, variant);
426 }
427 }
428
429 #[test]
430 fn test_step_source_parse() {
431 assert_eq!("SYSTEM".parse::<StepSource>().unwrap(), StepSource::System);
432 assert_eq!("USER".parse::<StepSource>().unwrap(), StepSource::User);
433 assert_eq!("MODEL".parse::<StepSource>().unwrap(), StepSource::Model);
434 }
435
436 #[test]
437 fn test_step_status_roundtrip() {
438 for (variant, expected_str) in [
439 (StepStatus::Active, "\"ACTIVE\""),
440 (StepStatus::Done, "\"DONE\""),
441 (StepStatus::WaitingForUser, "\"WAITING_FOR_USER\""),
442 (StepStatus::Error, "\"ERROR\""),
443 (StepStatus::Canceled, "\"CANCELED\""),
444 (StepStatus::Unknown, "\"UNKNOWN\""),
445 ] {
446 let json = serde_json::to_string(&variant).unwrap();
447 assert_eq!(json, expected_str);
448 let parsed: StepStatus = serde_json::from_str(&json).unwrap();
449 assert_eq!(parsed, variant);
450 }
451 }
452
453 #[test]
454 fn test_step_status_parse() {
455 assert_eq!("ACTIVE".parse::<StepStatus>().unwrap(), StepStatus::Active);
456 assert_eq!("DONE".parse::<StepStatus>().unwrap(), StepStatus::Done);
457 assert_eq!(
458 "WAITING_FOR_USER".parse::<StepStatus>().unwrap(),
459 StepStatus::WaitingForUser
460 );
461 assert_eq!("ERROR".parse::<StepStatus>().unwrap(), StepStatus::Error);
462 assert_eq!(
463 "CANCELED".parse::<StepStatus>().unwrap(),
464 StepStatus::Canceled
465 );
466 }
467
468 #[test]
469 fn test_step_type_parse_returns_err_for_unrecognized() {
470 assert!("NONEXISTENT".parse::<StepType>().is_err());
471 }
472
473 #[test]
474 fn test_step_source_parse_returns_err_for_unrecognized() {
475 assert!("???".parse::<StepSource>().is_err());
476 }
477
478 #[test]
479 fn test_step_status_parse_returns_err_for_unrecognized() {
480 assert!("nope".parse::<StepStatus>().is_err());
481 }
482
483 #[test]
484 fn test_tool_call_info_roundtrip() {
485 let tc = ToolCallInfo {
486 name: "view_file".to_string(),
487 args: serde_json::json!({"path": "/tmp/foo.rs", "line": 42}),
488 id: Some("call_123".to_string()),
489 canonical_path: Some("/tmp/foo.rs".to_string()),
490 };
491 let json = serde_json::to_string(&tc).unwrap();
492 let parsed: ToolCallInfo = serde_json::from_str(&json).unwrap();
493 assert_eq!(parsed, tc);
494 }
495
496 #[test]
497 fn test_tool_call_info_minimal() {
498 let json = r#"{"name":"custom_tool"}"#;
499 let parsed: ToolCallInfo = serde_json::from_str(json).unwrap();
500 assert_eq!(parsed.name, "custom_tool");
501 assert_eq!(parsed.args, serde_json::Value::Null);
502 assert!(parsed.id.is_none());
503 assert!(parsed.canonical_path.is_none());
504 }
505
506 #[test]
507 fn test_tool_result_roundtrip() {
508 let tr = ToolResult {
509 name: "run_command".to_string(),
510 id: Some("result_456".to_string()),
511 result: serde_json::json!({"output": "hello world"}),
512 error: None,
513 };
514 let json = serde_json::to_string(&tr).unwrap();
515 let parsed: ToolResult = serde_json::from_str(&json).unwrap();
516 assert_eq!(parsed, tr);
517 }
518
519 #[test]
520 fn test_tool_result_with_error() {
521 let tr = ToolResult {
522 name: "create_file".to_string(),
523 id: None,
524 result: serde_json::Value::Null,
525 error: Some("permission denied".to_string()),
526 };
527 let json = serde_json::to_string(&tr).unwrap();
528 let parsed: ToolResult = serde_json::from_str(&json).unwrap();
529 assert_eq!(parsed.error.as_deref(), Some("permission denied"));
530 }
531
532 #[test]
533 fn test_usage_metadata_roundtrip() {
534 let um = UsageMetadata {
535 prompt_token_count: Some(100),
536 cached_content_token_count: Some(20),
537 candidates_token_count: Some(50),
538 thoughts_token_count: Some(30),
539 total_token_count: Some(180),
540 };
541 let json = serde_json::to_string(&um).unwrap();
542 let parsed: UsageMetadata = serde_json::from_str(&json).unwrap();
543 assert_eq!(parsed, um);
544 }
545
546 #[test]
547 fn test_usage_metadata_defaults() {
548 let um: UsageMetadata = serde_json::from_str("{}").unwrap();
549 assert!(um.prompt_token_count.is_none());
550 assert!(um.total_token_count.is_none());
551 }
552
553 #[test]
554 fn test_step_full_roundtrip() {
555 let step = Step {
556 id: "traj:0".to_string(),
557 step_index: 3,
558 step_type: StepType::ToolCall,
559 source: StepSource::Model,
560 target: StepTarget::Environment,
561 status: StepStatus::Done,
562 content: "Running command...".to_string(),
563 content_delta: "Running".to_string(),
564 thinking: "I should run the command".to_string(),
565 thinking_delta: "I should".to_string(),
566 tool_calls: vec![ToolCallInfo {
567 name: "run_command".to_string(),
568 args: serde_json::json!({"command": "ls -la"}),
569 id: Some("call_1".to_string()),
570 canonical_path: None,
571 }],
572 error: String::new(),
573 is_complete_response: Some(false),
574 structured_output: None,
575 usage_metadata: Some(UsageMetadata {
576 prompt_token_count: Some(500),
577 cached_content_token_count: None,
578 candidates_token_count: Some(100),
579 thoughts_token_count: Some(50),
580 total_token_count: Some(650),
581 }),
582 };
583
584 let json = serde_json::to_string_pretty(&step).unwrap();
585 let parsed: Step = serde_json::from_str(&json).unwrap();
586 assert_eq!(parsed, step);
587 assert_eq!(parsed.tool_calls.len(), 1);
588 assert_eq!(parsed.tool_calls[0].name, "run_command");
589 }
590
591 #[test]
592 fn test_step_minimal_deserialization() {
593 let json = r#"{"id":"s1"}"#;
595 let step: Step = serde_json::from_str(json).unwrap();
596 assert_eq!(step.id, "s1");
597 assert_eq!(step.step_index, 0);
598 assert_eq!(step.step_type, StepType::Unknown);
599 assert_eq!(step.source, StepSource::Unknown);
600 assert_eq!(step.target, StepTarget::Unknown);
601 assert_eq!(step.status, StepStatus::Unknown);
602 assert!(step.content.is_empty());
603 assert!(step.content_delta.is_empty());
604 assert!(step.thinking.is_empty());
605 assert!(step.thinking_delta.is_empty());
606 assert!(step.tool_calls.is_empty());
607 assert!(step.error.is_empty());
608 assert!(step.is_complete_response.is_none());
609 assert!(step.structured_output.is_none());
610 assert!(step.usage_metadata.is_none());
611 }
612
613 #[test]
618 fn step_with_multiple_tool_calls() {
619 let step = Step {
620 id: "multi-tc".to_string(),
621 step_index: 7,
622 step_type: StepType::ToolCall,
623 source: StepSource::Model,
624 target: StepTarget::Environment,
625 status: StepStatus::Done,
626 content: String::new(),
627 content_delta: String::new(),
628 thinking: String::new(),
629 thinking_delta: String::new(),
630 tool_calls: vec![
631 ToolCallInfo {
632 name: "view_file".to_string(),
633 args: serde_json::json!({"path": "/a.rs"}),
634 id: Some("tc1".to_string()),
635 canonical_path: Some("/a.rs".to_string()),
636 },
637 ToolCallInfo {
638 name: "run_command".to_string(),
639 args: serde_json::json!({"command": "cargo test"}),
640 id: Some("tc2".to_string()),
641 canonical_path: None,
642 },
643 ],
644 error: String::new(),
645 is_complete_response: None,
646 structured_output: None,
647 usage_metadata: None,
648 };
649 let json = serde_json::to_string(&step).unwrap();
650 let parsed: Step = serde_json::from_str(&json).unwrap();
651 assert_eq!(parsed.tool_calls.len(), 2);
652 assert_eq!(parsed.tool_calls[0].name, "view_file");
653 assert_eq!(parsed.tool_calls[1].name, "run_command");
654 assert_eq!(
655 parsed.tool_calls[0].canonical_path.as_deref(),
656 Some("/a.rs")
657 );
658 assert!(parsed.tool_calls[1].canonical_path.is_none());
659 }
660
661 #[test]
666 fn tool_call_info_with_complex_args() {
667 let tc = ToolCallInfo {
668 name: "run_command".to_string(),
669 args: serde_json::json!({
670 "command": "cargo test",
671 "env": {"RUST_LOG": "debug"},
672 "timeout": 300,
673 "nested": [1, 2, {"deep": true}]
674 }),
675 id: None,
676 canonical_path: None,
677 };
678 let json = serde_json::to_string(&tc).unwrap();
679 let parsed: ToolCallInfo = serde_json::from_str(&json).unwrap();
680 assert_eq!(parsed.args["env"]["RUST_LOG"], "debug");
681 assert_eq!(parsed.args["nested"][2]["deep"], true);
682 }
683
684 #[test]
685 fn tool_result_with_complex_result() {
686 let tr = ToolResult {
687 name: "search_dir".to_string(),
688 id: Some("r1".to_string()),
689 result: serde_json::json!({
690 "matches": [
691 {"file": "/src/main.rs", "line": 42},
692 {"file": "/src/lib.rs", "line": 10},
693 ],
694 "total": 2
695 }),
696 error: None,
697 };
698 let json = serde_json::to_string(&tr).unwrap();
699 let parsed: ToolResult = serde_json::from_str(&json).unwrap();
700 assert_eq!(parsed.result["total"], 2);
701 assert_eq!(parsed.result["matches"][0]["line"], 42);
702 }
703
704 #[test]
709 fn usage_metadata_partial_fields() {
710 let json = r#"{"prompt_token_count":100,"total_token_count":200}"#;
711 let um: UsageMetadata = serde_json::from_str(json).unwrap();
712 assert_eq!(um.prompt_token_count, Some(100));
713 assert!(um.cached_content_token_count.is_none());
714 assert!(um.candidates_token_count.is_none());
715 assert!(um.thoughts_token_count.is_none());
716 assert_eq!(um.total_token_count, Some(200));
717 }
718
719 #[test]
724 fn test_step_target_roundtrip() {
725 for (variant, expected_str) in [
726 (StepTarget::User, "\"TARGET_USER\""),
727 (StepTarget::Environment, "\"TARGET_ENVIRONMENT\""),
728 (StepTarget::Unspecified, "\"TARGET_UNSPECIFIED\""),
729 (StepTarget::Unknown, "\"UNKNOWN\""),
730 ] {
731 let json = serde_json::to_string(&variant).unwrap();
732 assert_eq!(
733 json, expected_str,
734 "StepTarget serialization mismatch for {variant:?}"
735 );
736 let parsed: StepTarget = serde_json::from_str(&json).unwrap();
737 assert_eq!(parsed, variant);
738 }
739 }
740
741 #[test]
742 fn test_step_target_parse() {
743 assert_eq!(
744 "TARGET_MODEL".parse::<StepTarget>().unwrap(),
745 StepTarget::Model
746 );
747 assert_eq!(
748 "TARGET_USER".parse::<StepTarget>().unwrap(),
749 StepTarget::User
750 );
751 assert_eq!(
752 "TARGET_ENVIRONMENT".parse::<StepTarget>().unwrap(),
753 StepTarget::Environment
754 );
755 assert_eq!(
756 "TARGET_UNSPECIFIED".parse::<StepTarget>().unwrap(),
757 StepTarget::Unspecified
758 );
759 assert_eq!(
760 "UNKNOWN".parse::<StepTarget>().unwrap(),
761 StepTarget::Unknown
762 );
763 }
764
765 #[test]
766 fn test_step_target_parse_returns_err_for_unrecognized() {
767 assert!("INVALID_TARGET".parse::<StepTarget>().is_err());
768 }
769
770 #[test]
775 fn test_step_type_display() {
776 assert_eq!(StepType::TextResponse.to_string(), "TEXT_RESPONSE");
777 assert_eq!(StepType::ToolCall.to_string(), "TOOL_CALL");
778 assert_eq!(StepType::SystemMessage.to_string(), "SYSTEM_MESSAGE");
779 assert_eq!(StepType::Compaction.to_string(), "COMPACTION");
780 assert_eq!(StepType::Finish.to_string(), "FINISH");
781 assert_eq!(StepType::Unknown.to_string(), "UNKNOWN");
782 }
783
784 #[test]
785 fn test_step_source_display() {
786 assert_eq!(StepSource::System.to_string(), "SYSTEM");
787 assert_eq!(StepSource::User.to_string(), "USER");
788 assert_eq!(StepSource::Model.to_string(), "MODEL");
789 assert_eq!(StepSource::Unknown.to_string(), "UNKNOWN");
790 }
791
792 #[test]
793 fn test_step_status_display() {
794 assert_eq!(StepStatus::Active.to_string(), "ACTIVE");
795 assert_eq!(StepStatus::Done.to_string(), "DONE");
796 assert_eq!(StepStatus::WaitingForUser.to_string(), "WAITING_FOR_USER");
797 assert_eq!(StepStatus::Error.to_string(), "ERROR");
798 assert_eq!(StepStatus::Canceled.to_string(), "CANCELED");
799 assert_eq!(StepStatus::Unknown.to_string(), "UNKNOWN");
800 }
801
802 #[test]
803 fn test_step_target_display() {
804 assert_eq!(StepTarget::User.to_string(), "TARGET_USER");
805 assert_eq!(StepTarget::Environment.to_string(), "TARGET_ENVIRONMENT");
806 assert_eq!(StepTarget::Unspecified.to_string(), "TARGET_UNSPECIFIED");
807 assert_eq!(StepTarget::Unknown.to_string(), "UNKNOWN");
808 }
809
810 #[test]
815 fn test_step_type_display_from_str_roundtrip() {
816 for variant in [
817 StepType::TextResponse,
818 StepType::ToolCall,
819 StepType::SystemMessage,
820 StepType::Compaction,
821 StepType::Finish,
822 StepType::Unknown,
823 ] {
824 let s = variant.to_string();
825 let parsed: StepType = s.parse().unwrap();
826 assert_eq!(parsed, variant, "roundtrip failed for {variant:?}");
827 }
828 }
829
830 #[test]
831 fn test_step_source_display_from_str_roundtrip() {
832 for variant in [
833 StepSource::System,
834 StepSource::User,
835 StepSource::Model,
836 StepSource::Unknown,
837 ] {
838 let s = variant.to_string();
839 let parsed: StepSource = s.parse().unwrap();
840 assert_eq!(parsed, variant, "roundtrip failed for {variant:?}");
841 }
842 }
843
844 #[test]
845 fn test_step_status_display_from_str_roundtrip() {
846 for variant in [
847 StepStatus::Active,
848 StepStatus::Done,
849 StepStatus::WaitingForUser,
850 StepStatus::Error,
851 StepStatus::Canceled,
852 StepStatus::Unknown,
853 ] {
854 let s = variant.to_string();
855 let parsed: StepStatus = s.parse().unwrap();
856 assert_eq!(parsed, variant, "roundtrip failed for {variant:?}");
857 }
858 }
859
860 #[test]
861 fn test_step_target_display_from_str_roundtrip() {
862 for variant in [
863 StepTarget::Model,
864 StepTarget::User,
865 StepTarget::Environment,
866 StepTarget::Unspecified,
867 StepTarget::Unknown,
868 ] {
869 let s = variant.to_string();
870 let parsed: StepTarget = s.parse().unwrap();
871 assert_eq!(parsed, variant, "roundtrip failed for {variant:?}");
872 }
873 }
874
875 #[test]
880 fn test_from_str_garbage_returns_err() {
881 assert!("xyzzy".parse::<StepType>().is_err());
882 assert!("xyzzy".parse::<StepSource>().is_err());
883 assert!("xyzzy".parse::<StepStatus>().is_err());
884 assert!("xyzzy".parse::<StepTarget>().is_err());
885 }
886
887 #[test]
888 fn test_from_str_empty_returns_err() {
889 assert!("".parse::<StepType>().is_err());
890 assert!("".parse::<StepSource>().is_err());
891 assert!("".parse::<StepStatus>().is_err());
892 assert!("".parse::<StepTarget>().is_err());
893 }
894
895 #[test]
896 fn test_from_str_case_sensitive() {
897 assert!("text_response".parse::<StepType>().is_err());
899 assert!("system".parse::<StepSource>().is_err());
900 assert!("active".parse::<StepStatus>().is_err());
901 assert!("target_user".parse::<StepTarget>().is_err());
902 }
903
904 #[test]
909 fn test_message_role_roundtrip() {
910 for (variant, expected_str) in [
911 (MessageRole::User, "\"user\""),
912 (MessageRole::Model, "\"model\""),
913 (MessageRole::System, "\"system\""),
914 (MessageRole::Unknown("custom".to_string()), "\"custom\""),
915 ] {
916 let json = serde_json::to_string(&variant).unwrap();
917 assert_eq!(json, expected_str);
918 let parsed: MessageRole = serde_json::from_str(&json).unwrap();
919 assert_eq!(parsed, variant);
920 }
921 }
922
923 #[test]
924 fn test_conversation_message_roundtrip() {
925 let msg = ConversationMessage {
926 role: MessageRole::Model,
927 content: "Hello!".to_string(),
928 };
929 let json = serde_json::to_string(&msg).unwrap();
930 let parsed: ConversationMessage = serde_json::from_str(&json).unwrap();
931 assert_eq!(parsed, msg);
932 }
933
934 #[test]
939 fn test_pyo3_extract_roundtrip() {
940 use pyo3::{prelude::*, types::PyDictMethods};
941 pyo3::prepare_freethreaded_python();
942 pyo3::Python::with_gil(|py| {
943 let dict = pyo3::types::PyDict::new_bound(py);
944 dict.set_item("id", "step-1").unwrap();
945 dict.set_item("step_index", 42).unwrap();
946 dict.set_item("type", "TEXT_RESPONSE").unwrap();
947 dict.set_item("source", "MODEL").unwrap();
948 dict.set_item("target", "TARGET_USER").unwrap();
949 dict.set_item("status", "DONE").unwrap();
950
951 let step: Step = dict.extract().expect("failed to extract Step");
952 assert_eq!(step.id, "step-1");
953 assert_eq!(step.step_index, 42);
954 assert_eq!(step.step_type, StepType::TextResponse);
955 assert_eq!(step.source, StepSource::Model);
956 assert_eq!(step.target, StepTarget::User);
957 assert_eq!(step.status, StepStatus::Done);
958
959 let s = pyo3::types::PyString::new_bound(py, "SYSTEM_MESSAGE");
961 let st: StepType = s.extract().unwrap();
962 assert_eq!(st, StepType::SystemMessage);
963 });
964 }
965
966 #[test]
971 fn step_with_deltas_and_thinking() {
972 let step = Step {
973 id: "s2".to_string(),
974 step_index: 1,
975 step_type: StepType::TextResponse,
976 source: StepSource::Model,
977 target: StepTarget::User,
978 status: StepStatus::Active,
979 content: "Hello world".to_string(),
980 content_delta: "world".to_string(),
981 thinking: "The user said hi".to_string(),
982 thinking_delta: "said hi".to_string(),
983 tool_calls: vec![],
984 error: String::new(),
985 is_complete_response: Some(true),
986 structured_output: None,
987 usage_metadata: None,
988 };
989 let json = serde_json::to_string(&step).unwrap();
990 let parsed: Step = serde_json::from_str(&json).unwrap();
991 assert_eq!(parsed.content_delta, "world");
992 assert_eq!(parsed.thinking, "The user said hi");
993 assert_eq!(parsed.thinking_delta, "said hi");
994 assert_eq!(parsed.is_complete_response, Some(true));
995 assert_eq!(parsed.target, StepTarget::User);
996 }
997
998 #[test]
999 fn step_with_structured_output() {
1000 let payload = serde_json::json!({"answer": 42, "valid": true});
1001 let step = Step {
1002 id: "finish-1".to_string(),
1003 step_index: 5,
1004 step_type: StepType::Finish,
1005 source: StepSource::Model,
1006 target: StepTarget::User,
1007 status: StepStatus::Done,
1008 content: String::new(),
1009 content_delta: String::new(),
1010 thinking: String::new(),
1011 thinking_delta: String::new(),
1012 tool_calls: vec![],
1013 error: String::new(),
1014 is_complete_response: Some(true),
1015 structured_output: Some(payload.clone()),
1016 usage_metadata: None,
1017 };
1018 let json = serde_json::to_string(&step).unwrap();
1019 let parsed: Step = serde_json::from_str(&json).unwrap();
1020 assert_eq!(parsed.structured_output, Some(payload));
1021 assert_eq!(parsed.step_type, StepType::Finish);
1022 }
1023}