1use std::fmt;
4
5pub type Result<T> = std::result::Result<T, Error>;
7
8#[derive(Debug, thiserror::Error)]
10pub enum Error {
11 #[error("AWS SDK error: {0}")]
13 AwsSdk(String),
14
15 #[error("Session error: {0}")]
17 Session(#[from] SessionError),
18
19 #[error("Protocol error: {0}")]
21 Protocol(#[from] ProtocolError),
22
23 #[error("Transport error: {0}")]
25 Transport(#[from] TransportError),
26
27 #[error("Configuration error: {0}")]
29 Config(String),
30
31 #[error("IO error: {0}")]
33 Io(#[from] std::io::Error),
34
35 #[error("Serialization error: {0}")]
37 Serialization(#[from] serde_json::Error),
38
39 #[error("Invalid state: {0}")]
41 InvalidState(String),
42
43 #[error("Operation timed out")]
45 Timeout,
46
47 #[error("Operation was cancelled")]
49 Cancelled,
50}
51
52#[derive(Debug, thiserror::Error)]
54pub enum SessionError {
55 #[error("Session not found: {0}")]
57 NotFound(String),
58
59 #[error("Session already exists: {0}")]
61 AlreadyExists(String),
62
63 #[error("Session terminated: {reason}")]
65 Terminated {
66 reason: String,
68 },
69
70 #[error("Invalid session state: expected {expected}, got {actual}")]
72 InvalidState {
73 expected: String,
75 actual: String,
77 },
78
79 #[error("Session initialization failed: {0}")]
81 InitializationFailed(String),
82}
83
84#[derive(Debug, thiserror::Error)]
86pub enum ProtocolError {
87 #[error("Invalid message format: {0}")]
89 InvalidMessage(String),
90
91 #[error("Unknown message type: {0}")]
93 UnknownMessageType(String),
94
95 #[error("Invalid sequence number: expected {expected}, got {actual}")]
97 InvalidSequence {
98 expected: u64,
100 actual: u64,
102 },
103
104 #[error("Message framing error: {0}")]
106 Framing(String),
107
108 #[error("Unsupported protocol version: {0}")]
110 UnsupportedVersion(String),
111
112 #[error("Checksum mismatch")]
114 ChecksumMismatch,
115}
116
117#[derive(Debug, thiserror::Error)]
119pub enum TransportError {
120 #[error("WebSocket error: {0}")]
122 WebSocket(String),
123
124 #[error("Connection closed: {reason}")]
126 ConnectionClosed {
127 reason: String,
129 },
130
131 #[error("Connection failed: {0}")]
133 ConnectionFailed(String),
134
135 #[error("Channel error: {0}")]
137 Channel(String),
138
139 #[error("Heartbeat timeout")]
141 HeartbeatTimeout,
142}
143
144impl Error {
145 pub fn is_retriable(&self) -> bool {
147 match self {
148 Error::Timeout => true,
149 Error::Transport(TransportError::HeartbeatTimeout) => true,
150 Error::Transport(TransportError::ConnectionFailed(_)) => true,
151 Error::Transport(TransportError::WebSocket(_)) => true,
152 Error::AwsSdk(_) => true, _ => false,
154 }
155 }
156
157 pub fn is_fatal(&self) -> bool {
159 matches!(
160 self,
161 Error::Session(SessionError::Terminated { .. })
162 | Error::Transport(TransportError::ConnectionClosed { .. })
163 | Error::Session(SessionError::InvalidState { .. })
164 )
165 }
166}
167
168impl<E, R> From<aws_smithy_runtime_api::client::result::SdkError<E, R>> for Error
170where
171 E: fmt::Debug,
172 R: fmt::Debug,
173{
174 fn from(err: aws_smithy_runtime_api::client::result::SdkError<E, R>) -> Self {
175 Error::AwsSdk(format!("{:?}", err))
176 }
177}
178
179#[cfg(test)]
180mod tests {
181 use super::*;
182
183 #[test]
184 fn test_error_is_retriable() {
185 assert!(Error::Timeout.is_retriable());
187 assert!(Error::Transport(TransportError::HeartbeatTimeout).is_retriable());
188 assert!(Error::Transport(TransportError::ConnectionFailed("test".into())).is_retriable());
189 assert!(Error::Transport(TransportError::WebSocket("test".into())).is_retriable());
190 assert!(Error::AwsSdk("transient".into()).is_retriable());
191
192 assert!(!Error::Cancelled.is_retriable());
194 assert!(!Error::Config("bad config".into()).is_retriable());
195 assert!(!Error::InvalidState("invalid".into()).is_retriable());
196 assert!(!Error::Session(SessionError::NotFound("sess".into())).is_retriable());
197 }
198
199 #[test]
200 fn test_error_is_fatal() {
201 let terminated = Error::Session(SessionError::Terminated {
203 reason: "test".to_string(),
204 });
205 assert!(terminated.is_fatal());
206
207 let conn_closed = Error::Transport(TransportError::ConnectionClosed {
208 reason: "closed".to_string(),
209 });
210 assert!(conn_closed.is_fatal());
211
212 let invalid_state = Error::Session(SessionError::InvalidState {
213 expected: "Running".into(),
214 actual: "Terminated".into(),
215 });
216 assert!(invalid_state.is_fatal());
217
218 assert!(!Error::Timeout.is_fatal());
220 assert!(!Error::Cancelled.is_fatal());
221 assert!(!Error::AwsSdk("error".into()).is_fatal());
222 }
223
224 #[test]
225 fn test_error_display() {
226 let err = Error::Timeout;
227 assert_eq!(format!("{}", err), "Operation timed out");
228
229 let err = Error::Session(SessionError::NotFound("sess-123".into()));
230 assert!(format!("{}", err).contains("sess-123"));
231
232 let err = Error::Protocol(ProtocolError::ChecksumMismatch);
233 assert!(format!("{}", err).contains("Checksum"));
234 }
235
236 #[test]
237 fn test_session_error_variants() {
238 let err = SessionError::NotFound("sess-1".into());
239 assert!(format!("{}", err).contains("sess-1"));
240
241 let err = SessionError::AlreadyExists("sess-2".into());
242 assert!(format!("{}", err).contains("sess-2"));
243
244 let err = SessionError::InitializationFailed("handshake failed".into());
245 assert!(format!("{}", err).contains("handshake"));
246 }
247
248 #[test]
249 fn test_protocol_error_variants() {
250 let err = ProtocolError::InvalidMessage("bad header".into());
251 assert!(format!("{}", err).contains("bad header"));
252
253 let err = ProtocolError::UnknownMessageType("xyz".into());
254 assert!(format!("{}", err).contains("xyz"));
255
256 let err = ProtocolError::InvalidSequence {
257 expected: 5,
258 actual: 3,
259 };
260 assert!(format!("{}", err).contains("5"));
261 assert!(format!("{}", err).contains("3"));
262
263 let err = ProtocolError::Framing("truncated".into());
264 assert!(format!("{}", err).contains("truncated"));
265
266 let err = ProtocolError::UnsupportedVersion("2.0".into());
267 assert!(format!("{}", err).contains("2.0"));
268 }
269
270 #[test]
271 fn test_transport_error_variants() {
272 let err = TransportError::WebSocket("connection reset".into());
273 assert!(format!("{}", err).contains("connection reset"));
274
275 let err = TransportError::ConnectionClosed {
276 reason: "EOF".into(),
277 };
278 assert!(format!("{}", err).contains("EOF"));
279
280 let err = TransportError::Channel("send failed".into());
281 assert!(format!("{}", err).contains("send failed"));
282
283 let err = TransportError::HeartbeatTimeout;
284 assert!(format!("{}", err).contains("Heartbeat"));
285 }
286
287 #[test]
288 fn test_error_from_io() {
289 let io_err = std::io::Error::new(std::io::ErrorKind::NotFound, "file not found");
290 let err: Error = io_err.into();
291 assert!(matches!(err, Error::Io(_)));
292 }
293}