Skip to main content

aws_ssm_bridge/
errors.rs

1//! Error types for aws-ssm-bridge
2
3use std::fmt;
4
5/// Result type alias for aws-ssm-bridge operations
6pub type Result<T> = std::result::Result<T, Error>;
7
8/// Main error type for the library
9#[derive(Debug, thiserror::Error)]
10pub enum Error {
11    /// AWS SDK errors
12    #[error("AWS SDK error: {0}")]
13    AwsSdk(String),
14
15    /// Session errors
16    #[error("Session error: {0}")]
17    Session(#[from] SessionError),
18
19    /// Protocol errors
20    #[error("Protocol error: {0}")]
21    Protocol(#[from] ProtocolError),
22
23    /// Transport errors
24    #[error("Transport error: {0}")]
25    Transport(#[from] TransportError),
26
27    /// Configuration errors
28    #[error("Configuration error: {0}")]
29    Config(String),
30
31    /// IO errors
32    #[error("IO error: {0}")]
33    Io(#[from] std::io::Error),
34
35    /// Serialization errors
36    #[error("Serialization error: {0}")]
37    Serialization(#[from] serde_json::Error),
38
39    /// Invalid state error
40    #[error("Invalid state: {0}")]
41    InvalidState(String),
42
43    /// Timeout error
44    #[error("Operation timed out")]
45    Timeout,
46
47    /// Cancelled error
48    #[error("Operation was cancelled")]
49    Cancelled,
50}
51
52/// Session-specific errors
53#[derive(Debug, thiserror::Error)]
54pub enum SessionError {
55    /// Session not found
56    #[error("Session not found: {0}")]
57    NotFound(String),
58
59    /// Session already exists
60    #[error("Session already exists: {0}")]
61    AlreadyExists(String),
62
63    /// Session terminated
64    #[error("Session terminated: {reason}")]
65    Terminated {
66        /// The reason for termination
67        reason: String,
68    },
69
70    /// Invalid session state
71    #[error("Invalid session state: expected {expected}, got {actual}")]
72    InvalidState {
73        /// The expected state
74        expected: String,
75        /// The actual state
76        actual: String,
77    },
78
79    /// Session initialization failed
80    #[error("Session initialization failed: {0}")]
81    InitializationFailed(String),
82}
83
84/// Protocol-specific errors
85#[derive(Debug, thiserror::Error)]
86pub enum ProtocolError {
87    /// Invalid message format
88    #[error("Invalid message format: {0}")]
89    InvalidMessage(String),
90
91    /// Unknown message type
92    #[error("Unknown message type: {0}")]
93    UnknownMessageType(String),
94
95    /// Invalid sequence number
96    #[error("Invalid sequence number: expected {expected}, got {actual}")]
97    InvalidSequence {
98        /// The expected sequence number
99        expected: u64,
100        /// The actual sequence number received
101        actual: u64,
102    },
103
104    /// Message framing error
105    #[error("Message framing error: {0}")]
106    Framing(String),
107
108    /// Unsupported protocol version
109    #[error("Unsupported protocol version: {0}")]
110    UnsupportedVersion(String),
111
112    /// Checksum mismatch
113    #[error("Checksum mismatch")]
114    ChecksumMismatch,
115}
116
117/// Transport-specific errors
118#[derive(Debug, thiserror::Error)]
119pub enum TransportError {
120    /// WebSocket error
121    #[error("WebSocket error: {0}")]
122    WebSocket(String),
123
124    /// Connection closed
125    #[error("Connection closed: {reason}")]
126    ConnectionClosed {
127        /// The reason for connection closure
128        reason: String,
129    },
130
131    /// Connection failed
132    #[error("Connection failed: {0}")]
133    ConnectionFailed(String),
134
135    /// Channel error
136    #[error("Channel error: {0}")]
137    Channel(String),
138
139    /// Heartbeat timeout
140    #[error("Heartbeat timeout")]
141    HeartbeatTimeout,
142}
143
144impl Error {
145    /// Check if error is retriable
146    pub fn is_retriable(&self) -> bool {
147        match self {
148            Error::Timeout => true,
149            Error::Transport(TransportError::HeartbeatTimeout) => true,
150            Error::Transport(TransportError::ConnectionFailed(_)) => true,
151            Error::Transport(TransportError::WebSocket(_)) => true,
152            Error::AwsSdk(_) => true, // Some AWS errors are transient
153            _ => false,
154        }
155    }
156
157    /// Check if error is fatal (session should be terminated)
158    pub fn is_fatal(&self) -> bool {
159        matches!(
160            self,
161            Error::Session(SessionError::Terminated { .. })
162                | Error::Transport(TransportError::ConnectionClosed { .. })
163                | Error::Session(SessionError::InvalidState { .. })
164        )
165    }
166}
167
168// Implement conversion from AWS SDK errors
169impl<E, R> From<aws_smithy_runtime_api::client::result::SdkError<E, R>> for Error
170where
171    E: fmt::Debug,
172    R: fmt::Debug,
173{
174    fn from(err: aws_smithy_runtime_api::client::result::SdkError<E, R>) -> Self {
175        Error::AwsSdk(format!("{:?}", err))
176    }
177}
178
179#[cfg(test)]
180mod tests {
181    use super::*;
182
183    #[test]
184    fn test_error_is_retriable() {
185        // Retriable errors
186        assert!(Error::Timeout.is_retriable());
187        assert!(Error::Transport(TransportError::HeartbeatTimeout).is_retriable());
188        assert!(Error::Transport(TransportError::ConnectionFailed("test".into())).is_retriable());
189        assert!(Error::Transport(TransportError::WebSocket("test".into())).is_retriable());
190        assert!(Error::AwsSdk("transient".into()).is_retriable());
191
192        // Non-retriable errors
193        assert!(!Error::Cancelled.is_retriable());
194        assert!(!Error::Config("bad config".into()).is_retriable());
195        assert!(!Error::InvalidState("invalid".into()).is_retriable());
196        assert!(!Error::Session(SessionError::NotFound("sess".into())).is_retriable());
197    }
198
199    #[test]
200    fn test_error_is_fatal() {
201        // Fatal errors
202        let terminated = Error::Session(SessionError::Terminated {
203            reason: "test".to_string(),
204        });
205        assert!(terminated.is_fatal());
206
207        let conn_closed = Error::Transport(TransportError::ConnectionClosed {
208            reason: "closed".to_string(),
209        });
210        assert!(conn_closed.is_fatal());
211
212        let invalid_state = Error::Session(SessionError::InvalidState {
213            expected: "Running".into(),
214            actual: "Terminated".into(),
215        });
216        assert!(invalid_state.is_fatal());
217
218        // Non-fatal errors
219        assert!(!Error::Timeout.is_fatal());
220        assert!(!Error::Cancelled.is_fatal());
221        assert!(!Error::AwsSdk("error".into()).is_fatal());
222    }
223
224    #[test]
225    fn test_error_display() {
226        let err = Error::Timeout;
227        assert_eq!(format!("{}", err), "Operation timed out");
228
229        let err = Error::Session(SessionError::NotFound("sess-123".into()));
230        assert!(format!("{}", err).contains("sess-123"));
231
232        let err = Error::Protocol(ProtocolError::ChecksumMismatch);
233        assert!(format!("{}", err).contains("Checksum"));
234    }
235
236    #[test]
237    fn test_session_error_variants() {
238        let err = SessionError::NotFound("sess-1".into());
239        assert!(format!("{}", err).contains("sess-1"));
240
241        let err = SessionError::AlreadyExists("sess-2".into());
242        assert!(format!("{}", err).contains("sess-2"));
243
244        let err = SessionError::InitializationFailed("handshake failed".into());
245        assert!(format!("{}", err).contains("handshake"));
246    }
247
248    #[test]
249    fn test_protocol_error_variants() {
250        let err = ProtocolError::InvalidMessage("bad header".into());
251        assert!(format!("{}", err).contains("bad header"));
252
253        let err = ProtocolError::UnknownMessageType("xyz".into());
254        assert!(format!("{}", err).contains("xyz"));
255
256        let err = ProtocolError::InvalidSequence {
257            expected: 5,
258            actual: 3,
259        };
260        assert!(format!("{}", err).contains("5"));
261        assert!(format!("{}", err).contains("3"));
262
263        let err = ProtocolError::Framing("truncated".into());
264        assert!(format!("{}", err).contains("truncated"));
265
266        let err = ProtocolError::UnsupportedVersion("2.0".into());
267        assert!(format!("{}", err).contains("2.0"));
268    }
269
270    #[test]
271    fn test_transport_error_variants() {
272        let err = TransportError::WebSocket("connection reset".into());
273        assert!(format!("{}", err).contains("connection reset"));
274
275        let err = TransportError::ConnectionClosed {
276            reason: "EOF".into(),
277        };
278        assert!(format!("{}", err).contains("EOF"));
279
280        let err = TransportError::Channel("send failed".into());
281        assert!(format!("{}", err).contains("send failed"));
282
283        let err = TransportError::HeartbeatTimeout;
284        assert!(format!("{}", err).contains("Heartbeat"));
285    }
286
287    #[test]
288    fn test_error_from_io() {
289        let io_err = std::io::Error::new(std::io::ErrorKind::NotFound, "file not found");
290        let err: Error = io_err.into();
291        assert!(matches!(err, Error::Io(_)));
292    }
293}