1use serde::{Deserialize, Serialize};
4use bytes::Bytes;
5use crate::ProtocolError;
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
9pub struct FrameFlags(pub u8);
10
11impl FrameFlags {
12 pub const NONE: Self = Self(0);
14 pub const END_STREAM: Self = Self(1);
16 pub const ERROR: Self = Self(2);
18 pub const FLOW_CONTROL: Self = Self(4);
20
21 pub fn has_flag(self, flag: FrameFlags) -> bool {
23 (self.0 & flag.0) != 0
24 }
25
26 pub fn set_flag(&mut self, flag: FrameFlags) {
28 self.0 |= flag.0;
29 }
30
31 pub fn clear_flag(&mut self, flag: FrameFlags) {
33 self.0 &= !flag.0;
34 }
35}
36
37#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct Frame {
40 pub stream_id: u32,
42 pub sequence: u32,
44 pub flags: FrameFlags,
46 pub payload: Bytes,
48}
49
50impl Frame {
51 pub fn new(stream_id: u32, sequence: u32, flags: FrameFlags, payload: Bytes) -> Self {
53 Self {
54 stream_id,
55 sequence,
56 flags,
57 payload,
58 }
59 }
60
61 pub fn data(stream_id: u32, sequence: u32, payload: Bytes) -> Self {
63 Self::new(stream_id, sequence, FrameFlags::NONE, payload)
64 }
65
66 pub fn end_stream(stream_id: u32, sequence: u32) -> Self {
68 Self::new(stream_id, sequence, FrameFlags::END_STREAM, Bytes::new())
69 }
70
71 pub fn error(stream_id: u32, sequence: u32, payload: Bytes) -> Self {
73 Self::new(stream_id, sequence, FrameFlags::ERROR, payload)
74 }
75
76 pub fn to_msgpack(&self) -> Result<Vec<u8>, ProtocolError> {
78 rmp_serde::to_vec(self)
79 .map_err(|e| ProtocolError::Serialization(e.to_string()))
80 }
81
82 pub fn from_msgpack(bytes: &[u8]) -> Result<Self, ProtocolError> {
84 rmp_serde::from_slice(bytes)
85 .map_err(|e| ProtocolError::Serialization(e.to_string()))
86 }
87
88 pub fn payload_size(&self) -> usize {
90 self.payload.len()
91 }
92
93 pub fn is_end_stream(&self) -> bool {
95 self.flags.has_flag(FrameFlags::END_STREAM)
96 }
97
98 pub fn is_error(&self) -> bool {
100 self.flags.has_flag(FrameFlags::ERROR)
101 }
102}
103
104#[cfg(test)]
105mod tests {
106 use super::*;
107 use proptest::prelude::*;
108
109 #[test]
110 fn test_frame_flags() {
111 let mut flags = FrameFlags::NONE;
112 assert!(!flags.has_flag(FrameFlags::END_STREAM));
113
114 flags.set_flag(FrameFlags::END_STREAM);
115 assert!(flags.has_flag(FrameFlags::END_STREAM));
116
117 flags.clear_flag(FrameFlags::END_STREAM);
118 assert!(!flags.has_flag(FrameFlags::END_STREAM));
119 }
120
121 #[test]
122 fn test_frame_creation() {
123 let payload = Bytes::from("test payload");
124 let frame = Frame::data(1, 42, payload.clone());
125
126 assert_eq!(frame.stream_id, 1);
127 assert_eq!(frame.sequence, 42);
128 assert_eq!(frame.flags, FrameFlags::NONE);
129 assert_eq!(frame.payload, payload);
130 assert!(!frame.is_end_stream());
131 assert!(!frame.is_error());
132 }
133
134 #[test]
135 fn test_end_stream_frame() {
136 let frame = Frame::end_stream(1, 42);
137 assert!(frame.is_end_stream());
138 assert!(!frame.is_error());
139 assert_eq!(frame.payload.len(), 0);
140 }
141
142 #[test]
143 fn test_error_frame() {
144 let payload = Bytes::from("error message");
145 let frame = Frame::error(1, 42, payload.clone());
146 assert!(!frame.is_end_stream());
147 assert!(frame.is_error());
148 assert_eq!(frame.payload, payload);
149 }
150
151 #[test]
152 fn test_msgpack_serialization_roundtrip() {
153 let payload = Bytes::from("test payload data");
154 let original = Frame::data(123, 456, payload);
155
156 let serialized = original.to_msgpack().unwrap();
157 let deserialized = Frame::from_msgpack(&serialized).unwrap();
158
159 assert_eq!(original.stream_id, deserialized.stream_id);
160 assert_eq!(original.sequence, deserialized.sequence);
161 assert_eq!(original.flags, deserialized.flags);
162 assert_eq!(original.payload, deserialized.payload);
163 }
164
165 #[test]
166 fn test_empty_payload_serialization() {
167 let frame = Frame::end_stream(1, 1);
168 let serialized = frame.to_msgpack().unwrap();
169 let deserialized = Frame::from_msgpack(&serialized).unwrap();
170
171 assert_eq!(frame.stream_id, deserialized.stream_id);
172 assert_eq!(frame.sequence, deserialized.sequence);
173 assert_eq!(frame.flags, deserialized.flags);
174 assert_eq!(frame.payload, deserialized.payload);
175 assert!(deserialized.is_end_stream());
176 }
177
178 proptest! {
179 #[test]
180 fn test_frame_roundtrip_properties(
181 stream_id in any::<u32>(),
182 sequence in any::<u32>(),
183 flags in any::<u8>(),
184 payload in prop::collection::vec(any::<u8>(), 0..1024)
185 ) {
186 let frame = Frame::new(
187 stream_id,
188 sequence,
189 FrameFlags(flags),
190 Bytes::from(payload)
191 );
192
193 let serialized = frame.to_msgpack().unwrap();
194 let deserialized = Frame::from_msgpack(&serialized).unwrap();
195
196 prop_assert_eq!(frame.stream_id, deserialized.stream_id);
197 prop_assert_eq!(frame.sequence, deserialized.sequence);
198 prop_assert_eq!(frame.flags, deserialized.flags);
199 prop_assert_eq!(frame.payload, deserialized.payload);
200 }
201 }
202}