Skip to main content

opencode_sdk_rs/resources/
shared.rs

1//! Shared domain types mirroring the JS SDK's `resources/shared.ts`.
2
3use std::collections::HashMap;
4
5use serde::{Deserialize, Serialize};
6
7// ---------------------------------------------------------------------------
8// Individual error structs
9// ---------------------------------------------------------------------------
10
11/// An error indicating the message was aborted.
12#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
13pub struct MessageAbortedError {
14    /// Structured error data.
15    pub data: MessageAbortedErrorData,
16}
17
18/// Data payload for [`MessageAbortedError`].
19#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
20pub struct MessageAbortedErrorData {
21    /// Optional human-readable error message.
22    #[serde(default, skip_serializing_if = "Option::is_none")]
23    pub message: Option<String>,
24}
25
26/// An error indicating a provider authentication failure.
27#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
28pub struct ProviderAuthError {
29    /// Structured error data.
30    pub data: ProviderAuthErrorData,
31}
32
33/// Data payload for [`ProviderAuthError`].
34#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
35pub struct ProviderAuthErrorData {
36    /// Human-readable error message.
37    pub message: String,
38    /// The identifier of the provider that rejected authentication.
39    #[serde(rename = "providerID")]
40    pub provider_id: String,
41}
42
43/// A generic / unknown error.
44#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
45pub struct UnknownError {
46    /// Structured error data.
47    pub data: UnknownErrorData,
48}
49
50/// Data payload for [`UnknownError`].
51#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
52pub struct UnknownErrorData {
53    /// Human-readable error message.
54    pub message: String,
55}
56
57/// An error indicating the message output exceeded the allowed length.
58#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
59pub struct MessageOutputLengthError {
60    /// Arbitrary payload (maps to `unknown` in the JS SDK).
61    pub data: Option<serde_json::Value>,
62}
63
64/// An error indicating that structured output validation failed.
65#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
66pub struct StructuredOutputError {
67    /// Structured error data.
68    pub data: StructuredOutputErrorData,
69}
70
71/// Data payload for [`StructuredOutputError`].
72#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
73pub struct StructuredOutputErrorData {
74    /// Human-readable error message.
75    pub message: String,
76    /// Number of retries attempted.
77    pub retries: f64,
78}
79
80/// An error indicating the context window was exceeded.
81#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
82pub struct ContextOverflowError {
83    /// Structured error data.
84    pub data: ContextOverflowErrorData,
85}
86
87/// Data payload for [`ContextOverflowError`].
88#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
89pub struct ContextOverflowErrorData {
90    /// Human-readable error message.
91    pub message: String,
92    /// Optional response body from the provider.
93    #[serde(default, skip_serializing_if = "Option::is_none", rename = "responseBody")]
94    pub response_body: Option<String>,
95}
96
97/// An error originating from the upstream API provider.
98#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
99pub struct ApiError {
100    /// Structured error data.
101    pub data: ApiErrorData,
102}
103
104/// Data payload for [`ApiError`].
105#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
106pub struct ApiErrorData {
107    /// Human-readable error message.
108    pub message: String,
109    /// HTTP status code returned by the provider, if available.
110    #[serde(default, skip_serializing_if = "Option::is_none", rename = "statusCode")]
111    pub status_code: Option<f64>,
112    /// Whether the error is retryable.
113    #[serde(rename = "isRetryable")]
114    pub is_retryable: bool,
115    /// Response headers from the provider, if available.
116    #[serde(default, skip_serializing_if = "Option::is_none", rename = "responseHeaders")]
117    pub response_headers: Option<HashMap<String, String>>,
118    /// Response body from the provider, if available.
119    #[serde(default, skip_serializing_if = "Option::is_none", rename = "responseBody")]
120    pub response_body: Option<String>,
121    /// Additional metadata about the error.
122    #[serde(default, skip_serializing_if = "Option::is_none")]
123    pub metadata: Option<HashMap<String, String>>,
124}
125
126// ---------------------------------------------------------------------------
127// Discriminated union
128// ---------------------------------------------------------------------------
129
130/// A session-level error – one of the known error kinds.
131///
132/// Serialised with a `"name"` tag so the JSON representation matches the JS
133/// SDK's discriminated union: `{ "name": "ProviderAuthError", "data": … }`.
134#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
135#[serde(tag = "name")]
136pub enum SessionError {
137    /// The message was aborted by the user / system.
138    MessageAbortedError {
139        /// Structured error data.
140        data: MessageAbortedErrorData,
141    },
142    /// Provider authentication failed.
143    ProviderAuthError {
144        /// Structured error data.
145        data: ProviderAuthErrorData,
146    },
147    /// A generic / unknown error.
148    UnknownError {
149        /// Structured error data.
150        data: UnknownErrorData,
151    },
152    /// The message output exceeded the allowed length.
153    MessageOutputLengthError {
154        /// Arbitrary payload.
155        data: Option<serde_json::Value>,
156    },
157    /// Structured output validation failed.
158    StructuredOutputError {
159        /// Structured error data.
160        data: StructuredOutputErrorData,
161    },
162    /// The context window was exceeded.
163    ContextOverflowError {
164        /// Structured error data.
165        data: ContextOverflowErrorData,
166    },
167    /// An error originating from the upstream API provider.
168    #[allow(clippy::upper_case_acronyms)]
169    APIError {
170        /// Structured error data.
171        data: ApiErrorData,
172    },
173}
174
175// ---------------------------------------------------------------------------
176// Conversions from individual structs into the enum
177// ---------------------------------------------------------------------------
178
179impl From<MessageAbortedError> for SessionError {
180    fn from(e: MessageAbortedError) -> Self {
181        Self::MessageAbortedError { data: e.data }
182    }
183}
184
185impl From<ProviderAuthError> for SessionError {
186    fn from(e: ProviderAuthError) -> Self {
187        Self::ProviderAuthError { data: e.data }
188    }
189}
190
191impl From<UnknownError> for SessionError {
192    fn from(e: UnknownError) -> Self {
193        Self::UnknownError { data: e.data }
194    }
195}
196
197impl From<MessageOutputLengthError> for SessionError {
198    fn from(e: MessageOutputLengthError) -> Self {
199        Self::MessageOutputLengthError { data: e.data }
200    }
201}
202
203impl From<StructuredOutputError> for SessionError {
204    fn from(e: StructuredOutputError) -> Self {
205        Self::StructuredOutputError { data: e.data }
206    }
207}
208
209impl From<ContextOverflowError> for SessionError {
210    fn from(e: ContextOverflowError) -> Self {
211        Self::ContextOverflowError { data: e.data }
212    }
213}
214
215impl From<ApiError> for SessionError {
216    fn from(e: ApiError) -> Self {
217        Self::APIError { data: e.data }
218    }
219}
220
221// ---------------------------------------------------------------------------
222// Tests
223// ---------------------------------------------------------------------------
224
225#[cfg(test)]
226mod tests {
227    use serde_json::json;
228
229    use super::*;
230
231    // -- Individual struct round-trips --
232
233    #[test]
234    fn message_aborted_error_round_trip() {
235        let err = MessageAbortedError {
236            data: MessageAbortedErrorData { message: Some("user cancelled".into()) },
237        };
238        let json = serde_json::to_string(&err).unwrap();
239        let back: MessageAbortedError = serde_json::from_str(&json).unwrap();
240        assert_eq!(err, back);
241    }
242
243    #[test]
244    fn message_aborted_error_null_message() {
245        let err = MessageAbortedError { data: MessageAbortedErrorData { message: None } };
246        let json = serde_json::to_string(&err).unwrap();
247        let back: MessageAbortedError = serde_json::from_str(&json).unwrap();
248        assert_eq!(err, back);
249    }
250
251    #[test]
252    fn message_aborted_error_from_empty_object() {
253        let input = json!({"data": {}});
254        let err: MessageAbortedError = serde_json::from_value(input).unwrap();
255        assert_eq!(err.data.message, None);
256    }
257
258    #[test]
259    fn provider_auth_error_round_trip() {
260        let err = ProviderAuthError {
261            data: ProviderAuthErrorData {
262                message: "invalid token".into(),
263                provider_id: "openai".into(),
264            },
265        };
266        let json = serde_json::to_string(&err).unwrap();
267        assert!(json.contains("providerID"));
268        let back: ProviderAuthError = serde_json::from_str(&json).unwrap();
269        assert_eq!(err, back);
270    }
271
272    #[test]
273    fn unknown_error_round_trip() {
274        let err =
275            UnknownError { data: UnknownErrorData { message: "something went wrong".into() } };
276        let json = serde_json::to_string(&err).unwrap();
277        let back: UnknownError = serde_json::from_str(&json).unwrap();
278        assert_eq!(err, back);
279    }
280
281    #[test]
282    fn message_output_length_error_round_trip() {
283        let err = MessageOutputLengthError { data: Some(json!(42)) };
284        let json = serde_json::to_string(&err).unwrap();
285        let back: MessageOutputLengthError = serde_json::from_str(&json).unwrap();
286        assert_eq!(err, back);
287    }
288
289    // -- SessionError enum deserialisation via `name` tag --
290
291    #[test]
292    fn session_error_message_aborted() {
293        let input = json!({
294            "name": "MessageAbortedError",
295            "data": {}
296        });
297        let err: SessionError = serde_json::from_value(input).unwrap();
298        assert_eq!(
299            err,
300            SessionError::MessageAbortedError { data: MessageAbortedErrorData { message: None } }
301        );
302    }
303
304    #[test]
305    fn session_error_message_aborted_with_message() {
306        let input = json!({
307            "name": "MessageAbortedError",
308            "data": { "message": "cancelled" }
309        });
310        let err: SessionError = serde_json::from_value(input).unwrap();
311        assert_eq!(
312            err,
313            SessionError::MessageAbortedError {
314                data: MessageAbortedErrorData { message: Some("cancelled".into()) }
315            }
316        );
317    }
318
319    #[test]
320    fn session_error_provider_auth() {
321        let input = json!({
322            "name": "ProviderAuthError",
323            "data": {
324                "message": "bad credentials",
325                "providerID": "anthropic"
326            }
327        });
328        let err: SessionError = serde_json::from_value(input).unwrap();
329        assert_eq!(
330            err,
331            SessionError::ProviderAuthError {
332                data: ProviderAuthErrorData {
333                    message: "bad credentials".into(),
334                    provider_id: "anthropic".into(),
335                }
336            }
337        );
338    }
339
340    #[test]
341    fn session_error_unknown() {
342        let input = json!({
343            "name": "UnknownError",
344            "data": {
345                "message": "oops"
346            }
347        });
348        let err: SessionError = serde_json::from_value(input).unwrap();
349        assert_eq!(
350            err,
351            SessionError::UnknownError { data: UnknownErrorData { message: "oops".into() } }
352        );
353    }
354
355    #[test]
356    fn session_error_message_output_length() {
357        let input = json!({
358            "name": "MessageOutputLengthError",
359            "data": {"limit": 4096}
360        });
361        let err: SessionError = serde_json::from_value(input).unwrap();
362        assert_eq!(
363            err,
364            SessionError::MessageOutputLengthError { data: Some(json!({"limit": 4096})) }
365        );
366    }
367
368    #[test]
369    fn session_error_round_trip_serialization() {
370        let err = SessionError::ProviderAuthError {
371            data: ProviderAuthErrorData { message: "expired".into(), provider_id: "google".into() },
372        };
373        let json = serde_json::to_value(&err).unwrap();
374        assert_eq!(json["name"], "ProviderAuthError");
375        assert_eq!(json["data"]["providerID"], "google");
376
377        let back: SessionError = serde_json::from_value(json).unwrap();
378        assert_eq!(err, back);
379    }
380
381    // -- Edge cases: full round-trip for every SessionError variant --
382
383    #[test]
384    fn session_error_message_aborted_round_trip_with_message() {
385        let err = SessionError::MessageAbortedError {
386            data: MessageAbortedErrorData { message: Some("user pressed ctrl-c".into()) },
387        };
388        let json = serde_json::to_value(&err).unwrap();
389        assert_eq!(json["name"], "MessageAbortedError");
390        let back: SessionError = serde_json::from_value(json).unwrap();
391        assert_eq!(err, back);
392    }
393
394    #[test]
395    fn session_error_message_aborted_round_trip_no_message() {
396        let err =
397            SessionError::MessageAbortedError { data: MessageAbortedErrorData { message: None } };
398        let json = serde_json::to_value(&err).unwrap();
399        assert_eq!(json["name"], "MessageAbortedError");
400        let back: SessionError = serde_json::from_value(json).unwrap();
401        assert_eq!(err, back);
402    }
403
404    #[test]
405    fn session_error_unknown_round_trip() {
406        let err =
407            SessionError::UnknownError { data: UnknownErrorData { message: "kaboom".into() } };
408        let json = serde_json::to_value(&err).unwrap();
409        assert_eq!(json["name"], "UnknownError");
410        assert_eq!(json["data"]["message"], "kaboom");
411        let back: SessionError = serde_json::from_value(json).unwrap();
412        assert_eq!(err, back);
413    }
414
415    #[test]
416    fn session_error_output_length_round_trip_with_data() {
417        let err = SessionError::MessageOutputLengthError {
418            data: Some(json!({"limit": 8192, "actual": 10000})),
419        };
420        let json = serde_json::to_value(&err).unwrap();
421        assert_eq!(json["name"], "MessageOutputLengthError");
422        let back: SessionError = serde_json::from_value(json).unwrap();
423        assert_eq!(err, back);
424    }
425
426    #[test]
427    fn session_error_output_length_round_trip_null_data() {
428        let err = SessionError::MessageOutputLengthError { data: None };
429        let json = serde_json::to_value(&err).unwrap();
430        assert_eq!(json["name"], "MessageOutputLengthError");
431        assert_eq!(json["data"], serde_json::Value::Null);
432        let back: SessionError = serde_json::from_value(json).unwrap();
433        assert_eq!(err, back);
434    }
435
436    #[test]
437    fn provider_auth_error_data_fields() {
438        let data = ProviderAuthErrorData {
439            message: "token expired".into(),
440            provider_id: "azure-openai".into(),
441        };
442        let v = serde_json::to_value(&data).unwrap();
443        // Verify rename: Rust field is provider_id, JSON key is providerID
444        assert_eq!(v["providerID"], "azure-openai");
445        assert!(v.get("provider_id").is_none());
446        assert_eq!(v["message"], "token expired");
447        let back: ProviderAuthErrorData = serde_json::from_value(v).unwrap();
448        assert_eq!(data, back);
449    }
450
451    #[test]
452    fn message_output_length_error_null_data() {
453        let err = MessageOutputLengthError { data: None };
454        let json_str = serde_json::to_string(&err).unwrap();
455        let back: MessageOutputLengthError = serde_json::from_str(&json_str).unwrap();
456        assert_eq!(err, back);
457    }
458
459    // -- StructuredOutputError tests --
460
461    #[test]
462    fn structured_output_error_round_trip() {
463        let err = StructuredOutputError {
464            data: StructuredOutputErrorData { message: "schema mismatch".into(), retries: 3.0 },
465        };
466        let json = serde_json::to_string(&err).unwrap();
467        let back: StructuredOutputError = serde_json::from_str(&json).unwrap();
468        assert_eq!(err, back);
469    }
470
471    #[test]
472    fn session_error_structured_output() {
473        let input = json!({
474            "name": "StructuredOutputError",
475            "data": {
476                "message": "invalid schema",
477                "retries": 2.0
478            }
479        });
480        let err: SessionError = serde_json::from_value(input).unwrap();
481        assert_eq!(
482            err,
483            SessionError::StructuredOutputError {
484                data: StructuredOutputErrorData { message: "invalid schema".into(), retries: 2.0 }
485            }
486        );
487    }
488
489    #[test]
490    fn session_error_structured_output_round_trip() {
491        let err = SessionError::StructuredOutputError {
492            data: StructuredOutputErrorData { message: "bad output".into(), retries: 5.0 },
493        };
494        let json = serde_json::to_value(&err).unwrap();
495        assert_eq!(json["name"], "StructuredOutputError");
496        assert_eq!(json["data"]["retries"], 5.0);
497        let back: SessionError = serde_json::from_value(json).unwrap();
498        assert_eq!(err, back);
499    }
500
501    #[test]
502    fn structured_output_error_from_conversion() {
503        let err = StructuredOutputError {
504            data: StructuredOutputErrorData { message: "fail".into(), retries: 1.0 },
505        };
506        let session: SessionError = err.into();
507        assert!(matches!(session, SessionError::StructuredOutputError { .. }));
508    }
509
510    // -- ContextOverflowError tests --
511
512    #[test]
513    fn context_overflow_error_round_trip() {
514        let err = ContextOverflowError {
515            data: ContextOverflowErrorData {
516                message: "context too large".into(),
517                response_body: Some("truncated".into()),
518            },
519        };
520        let json = serde_json::to_string(&err).unwrap();
521        assert!(json.contains("responseBody"));
522        let back: ContextOverflowError = serde_json::from_str(&json).unwrap();
523        assert_eq!(err, back);
524    }
525
526    #[test]
527    fn context_overflow_error_no_response_body() {
528        let err = ContextOverflowError {
529            data: ContextOverflowErrorData { message: "overflow".into(), response_body: None },
530        };
531        let json = serde_json::to_string(&err).unwrap();
532        assert!(!json.contains("responseBody"));
533        let back: ContextOverflowError = serde_json::from_str(&json).unwrap();
534        assert_eq!(err, back);
535    }
536
537    #[test]
538    fn session_error_context_overflow() {
539        let input = json!({
540            "name": "ContextOverflowError",
541            "data": {
542                "message": "window exceeded",
543                "responseBody": "partial response"
544            }
545        });
546        let err: SessionError = serde_json::from_value(input).unwrap();
547        assert_eq!(
548            err,
549            SessionError::ContextOverflowError {
550                data: ContextOverflowErrorData {
551                    message: "window exceeded".into(),
552                    response_body: Some("partial response".into()),
553                }
554            }
555        );
556    }
557
558    #[test]
559    fn session_error_context_overflow_round_trip() {
560        let err = SessionError::ContextOverflowError {
561            data: ContextOverflowErrorData { message: "too big".into(), response_body: None },
562        };
563        let json = serde_json::to_value(&err).unwrap();
564        assert_eq!(json["name"], "ContextOverflowError");
565        let back: SessionError = serde_json::from_value(json).unwrap();
566        assert_eq!(err, back);
567    }
568
569    #[test]
570    fn context_overflow_error_from_conversion() {
571        let err = ContextOverflowError {
572            data: ContextOverflowErrorData { message: "overflow".into(), response_body: None },
573        };
574        let session: SessionError = err.into();
575        assert!(matches!(session, SessionError::ContextOverflowError { .. }));
576    }
577
578    // -- APIError tests --
579
580    #[test]
581    fn api_error_round_trip() {
582        let mut headers = HashMap::new();
583        headers.insert("x-request-id".into(), "abc123".into());
584        let err = ApiError {
585            data: ApiErrorData {
586                message: "rate limited".into(),
587                status_code: Some(429.0),
588                is_retryable: true,
589                response_headers: Some(headers),
590                response_body: Some("{\"error\": \"too many requests\"}".into()),
591                metadata: None,
592            },
593        };
594        let json = serde_json::to_string(&err).unwrap();
595        assert!(json.contains("statusCode"));
596        assert!(json.contains("isRetryable"));
597        assert!(json.contains("responseHeaders"));
598        assert!(json.contains("responseBody"));
599        let back: ApiError = serde_json::from_str(&json).unwrap();
600        assert_eq!(err, back);
601    }
602
603    #[test]
604    fn api_error_minimal() {
605        let err = ApiError {
606            data: ApiErrorData {
607                message: "server error".into(),
608                status_code: None,
609                is_retryable: false,
610                response_headers: None,
611                response_body: None,
612                metadata: None,
613            },
614        };
615        let json = serde_json::to_string(&err).unwrap();
616        assert!(!json.contains("statusCode"));
617        assert!(!json.contains("responseHeaders"));
618        assert!(!json.contains("responseBody"));
619        assert!(!json.contains("metadata"));
620        let back: ApiError = serde_json::from_str(&json).unwrap();
621        assert_eq!(err, back);
622    }
623
624    #[test]
625    fn session_error_api_error() {
626        let input = json!({
627            "name": "APIError",
628            "data": {
629                "message": "upstream failure",
630                "statusCode": 500.0,
631                "isRetryable": true
632            }
633        });
634        let err: SessionError = serde_json::from_value(input).unwrap();
635        assert_eq!(
636            err,
637            SessionError::APIError {
638                data: ApiErrorData {
639                    message: "upstream failure".into(),
640                    status_code: Some(500.0),
641                    is_retryable: true,
642                    response_headers: None,
643                    response_body: None,
644                    metadata: None,
645                }
646            }
647        );
648    }
649
650    #[test]
651    fn session_error_api_error_round_trip() {
652        let mut meta = HashMap::new();
653        meta.insert("region".into(), "us-east-1".into());
654        let err = SessionError::APIError {
655            data: ApiErrorData {
656                message: "bad gateway".into(),
657                status_code: Some(502.0),
658                is_retryable: true,
659                response_headers: None,
660                response_body: None,
661                metadata: Some(meta),
662            },
663        };
664        let json = serde_json::to_value(&err).unwrap();
665        assert_eq!(json["name"], "APIError");
666        assert_eq!(json["data"]["statusCode"], 502.0);
667        assert_eq!(json["data"]["isRetryable"], true);
668        let back: SessionError = serde_json::from_value(json).unwrap();
669        assert_eq!(err, back);
670    }
671
672    #[test]
673    fn api_error_from_conversion() {
674        let err = ApiError {
675            data: ApiErrorData {
676                message: "oops".into(),
677                status_code: None,
678                is_retryable: false,
679                response_headers: None,
680                response_body: None,
681                metadata: None,
682            },
683        };
684        let session: SessionError = err.into();
685        assert!(matches!(session, SessionError::APIError { .. }));
686    }
687
688    #[test]
689    fn api_error_data_field_renames() {
690        let data = ApiErrorData {
691            message: "test".into(),
692            status_code: Some(401.0),
693            is_retryable: false,
694            response_headers: None,
695            response_body: None,
696            metadata: None,
697        };
698        let v = serde_json::to_value(&data).unwrap();
699        assert!(v.get("statusCode").is_some());
700        assert!(v.get("status_code").is_none());
701        assert!(v.get("isRetryable").is_some());
702        assert!(v.get("is_retryable").is_none());
703        let back: ApiErrorData = serde_json::from_value(v).unwrap();
704        assert_eq!(data, back);
705    }
706}