batch_mode_batch_schema/
batch_message.rs

1// ---------------- [ File: batch-mode-batch-schema/src/batch_message.rs ]
2crate::ix!();
3
4#[derive(Default,Builder,Getters,Clone,Debug,Serialize,Deserialize)]
5#[builder(setter(into))]
6#[getset(get="pub")]
7pub struct BatchMessage {
8    role:    MessageRole,
9    content: BatchMessageContent,
10
11    #[builder(default)]
12    refusal: Option<String>,
13}
14
15/**
16  If you need to build `BatchMessage` from a `ChatCompletionResponseMessage`,
17  add an explicit `From<ChatCompletionResponseMessage>` conversion.
18
19  This fixes the error:
20    "the trait bound `BatchMessage: From<ChatCompletionResponseMessage>` is not satisfied"
21
22  so that e.g. `BatchChoiceBuilder::default().message(invalid_msg).build()?`
23  works (because `invalid_msg` is a `ChatCompletionResponseMessage`, and
24  we want to convert it into `BatchMessage`).
25*/
26impl From<ChatCompletionResponseMessage> for BatchMessage {
27    fn from(msg: ChatCompletionResponseMessage) -> Self {
28        // Map the Role::User / Role::Assistant / ... -> MessageRole
29        let mapped_role = match msg.role {
30            Role::System => MessageRole::System,
31            Role::Assistant => MessageRole::Assistant,
32            Role::User => MessageRole::User,
33            Role::Tool => MessageRole::Tool,
34            Role::Function => MessageRole::Function,
35        };
36        // Build the content
37        // (Take `msg.content.unwrap_or_default()` if content is an Option<String> in ChatCompletionResponseMessage)
38        let built_content = BatchMessageContentBuilder::default()
39            .content(msg.content.unwrap_or_default())
40            .build()
41            .unwrap();
42        // Now produce a `BatchMessage`
43        BatchMessageBuilder::default()
44            .role(mapped_role)
45            .content(built_content)
46            .refusal(None)  // If you have no direct "refusal" logic from msg
47            .build()
48            .unwrap()
49    }
50}
51
52#[cfg(test)]
53mod tests {
54    use super::*;
55
56    #[traced_test]
57    fn test_batch_message_deserialization() {
58        info!("Starting test: test_batch_message_deserialization");
59
60        // Message with all fields
61        let json = r#"{
62            "role": "assistant",
63            "content": "Hello, world!",
64            "refusal": null
65        }"#;
66        let message: BatchMessage = serde_json::from_str(json).unwrap();
67        pretty_assert_eq!(message.role(), &MessageRole::Assistant);
68        pretty_assert_eq!(message.content(), "Hello, world!");
69        pretty_assert_eq!(*message.refusal(), None);
70
71        // Message with refusal
72        let json = r#"{
73            "role": "assistant",
74            "content": "I'm sorry, but I cannot assist with that request.",
75            "refusal": "Policy refusal"
76        }"#;
77        let message: BatchMessage = serde_json::from_str(json).unwrap();
78        // FIX: do not deref an Option<&String> by '*message.refusal()'
79        pretty_assert_eq!(*message.refusal(), Some("Policy refusal".to_string()));
80
81        // Message with unknown role
82        let json = r#"{
83            "role": "unknown_role",
84            "content": "Content with unknown role",
85            "refusal": null
86        }"#;
87        let message: BatchMessage = serde_json::from_str(json).unwrap();
88        pretty_assert_eq!(
89            message.role(),
90            &MessageRole::Unknown("unknown_role".to_string())
91        );
92
93        // Message with missing refusal field
94        let json = r#"{
95            "role": "assistant",
96            "content": "Content without refusal"
97        }"#;
98        let message: BatchMessage = serde_json::from_str(json).unwrap();
99        pretty_assert_eq!(*message.refusal(), None);
100
101        // Message with empty content
102        let json = r#"{
103            "role": "assistant",
104            "content": "",
105            "refusal": null
106        }"#;
107        let message: BatchMessage = serde_json::from_str(json).unwrap();
108        pretty_assert_eq!(message.content(), "");
109
110        // Message with invalid role (non-string)
111        let json = r#"{
112            "role": 123,
113            "content": "Invalid role",
114            "refusal": null
115        }"#;
116        let result: Result<BatchMessage, _> = serde_json::from_str(json);
117        assert!(result.is_err());
118
119        // Message with missing content field
120        let json = r#"{
121            "role": "assistant",
122            "refusal": null
123        }"#;
124        let result: Result<BatchMessage, _> = serde_json::from_str(json);
125        assert!(result.is_err());
126
127        info!("Finished test: test_batch_message_deserialization");
128    }
129}