quic_reverse_control/
codec.rs1use crate::ProtocolMessage;
21use serde::{de::DeserializeOwned, Serialize};
22use thiserror::Error;
23
24#[derive(Debug, Error)]
26pub enum CodecError {
27 #[error("serialization failed: {0}")]
29 Serialize(String),
30
31 #[error("deserialization failed: {0}")]
33 Deserialize(String),
34}
35
36pub trait Codec: Send + Sync + 'static {
41 fn encode<T: Serialize>(&self, value: &T) -> Result<Vec<u8>, CodecError>;
47
48 fn decode<T: DeserializeOwned>(&self, data: &[u8]) -> Result<T, CodecError>;
54
55 fn encode_message(&self, message: &ProtocolMessage) -> Result<Vec<u8>, CodecError> {
63 self.encode(message)
64 }
65
66 fn decode_message(&self, data: &[u8]) -> Result<ProtocolMessage, CodecError> {
74 self.decode(data)
75 }
76}
77
78#[derive(Debug, Clone, Copy, Default)]
83pub struct BincodeCodec;
84
85impl BincodeCodec {
86 #[must_use]
88 pub const fn new() -> Self {
89 Self
90 }
91}
92
93impl Codec for BincodeCodec {
94 fn encode<T: Serialize>(&self, value: &T) -> Result<Vec<u8>, CodecError> {
95 bincode::serialize(value).map_err(|e| CodecError::Serialize(e.to_string()))
96 }
97
98 fn decode<T: DeserializeOwned>(&self, data: &[u8]) -> Result<T, CodecError> {
99 bincode::deserialize(data).map_err(|e| CodecError::Deserialize(e.to_string()))
100 }
101}
102
103#[cfg(test)]
104mod tests {
105 use super::*;
106 use crate::{Features, Hello, OpenRequest, OpenResponse, OpenStatus, RejectCode, ServiceId};
107
108 #[test]
109 fn bincode_hello_round_trip() {
110 let codec = BincodeCodec::new();
111 let hello =
112 Hello::new(Features::PING_PONG | Features::STRUCTURED_METADATA).with_agent("test/1.0");
113 let message = ProtocolMessage::Hello(hello.clone());
114
115 let encoded = codec
116 .encode_message(&message)
117 .expect("encode should succeed");
118 let decoded = codec
119 .decode_message(&encoded)
120 .expect("decode should succeed");
121
122 match decoded {
123 ProtocolMessage::Hello(h) => {
124 assert_eq!(h.protocol_version, hello.protocol_version);
125 assert_eq!(h.features, hello.features);
126 assert_eq!(h.agent, hello.agent);
127 }
128 _ => panic!("expected Hello message"),
129 }
130 }
131
132 #[test]
133 fn bincode_open_request_round_trip() {
134 let codec = BincodeCodec::new();
135 let request = OpenRequest::new(42, ServiceId::new("ssh"));
136 let message = ProtocolMessage::OpenRequest(request.clone());
137
138 let encoded = codec
139 .encode_message(&message)
140 .expect("encode should succeed");
141 let decoded = codec
142 .decode_message(&encoded)
143 .expect("decode should succeed");
144
145 match decoded {
146 ProtocolMessage::OpenRequest(r) => {
147 assert_eq!(r.request_id, request.request_id);
148 assert_eq!(r.service, request.service);
149 }
150 _ => panic!("expected OpenRequest message"),
151 }
152 }
153
154 #[test]
155 fn bincode_open_response_accepted_round_trip() {
156 let codec = BincodeCodec::new();
157 let response = OpenResponse::accepted(42, 100);
158 let message = ProtocolMessage::OpenResponse(response.clone());
159
160 let encoded = codec
161 .encode_message(&message)
162 .expect("encode should succeed");
163 let decoded = codec
164 .decode_message(&encoded)
165 .expect("decode should succeed");
166
167 match decoded {
168 ProtocolMessage::OpenResponse(r) => {
169 assert_eq!(r.request_id, response.request_id);
170 assert_eq!(r.status, OpenStatus::Accepted);
171 assert_eq!(r.logical_stream_id, Some(100));
172 }
173 _ => panic!("expected OpenResponse message"),
174 }
175 }
176
177 #[test]
178 fn bincode_open_response_rejected_round_trip() {
179 let codec = BincodeCodec::new();
180 let response =
181 OpenResponse::rejected(42, RejectCode::Unauthorized, Some("access denied".into()));
182 let message = ProtocolMessage::OpenResponse(response);
183
184 let encoded = codec
185 .encode_message(&message)
186 .expect("encode should succeed");
187 let decoded = codec
188 .decode_message(&encoded)
189 .expect("decode should succeed");
190
191 match decoded {
192 ProtocolMessage::OpenResponse(r) => {
193 assert_eq!(r.request_id, 42);
194 assert_eq!(r.status, OpenStatus::Rejected(RejectCode::Unauthorized));
195 assert_eq!(r.reason.as_deref(), Some("access denied"));
196 assert!(r.logical_stream_id.is_none());
197 }
198 _ => panic!("expected OpenResponse message"),
199 }
200 }
201
202 #[test]
203 fn bincode_decode_invalid_data() {
204 let codec = BincodeCodec::new();
205 let invalid_data = &[0xff, 0xff, 0xff, 0xff];
206 let result: Result<ProtocolMessage, _> = codec.decode(invalid_data);
207 assert!(result.is_err());
208 }
209
210 #[test]
211 fn codec_is_send_sync() {
212 fn assert_send_sync<T: Send + Sync>() {}
213 assert_send_sync::<BincodeCodec>();
214 }
215}