1use serde::{Deserialize, Serialize};
4
5#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
7#[serde(tag = "type", rename_all = "kebab-case")]
8pub enum ControlMessage {
9 DecoderConfig(DecoderConfigMessage),
10 FrameChecksum(FrameChecksumMessage),
11 Status(StatusMessage),
12 SessionMetrics(SessionMetricsMessage),
13}
14
15impl ControlMessage {
16 pub fn to_bytes(&self) -> Result<Vec<u8>, serde_json::Error> {
18 serde_json::to_vec(self)
19 }
20
21 pub fn from_slice(bytes: &[u8]) -> Result<Self, serde_json::Error> {
23 serde_json::from_slice(bytes)
24 }
25}
26
27#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
29pub struct DecoderConfigMessage {
30 pub codec: String,
31 #[serde(rename = "hardwareAcceleration")]
32 pub hardware_acceleration: String,
33 #[serde(rename = "optimizeForLatency")]
34 pub optimize_for_latency: bool,
35 #[serde(rename = "codedWidth", skip_serializing_if = "Option::is_none")]
36 pub coded_width: Option<u32>,
37 #[serde(rename = "codedHeight", skip_serializing_if = "Option::is_none")]
38 pub coded_height: Option<u32>,
39 #[serde(
40 rename = "descriptionBase64",
41 default,
42 skip_serializing_if = "Option::is_none",
43 with = "optional_base64_bytes"
44 )]
45 pub description: Option<Vec<u8>>,
46}
47
48impl DecoderConfigMessage {
49 pub fn low_latency(codec: impl Into<String>) -> Self {
51 Self {
52 codec: codec.into(),
53 hardware_acceleration: "prefer-hardware".to_owned(),
54 optimize_for_latency: true,
55 coded_width: None,
56 coded_height: None,
57 description: None,
58 }
59 }
60
61 pub fn with_dimensions(mut self, coded_width: u32, coded_height: u32) -> Self {
63 self.coded_width = Some(coded_width);
64 self.coded_height = Some(coded_height);
65 self
66 }
67
68 pub fn with_description(mut self, description: Vec<u8>) -> Self {
70 self.description = Some(description);
71 self
72 }
73}
74
75#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
77pub struct FrameChecksumMessage {
78 #[serde(rename = "frameId")]
79 pub frame_id: u32,
80 pub algorithm: String,
81 #[serde(rename = "hashHex")]
82 pub hash_hex: String,
83 #[serde(skip_serializing_if = "Option::is_none")]
84 pub width: Option<u32>,
85 #[serde(skip_serializing_if = "Option::is_none")]
86 pub height: Option<u32>,
87}
88
89impl FrameChecksumMessage {
90 pub fn rgba8_fnv1a64(frame_id: u32, hash_hex: impl Into<String>) -> Self {
91 Self {
92 frame_id,
93 algorithm: "fnv1a64-rgba8".to_owned(),
94 hash_hex: hash_hex.into(),
95 width: None,
96 height: None,
97 }
98 }
99
100 pub fn with_dimensions(mut self, width: u32, height: u32) -> Self {
101 self.width = Some(width);
102 self.height = Some(height);
103 self
104 }
105}
106
107#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
109pub struct StatusMessage {
110 pub message: String,
111}
112
113impl StatusMessage {
114 pub fn new(message: impl Into<String>) -> Self {
115 Self {
116 message: message.into(),
117 }
118 }
119}
120
121#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
123pub struct SessionMetricsMessage {
124 #[serde(rename = "encodeTimeUs", skip_serializing_if = "Option::is_none")]
125 pub encode_time_us: Option<u64>,
126 #[serde(rename = "transportRttMs", skip_serializing_if = "Option::is_none")]
127 pub transport_rtt_ms: Option<f32>,
128}
129
130impl SessionMetricsMessage {
131 pub fn new() -> Self {
132 Self::default()
133 }
134
135 pub fn with_encode_time_us(mut self, encode_time_us: u64) -> Self {
136 self.encode_time_us = Some(encode_time_us);
137 self
138 }
139
140 pub fn with_transport_rtt_ms(mut self, transport_rtt_ms: f32) -> Self {
141 self.transport_rtt_ms = Some(transport_rtt_ms);
142 self
143 }
144}
145
146mod optional_base64_bytes {
147 use base64::{Engine as _, engine::general_purpose::STANDARD as BASE64};
148 use serde::{Deserialize, Deserializer, Serializer};
149
150 pub fn serialize<S>(value: &Option<Vec<u8>>, serializer: S) -> Result<S::Ok, S::Error>
151 where
152 S: Serializer,
153 {
154 match value {
155 Some(bytes) => serializer.serialize_some(&BASE64.encode(bytes)),
156 None => serializer.serialize_none(),
157 }
158 }
159
160 pub fn deserialize<'de, D>(deserializer: D) -> Result<Option<Vec<u8>>, D::Error>
161 where
162 D: Deserializer<'de>,
163 {
164 let encoded = Option::<String>::deserialize(deserializer)?;
165 encoded
166 .map(|value| BASE64.decode(value).map_err(serde::de::Error::custom))
167 .transpose()
168 }
169}
170
171#[cfg(test)]
172mod tests {
173 use super::{
174 ControlMessage, DecoderConfigMessage, FrameChecksumMessage, SessionMetricsMessage,
175 StatusMessage,
176 };
177
178 #[test]
179 fn roundtrips_decoder_config_message_with_base64_description() {
180 let message = ControlMessage::DecoderConfig(
181 DecoderConfigMessage::low_latency("hvc1.1.6.L153.B0")
182 .with_dimensions(1920, 1080)
183 .with_description(vec![0x01, 0x02, 0x03]),
184 );
185
186 let bytes = message.to_bytes().unwrap();
187 let json = String::from_utf8(bytes.clone()).unwrap();
188 assert!(json.contains("\"type\":\"decoder-config\""));
189 assert!(json.contains("\"codec\":\"hvc1.1.6.L153.B0\""));
190 assert!(json.contains("\"codedWidth\":1920"));
191 assert!(json.contains("\"codedHeight\":1080"));
192 assert!(json.contains("\"descriptionBase64\":\"AQID\""));
193
194 let decoded = ControlMessage::from_slice(&bytes).unwrap();
195 assert_eq!(decoded, message);
196 }
197
198 #[test]
199 fn serializes_session_metrics_message() {
200 let message = ControlMessage::SessionMetrics(
201 SessionMetricsMessage::new()
202 .with_encode_time_us(1_750)
203 .with_transport_rtt_ms(2.5),
204 );
205
206 let bytes = message.to_bytes().unwrap();
207 let json = String::from_utf8(bytes).unwrap();
208 assert!(json.contains("\"type\":\"session-metrics\""));
209 assert!(json.contains("\"encodeTimeUs\":1750"));
210 assert!(json.contains("\"transportRttMs\":2.5"));
211 }
212
213 #[test]
214 fn serializes_frame_checksum_message() {
215 let message = ControlMessage::FrameChecksum(
216 FrameChecksumMessage::rgba8_fnv1a64(7, "0123456789abcdef").with_dimensions(1920, 1080),
217 );
218
219 let bytes = message.to_bytes().unwrap();
220 let json = String::from_utf8(bytes.clone()).unwrap();
221 assert!(json.contains("\"type\":\"frame-checksum\""));
222 assert!(json.contains("\"frameId\":7"));
223 assert!(json.contains("\"algorithm\":\"fnv1a64-rgba8\""));
224 assert!(json.contains("\"hashHex\":\"0123456789abcdef\""));
225 assert!(json.contains("\"width\":1920"));
226 assert!(json.contains("\"height\":1080"));
227
228 let decoded = ControlMessage::from_slice(&bytes).unwrap();
229 assert_eq!(decoded, message);
230 }
231
232 #[test]
233 fn parses_status_message_from_json() {
234 let decoded =
235 ControlMessage::from_slice(br#"{"type":"status","message":"ready"}"#).unwrap();
236
237 assert_eq!(decoded, ControlMessage::Status(StatusMessage::new("ready")));
238 }
239}