batch_mode_batch_schema/
batch_response_body.rs

1// ---------------- [ File: batch-mode-batch-schema/src/batch_response_body.rs ]
2crate::ix!();
3
4#[derive(Clone,Debug,Serialize)]
5//#[serde(tag = "object", content = "data")]
6pub enum BatchResponseBody {
7
8    #[serde(rename = "chat.completion")]
9    Success(BatchSuccessResponseBody),
10
11    #[serde(rename = "error")]
12    Error(BatchErrorResponseBody),
13}
14
15impl<'de> Deserialize<'de> for BatchResponseBody {
16    fn deserialize<D>(deserializer: D) -> Result<BatchResponseBody, D::Error>
17    where
18        D: Deserializer<'de>,
19    {
20        let value: serde_json::Value = Deserialize::deserialize(deserializer)?;
21
22        if value.get("error").is_some() {
23            let error_body = BatchErrorResponseBody::deserialize(&value)
24                .map_err(serde::de::Error::custom)?;
25            Ok(BatchResponseBody::Error(error_body))
26        } else {
27            let success_body = BatchSuccessResponseBody::deserialize(&value)
28                .map_err(serde::de::Error::custom)?;
29            Ok(BatchResponseBody::Success(success_body))
30        }
31    }
32}
33
34impl BatchResponseBody {
35
36    pub fn mock_with_code_and_body(code: u16, body: &serde_json::Value) -> Self {
37        if code == 200 {
38            BatchResponseBody::Success(
39                serde_json::from_value(body.clone()).unwrap()
40            )
41        } else {
42            BatchResponseBody::Error(
43                serde_json::from_value(body.clone()).unwrap()
44            )
45        }
46    }
47
48    pub fn mock(custom_id: &str, code: u16) -> Self {
49        if code == 200 {
50            BatchResponseBody::Success(BatchSuccessResponseBody::mock())
51        } else {
52            BatchResponseBody::Error(BatchErrorResponseBody::mock(custom_id))
53        }
54    }
55
56    pub fn mock_error(custom_id: &str) -> Self {
57        BatchResponseBody::Error(BatchErrorResponseBody::mock(custom_id))
58    }
59
60    /// Returns `Some(&BatchSuccessResponseBody)` if the response is a success.
61    pub fn as_success(&self) -> Option<&BatchSuccessResponseBody> {
62        if let BatchResponseBody::Success(ref success_body) = *self {
63            Some(success_body)
64        } else {
65            None
66        }
67    }
68
69    /// Returns `Some(&BatchErrorResponseBody)` if the response is an error.
70    pub fn as_error(&self) -> Option<&BatchErrorResponseBody> {
71        if let BatchResponseBody::Error(ref error_body) = *self {
72            Some(error_body)
73        } else {
74            None
75        }
76    }
77
78    /// Retrieves the `id` if the response is successful.
79    pub fn id(&self) -> Option<&String> {
80        self.as_success().map(|body| body.id())
81    }
82
83    /// Retrieves the `object` if the response is successful.
84    pub fn object(&self) -> Option<&String> {
85        self.as_success().map(|body| body.object())
86    }
87
88    /// Retrieves the `model` if the response is successful.
89    pub fn model(&self) -> Option<&String> {
90        self.as_success().map(|body| body.model())
91    }
92
93    /// Retrieves the `choices` if the response is successful.
94    pub fn choices(&self) -> Option<&Vec<BatchChoice>> {
95        self.as_success().map(|body| body.choices())
96    }
97
98    /// Retrieves the `usage` if the response is successful.
99    pub fn usage(&self) -> Option<&BatchUsage> {
100        self.as_success().map(|body| body.usage())
101    }
102
103    /// Retrieves the `system_fingerprint` if the response is successful.
104    pub fn system_fingerprint(&self) -> Option<String> {
105        self.as_success().and_then(|body| body.system_fingerprint().clone())
106    }
107}
108
109#[cfg(test)]
110mod batch_response_body_tests {
111    use super::*;
112    use serde_json::json;
113
114    #[test]
115    fn test_success_body_deserialization() {
116
117        let json_data = json!({
118            "id":                 "chatcmpl-AVW7Z2Dd49g7Zq5eVExww6dlKA8T9",
119            "object":             "chat.completion",
120            "created":            1732075005,
121            "model":              "gpt-4o-2024-08-06",
122            "choices":            [],
123            "usage":              {
124                "prompt_tokens":      40,
125                "completion_tokens": 360,
126                "total_tokens":      400,
127            },
128            "system_fingerprint": "fp_7f6be3efb0"
129        });
130
131        let body: BatchResponseBody = serde_json::from_value(json_data).unwrap();
132
133        match body {
134            BatchResponseBody::Success(success_body) => {
135                pretty_assert_eq!(success_body.id(), "chatcmpl-AVW7Z2Dd49g7Zq5eVExww6dlKA8T9");
136            }
137            _ => panic!("Expected success body"),
138        }
139    }
140
141    #[test]
142    fn test_error_body_deserialization() {
143        let json_data = json!({
144            "error": {
145                "message": "An error occurred",
146                "type": "server_error",
147                "param": null,
148                "code": null
149            }
150        });
151
152        let body: BatchResponseBody = serde_json::from_value(json_data).unwrap();
153        match body {
154            BatchResponseBody::Error(error_body) => {
155                pretty_assert_eq!(error_body.error().message(), "An error occurred");
156            }
157            _ => panic!("Expected error body"),
158        }
159    }
160}