Skip to main content

rstmdb_protocol/
error.rs

1//! Protocol error types and error codes.
2
3use serde::{Deserialize, Serialize};
4use std::fmt;
5use thiserror::Error;
6
7/// Protocol-level errors that can occur during framing or message handling.
8#[derive(Debug, Error)]
9pub enum ProtocolError {
10    #[error("invalid magic bytes: expected 'RCPX', got {0:?}")]
11    InvalidMagic([u8; 4]),
12
13    #[error("unsupported protocol version: {0}")]
14    UnsupportedVersion(u16),
15
16    #[error("frame too large: {size} bytes (max {max})")]
17    FrameTooLarge { size: u32, max: u32 },
18
19    #[error("CRC mismatch: expected {expected:#x}, got {actual:#x}")]
20    CrcMismatch { expected: u32, actual: u32 },
21
22    #[error("invalid frame flags: {0:#x}")]
23    InvalidFlags(u16),
24
25    #[error("incomplete frame: need {needed} more bytes")]
26    IncompleteFrame { needed: usize },
27
28    #[error("JSON error: {0}")]
29    Json(#[from] serde_json::Error),
30
31    #[error("I/O error: {0}")]
32    Io(#[from] std::io::Error),
33
34    #[error("invalid UTF-8 in payload")]
35    InvalidUtf8,
36
37    #[error("missing required field: {0}")]
38    MissingField(&'static str),
39}
40
41/// Stable error codes returned in error responses.
42///
43/// These codes are part of the protocol contract and must remain stable
44/// across versions.
45#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
46#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
47pub enum ErrorCode {
48    // Protocol errors
49    UnsupportedProtocol,
50    BadRequest,
51
52    // Authentication errors
53    Unauthorized,
54    AuthFailed,
55
56    // Resource errors
57    NotFound,
58    MachineNotFound,
59    MachineVersionExists,
60    MachineVersionLimitExceeded,
61    InstanceNotFound,
62    InstanceExists,
63
64    // State machine errors
65    InvalidTransition,
66    GuardFailed,
67    Conflict,
68
69    // System errors
70    WalIoError,
71    InternalError,
72    RateLimited,
73}
74
75impl ErrorCode {
76    /// Returns whether this error is potentially retryable.
77    pub fn is_retryable(&self) -> bool {
78        matches!(
79            self,
80            ErrorCode::WalIoError | ErrorCode::RateLimited | ErrorCode::InternalError
81        )
82    }
83}
84
85impl fmt::Display for ErrorCode {
86    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
87        match self {
88            ErrorCode::UnsupportedProtocol => write!(f, "UNSUPPORTED_PROTOCOL"),
89            ErrorCode::BadRequest => write!(f, "BAD_REQUEST"),
90            ErrorCode::Unauthorized => write!(f, "UNAUTHORIZED"),
91            ErrorCode::AuthFailed => write!(f, "AUTH_FAILED"),
92            ErrorCode::NotFound => write!(f, "NOT_FOUND"),
93            ErrorCode::MachineNotFound => write!(f, "MACHINE_NOT_FOUND"),
94            ErrorCode::MachineVersionExists => write!(f, "MACHINE_VERSION_EXISTS"),
95            ErrorCode::MachineVersionLimitExceeded => write!(f, "MACHINE_VERSION_LIMIT_EXCEEDED"),
96            ErrorCode::InstanceNotFound => write!(f, "INSTANCE_NOT_FOUND"),
97            ErrorCode::InstanceExists => write!(f, "INSTANCE_EXISTS"),
98            ErrorCode::InvalidTransition => write!(f, "INVALID_TRANSITION"),
99            ErrorCode::GuardFailed => write!(f, "GUARD_FAILED"),
100            ErrorCode::Conflict => write!(f, "CONFLICT"),
101            ErrorCode::WalIoError => write!(f, "WAL_IO_ERROR"),
102            ErrorCode::InternalError => write!(f, "INTERNAL_ERROR"),
103            ErrorCode::RateLimited => write!(f, "RATE_LIMITED"),
104        }
105    }
106}
107
108#[cfg(test)]
109mod tests {
110    use super::*;
111
112    #[test]
113    fn test_error_code_retryable() {
114        // Retryable errors
115        assert!(ErrorCode::WalIoError.is_retryable());
116        assert!(ErrorCode::RateLimited.is_retryable());
117        assert!(ErrorCode::InternalError.is_retryable());
118
119        // Non-retryable errors
120        assert!(!ErrorCode::BadRequest.is_retryable());
121        assert!(!ErrorCode::NotFound.is_retryable());
122        assert!(!ErrorCode::InvalidTransition.is_retryable());
123        assert!(!ErrorCode::GuardFailed.is_retryable());
124        assert!(!ErrorCode::Conflict.is_retryable());
125        assert!(!ErrorCode::Unauthorized.is_retryable());
126        assert!(!ErrorCode::AuthFailed.is_retryable());
127    }
128
129    #[test]
130    fn test_error_code_display() {
131        assert_eq!(
132            format!("{}", ErrorCode::UnsupportedProtocol),
133            "UNSUPPORTED_PROTOCOL"
134        );
135        assert_eq!(format!("{}", ErrorCode::BadRequest), "BAD_REQUEST");
136        assert_eq!(format!("{}", ErrorCode::Unauthorized), "UNAUTHORIZED");
137        assert_eq!(format!("{}", ErrorCode::AuthFailed), "AUTH_FAILED");
138        assert_eq!(format!("{}", ErrorCode::NotFound), "NOT_FOUND");
139        assert_eq!(
140            format!("{}", ErrorCode::MachineNotFound),
141            "MACHINE_NOT_FOUND"
142        );
143        assert_eq!(
144            format!("{}", ErrorCode::MachineVersionExists),
145            "MACHINE_VERSION_EXISTS"
146        );
147        assert_eq!(
148            format!("{}", ErrorCode::MachineVersionLimitExceeded),
149            "MACHINE_VERSION_LIMIT_EXCEEDED"
150        );
151        assert_eq!(
152            format!("{}", ErrorCode::InstanceNotFound),
153            "INSTANCE_NOT_FOUND"
154        );
155        assert_eq!(format!("{}", ErrorCode::InstanceExists), "INSTANCE_EXISTS");
156        assert_eq!(
157            format!("{}", ErrorCode::InvalidTransition),
158            "INVALID_TRANSITION"
159        );
160        assert_eq!(format!("{}", ErrorCode::GuardFailed), "GUARD_FAILED");
161        assert_eq!(format!("{}", ErrorCode::Conflict), "CONFLICT");
162        assert_eq!(format!("{}", ErrorCode::WalIoError), "WAL_IO_ERROR");
163        assert_eq!(format!("{}", ErrorCode::InternalError), "INTERNAL_ERROR");
164        assert_eq!(format!("{}", ErrorCode::RateLimited), "RATE_LIMITED");
165    }
166
167    #[test]
168    fn test_error_code_serialization() {
169        let code = ErrorCode::NotFound;
170        let json = serde_json::to_string(&code).unwrap();
171        assert_eq!(json, "\"NOT_FOUND\"");
172
173        let parsed: ErrorCode = serde_json::from_str("\"CONFLICT\"").unwrap();
174        assert_eq!(parsed, ErrorCode::Conflict);
175    }
176
177    #[test]
178    fn test_protocol_error_display() {
179        let err = ProtocolError::InvalidMagic(*b"XXXX");
180        // InvalidMagic displays as byte array, e.g. [88, 88, 88, 88]
181        assert!(err.to_string().contains("magic"));
182
183        let err = ProtocolError::UnsupportedVersion(99);
184        assert!(err.to_string().contains("99"));
185
186        let err = ProtocolError::FrameTooLarge { size: 100, max: 50 };
187        assert!(err.to_string().contains("100"));
188
189        // CRC uses hex format
190        let err = ProtocolError::CrcMismatch {
191            expected: 0xABC,
192            actual: 0xDEF,
193        };
194        let msg = err.to_string();
195        assert!(msg.contains("abc") || msg.contains("ABC"));
196
197        let err = ProtocolError::IncompleteFrame { needed: 10 };
198        assert!(err.to_string().contains("10"));
199
200        let err = ProtocolError::InvalidUtf8;
201        assert!(err.to_string().contains("UTF-8"));
202
203        let err = ProtocolError::MissingField("test_field");
204        assert!(err.to_string().contains("test_field"));
205
206        let err = ProtocolError::InvalidFlags(0xFF);
207        let msg = err.to_string();
208        assert!(msg.contains("ff") || msg.contains("FF"));
209    }
210}