Skip to main content

rtmp_rs/
error.rs

1//! Unified error types for rtmp-rs
2
3use std::fmt;
4use std::io;
5
6/// Result type alias using the library's Error type
7pub type Result<T> = std::result::Result<T, Error>;
8
9/// Unified error type for all RTMP operations
10#[derive(Debug)]
11pub enum Error {
12    /// I/O error during network operations
13    Io(io::Error),
14    /// RTMP protocol violation
15    Protocol(ProtocolError),
16    /// AMF encoding/decoding error
17    Amf(AmfError),
18    /// Handshake failure
19    Handshake(HandshakeError),
20    /// Media parsing error
21    Media(MediaError),
22    /// Connection rejected by peer or handler
23    Rejected(String),
24    /// Operation timed out
25    Timeout,
26    /// Connection was closed
27    ConnectionClosed,
28    /// Invalid configuration
29    Config(String),
30}
31
32impl fmt::Display for Error {
33    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
34        match self {
35            Error::Io(e) => write!(f, "I/O error: {}", e),
36            Error::Protocol(e) => write!(f, "Protocol error: {}", e),
37            Error::Amf(e) => write!(f, "AMF error: {}", e),
38            Error::Handshake(e) => write!(f, "Handshake error: {}", e),
39            Error::Media(e) => write!(f, "Media error: {}", e),
40            Error::Rejected(msg) => write!(f, "Connection rejected: {}", msg),
41            Error::Timeout => write!(f, "Operation timed out"),
42            Error::ConnectionClosed => write!(f, "Connection closed"),
43            Error::Config(msg) => write!(f, "Configuration error: {}", msg),
44        }
45    }
46}
47
48impl std::error::Error for Error {
49    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
50        match self {
51            Error::Io(e) => Some(e),
52            _ => None,
53        }
54    }
55}
56
57impl From<io::Error> for Error {
58    fn from(err: io::Error) -> Self {
59        Error::Io(err)
60    }
61}
62
63impl From<ProtocolError> for Error {
64    fn from(err: ProtocolError) -> Self {
65        Error::Protocol(err)
66    }
67}
68
69impl From<AmfError> for Error {
70    fn from(err: AmfError) -> Self {
71        Error::Amf(err)
72    }
73}
74
75impl From<HandshakeError> for Error {
76    fn from(err: HandshakeError) -> Self {
77        Error::Handshake(err)
78    }
79}
80
81impl From<MediaError> for Error {
82    fn from(err: MediaError) -> Self {
83        Error::Media(err)
84    }
85}
86
87/// Protocol-level errors
88#[derive(Debug)]
89pub enum ProtocolError {
90    InvalidChunkHeader,
91    UnknownMessageType(u8),
92    MessageTooLarge { size: u32, max: u32 },
93    InvalidChunkStreamId(u32),
94    UnexpectedMessage(String),
95    MissingField(String),
96    InvalidCommand(String),
97    StreamNotFound(u32),
98}
99
100impl fmt::Display for ProtocolError {
101    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
102        match self {
103            ProtocolError::InvalidChunkHeader => write!(f, "Invalid chunk header"),
104            ProtocolError::UnknownMessageType(t) => write!(f, "Unknown message type: {}", t),
105            ProtocolError::MessageTooLarge { size, max } => {
106                write!(f, "Message too large: {} bytes (max {})", size, max)
107            }
108            ProtocolError::InvalidChunkStreamId(id) => write!(f, "Invalid chunk stream ID: {}", id),
109            ProtocolError::UnexpectedMessage(msg) => write!(f, "Unexpected message: {}", msg),
110            ProtocolError::MissingField(field) => write!(f, "Missing required field: {}", field),
111            ProtocolError::InvalidCommand(cmd) => write!(f, "Invalid command: {}", cmd),
112            ProtocolError::StreamNotFound(id) => write!(f, "Stream not found: {}", id),
113        }
114    }
115}
116
117impl std::error::Error for ProtocolError {}
118
119/// AMF encoding/decoding errors
120#[derive(Debug)]
121pub enum AmfError {
122    UnknownMarker(u8),
123    UnexpectedEof,
124    InvalidUtf8,
125    InvalidReference(u16),
126    NestingTooDeep,
127    InvalidObjectEnd,
128}
129
130impl fmt::Display for AmfError {
131    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
132        match self {
133            AmfError::UnknownMarker(m) => write!(f, "Unknown AMF marker: 0x{:02x}", m),
134            AmfError::UnexpectedEof => write!(f, "Unexpected end of AMF data"),
135            AmfError::InvalidUtf8 => write!(f, "Invalid UTF-8 in AMF string"),
136            AmfError::InvalidReference(idx) => write!(f, "Invalid AMF reference: {}", idx),
137            AmfError::NestingTooDeep => write!(f, "AMF nesting too deep"),
138            AmfError::InvalidObjectEnd => write!(f, "Invalid object end marker"),
139        }
140    }
141}
142
143impl std::error::Error for AmfError {}
144
145/// Handshake-specific errors
146#[derive(Debug)]
147pub enum HandshakeError {
148    InvalidVersion(u8),
149    DigestMismatch,
150    InvalidState,
151    ResponseMismatch,
152}
153
154impl fmt::Display for HandshakeError {
155    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
156        match self {
157            HandshakeError::InvalidVersion(v) => write!(f, "Invalid RTMP version: {}", v),
158            HandshakeError::DigestMismatch => write!(f, "Handshake digest mismatch"),
159            HandshakeError::InvalidState => write!(f, "Invalid handshake state"),
160            HandshakeError::ResponseMismatch => write!(f, "Handshake response mismatch"),
161        }
162    }
163}
164
165impl std::error::Error for HandshakeError {}
166
167/// Media parsing errors
168#[derive(Debug)]
169pub enum MediaError {
170    InvalidFlvTag,
171    InvalidAvcPacket,
172    InvalidAacPacket,
173    UnsupportedCodec(String),
174    InvalidNalu,
175    MissingSequenceHeader,
176    /// Invalid enhanced video packet (E-RTMP)
177    InvalidEnhancedVideoPacket,
178    /// Invalid enhanced audio packet (E-RTMP)
179    InvalidEnhancedAudioPacket,
180    /// Unsupported video codec FOURCC
181    UnsupportedVideoCodec,
182    /// Unsupported audio codec FOURCC
183    UnsupportedAudioCodec,
184}
185
186impl fmt::Display for MediaError {
187    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
188        match self {
189            MediaError::InvalidFlvTag => write!(f, "Invalid FLV tag"),
190            MediaError::InvalidAvcPacket => write!(f, "Invalid AVC packet"),
191            MediaError::InvalidAacPacket => write!(f, "Invalid AAC packet"),
192            MediaError::UnsupportedCodec(c) => write!(f, "Unsupported codec: {}", c),
193            MediaError::InvalidNalu => write!(f, "Invalid NAL unit"),
194            MediaError::MissingSequenceHeader => write!(f, "Missing sequence header"),
195            MediaError::InvalidEnhancedVideoPacket => write!(f, "Invalid enhanced video packet"),
196            MediaError::InvalidEnhancedAudioPacket => write!(f, "Invalid enhanced audio packet"),
197            MediaError::UnsupportedVideoCodec => write!(f, "Unsupported video codec FOURCC"),
198            MediaError::UnsupportedAudioCodec => write!(f, "Unsupported audio codec FOURCC"),
199        }
200    }
201}
202
203impl std::error::Error for MediaError {}
204
205#[cfg(test)]
206mod tests {
207    use super::*;
208    use std::error::Error as StdError;
209    use std::io;
210
211    #[test]
212    fn test_error_display() {
213        // Test Error::Io display
214        let io_err = io::Error::new(io::ErrorKind::ConnectionReset, "connection reset");
215        let err = Error::Io(io_err);
216        assert!(err.to_string().contains("I/O error"));
217
218        // Test Error::Protocol display
219        let err = Error::Protocol(ProtocolError::InvalidChunkHeader);
220        assert!(err.to_string().contains("Protocol error"));
221        assert!(err.to_string().contains("Invalid chunk header"));
222
223        // Test Error::Amf display
224        let err = Error::Amf(AmfError::UnknownMarker(0xFF));
225        assert!(err.to_string().contains("AMF error"));
226        assert!(err.to_string().contains("0xff"));
227
228        // Test Error::Handshake display
229        let err = Error::Handshake(HandshakeError::InvalidVersion(5));
230        assert!(err.to_string().contains("Handshake error"));
231        assert!(err.to_string().contains("5"));
232
233        // Test Error::Media display
234        let err = Error::Media(MediaError::UnsupportedCodec("VP9".into()));
235        assert!(err.to_string().contains("Media error"));
236        assert!(err.to_string().contains("VP9"));
237
238        // Test Error::Rejected display
239        let err = Error::Rejected("stream key invalid".into());
240        assert!(err.to_string().contains("Connection rejected"));
241        assert!(err.to_string().contains("stream key invalid"));
242
243        // Test Error::Timeout display
244        let err = Error::Timeout;
245        assert!(err.to_string().contains("timed out"));
246
247        // Test Error::ConnectionClosed display
248        let err = Error::ConnectionClosed;
249        assert!(err.to_string().contains("closed"));
250
251        // Test Error::Config display
252        let err = Error::Config("invalid port".into());
253        assert!(err.to_string().contains("Configuration error"));
254    }
255
256    #[test]
257    fn test_error_source() {
258        // Only Io error should have a source
259        let io_err = io::Error::new(io::ErrorKind::NotFound, "file not found");
260        let err = Error::Io(io_err);
261        assert!(StdError::source(&err).is_some());
262
263        // Other errors should not have a source
264        let err = Error::Protocol(ProtocolError::InvalidChunkHeader);
265        assert!(StdError::source(&err).is_none());
266
267        let err = Error::Timeout;
268        assert!(StdError::source(&err).is_none());
269    }
270
271    #[test]
272    fn test_from_conversions() {
273        // Test From<io::Error>
274        let io_err = io::Error::new(io::ErrorKind::TimedOut, "timeout");
275        let err: Error = io_err.into();
276        assert!(matches!(err, Error::Io(_)));
277
278        // Test From<ProtocolError>
279        let proto_err = ProtocolError::MessageTooLarge { size: 100, max: 50 };
280        let err: Error = proto_err.into();
281        assert!(matches!(err, Error::Protocol(_)));
282
283        // Test From<AmfError>
284        let amf_err = AmfError::UnexpectedEof;
285        let err: Error = amf_err.into();
286        assert!(matches!(err, Error::Amf(_)));
287
288        // Test From<HandshakeError>
289        let hs_err = HandshakeError::DigestMismatch;
290        let err: Error = hs_err.into();
291        assert!(matches!(err, Error::Handshake(_)));
292
293        // Test From<MediaError>
294        let media_err = MediaError::InvalidFlvTag;
295        let err: Error = media_err.into();
296        assert!(matches!(err, Error::Media(_)));
297    }
298
299    #[test]
300    fn test_protocol_error_display() {
301        assert!(ProtocolError::InvalidChunkHeader
302            .to_string()
303            .contains("Invalid chunk header"));
304
305        assert!(ProtocolError::UnknownMessageType(99)
306            .to_string()
307            .contains("99"));
308
309        let err = ProtocolError::MessageTooLarge {
310            size: 1000,
311            max: 500,
312        };
313        assert!(err.to_string().contains("1000"));
314        assert!(err.to_string().contains("500"));
315
316        assert!(ProtocolError::InvalidChunkStreamId(123)
317            .to_string()
318            .contains("123"));
319
320        assert!(ProtocolError::UnexpectedMessage("test".into())
321            .to_string()
322            .contains("test"));
323
324        assert!(ProtocolError::MissingField("app".into())
325            .to_string()
326            .contains("app"));
327
328        assert!(ProtocolError::InvalidCommand("bad".into())
329            .to_string()
330            .contains("bad"));
331
332        assert!(ProtocolError::StreamNotFound(5).to_string().contains("5"));
333    }
334
335    #[test]
336    fn test_amf_error_display() {
337        assert!(AmfError::UnknownMarker(0xAB).to_string().contains("0xab"));
338
339        assert!(AmfError::UnexpectedEof.to_string().contains("end of AMF"));
340
341        assert!(AmfError::InvalidUtf8.to_string().contains("UTF-8"));
342
343        assert!(AmfError::InvalidReference(42).to_string().contains("42"));
344
345        assert!(AmfError::NestingTooDeep.to_string().contains("deep"));
346
347        assert!(AmfError::InvalidObjectEnd.to_string().contains("end"));
348    }
349
350    #[test]
351    fn test_handshake_error_display() {
352        assert!(HandshakeError::InvalidVersion(10)
353            .to_string()
354            .contains("10"));
355
356        assert!(HandshakeError::DigestMismatch
357            .to_string()
358            .contains("digest"));
359
360        assert!(HandshakeError::InvalidState.to_string().contains("state"));
361
362        assert!(HandshakeError::ResponseMismatch
363            .to_string()
364            .contains("response"));
365    }
366
367    #[test]
368    fn test_media_error_display() {
369        assert!(MediaError::InvalidFlvTag.to_string().contains("FLV"));
370        assert!(MediaError::InvalidAvcPacket.to_string().contains("AVC"));
371        assert!(MediaError::InvalidAacPacket.to_string().contains("AAC"));
372        assert!(MediaError::UnsupportedCodec("HEVC".into())
373            .to_string()
374            .contains("HEVC"));
375        assert!(MediaError::InvalidNalu.to_string().contains("NAL"));
376        assert!(MediaError::MissingSequenceHeader
377            .to_string()
378            .contains("sequence"));
379        assert!(MediaError::InvalidEnhancedVideoPacket
380            .to_string()
381            .contains("enhanced video"));
382        assert!(MediaError::InvalidEnhancedAudioPacket
383            .to_string()
384            .contains("enhanced audio"));
385        assert!(MediaError::UnsupportedVideoCodec
386            .to_string()
387            .contains("video codec"));
388        assert!(MediaError::UnsupportedAudioCodec
389            .to_string()
390            .contains("audio codec"));
391    }
392}