Skip to main content

opencode_sdk_rs/resources/
shared.rs

1//! Shared domain types mirroring the JS SDK's `resources/shared.ts`.
2
3use serde::{Deserialize, Serialize};
4
5// ---------------------------------------------------------------------------
6// Individual error structs
7// ---------------------------------------------------------------------------
8
9/// An error indicating the message was aborted.
10#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
11pub struct MessageAbortedError {
12    /// Arbitrary payload (maps to `unknown` in the JS SDK).
13    pub data: Option<serde_json::Value>,
14}
15
16/// An error indicating a provider authentication failure.
17#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
18pub struct ProviderAuthError {
19    /// Structured error data.
20    pub data: ProviderAuthErrorData,
21}
22
23/// Data payload for [`ProviderAuthError`].
24#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
25pub struct ProviderAuthErrorData {
26    /// Human-readable error message.
27    pub message: String,
28    /// The identifier of the provider that rejected authentication.
29    #[serde(rename = "providerID")]
30    pub provider_id: String,
31}
32
33/// A generic / unknown error.
34#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
35pub struct UnknownError {
36    /// Structured error data.
37    pub data: UnknownErrorData,
38}
39
40/// Data payload for [`UnknownError`].
41#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
42pub struct UnknownErrorData {
43    /// Human-readable error message.
44    pub message: String,
45}
46
47/// An error indicating the message output exceeded the allowed length.
48#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
49pub struct MessageOutputLengthError {
50    /// Arbitrary payload (maps to `unknown` in the JS SDK).
51    pub data: Option<serde_json::Value>,
52}
53
54// ---------------------------------------------------------------------------
55// Discriminated union
56// ---------------------------------------------------------------------------
57
58/// A session-level error – one of the four known error kinds.
59///
60/// Serialised with a `"name"` tag so the JSON representation matches the JS
61/// SDK's discriminated union: `{ "name": "ProviderAuthError", "data": … }`.
62#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
63#[serde(tag = "name")]
64pub enum SessionError {
65    /// The message was aborted by the user / system.
66    MessageAbortedError {
67        /// Arbitrary payload.
68        data: Option<serde_json::Value>,
69    },
70    /// Provider authentication failed.
71    ProviderAuthError {
72        /// Structured error data.
73        data: ProviderAuthErrorData,
74    },
75    /// A generic / unknown error.
76    UnknownError {
77        /// Structured error data.
78        data: UnknownErrorData,
79    },
80    /// The message output exceeded the allowed length.
81    MessageOutputLengthError {
82        /// Arbitrary payload.
83        data: Option<serde_json::Value>,
84    },
85}
86
87// ---------------------------------------------------------------------------
88// Conversions from individual structs into the enum
89// ---------------------------------------------------------------------------
90
91impl From<MessageAbortedError> for SessionError {
92    fn from(e: MessageAbortedError) -> Self {
93        Self::MessageAbortedError { data: e.data }
94    }
95}
96
97impl From<ProviderAuthError> for SessionError {
98    fn from(e: ProviderAuthError) -> Self {
99        Self::ProviderAuthError { data: e.data }
100    }
101}
102
103impl From<UnknownError> for SessionError {
104    fn from(e: UnknownError) -> Self {
105        Self::UnknownError { data: e.data }
106    }
107}
108
109impl From<MessageOutputLengthError> for SessionError {
110    fn from(e: MessageOutputLengthError) -> Self {
111        Self::MessageOutputLengthError { data: e.data }
112    }
113}
114
115// ---------------------------------------------------------------------------
116// Tests
117// ---------------------------------------------------------------------------
118
119#[cfg(test)]
120mod tests {
121    use serde_json::json;
122
123    use super::*;
124
125    // -- Individual struct round-trips --
126
127    #[test]
128    fn message_aborted_error_round_trip() {
129        let err = MessageAbortedError { data: Some(json!({"reason": "user cancelled"})) };
130        let json = serde_json::to_string(&err).unwrap();
131        let back: MessageAbortedError = serde_json::from_str(&json).unwrap();
132        assert_eq!(err, back);
133    }
134
135    #[test]
136    fn message_aborted_error_null_data() {
137        let err = MessageAbortedError { data: None };
138        let json = serde_json::to_string(&err).unwrap();
139        let back: MessageAbortedError = serde_json::from_str(&json).unwrap();
140        assert_eq!(err, back);
141    }
142
143    #[test]
144    fn provider_auth_error_round_trip() {
145        let err = ProviderAuthError {
146            data: ProviderAuthErrorData {
147                message: "invalid token".into(),
148                provider_id: "openai".into(),
149            },
150        };
151        let json = serde_json::to_string(&err).unwrap();
152        assert!(json.contains("providerID"));
153        let back: ProviderAuthError = serde_json::from_str(&json).unwrap();
154        assert_eq!(err, back);
155    }
156
157    #[test]
158    fn unknown_error_round_trip() {
159        let err =
160            UnknownError { data: UnknownErrorData { message: "something went wrong".into() } };
161        let json = serde_json::to_string(&err).unwrap();
162        let back: UnknownError = serde_json::from_str(&json).unwrap();
163        assert_eq!(err, back);
164    }
165
166    #[test]
167    fn message_output_length_error_round_trip() {
168        let err = MessageOutputLengthError { data: Some(json!(42)) };
169        let json = serde_json::to_string(&err).unwrap();
170        let back: MessageOutputLengthError = serde_json::from_str(&json).unwrap();
171        assert_eq!(err, back);
172    }
173
174    // -- SessionError enum deserialisation via `name` tag --
175
176    #[test]
177    fn session_error_message_aborted() {
178        let input = json!({
179            "name": "MessageAbortedError",
180            "data": null
181        });
182        let err: SessionError = serde_json::from_value(input).unwrap();
183        assert_eq!(err, SessionError::MessageAbortedError { data: None });
184    }
185
186    #[test]
187    fn session_error_provider_auth() {
188        let input = json!({
189            "name": "ProviderAuthError",
190            "data": {
191                "message": "bad credentials",
192                "providerID": "anthropic"
193            }
194        });
195        let err: SessionError = serde_json::from_value(input).unwrap();
196        assert_eq!(
197            err,
198            SessionError::ProviderAuthError {
199                data: ProviderAuthErrorData {
200                    message: "bad credentials".into(),
201                    provider_id: "anthropic".into(),
202                }
203            }
204        );
205    }
206
207    #[test]
208    fn session_error_unknown() {
209        let input = json!({
210            "name": "UnknownError",
211            "data": {
212                "message": "oops"
213            }
214        });
215        let err: SessionError = serde_json::from_value(input).unwrap();
216        assert_eq!(
217            err,
218            SessionError::UnknownError { data: UnknownErrorData { message: "oops".into() } }
219        );
220    }
221
222    #[test]
223    fn session_error_message_output_length() {
224        let input = json!({
225            "name": "MessageOutputLengthError",
226            "data": {"limit": 4096}
227        });
228        let err: SessionError = serde_json::from_value(input).unwrap();
229        assert_eq!(
230            err,
231            SessionError::MessageOutputLengthError { data: Some(json!({"limit": 4096})) }
232        );
233    }
234
235    #[test]
236    fn session_error_round_trip_serialization() {
237        let err = SessionError::ProviderAuthError {
238            data: ProviderAuthErrorData { message: "expired".into(), provider_id: "google".into() },
239        };
240        let json = serde_json::to_value(&err).unwrap();
241        assert_eq!(json["name"], "ProviderAuthError");
242        assert_eq!(json["data"]["providerID"], "google");
243
244        let back: SessionError = serde_json::from_value(json).unwrap();
245        assert_eq!(err, back);
246    }
247
248    // -- Edge cases: full round-trip for every SessionError variant --
249
250    #[test]
251    fn session_error_message_aborted_round_trip_with_data() {
252        let err = SessionError::MessageAbortedError {
253            data: Some(json!({"reason": "user pressed ctrl-c"})),
254        };
255        let json = serde_json::to_value(&err).unwrap();
256        assert_eq!(json["name"], "MessageAbortedError");
257        let back: SessionError = serde_json::from_value(json).unwrap();
258        assert_eq!(err, back);
259    }
260
261    #[test]
262    fn session_error_message_aborted_round_trip_null_data() {
263        let err = SessionError::MessageAbortedError { data: None };
264        let json = serde_json::to_value(&err).unwrap();
265        assert_eq!(json["name"], "MessageAbortedError");
266        assert_eq!(json["data"], serde_json::Value::Null);
267        let back: SessionError = serde_json::from_value(json).unwrap();
268        assert_eq!(err, back);
269    }
270
271    #[test]
272    fn session_error_unknown_round_trip() {
273        let err =
274            SessionError::UnknownError { data: UnknownErrorData { message: "kaboom".into() } };
275        let json = serde_json::to_value(&err).unwrap();
276        assert_eq!(json["name"], "UnknownError");
277        assert_eq!(json["data"]["message"], "kaboom");
278        let back: SessionError = serde_json::from_value(json).unwrap();
279        assert_eq!(err, back);
280    }
281
282    #[test]
283    fn session_error_output_length_round_trip_with_data() {
284        let err = SessionError::MessageOutputLengthError {
285            data: Some(json!({"limit": 8192, "actual": 10000})),
286        };
287        let json = serde_json::to_value(&err).unwrap();
288        assert_eq!(json["name"], "MessageOutputLengthError");
289        let back: SessionError = serde_json::from_value(json).unwrap();
290        assert_eq!(err, back);
291    }
292
293    #[test]
294    fn session_error_output_length_round_trip_null_data() {
295        let err = SessionError::MessageOutputLengthError { data: None };
296        let json = serde_json::to_value(&err).unwrap();
297        assert_eq!(json["name"], "MessageOutputLengthError");
298        assert_eq!(json["data"], serde_json::Value::Null);
299        let back: SessionError = serde_json::from_value(json).unwrap();
300        assert_eq!(err, back);
301    }
302
303    #[test]
304    fn provider_auth_error_data_fields() {
305        let data = ProviderAuthErrorData {
306            message: "token expired".into(),
307            provider_id: "azure-openai".into(),
308        };
309        let v = serde_json::to_value(&data).unwrap();
310        // Verify rename: Rust field is provider_id, JSON key is providerID
311        assert_eq!(v["providerID"], "azure-openai");
312        assert!(v.get("provider_id").is_none());
313        assert_eq!(v["message"], "token expired");
314        let back: ProviderAuthErrorData = serde_json::from_value(v).unwrap();
315        assert_eq!(data, back);
316    }
317
318    #[test]
319    fn message_output_length_error_null_data() {
320        let err = MessageOutputLengthError { data: None };
321        let json_str = serde_json::to_string(&err).unwrap();
322        let back: MessageOutputLengthError = serde_json::from_str(&json_str).unwrap();
323        assert_eq!(err, back);
324    }
325}