1use std::{fmt, str::FromStr};
11
12use serde::{Deserialize, Serialize};
13
14macro_rules! define_sdk_enum {
38 (
39 $(#[$meta:meta])*
40 $name:ident {
41 $(
42 $(#[$vmeta:meta])*
43 $variant:ident => $wire:literal
44 ),+ $(,)?
45 }
46 ) => {
47 $(#[$meta])*
48 #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
49 #[serde(rename_all = "SCREAMING_SNAKE_CASE")]
50 pub enum $name {
51 $(
52 $(#[$vmeta])*
53 $variant,
54 )+
55 }
56
57 impl fmt::Display for $name {
58 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
59 let s = match self {
60 $( Self::$variant => $wire, )+
61 };
62 f.write_str(s)
63 }
64 }
65
66 impl FromStr for $name {
67 type Err = String;
68
69 fn from_str(s: &str) -> Result<Self, Self::Err> {
70 match s {
71 $( $wire => Ok(Self::$variant), )+
72 other => Err(format!(concat!("Unrecognized ", stringify!($name), ": {:?}"), other)),
73 }
74 }
75 }
76 };
77}
78
79macro_rules! define_sdk_enum_custom_serde {
85 (
86 $(#[$meta:meta])*
87 $name:ident {
88 $(
89 $(#[$vmeta:meta])*
90 $variant:ident => $wire:literal
91 ),+ $(,)?
92 }
93 ) => {
94 $(#[$meta])*
95 #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
96 pub enum $name {
97 $(
98 $(#[$vmeta])*
99 #[serde(rename = $wire)]
100 $variant,
101 )+
102 }
103
104 impl fmt::Display for $name {
105 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
106 let s = match self {
107 $( Self::$variant => $wire, )+
108 };
109 f.write_str(s)
110 }
111 }
112
113 impl FromStr for $name {
114 type Err = String;
115
116 fn from_str(s: &str) -> Result<Self, Self::Err> {
117 match s {
118 $( $wire => Ok(Self::$variant), )+
119 other => Err(format!(concat!("Unrecognized ", stringify!($name), ": {:?}"), other)),
120 }
121 }
122 }
123 };
124}
125
126define_sdk_enum! {
127 StepType {
129 TextResponse => "TEXT_RESPONSE",
131 ToolCall => "TOOL_CALL",
133 SystemMessage => "SYSTEM_MESSAGE",
135 Compaction => "COMPACTION",
137 Finish => "FINISH",
139 #[default]
141 Unknown => "UNKNOWN",
142 }
143}
144
145define_sdk_enum! {
146 StepSource {
148 System => "SYSTEM",
150 User => "USER",
152 Model => "MODEL",
154 #[default]
156 Unknown => "UNKNOWN",
157 }
158}
159
160define_sdk_enum! {
161 StepStatus {
163 Active => "ACTIVE",
165 Done => "DONE",
167 WaitingForUser => "WAITING_FOR_USER",
169 Error => "ERROR",
171 Canceled => "CANCELED",
173 #[default]
175 Unknown => "UNKNOWN",
176 }
177}
178
179define_sdk_enum_custom_serde! {
180 StepTarget {
186 Model => "TARGET_MODEL",
188 User => "TARGET_USER",
190 Environment => "TARGET_ENVIRONMENT",
192 Unspecified => "TARGET_UNSPECIFIED",
194 #[default]
196 Unknown => "UNKNOWN",
197 }
198}
199
200#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
202pub struct ToolCallInfo {
203 pub name: String,
205 #[serde(default)]
207 pub args: serde_json::Value,
208 #[serde(default)]
210 pub id: Option<String>,
211 #[serde(default)]
213 pub canonical_path: Option<String>,
214}
215
216#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
218pub struct ToolResult {
219 pub name: String,
221 #[serde(default)]
223 pub id: Option<String>,
224 #[serde(default)]
226 pub result: serde_json::Value,
227 #[serde(default)]
229 pub error: Option<String>,
230}
231
232#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
234pub struct UsageMetadata {
235 #[serde(default)]
237 pub prompt_token_count: Option<u64>,
238 #[serde(default)]
240 pub cached_content_token_count: Option<u64>,
241 #[serde(default)]
243 pub candidates_token_count: Option<u64>,
244 #[serde(default)]
246 pub thoughts_token_count: Option<u64>,
247 #[serde(default)]
249 pub total_token_count: Option<u64>,
250}
251
252#[non_exhaustive]
254#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
255#[serde(rename_all = "lowercase")]
256#[derive(Default)]
257pub enum MessageRole {
258 #[default]
260 User,
261 Model,
263 System,
265 #[serde(untagged)]
268 Unknown(String),
269}
270
271impl std::fmt::Display for MessageRole {
272 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
273 match self {
274 Self::User => f.write_str("user"),
275 Self::Model => f.write_str("model"),
276 Self::System => f.write_str("system"),
277 Self::Unknown(s) => f.write_str(s),
278 }
279 }
280}
281
282#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
287pub struct ConversationMessage {
288 #[serde(default)]
290 pub role: MessageRole,
291 #[serde(default)]
293 pub content: String,
294}
295
296#[non_exhaustive]
298#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
299pub struct Step {
300 #[serde(default)]
302 pub id: String,
303 #[serde(default)]
305 pub step_index: u32,
306 #[serde(default, rename = "type")]
308 pub step_type: StepType,
309 #[serde(default)]
311 pub source: StepSource,
312 #[serde(default)]
314 pub target: StepTarget,
315 #[serde(default)]
317 pub status: StepStatus,
318 #[serde(default)]
320 pub content: String,
321 #[serde(default)]
323 pub content_delta: String,
324 #[serde(default)]
326 pub thinking: String,
327 #[serde(default)]
329 pub thinking_delta: String,
330 #[serde(default)]
332 pub tool_calls: Vec<ToolCallInfo>,
333 #[serde(default)]
335 pub error: String,
336 #[serde(default)]
341 pub is_complete_response: Option<bool>,
342 #[serde(default)]
347 pub structured_output: Option<serde_json::Value>,
348 #[serde(default)]
350 pub usage_metadata: Option<UsageMetadata>,
351}
352
353macro_rules! impl_from_py_object {
354 ($($t:ty),+) => {
355 $(
356 impl<'a, 'py> pyo3::FromPyObject<'a, 'py> for $t {
357 type Error = pyo3::PyErr;
358
359 fn extract(ob: pyo3::Borrowed<'a, 'py, pyo3::PyAny>) -> pyo3::PyResult<Self> {
360 pythonize::depythonize(&*ob).map_err(|e| {
361 pyo3::exceptions::PyValueError::new_err(format!(
362 "Failed to deserialize {} from Python dict: {}",
363 stringify!($t),
364 e
365 ))
366 })
367 }
368 }
369 )+
370 };
371}
372
373impl_from_py_object!(
374 StepType,
375 StepSource,
376 StepStatus,
377 StepTarget,
378 ToolCallInfo,
379 ToolResult,
380 UsageMetadata,
381 MessageRole,
382 ConversationMessage,
383 Step
384);
385
386#[cfg(test)]
387mod tests {
388 use super::*;
389
390 #[test]
395 fn test_step_type_roundtrip() {
396 for (variant, expected_str) in [
397 (StepType::TextResponse, "\"TEXT_RESPONSE\""),
398 (StepType::ToolCall, "\"TOOL_CALL\""),
399 (StepType::SystemMessage, "\"SYSTEM_MESSAGE\""),
400 (StepType::Compaction, "\"COMPACTION\""),
401 (StepType::Finish, "\"FINISH\""),
402 (StepType::Unknown, "\"UNKNOWN\""),
403 ] {
404 let json = serde_json::to_string(&variant).unwrap();
405 assert_eq!(
406 json, expected_str,
407 "StepType serialization mismatch for {variant:?}"
408 );
409 let parsed: StepType = serde_json::from_str(&json).unwrap();
410 assert_eq!(parsed, variant);
411 }
412 }
413
414 #[test]
415 fn test_step_type_parse() {
416 assert_eq!(
417 "TEXT_RESPONSE".parse::<StepType>().unwrap(),
418 StepType::TextResponse
419 );
420 assert_eq!("TOOL_CALL".parse::<StepType>().unwrap(), StepType::ToolCall);
421 assert_eq!(
422 "SYSTEM_MESSAGE".parse::<StepType>().unwrap(),
423 StepType::SystemMessage
424 );
425 assert_eq!(
426 "COMPACTION".parse::<StepType>().unwrap(),
427 StepType::Compaction
428 );
429 assert_eq!("FINISH".parse::<StepType>().unwrap(), StepType::Finish);
430 }
431
432 #[test]
433 fn test_step_source_roundtrip() {
434 for (variant, expected_str) in [
435 (StepSource::System, "\"SYSTEM\""),
436 (StepSource::User, "\"USER\""),
437 (StepSource::Model, "\"MODEL\""),
438 (StepSource::Unknown, "\"UNKNOWN\""),
439 ] {
440 let json = serde_json::to_string(&variant).unwrap();
441 assert_eq!(json, expected_str);
442 let parsed: StepSource = serde_json::from_str(&json).unwrap();
443 assert_eq!(parsed, variant);
444 }
445 }
446
447 #[test]
448 fn test_step_source_parse() {
449 assert_eq!("SYSTEM".parse::<StepSource>().unwrap(), StepSource::System);
450 assert_eq!("USER".parse::<StepSource>().unwrap(), StepSource::User);
451 assert_eq!("MODEL".parse::<StepSource>().unwrap(), StepSource::Model);
452 }
453
454 #[test]
455 fn test_step_status_roundtrip() {
456 for (variant, expected_str) in [
457 (StepStatus::Active, "\"ACTIVE\""),
458 (StepStatus::Done, "\"DONE\""),
459 (StepStatus::WaitingForUser, "\"WAITING_FOR_USER\""),
460 (StepStatus::Error, "\"ERROR\""),
461 (StepStatus::Canceled, "\"CANCELED\""),
462 (StepStatus::Unknown, "\"UNKNOWN\""),
463 ] {
464 let json = serde_json::to_string(&variant).unwrap();
465 assert_eq!(json, expected_str);
466 let parsed: StepStatus = serde_json::from_str(&json).unwrap();
467 assert_eq!(parsed, variant);
468 }
469 }
470
471 #[test]
472 fn test_step_status_parse() {
473 assert_eq!("ACTIVE".parse::<StepStatus>().unwrap(), StepStatus::Active);
474 assert_eq!("DONE".parse::<StepStatus>().unwrap(), StepStatus::Done);
475 assert_eq!(
476 "WAITING_FOR_USER".parse::<StepStatus>().unwrap(),
477 StepStatus::WaitingForUser
478 );
479 assert_eq!("ERROR".parse::<StepStatus>().unwrap(), StepStatus::Error);
480 assert_eq!(
481 "CANCELED".parse::<StepStatus>().unwrap(),
482 StepStatus::Canceled
483 );
484 }
485
486 #[test]
487 fn test_step_type_parse_returns_err_for_unrecognized() {
488 assert!("NONEXISTENT".parse::<StepType>().is_err());
489 }
490
491 #[test]
492 fn test_step_source_parse_returns_err_for_unrecognized() {
493 assert!("???".parse::<StepSource>().is_err());
494 }
495
496 #[test]
497 fn test_step_status_parse_returns_err_for_unrecognized() {
498 assert!("nope".parse::<StepStatus>().is_err());
499 }
500
501 #[test]
502 fn test_tool_call_info_roundtrip() {
503 let tc = ToolCallInfo {
504 name: "view_file".to_string(),
505 args: serde_json::json!({"path": "/tmp/foo.rs", "line": 42}),
506 id: Some("call_123".to_string()),
507 canonical_path: Some("/tmp/foo.rs".to_string()),
508 };
509 let json = serde_json::to_string(&tc).unwrap();
510 let parsed: ToolCallInfo = serde_json::from_str(&json).unwrap();
511 assert_eq!(parsed, tc);
512 }
513
514 #[test]
515 fn test_tool_call_info_minimal() {
516 let json = r#"{"name":"custom_tool"}"#;
517 let parsed: ToolCallInfo = serde_json::from_str(json).unwrap();
518 assert_eq!(parsed.name, "custom_tool");
519 assert_eq!(parsed.args, serde_json::Value::Null);
520 assert!(parsed.id.is_none());
521 assert!(parsed.canonical_path.is_none());
522 }
523
524 #[test]
525 fn test_tool_result_roundtrip() {
526 let tr = ToolResult {
527 name: "run_command".to_string(),
528 id: Some("result_456".to_string()),
529 result: serde_json::json!({"output": "hello world"}),
530 error: None,
531 };
532 let json = serde_json::to_string(&tr).unwrap();
533 let parsed: ToolResult = serde_json::from_str(&json).unwrap();
534 assert_eq!(parsed, tr);
535 }
536
537 #[test]
538 fn test_tool_result_with_error() {
539 let tr = ToolResult {
540 name: "create_file".to_string(),
541 id: None,
542 result: serde_json::Value::Null,
543 error: Some("permission denied".to_string()),
544 };
545 let json = serde_json::to_string(&tr).unwrap();
546 let parsed: ToolResult = serde_json::from_str(&json).unwrap();
547 assert_eq!(parsed.error.as_deref(), Some("permission denied"));
548 }
549
550 #[test]
551 fn test_usage_metadata_roundtrip() {
552 let um = UsageMetadata {
553 prompt_token_count: Some(100),
554 cached_content_token_count: Some(20),
555 candidates_token_count: Some(50),
556 thoughts_token_count: Some(30),
557 total_token_count: Some(180),
558 };
559 let json = serde_json::to_string(&um).unwrap();
560 let parsed: UsageMetadata = serde_json::from_str(&json).unwrap();
561 assert_eq!(parsed, um);
562 }
563
564 #[test]
565 fn test_usage_metadata_defaults() {
566 let um: UsageMetadata = serde_json::from_str("{}").unwrap();
567 assert!(um.prompt_token_count.is_none());
568 assert!(um.total_token_count.is_none());
569 }
570
571 #[test]
572 fn test_step_full_roundtrip() {
573 let step = Step {
574 id: "traj:0".to_string(),
575 step_index: 3,
576 step_type: StepType::ToolCall,
577 source: StepSource::Model,
578 target: StepTarget::Environment,
579 status: StepStatus::Done,
580 content: "Running command...".to_string(),
581 content_delta: "Running".to_string(),
582 thinking: "I should run the command".to_string(),
583 thinking_delta: "I should".to_string(),
584 tool_calls: vec![ToolCallInfo {
585 name: "run_command".to_string(),
586 args: serde_json::json!({"command": "ls -la"}),
587 id: Some("call_1".to_string()),
588 canonical_path: None,
589 }],
590 error: String::new(),
591 is_complete_response: Some(false),
592 structured_output: None,
593 usage_metadata: Some(UsageMetadata {
594 prompt_token_count: Some(500),
595 cached_content_token_count: None,
596 candidates_token_count: Some(100),
597 thoughts_token_count: Some(50),
598 total_token_count: Some(650),
599 }),
600 };
601
602 let json = serde_json::to_string_pretty(&step).unwrap();
603 let parsed: Step = serde_json::from_str(&json).unwrap();
604 assert_eq!(parsed, step);
605 assert_eq!(parsed.tool_calls.len(), 1);
606 assert_eq!(parsed.tool_calls[0].name, "run_command");
607 }
608
609 #[test]
610 fn test_step_minimal_deserialization() {
611 let json = r#"{"id":"s1"}"#;
613 let step: Step = serde_json::from_str(json).unwrap();
614 assert_eq!(step.id, "s1");
615 assert_eq!(step.step_index, 0);
616 assert_eq!(step.step_type, StepType::Unknown);
617 assert_eq!(step.source, StepSource::Unknown);
618 assert_eq!(step.target, StepTarget::Unknown);
619 assert_eq!(step.status, StepStatus::Unknown);
620 assert!(step.content.is_empty());
621 assert!(step.content_delta.is_empty());
622 assert!(step.thinking.is_empty());
623 assert!(step.thinking_delta.is_empty());
624 assert!(step.tool_calls.is_empty());
625 assert!(step.error.is_empty());
626 assert!(step.is_complete_response.is_none());
627 assert!(step.structured_output.is_none());
628 assert!(step.usage_metadata.is_none());
629 }
630
631 #[test]
636 fn step_with_multiple_tool_calls() {
637 let step = Step {
638 id: "multi-tc".to_string(),
639 step_index: 7,
640 step_type: StepType::ToolCall,
641 source: StepSource::Model,
642 target: StepTarget::Environment,
643 status: StepStatus::Done,
644 content: String::new(),
645 content_delta: String::new(),
646 thinking: String::new(),
647 thinking_delta: String::new(),
648 tool_calls: vec![
649 ToolCallInfo {
650 name: "view_file".to_string(),
651 args: serde_json::json!({"path": "/a.rs"}),
652 id: Some("tc1".to_string()),
653 canonical_path: Some("/a.rs".to_string()),
654 },
655 ToolCallInfo {
656 name: "run_command".to_string(),
657 args: serde_json::json!({"command": "cargo test"}),
658 id: Some("tc2".to_string()),
659 canonical_path: None,
660 },
661 ],
662 error: String::new(),
663 is_complete_response: None,
664 structured_output: None,
665 usage_metadata: None,
666 };
667 let json = serde_json::to_string(&step).unwrap();
668 let parsed: Step = serde_json::from_str(&json).unwrap();
669 assert_eq!(parsed.tool_calls.len(), 2);
670 assert_eq!(parsed.tool_calls[0].name, "view_file");
671 assert_eq!(parsed.tool_calls[1].name, "run_command");
672 assert_eq!(
673 parsed.tool_calls[0].canonical_path.as_deref(),
674 Some("/a.rs")
675 );
676 assert!(parsed.tool_calls[1].canonical_path.is_none());
677 }
678
679 #[test]
684 fn tool_call_info_with_complex_args() {
685 let tc = ToolCallInfo {
686 name: "run_command".to_string(),
687 args: serde_json::json!({
688 "command": "cargo test",
689 "env": {"RUST_LOG": "debug"},
690 "timeout": 300,
691 "nested": [1, 2, {"deep": true}]
692 }),
693 id: None,
694 canonical_path: None,
695 };
696 let json = serde_json::to_string(&tc).unwrap();
697 let parsed: ToolCallInfo = serde_json::from_str(&json).unwrap();
698 assert_eq!(parsed.args["env"]["RUST_LOG"], "debug");
699 assert_eq!(parsed.args["nested"][2]["deep"], true);
700 }
701
702 #[test]
703 fn tool_result_with_complex_result() {
704 let tr = ToolResult {
705 name: "search_dir".to_string(),
706 id: Some("r1".to_string()),
707 result: serde_json::json!({
708 "matches": [
709 {"file": "/src/main.rs", "line": 42},
710 {"file": "/src/lib.rs", "line": 10},
711 ],
712 "total": 2
713 }),
714 error: None,
715 };
716 let json = serde_json::to_string(&tr).unwrap();
717 let parsed: ToolResult = serde_json::from_str(&json).unwrap();
718 assert_eq!(parsed.result["total"], 2);
719 assert_eq!(parsed.result["matches"][0]["line"], 42);
720 }
721
722 #[test]
727 fn usage_metadata_partial_fields() {
728 let json = r#"{"prompt_token_count":100,"total_token_count":200}"#;
729 let um: UsageMetadata = serde_json::from_str(json).unwrap();
730 assert_eq!(um.prompt_token_count, Some(100));
731 assert!(um.cached_content_token_count.is_none());
732 assert!(um.candidates_token_count.is_none());
733 assert!(um.thoughts_token_count.is_none());
734 assert_eq!(um.total_token_count, Some(200));
735 }
736
737 #[test]
742 fn test_step_target_roundtrip() {
743 for (variant, expected_str) in [
744 (StepTarget::User, "\"TARGET_USER\""),
745 (StepTarget::Environment, "\"TARGET_ENVIRONMENT\""),
746 (StepTarget::Unspecified, "\"TARGET_UNSPECIFIED\""),
747 (StepTarget::Unknown, "\"UNKNOWN\""),
748 ] {
749 let json = serde_json::to_string(&variant).unwrap();
750 assert_eq!(
751 json, expected_str,
752 "StepTarget serialization mismatch for {variant:?}"
753 );
754 let parsed: StepTarget = serde_json::from_str(&json).unwrap();
755 assert_eq!(parsed, variant);
756 }
757 }
758
759 #[test]
760 fn test_step_target_parse() {
761 assert_eq!(
762 "TARGET_MODEL".parse::<StepTarget>().unwrap(),
763 StepTarget::Model
764 );
765 assert_eq!(
766 "TARGET_USER".parse::<StepTarget>().unwrap(),
767 StepTarget::User
768 );
769 assert_eq!(
770 "TARGET_ENVIRONMENT".parse::<StepTarget>().unwrap(),
771 StepTarget::Environment
772 );
773 assert_eq!(
774 "TARGET_UNSPECIFIED".parse::<StepTarget>().unwrap(),
775 StepTarget::Unspecified
776 );
777 assert_eq!(
778 "UNKNOWN".parse::<StepTarget>().unwrap(),
779 StepTarget::Unknown
780 );
781 }
782
783 #[test]
784 fn test_step_target_parse_returns_err_for_unrecognized() {
785 assert!("INVALID_TARGET".parse::<StepTarget>().is_err());
786 }
787
788 #[test]
793 fn test_step_type_display() {
794 assert_eq!(StepType::TextResponse.to_string(), "TEXT_RESPONSE");
795 assert_eq!(StepType::ToolCall.to_string(), "TOOL_CALL");
796 assert_eq!(StepType::SystemMessage.to_string(), "SYSTEM_MESSAGE");
797 assert_eq!(StepType::Compaction.to_string(), "COMPACTION");
798 assert_eq!(StepType::Finish.to_string(), "FINISH");
799 assert_eq!(StepType::Unknown.to_string(), "UNKNOWN");
800 }
801
802 #[test]
803 fn test_step_source_display() {
804 assert_eq!(StepSource::System.to_string(), "SYSTEM");
805 assert_eq!(StepSource::User.to_string(), "USER");
806 assert_eq!(StepSource::Model.to_string(), "MODEL");
807 assert_eq!(StepSource::Unknown.to_string(), "UNKNOWN");
808 }
809
810 #[test]
811 fn test_step_status_display() {
812 assert_eq!(StepStatus::Active.to_string(), "ACTIVE");
813 assert_eq!(StepStatus::Done.to_string(), "DONE");
814 assert_eq!(StepStatus::WaitingForUser.to_string(), "WAITING_FOR_USER");
815 assert_eq!(StepStatus::Error.to_string(), "ERROR");
816 assert_eq!(StepStatus::Canceled.to_string(), "CANCELED");
817 assert_eq!(StepStatus::Unknown.to_string(), "UNKNOWN");
818 }
819
820 #[test]
821 fn test_step_target_display() {
822 assert_eq!(StepTarget::User.to_string(), "TARGET_USER");
823 assert_eq!(StepTarget::Environment.to_string(), "TARGET_ENVIRONMENT");
824 assert_eq!(StepTarget::Unspecified.to_string(), "TARGET_UNSPECIFIED");
825 assert_eq!(StepTarget::Unknown.to_string(), "UNKNOWN");
826 }
827
828 #[test]
833 fn test_step_type_display_from_str_roundtrip() {
834 for variant in [
835 StepType::TextResponse,
836 StepType::ToolCall,
837 StepType::SystemMessage,
838 StepType::Compaction,
839 StepType::Finish,
840 StepType::Unknown,
841 ] {
842 let s = variant.to_string();
843 let parsed: StepType = s.parse().unwrap();
844 assert_eq!(parsed, variant, "roundtrip failed for {variant:?}");
845 }
846 }
847
848 #[test]
849 fn test_step_source_display_from_str_roundtrip() {
850 for variant in [
851 StepSource::System,
852 StepSource::User,
853 StepSource::Model,
854 StepSource::Unknown,
855 ] {
856 let s = variant.to_string();
857 let parsed: StepSource = s.parse().unwrap();
858 assert_eq!(parsed, variant, "roundtrip failed for {variant:?}");
859 }
860 }
861
862 #[test]
863 fn test_step_status_display_from_str_roundtrip() {
864 for variant in [
865 StepStatus::Active,
866 StepStatus::Done,
867 StepStatus::WaitingForUser,
868 StepStatus::Error,
869 StepStatus::Canceled,
870 StepStatus::Unknown,
871 ] {
872 let s = variant.to_string();
873 let parsed: StepStatus = s.parse().unwrap();
874 assert_eq!(parsed, variant, "roundtrip failed for {variant:?}");
875 }
876 }
877
878 #[test]
879 fn test_step_target_display_from_str_roundtrip() {
880 for variant in [
881 StepTarget::Model,
882 StepTarget::User,
883 StepTarget::Environment,
884 StepTarget::Unspecified,
885 StepTarget::Unknown,
886 ] {
887 let s = variant.to_string();
888 let parsed: StepTarget = s.parse().unwrap();
889 assert_eq!(parsed, variant, "roundtrip failed for {variant:?}");
890 }
891 }
892
893 #[test]
898 fn test_from_str_garbage_returns_err() {
899 assert!("xyzzy".parse::<StepType>().is_err());
900 assert!("xyzzy".parse::<StepSource>().is_err());
901 assert!("xyzzy".parse::<StepStatus>().is_err());
902 assert!("xyzzy".parse::<StepTarget>().is_err());
903 }
904
905 #[test]
906 fn test_from_str_empty_returns_err() {
907 assert!("".parse::<StepType>().is_err());
908 assert!("".parse::<StepSource>().is_err());
909 assert!("".parse::<StepStatus>().is_err());
910 assert!("".parse::<StepTarget>().is_err());
911 }
912
913 #[test]
914 fn test_from_str_case_sensitive() {
915 assert!("text_response".parse::<StepType>().is_err());
917 assert!("system".parse::<StepSource>().is_err());
918 assert!("active".parse::<StepStatus>().is_err());
919 assert!("target_user".parse::<StepTarget>().is_err());
920 }
921
922 #[test]
927 fn test_message_role_roundtrip() {
928 for (variant, expected_str) in [
929 (MessageRole::User, "\"user\""),
930 (MessageRole::Model, "\"model\""),
931 (MessageRole::System, "\"system\""),
932 (MessageRole::Unknown("custom".to_string()), "\"custom\""),
933 ] {
934 let json = serde_json::to_string(&variant).unwrap();
935 assert_eq!(json, expected_str);
936 let parsed: MessageRole = serde_json::from_str(&json).unwrap();
937 assert_eq!(parsed, variant);
938 }
939 }
940
941 #[test]
942 fn test_conversation_message_roundtrip() {
943 let msg = ConversationMessage {
944 role: MessageRole::Model,
945 content: "Hello!".to_string(),
946 };
947 let json = serde_json::to_string(&msg).unwrap();
948 let parsed: ConversationMessage = serde_json::from_str(&json).unwrap();
949 assert_eq!(parsed, msg);
950 }
951
952 #[test]
957 fn test_pyo3_extract_roundtrip() {
958 use pyo3::{prelude::*, types::PyDictMethods};
959 pyo3::Python::initialize();
960 pyo3::Python::attach(|py| {
961 let dict = pyo3::types::PyDict::new(py);
962 dict.set_item("id", "step-1").unwrap();
963 dict.set_item("step_index", 42).unwrap();
964 dict.set_item("type", "TEXT_RESPONSE").unwrap();
965 dict.set_item("source", "MODEL").unwrap();
966 dict.set_item("target", "TARGET_USER").unwrap();
967 dict.set_item("status", "DONE").unwrap();
968
969 let step: Step = dict.extract().expect("failed to extract Step");
970 assert_eq!(step.id, "step-1");
971 assert_eq!(step.step_index, 42);
972 assert_eq!(step.step_type, StepType::TextResponse);
973 assert_eq!(step.source, StepSource::Model);
974 assert_eq!(step.target, StepTarget::User);
975 assert_eq!(step.status, StepStatus::Done);
976
977 let s = pyo3::types::PyString::new(py, "SYSTEM_MESSAGE");
979 let st: StepType = s.extract().unwrap();
980 assert_eq!(st, StepType::SystemMessage);
981 });
982 }
983
984 #[test]
989 fn step_with_deltas_and_thinking() {
990 let step = Step {
991 id: "s2".to_string(),
992 step_index: 1,
993 step_type: StepType::TextResponse,
994 source: StepSource::Model,
995 target: StepTarget::User,
996 status: StepStatus::Active,
997 content: "Hello world".to_string(),
998 content_delta: "world".to_string(),
999 thinking: "The user said hi".to_string(),
1000 thinking_delta: "said hi".to_string(),
1001 tool_calls: vec![],
1002 error: String::new(),
1003 is_complete_response: Some(true),
1004 structured_output: None,
1005 usage_metadata: None,
1006 };
1007 let json = serde_json::to_string(&step).unwrap();
1008 let parsed: Step = serde_json::from_str(&json).unwrap();
1009 assert_eq!(parsed.content_delta, "world");
1010 assert_eq!(parsed.thinking, "The user said hi");
1011 assert_eq!(parsed.thinking_delta, "said hi");
1012 assert_eq!(parsed.is_complete_response, Some(true));
1013 assert_eq!(parsed.target, StepTarget::User);
1014 }
1015
1016 #[test]
1017 fn step_with_structured_output() {
1018 let payload = serde_json::json!({"answer": 42, "valid": true});
1019 let step = Step {
1020 id: "finish-1".to_string(),
1021 step_index: 5,
1022 step_type: StepType::Finish,
1023 source: StepSource::Model,
1024 target: StepTarget::User,
1025 status: StepStatus::Done,
1026 content: String::new(),
1027 content_delta: String::new(),
1028 thinking: String::new(),
1029 thinking_delta: String::new(),
1030 tool_calls: vec![],
1031 error: String::new(),
1032 is_complete_response: Some(true),
1033 structured_output: Some(payload.clone()),
1034 usage_metadata: None,
1035 };
1036 let json = serde_json::to_string(&step).unwrap();
1037 let parsed: Step = serde_json::from_str(&json).unwrap();
1038 assert_eq!(parsed.structured_output, Some(payload));
1039 assert_eq!(parsed.step_type, StepType::Finish);
1040 }
1041}