1use std::collections::HashMap;
4
5use serde::{Deserialize, Serialize};
6
7#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
13pub struct MessageAbortedError {
14 pub data: MessageAbortedErrorData,
16}
17
18#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
20pub struct MessageAbortedErrorData {
21 #[serde(default, skip_serializing_if = "Option::is_none")]
23 pub message: Option<String>,
24}
25
26#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
28pub struct ProviderAuthError {
29 pub data: ProviderAuthErrorData,
31}
32
33#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
35pub struct ProviderAuthErrorData {
36 pub message: String,
38 #[serde(rename = "providerID")]
40 pub provider_id: String,
41}
42
43#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
45pub struct UnknownError {
46 pub data: UnknownErrorData,
48}
49
50#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
52pub struct UnknownErrorData {
53 pub message: String,
55}
56
57#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
59pub struct MessageOutputLengthError {
60 pub data: Option<serde_json::Value>,
62}
63
64#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
66pub struct StructuredOutputError {
67 pub data: StructuredOutputErrorData,
69}
70
71#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
73pub struct StructuredOutputErrorData {
74 pub message: String,
76 pub retries: f64,
78}
79
80#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
82pub struct ContextOverflowError {
83 pub data: ContextOverflowErrorData,
85}
86
87#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
89pub struct ContextOverflowErrorData {
90 pub message: String,
92 #[serde(default, skip_serializing_if = "Option::is_none", rename = "responseBody")]
94 pub response_body: Option<String>,
95}
96
97#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
99pub struct ApiError {
100 pub data: ApiErrorData,
102}
103
104#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
106pub struct ApiErrorData {
107 pub message: String,
109 #[serde(default, skip_serializing_if = "Option::is_none", rename = "statusCode")]
111 pub status_code: Option<f64>,
112 #[serde(rename = "isRetryable")]
114 pub is_retryable: bool,
115 #[serde(default, skip_serializing_if = "Option::is_none", rename = "responseHeaders")]
117 pub response_headers: Option<HashMap<String, String>>,
118 #[serde(default, skip_serializing_if = "Option::is_none", rename = "responseBody")]
120 pub response_body: Option<String>,
121 #[serde(default, skip_serializing_if = "Option::is_none")]
123 pub metadata: Option<HashMap<String, String>>,
124}
125
126#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
135#[serde(tag = "name")]
136pub enum SessionError {
137 MessageAbortedError {
139 data: MessageAbortedErrorData,
141 },
142 ProviderAuthError {
144 data: ProviderAuthErrorData,
146 },
147 UnknownError {
149 data: UnknownErrorData,
151 },
152 MessageOutputLengthError {
154 data: Option<serde_json::Value>,
156 },
157 StructuredOutputError {
159 data: StructuredOutputErrorData,
161 },
162 ContextOverflowError {
164 data: ContextOverflowErrorData,
166 },
167 #[allow(clippy::upper_case_acronyms)]
169 APIError {
170 data: ApiErrorData,
172 },
173}
174
175impl From<MessageAbortedError> for SessionError {
180 fn from(e: MessageAbortedError) -> Self {
181 Self::MessageAbortedError { data: e.data }
182 }
183}
184
185impl From<ProviderAuthError> for SessionError {
186 fn from(e: ProviderAuthError) -> Self {
187 Self::ProviderAuthError { data: e.data }
188 }
189}
190
191impl From<UnknownError> for SessionError {
192 fn from(e: UnknownError) -> Self {
193 Self::UnknownError { data: e.data }
194 }
195}
196
197impl From<MessageOutputLengthError> for SessionError {
198 fn from(e: MessageOutputLengthError) -> Self {
199 Self::MessageOutputLengthError { data: e.data }
200 }
201}
202
203impl From<StructuredOutputError> for SessionError {
204 fn from(e: StructuredOutputError) -> Self {
205 Self::StructuredOutputError { data: e.data }
206 }
207}
208
209impl From<ContextOverflowError> for SessionError {
210 fn from(e: ContextOverflowError) -> Self {
211 Self::ContextOverflowError { data: e.data }
212 }
213}
214
215impl From<ApiError> for SessionError {
216 fn from(e: ApiError) -> Self {
217 Self::APIError { data: e.data }
218 }
219}
220
221#[cfg(test)]
226mod tests {
227 use serde_json::json;
228
229 use super::*;
230
231 #[test]
234 fn message_aborted_error_round_trip() {
235 let err = MessageAbortedError {
236 data: MessageAbortedErrorData { message: Some("user cancelled".into()) },
237 };
238 let json = serde_json::to_string(&err).unwrap();
239 let back: MessageAbortedError = serde_json::from_str(&json).unwrap();
240 assert_eq!(err, back);
241 }
242
243 #[test]
244 fn message_aborted_error_null_message() {
245 let err = MessageAbortedError { data: MessageAbortedErrorData { message: None } };
246 let json = serde_json::to_string(&err).unwrap();
247 let back: MessageAbortedError = serde_json::from_str(&json).unwrap();
248 assert_eq!(err, back);
249 }
250
251 #[test]
252 fn message_aborted_error_from_empty_object() {
253 let input = json!({"data": {}});
254 let err: MessageAbortedError = serde_json::from_value(input).unwrap();
255 assert_eq!(err.data.message, None);
256 }
257
258 #[test]
259 fn provider_auth_error_round_trip() {
260 let err = ProviderAuthError {
261 data: ProviderAuthErrorData {
262 message: "invalid token".into(),
263 provider_id: "openai".into(),
264 },
265 };
266 let json = serde_json::to_string(&err).unwrap();
267 assert!(json.contains("providerID"));
268 let back: ProviderAuthError = serde_json::from_str(&json).unwrap();
269 assert_eq!(err, back);
270 }
271
272 #[test]
273 fn unknown_error_round_trip() {
274 let err =
275 UnknownError { data: UnknownErrorData { message: "something went wrong".into() } };
276 let json = serde_json::to_string(&err).unwrap();
277 let back: UnknownError = serde_json::from_str(&json).unwrap();
278 assert_eq!(err, back);
279 }
280
281 #[test]
282 fn message_output_length_error_round_trip() {
283 let err = MessageOutputLengthError { data: Some(json!(42)) };
284 let json = serde_json::to_string(&err).unwrap();
285 let back: MessageOutputLengthError = serde_json::from_str(&json).unwrap();
286 assert_eq!(err, back);
287 }
288
289 #[test]
292 fn session_error_message_aborted() {
293 let input = json!({
294 "name": "MessageAbortedError",
295 "data": {}
296 });
297 let err: SessionError = serde_json::from_value(input).unwrap();
298 assert_eq!(
299 err,
300 SessionError::MessageAbortedError { data: MessageAbortedErrorData { message: None } }
301 );
302 }
303
304 #[test]
305 fn session_error_message_aborted_with_message() {
306 let input = json!({
307 "name": "MessageAbortedError",
308 "data": { "message": "cancelled" }
309 });
310 let err: SessionError = serde_json::from_value(input).unwrap();
311 assert_eq!(
312 err,
313 SessionError::MessageAbortedError {
314 data: MessageAbortedErrorData { message: Some("cancelled".into()) }
315 }
316 );
317 }
318
319 #[test]
320 fn session_error_provider_auth() {
321 let input = json!({
322 "name": "ProviderAuthError",
323 "data": {
324 "message": "bad credentials",
325 "providerID": "anthropic"
326 }
327 });
328 let err: SessionError = serde_json::from_value(input).unwrap();
329 assert_eq!(
330 err,
331 SessionError::ProviderAuthError {
332 data: ProviderAuthErrorData {
333 message: "bad credentials".into(),
334 provider_id: "anthropic".into(),
335 }
336 }
337 );
338 }
339
340 #[test]
341 fn session_error_unknown() {
342 let input = json!({
343 "name": "UnknownError",
344 "data": {
345 "message": "oops"
346 }
347 });
348 let err: SessionError = serde_json::from_value(input).unwrap();
349 assert_eq!(
350 err,
351 SessionError::UnknownError { data: UnknownErrorData { message: "oops".into() } }
352 );
353 }
354
355 #[test]
356 fn session_error_message_output_length() {
357 let input = json!({
358 "name": "MessageOutputLengthError",
359 "data": {"limit": 4096}
360 });
361 let err: SessionError = serde_json::from_value(input).unwrap();
362 assert_eq!(
363 err,
364 SessionError::MessageOutputLengthError { data: Some(json!({"limit": 4096})) }
365 );
366 }
367
368 #[test]
369 fn session_error_round_trip_serialization() {
370 let err = SessionError::ProviderAuthError {
371 data: ProviderAuthErrorData { message: "expired".into(), provider_id: "google".into() },
372 };
373 let json = serde_json::to_value(&err).unwrap();
374 assert_eq!(json["name"], "ProviderAuthError");
375 assert_eq!(json["data"]["providerID"], "google");
376
377 let back: SessionError = serde_json::from_value(json).unwrap();
378 assert_eq!(err, back);
379 }
380
381 #[test]
384 fn session_error_message_aborted_round_trip_with_message() {
385 let err = SessionError::MessageAbortedError {
386 data: MessageAbortedErrorData { message: Some("user pressed ctrl-c".into()) },
387 };
388 let json = serde_json::to_value(&err).unwrap();
389 assert_eq!(json["name"], "MessageAbortedError");
390 let back: SessionError = serde_json::from_value(json).unwrap();
391 assert_eq!(err, back);
392 }
393
394 #[test]
395 fn session_error_message_aborted_round_trip_no_message() {
396 let err =
397 SessionError::MessageAbortedError { data: MessageAbortedErrorData { message: None } };
398 let json = serde_json::to_value(&err).unwrap();
399 assert_eq!(json["name"], "MessageAbortedError");
400 let back: SessionError = serde_json::from_value(json).unwrap();
401 assert_eq!(err, back);
402 }
403
404 #[test]
405 fn session_error_unknown_round_trip() {
406 let err =
407 SessionError::UnknownError { data: UnknownErrorData { message: "kaboom".into() } };
408 let json = serde_json::to_value(&err).unwrap();
409 assert_eq!(json["name"], "UnknownError");
410 assert_eq!(json["data"]["message"], "kaboom");
411 let back: SessionError = serde_json::from_value(json).unwrap();
412 assert_eq!(err, back);
413 }
414
415 #[test]
416 fn session_error_output_length_round_trip_with_data() {
417 let err = SessionError::MessageOutputLengthError {
418 data: Some(json!({"limit": 8192, "actual": 10000})),
419 };
420 let json = serde_json::to_value(&err).unwrap();
421 assert_eq!(json["name"], "MessageOutputLengthError");
422 let back: SessionError = serde_json::from_value(json).unwrap();
423 assert_eq!(err, back);
424 }
425
426 #[test]
427 fn session_error_output_length_round_trip_null_data() {
428 let err = SessionError::MessageOutputLengthError { data: None };
429 let json = serde_json::to_value(&err).unwrap();
430 assert_eq!(json["name"], "MessageOutputLengthError");
431 assert_eq!(json["data"], serde_json::Value::Null);
432 let back: SessionError = serde_json::from_value(json).unwrap();
433 assert_eq!(err, back);
434 }
435
436 #[test]
437 fn provider_auth_error_data_fields() {
438 let data = ProviderAuthErrorData {
439 message: "token expired".into(),
440 provider_id: "azure-openai".into(),
441 };
442 let v = serde_json::to_value(&data).unwrap();
443 assert_eq!(v["providerID"], "azure-openai");
445 assert!(v.get("provider_id").is_none());
446 assert_eq!(v["message"], "token expired");
447 let back: ProviderAuthErrorData = serde_json::from_value(v).unwrap();
448 assert_eq!(data, back);
449 }
450
451 #[test]
452 fn message_output_length_error_null_data() {
453 let err = MessageOutputLengthError { data: None };
454 let json_str = serde_json::to_string(&err).unwrap();
455 let back: MessageOutputLengthError = serde_json::from_str(&json_str).unwrap();
456 assert_eq!(err, back);
457 }
458
459 #[test]
462 fn structured_output_error_round_trip() {
463 let err = StructuredOutputError {
464 data: StructuredOutputErrorData { message: "schema mismatch".into(), retries: 3.0 },
465 };
466 let json = serde_json::to_string(&err).unwrap();
467 let back: StructuredOutputError = serde_json::from_str(&json).unwrap();
468 assert_eq!(err, back);
469 }
470
471 #[test]
472 fn session_error_structured_output() {
473 let input = json!({
474 "name": "StructuredOutputError",
475 "data": {
476 "message": "invalid schema",
477 "retries": 2.0
478 }
479 });
480 let err: SessionError = serde_json::from_value(input).unwrap();
481 assert_eq!(
482 err,
483 SessionError::StructuredOutputError {
484 data: StructuredOutputErrorData { message: "invalid schema".into(), retries: 2.0 }
485 }
486 );
487 }
488
489 #[test]
490 fn session_error_structured_output_round_trip() {
491 let err = SessionError::StructuredOutputError {
492 data: StructuredOutputErrorData { message: "bad output".into(), retries: 5.0 },
493 };
494 let json = serde_json::to_value(&err).unwrap();
495 assert_eq!(json["name"], "StructuredOutputError");
496 assert_eq!(json["data"]["retries"], 5.0);
497 let back: SessionError = serde_json::from_value(json).unwrap();
498 assert_eq!(err, back);
499 }
500
501 #[test]
502 fn structured_output_error_from_conversion() {
503 let err = StructuredOutputError {
504 data: StructuredOutputErrorData { message: "fail".into(), retries: 1.0 },
505 };
506 let session: SessionError = err.into();
507 assert!(matches!(session, SessionError::StructuredOutputError { .. }));
508 }
509
510 #[test]
513 fn context_overflow_error_round_trip() {
514 let err = ContextOverflowError {
515 data: ContextOverflowErrorData {
516 message: "context too large".into(),
517 response_body: Some("truncated".into()),
518 },
519 };
520 let json = serde_json::to_string(&err).unwrap();
521 assert!(json.contains("responseBody"));
522 let back: ContextOverflowError = serde_json::from_str(&json).unwrap();
523 assert_eq!(err, back);
524 }
525
526 #[test]
527 fn context_overflow_error_no_response_body() {
528 let err = ContextOverflowError {
529 data: ContextOverflowErrorData { message: "overflow".into(), response_body: None },
530 };
531 let json = serde_json::to_string(&err).unwrap();
532 assert!(!json.contains("responseBody"));
533 let back: ContextOverflowError = serde_json::from_str(&json).unwrap();
534 assert_eq!(err, back);
535 }
536
537 #[test]
538 fn session_error_context_overflow() {
539 let input = json!({
540 "name": "ContextOverflowError",
541 "data": {
542 "message": "window exceeded",
543 "responseBody": "partial response"
544 }
545 });
546 let err: SessionError = serde_json::from_value(input).unwrap();
547 assert_eq!(
548 err,
549 SessionError::ContextOverflowError {
550 data: ContextOverflowErrorData {
551 message: "window exceeded".into(),
552 response_body: Some("partial response".into()),
553 }
554 }
555 );
556 }
557
558 #[test]
559 fn session_error_context_overflow_round_trip() {
560 let err = SessionError::ContextOverflowError {
561 data: ContextOverflowErrorData { message: "too big".into(), response_body: None },
562 };
563 let json = serde_json::to_value(&err).unwrap();
564 assert_eq!(json["name"], "ContextOverflowError");
565 let back: SessionError = serde_json::from_value(json).unwrap();
566 assert_eq!(err, back);
567 }
568
569 #[test]
570 fn context_overflow_error_from_conversion() {
571 let err = ContextOverflowError {
572 data: ContextOverflowErrorData { message: "overflow".into(), response_body: None },
573 };
574 let session: SessionError = err.into();
575 assert!(matches!(session, SessionError::ContextOverflowError { .. }));
576 }
577
578 #[test]
581 fn api_error_round_trip() {
582 let mut headers = HashMap::new();
583 headers.insert("x-request-id".into(), "abc123".into());
584 let err = ApiError {
585 data: ApiErrorData {
586 message: "rate limited".into(),
587 status_code: Some(429.0),
588 is_retryable: true,
589 response_headers: Some(headers),
590 response_body: Some("{\"error\": \"too many requests\"}".into()),
591 metadata: None,
592 },
593 };
594 let json = serde_json::to_string(&err).unwrap();
595 assert!(json.contains("statusCode"));
596 assert!(json.contains("isRetryable"));
597 assert!(json.contains("responseHeaders"));
598 assert!(json.contains("responseBody"));
599 let back: ApiError = serde_json::from_str(&json).unwrap();
600 assert_eq!(err, back);
601 }
602
603 #[test]
604 fn api_error_minimal() {
605 let err = ApiError {
606 data: ApiErrorData {
607 message: "server error".into(),
608 status_code: None,
609 is_retryable: false,
610 response_headers: None,
611 response_body: None,
612 metadata: None,
613 },
614 };
615 let json = serde_json::to_string(&err).unwrap();
616 assert!(!json.contains("statusCode"));
617 assert!(!json.contains("responseHeaders"));
618 assert!(!json.contains("responseBody"));
619 assert!(!json.contains("metadata"));
620 let back: ApiError = serde_json::from_str(&json).unwrap();
621 assert_eq!(err, back);
622 }
623
624 #[test]
625 fn session_error_api_error() {
626 let input = json!({
627 "name": "APIError",
628 "data": {
629 "message": "upstream failure",
630 "statusCode": 500.0,
631 "isRetryable": true
632 }
633 });
634 let err: SessionError = serde_json::from_value(input).unwrap();
635 assert_eq!(
636 err,
637 SessionError::APIError {
638 data: ApiErrorData {
639 message: "upstream failure".into(),
640 status_code: Some(500.0),
641 is_retryable: true,
642 response_headers: None,
643 response_body: None,
644 metadata: None,
645 }
646 }
647 );
648 }
649
650 #[test]
651 fn session_error_api_error_round_trip() {
652 let mut meta = HashMap::new();
653 meta.insert("region".into(), "us-east-1".into());
654 let err = SessionError::APIError {
655 data: ApiErrorData {
656 message: "bad gateway".into(),
657 status_code: Some(502.0),
658 is_retryable: true,
659 response_headers: None,
660 response_body: None,
661 metadata: Some(meta),
662 },
663 };
664 let json = serde_json::to_value(&err).unwrap();
665 assert_eq!(json["name"], "APIError");
666 assert_eq!(json["data"]["statusCode"], 502.0);
667 assert_eq!(json["data"]["isRetryable"], true);
668 let back: SessionError = serde_json::from_value(json).unwrap();
669 assert_eq!(err, back);
670 }
671
672 #[test]
673 fn api_error_from_conversion() {
674 let err = ApiError {
675 data: ApiErrorData {
676 message: "oops".into(),
677 status_code: None,
678 is_retryable: false,
679 response_headers: None,
680 response_body: None,
681 metadata: None,
682 },
683 };
684 let session: SessionError = err.into();
685 assert!(matches!(session, SessionError::APIError { .. }));
686 }
687
688 #[test]
689 fn api_error_data_field_renames() {
690 let data = ApiErrorData {
691 message: "test".into(),
692 status_code: Some(401.0),
693 is_retryable: false,
694 response_headers: None,
695 response_body: None,
696 metadata: None,
697 };
698 let v = serde_json::to_value(&data).unwrap();
699 assert!(v.get("statusCode").is_some());
700 assert!(v.get("status_code").is_none());
701 assert!(v.get("isRetryable").is_some());
702 assert!(v.get("is_retryable").is_none());
703 let back: ApiErrorData = serde_json::from_value(v).unwrap();
704 assert_eq!(data, back);
705 }
706}