1use std::fmt;
17
18use crate::http::HttpError;
19
20#[derive(Debug)]
22pub enum RestError {
23 Io(std::io::Error),
25 Http(HttpError),
27 BodyTooLarge {
29 size: usize,
31 max: usize,
33 },
34 RequestTooLarge {
36 capacity: usize,
38 },
39 CrlfInjection,
41 ConnectionPoisoned,
43 ReadTimeout,
45 ConnectionStale,
47 ConnectionClosed(&'static str),
49 InvalidUrl(String),
51 TlsNotEnabled,
53 #[cfg(feature = "tls")]
70 Tls(nexus_net::tls::TlsError),
71}
72
73impl fmt::Display for RestError {
74 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
75 match self {
76 Self::Io(e) => write!(f, "I/O error: {e}"),
77 Self::Http(e) => write!(f, "HTTP error: {e}"),
78 Self::BodyTooLarge { size, max } => {
79 write!(f, "response body too large: {size} bytes (max: {max})")
80 }
81 Self::RequestTooLarge { capacity } => {
82 write!(
83 f,
84 "request exceeds write buffer capacity ({capacity} bytes)"
85 )
86 }
87 Self::CrlfInjection => {
88 write!(f, "header or query parameter contains CR/LF")
89 }
90 Self::ConnectionPoisoned => write!(f, "connection poisoned after I/O error"),
91 Self::ReadTimeout => write!(f, "read timed out waiting for response"),
92 Self::ConnectionStale => write!(f, "connection stale (dead socket)"),
93 Self::TlsNotEnabled => write!(f, "https:// requires the `tls` feature"),
94 Self::ConnectionClosed(ctx) => write!(f, "connection closed: {ctx}"),
95 Self::InvalidUrl(u) => write!(f, "invalid URL: {u}"),
96 #[cfg(feature = "tls")]
97 Self::Tls(e) => write!(f, "TLS error: {e}"),
98 }
99 }
100}
101
102impl std::error::Error for RestError {
103 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
104 match self {
105 Self::Io(e) => Some(e),
106 Self::Http(e) => Some(e),
107 #[cfg(feature = "tls")]
108 Self::Tls(e) => Some(e),
109 _ => None,
110 }
111 }
112}
113
114impl From<std::io::Error> for RestError {
115 fn from(e: std::io::Error) -> Self {
116 Self::Io(e)
117 }
118}
119
120impl From<HttpError> for RestError {
121 fn from(e: HttpError) -> Self {
122 Self::Http(e)
123 }
124}
125
126#[cfg(feature = "tls")]
127impl From<nexus_net::tls::TlsError> for RestError {
128 fn from(e: nexus_net::tls::TlsError) -> Self {
129 match e {
130 nexus_net::tls::TlsError::Io(io) => Self::Io(io),
131 other => Self::Tls(other),
132 }
133 }
134}
135
136#[cfg(test)]
137mod tests {
138 use super::*;
139 use std::error::Error;
140
141 #[test]
142 fn rest_error_io() {
143 let io_err = std::io::Error::new(std::io::ErrorKind::TimedOut, "timeout");
144 let err = RestError::from(io_err);
145 assert!(matches!(err, RestError::Io(_)));
146 assert!(err.to_string().contains("timeout"));
147 assert!(err.source().is_some());
148 }
149
150 #[test]
151 fn rest_error_http() {
152 let http_err = HttpError::TooManyHeaders;
153 let err = RestError::from(http_err);
154 assert!(matches!(err, RestError::Http(_)));
155 assert!(err.to_string().contains("too many"));
156 assert!(err.source().is_some());
157 }
158
159 #[test]
160 fn rest_error_body_too_large() {
161 let err = RestError::BodyTooLarge {
162 size: 10_000,
163 max: 4096,
164 };
165 assert!(matches!(
166 err,
167 RestError::BodyTooLarge {
168 size: 10_000,
169 max: 4096,
170 }
171 ));
172 assert_eq!(
173 err.to_string(),
174 "response body too large: 10000 bytes (max: 4096)"
175 );
176 }
177
178 #[test]
179 fn rest_error_request_too_large() {
180 let err = RestError::RequestTooLarge { capacity: 32768 };
181 assert!(matches!(
182 err,
183 RestError::RequestTooLarge { capacity: 32768 }
184 ));
185 assert!(
186 err.to_string()
187 .contains("exceeds write buffer capacity (32768 bytes)")
188 );
189 }
190
191 #[test]
192 fn rest_error_crlf_injection() {
193 let err = RestError::CrlfInjection;
194 assert!(matches!(err, RestError::CrlfInjection));
195 assert_eq!(err.to_string(), "header or query parameter contains CR/LF");
196 }
197
198 #[test]
199 fn rest_error_connection_poisoned() {
200 let err = RestError::ConnectionPoisoned;
201 assert!(matches!(err, RestError::ConnectionPoisoned));
202 assert_eq!(err.to_string(), "connection poisoned after I/O error");
203 }
204
205 #[test]
206 fn rest_error_read_timeout() {
207 let err = RestError::ReadTimeout;
208 assert!(matches!(err, RestError::ReadTimeout));
209 assert_eq!(err.to_string(), "read timed out waiting for response");
210 }
211
212 #[test]
213 fn rest_error_connection_stale() {
214 let err = RestError::ConnectionStale;
215 assert!(matches!(err, RestError::ConnectionStale));
216 assert_eq!(err.to_string(), "connection stale (dead socket)");
217 }
218
219 #[test]
220 fn rest_error_connection_closed() {
221 let err = RestError::ConnectionClosed("during body read");
222 assert!(matches!(
223 err,
224 RestError::ConnectionClosed("during body read")
225 ));
226 assert_eq!(err.to_string(), "connection closed: during body read");
227 }
228
229 #[test]
230 fn rest_error_invalid_url() {
231 let err = RestError::InvalidUrl("ftp://bad".into());
232 assert!(matches!(err, RestError::InvalidUrl(_)));
233 assert_eq!(err.to_string(), "invalid URL: ftp://bad");
234 }
235
236 #[test]
237 fn rest_error_tls_not_enabled() {
238 let err = RestError::TlsNotEnabled;
239 assert!(matches!(err, RestError::TlsNotEnabled));
240 assert_eq!(err.to_string(), "https:// requires the `tls` feature");
241 }
242
243 #[test]
244 fn rest_error_source_none_for_leaf_variants() {
245 assert!(RestError::CrlfInjection.source().is_none());
246 assert!(RestError::ConnectionPoisoned.source().is_none());
247 assert!(RestError::ReadTimeout.source().is_none());
248 assert!(RestError::ConnectionStale.source().is_none());
249 assert!(RestError::TlsNotEnabled.source().is_none());
250 assert!(RestError::InvalidUrl("x".into()).source().is_none());
251 assert!(RestError::ConnectionClosed("x").source().is_none());
252 assert!(
253 RestError::BodyTooLarge { size: 1, max: 1 }
254 .source()
255 .is_none()
256 );
257 assert!(
258 RestError::RequestTooLarge { capacity: 1 }
259 .source()
260 .is_none()
261 );
262 }
263
264 #[cfg(feature = "tls")]
265 #[test]
266 fn rest_error_from_tls_io_flattens() {
267 let io_err = std::io::Error::new(std::io::ErrorKind::BrokenPipe, "broken");
268 let tls_err = nexus_net::tls::TlsError::Io(io_err);
269 let rest_err = RestError::from(tls_err);
270 assert!(matches!(rest_err, RestError::Io(_)));
272 }
273
274 #[cfg(feature = "tls")]
275 #[test]
276 fn rest_error_from_tls_non_io_preserves() {
277 let tls_err = nexus_net::tls::TlsError::NoRootCerts;
278 let rest_err = RestError::from(tls_err);
279 assert!(matches!(rest_err, RestError::Tls(_)));
281 assert!(rest_err.source().is_some());
282 }
283}