use std::fmt;
pub type Result<T> = std::result::Result<T, Error>;
#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error("AWS SDK error: {0}")]
AwsSdk(String),
#[error("Session error: {0}")]
Session(#[from] SessionError),
#[error("Protocol error: {0}")]
Protocol(#[from] ProtocolError),
#[error("Transport error: {0}")]
Transport(#[from] TransportError),
#[error("Configuration error: {0}")]
Config(String),
#[error("IO error: {0}")]
Io(#[from] std::io::Error),
#[error("Serialization error: {0}")]
Serialization(#[from] serde_json::Error),
#[error("Invalid state: {0}")]
InvalidState(String),
#[error("Operation timed out")]
Timeout,
#[error("Operation was cancelled")]
Cancelled,
}
#[derive(Debug, thiserror::Error)]
pub enum SessionError {
#[error("Session not found: {0}")]
NotFound(String),
#[error("Session already exists: {0}")]
AlreadyExists(String),
#[error("Session terminated: {reason}")]
Terminated {
reason: String,
},
#[error("Invalid session state: expected {expected}, got {actual}")]
InvalidState {
expected: String,
actual: String,
},
#[error("Session initialization failed: {0}")]
InitializationFailed(String),
}
#[derive(Debug, thiserror::Error)]
pub enum ProtocolError {
#[error("Invalid message format: {0}")]
InvalidMessage(String),
#[error("Unknown message type: {0}")]
UnknownMessageType(String),
#[error("Invalid sequence number: expected {expected}, got {actual}")]
InvalidSequence {
expected: u64,
actual: u64,
},
#[error("Message framing error: {0}")]
Framing(String),
#[error("Unsupported protocol version: {0}")]
UnsupportedVersion(String),
#[error("Checksum mismatch")]
ChecksumMismatch,
}
#[derive(Debug, thiserror::Error)]
pub enum TransportError {
#[error("WebSocket error: {0}")]
WebSocket(String),
#[error("Connection closed: {reason}")]
ConnectionClosed {
reason: String,
},
#[error("Connection failed: {0}")]
ConnectionFailed(String),
#[error("Channel error: {0}")]
Channel(String),
#[error("Heartbeat timeout")]
HeartbeatTimeout,
}
impl Error {
pub fn is_retriable(&self) -> bool {
match self {
Error::Timeout => true,
Error::Transport(TransportError::HeartbeatTimeout) => true,
Error::Transport(TransportError::ConnectionFailed(_)) => true,
Error::Transport(TransportError::WebSocket(_)) => true,
Error::AwsSdk(_) => true, _ => false,
}
}
pub fn is_fatal(&self) -> bool {
matches!(
self,
Error::Session(SessionError::Terminated { .. })
| Error::Transport(TransportError::ConnectionClosed { .. })
| Error::Session(SessionError::InvalidState { .. })
)
}
}
impl<E, R> From<aws_smithy_runtime_api::client::result::SdkError<E, R>> for Error
where
E: fmt::Debug,
R: fmt::Debug,
{
fn from(err: aws_smithy_runtime_api::client::result::SdkError<E, R>) -> Self {
Error::AwsSdk(format!("{:?}", err))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_error_is_retriable() {
assert!(Error::Timeout.is_retriable());
assert!(Error::Transport(TransportError::HeartbeatTimeout).is_retriable());
assert!(Error::Transport(TransportError::ConnectionFailed("test".into())).is_retriable());
assert!(Error::Transport(TransportError::WebSocket("test".into())).is_retriable());
assert!(Error::AwsSdk("transient".into()).is_retriable());
assert!(!Error::Cancelled.is_retriable());
assert!(!Error::Config("bad config".into()).is_retriable());
assert!(!Error::InvalidState("invalid".into()).is_retriable());
assert!(!Error::Session(SessionError::NotFound("sess".into())).is_retriable());
}
#[test]
fn test_error_is_fatal() {
let terminated = Error::Session(SessionError::Terminated {
reason: "test".to_string(),
});
assert!(terminated.is_fatal());
let conn_closed = Error::Transport(TransportError::ConnectionClosed {
reason: "closed".to_string(),
});
assert!(conn_closed.is_fatal());
let invalid_state = Error::Session(SessionError::InvalidState {
expected: "Running".into(),
actual: "Terminated".into(),
});
assert!(invalid_state.is_fatal());
assert!(!Error::Timeout.is_fatal());
assert!(!Error::Cancelled.is_fatal());
assert!(!Error::AwsSdk("error".into()).is_fatal());
}
#[test]
fn test_error_display() {
let err = Error::Timeout;
assert_eq!(format!("{}", err), "Operation timed out");
let err = Error::Session(SessionError::NotFound("sess-123".into()));
assert!(format!("{}", err).contains("sess-123"));
let err = Error::Protocol(ProtocolError::ChecksumMismatch);
assert!(format!("{}", err).contains("Checksum"));
}
#[test]
fn test_session_error_variants() {
let err = SessionError::NotFound("sess-1".into());
assert!(format!("{}", err).contains("sess-1"));
let err = SessionError::AlreadyExists("sess-2".into());
assert!(format!("{}", err).contains("sess-2"));
let err = SessionError::InitializationFailed("handshake failed".into());
assert!(format!("{}", err).contains("handshake"));
}
#[test]
fn test_protocol_error_variants() {
let err = ProtocolError::InvalidMessage("bad header".into());
assert!(format!("{}", err).contains("bad header"));
let err = ProtocolError::UnknownMessageType("xyz".into());
assert!(format!("{}", err).contains("xyz"));
let err = ProtocolError::InvalidSequence {
expected: 5,
actual: 3,
};
assert!(format!("{}", err).contains("5"));
assert!(format!("{}", err).contains("3"));
let err = ProtocolError::Framing("truncated".into());
assert!(format!("{}", err).contains("truncated"));
let err = ProtocolError::UnsupportedVersion("2.0".into());
assert!(format!("{}", err).contains("2.0"));
}
#[test]
fn test_transport_error_variants() {
let err = TransportError::WebSocket("connection reset".into());
assert!(format!("{}", err).contains("connection reset"));
let err = TransportError::ConnectionClosed {
reason: "EOF".into(),
};
assert!(format!("{}", err).contains("EOF"));
let err = TransportError::Channel("send failed".into());
assert!(format!("{}", err).contains("send failed"));
let err = TransportError::HeartbeatTimeout;
assert!(format!("{}", err).contains("Heartbeat"));
}
#[test]
fn test_error_from_io() {
let io_err = std::io::Error::new(std::io::ErrorKind::NotFound, "file not found");
let err: Error = io_err.into();
assert!(matches!(err, Error::Io(_)));
}
}