1use crate::error::{LightningError, Result};
2use quinn::{RecvStream, SendStream};
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5
6pub const DEFAULT_MAX_FRAME_PAYLOAD: usize = 64 * 1024 * 1024;
8const _: () = assert!(DEFAULT_MAX_FRAME_PAYLOAD >= 1_048_576);
9const _: () = assert!(DEFAULT_MAX_FRAME_PAYLOAD <= u32::MAX as usize);
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct QuicAxonInfo {
14 pub hotkey: String,
16 pub ip: String,
18 pub port: u16,
20 pub protocol: u8,
22}
23
24impl QuicAxonInfo {
25 pub fn new(hotkey: String, ip: String, port: u16, protocol: u8) -> Self {
26 Self {
27 hotkey,
28 ip,
29 port,
30 protocol,
31 }
32 }
33
34 pub fn addr_key(&self) -> PeerAddr {
35 PeerAddr::new(&self.ip, self.port)
36 }
37}
38
39#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
40pub struct PeerAddr(String);
41
42impl PeerAddr {
43 pub fn new(ip: &str, port: u16) -> Self {
44 if ip.contains(':') {
45 Self(format!("[{}]:{}", ip, port))
46 } else {
47 Self(format!("{}:{}", ip, port))
48 }
49 }
50}
51
52impl std::fmt::Display for PeerAddr {
53 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54 f.write_str(&self.0)
55 }
56}
57
58impl AsRef<str> for PeerAddr {
59 fn as_ref(&self) -> &str {
60 &self.0
61 }
62}
63
64pub(crate) fn hashmap_to_rmpv_map(data: HashMap<String, rmpv::Value>) -> rmpv::Value {
65 rmpv::Value::Map(
66 data.into_iter()
67 .map(|(k, v)| (rmpv::Value::String(k.into()), v))
68 .collect(),
69 )
70}
71
72pub(crate) fn serialize_to_rmpv_map<T: serde::Serialize>(
73 val: &T,
74) -> Result<HashMap<String, rmpv::Value>> {
75 let rmpv_val = val
76 .serialize(NamedSerializer)
77 .map_err(|e| LightningError::Serialization(e.to_string()))?;
78 match rmpv_val {
79 rmpv::Value::Map(entries) => entries
80 .into_iter()
81 .map(|(k, v)| {
82 let key = match k {
83 rmpv::Value::String(s) => s
84 .into_str()
85 .ok_or_else(|| LightningError::Serialization("non-UTF8 map key".into())),
86 other => Ok(other.to_string()),
87 };
88 key.map(|k| (k, v))
89 })
90 .collect(),
91 _ => Err(LightningError::Serialization(
92 "expected map from serialized struct".into(),
93 )),
94 }
95}
96
97pub(crate) fn handshake_request_message(
101 validator_hotkey: &str,
102 timestamp: u64,
103 nonce: &str,
104 cert_fp_b64: &str,
105) -> String {
106 format!(
107 "handshake:{}:{}:{}:{}",
108 validator_hotkey, timestamp, nonce, cert_fp_b64
109 )
110}
111
112pub(crate) fn handshake_response_message(
113 validator_hotkey: &str,
114 miner_hotkey: &str,
115 timestamp: u64,
116 nonce: &str,
117 cert_fp_b64: &str,
118) -> String {
119 format!(
120 "handshake_response:{}:{}:{}:{}:{}",
121 validator_hotkey, miner_hotkey, timestamp, nonce, cert_fp_b64
122 )
123}
124
125struct NamedSerializer;
126
127impl serde::Serializer for NamedSerializer {
128 type Ok = rmpv::Value;
129 type Error = rmpv::ext::Error;
130
131 type SerializeSeq = SerializeVec;
132 type SerializeTuple = SerializeVec;
133 type SerializeTupleStruct = SerializeVec;
134 type SerializeTupleVariant = SerializeTupleVariant;
135 type SerializeMap = SerializeMap;
136 type SerializeStruct = SerializeMap;
137 type SerializeStructVariant = SerializeStructVariant;
138
139 fn serialize_bool(self, v: bool) -> std::result::Result<rmpv::Value, Self::Error> {
140 Ok(rmpv::Value::Boolean(v))
141 }
142
143 fn serialize_i8(self, v: i8) -> std::result::Result<rmpv::Value, Self::Error> {
144 self.serialize_i64(v as i64)
145 }
146
147 fn serialize_i16(self, v: i16) -> std::result::Result<rmpv::Value, Self::Error> {
148 self.serialize_i64(v as i64)
149 }
150
151 fn serialize_i32(self, v: i32) -> std::result::Result<rmpv::Value, Self::Error> {
152 self.serialize_i64(v as i64)
153 }
154
155 fn serialize_i64(self, v: i64) -> std::result::Result<rmpv::Value, Self::Error> {
156 Ok(rmpv::Value::Integer(rmpv::Integer::from(v)))
157 }
158
159 fn serialize_u8(self, v: u8) -> std::result::Result<rmpv::Value, Self::Error> {
160 self.serialize_u64(v as u64)
161 }
162
163 fn serialize_u16(self, v: u16) -> std::result::Result<rmpv::Value, Self::Error> {
164 self.serialize_u64(v as u64)
165 }
166
167 fn serialize_u32(self, v: u32) -> std::result::Result<rmpv::Value, Self::Error> {
168 self.serialize_u64(v as u64)
169 }
170
171 fn serialize_u64(self, v: u64) -> std::result::Result<rmpv::Value, Self::Error> {
172 Ok(rmpv::Value::Integer(rmpv::Integer::from(v)))
173 }
174
175 fn serialize_f32(self, v: f32) -> std::result::Result<rmpv::Value, Self::Error> {
176 Ok(rmpv::Value::F32(v))
177 }
178
179 fn serialize_f64(self, v: f64) -> std::result::Result<rmpv::Value, Self::Error> {
180 Ok(rmpv::Value::F64(v))
181 }
182
183 fn serialize_char(self, v: char) -> std::result::Result<rmpv::Value, Self::Error> {
184 let mut s = String::new();
185 s.push(v);
186 self.serialize_str(&s)
187 }
188
189 fn serialize_str(self, v: &str) -> std::result::Result<rmpv::Value, Self::Error> {
190 Ok(rmpv::Value::String(rmpv::Utf8String::from(v)))
191 }
192
193 fn serialize_bytes(self, v: &[u8]) -> std::result::Result<rmpv::Value, Self::Error> {
194 Ok(rmpv::Value::Binary(v.to_vec()))
195 }
196
197 fn serialize_none(self) -> std::result::Result<rmpv::Value, Self::Error> {
198 Ok(rmpv::Value::Nil)
199 }
200
201 fn serialize_some<T: ?Sized + serde::Serialize>(
202 self,
203 value: &T,
204 ) -> std::result::Result<rmpv::Value, Self::Error> {
205 value.serialize(self)
206 }
207
208 fn serialize_unit(self) -> std::result::Result<rmpv::Value, Self::Error> {
209 Ok(rmpv::Value::Nil)
210 }
211
212 fn serialize_unit_struct(
213 self,
214 _name: &'static str,
215 ) -> std::result::Result<rmpv::Value, Self::Error> {
216 Ok(rmpv::Value::Nil)
217 }
218
219 fn serialize_unit_variant(
220 self,
221 _name: &'static str,
222 idx: u32,
223 _variant: &'static str,
224 ) -> std::result::Result<rmpv::Value, Self::Error> {
225 Ok(rmpv::Value::Integer(rmpv::Integer::from(idx)))
226 }
227
228 fn serialize_newtype_struct<T: ?Sized + serde::Serialize>(
229 self,
230 _name: &'static str,
231 value: &T,
232 ) -> std::result::Result<rmpv::Value, Self::Error> {
233 value.serialize(self)
234 }
235
236 fn serialize_newtype_variant<T: ?Sized + serde::Serialize>(
237 self,
238 _name: &'static str,
239 idx: u32,
240 _variant: &'static str,
241 value: &T,
242 ) -> std::result::Result<rmpv::Value, Self::Error> {
243 let inner = value.serialize(NamedSerializer)?;
244 Ok(rmpv::Value::Map(vec![(
245 rmpv::Value::Integer(rmpv::Integer::from(idx)),
246 inner,
247 )]))
248 }
249
250 fn serialize_seq(
251 self,
252 len: Option<usize>,
253 ) -> std::result::Result<Self::SerializeSeq, Self::Error> {
254 Ok(SerializeVec {
255 vec: Vec::with_capacity(len.unwrap_or(0)),
256 })
257 }
258
259 fn serialize_tuple(self, len: usize) -> std::result::Result<Self::SerializeTuple, Self::Error> {
260 self.serialize_seq(Some(len))
261 }
262
263 fn serialize_tuple_struct(
264 self,
265 _name: &'static str,
266 len: usize,
267 ) -> std::result::Result<Self::SerializeTupleStruct, Self::Error> {
268 self.serialize_seq(Some(len))
269 }
270
271 fn serialize_tuple_variant(
272 self,
273 _name: &'static str,
274 idx: u32,
275 _variant: &'static str,
276 len: usize,
277 ) -> std::result::Result<Self::SerializeTupleVariant, Self::Error> {
278 Ok(SerializeTupleVariant {
279 idx,
280 vec: Vec::with_capacity(len),
281 })
282 }
283
284 fn serialize_map(
285 self,
286 len: Option<usize>,
287 ) -> std::result::Result<Self::SerializeMap, Self::Error> {
288 Ok(SerializeMap {
289 entries: Vec::with_capacity(len.unwrap_or(0)),
290 cur_key: None,
291 })
292 }
293
294 fn serialize_struct(
295 self,
296 _name: &'static str,
297 len: usize,
298 ) -> std::result::Result<Self::SerializeStruct, Self::Error> {
299 self.serialize_map(Some(len))
300 }
301
302 fn serialize_struct_variant(
303 self,
304 _name: &'static str,
305 idx: u32,
306 _variant: &'static str,
307 len: usize,
308 ) -> std::result::Result<Self::SerializeStructVariant, Self::Error> {
309 Ok(SerializeStructVariant {
310 idx,
311 entries: Vec::with_capacity(len),
312 })
313 }
314}
315
316struct SerializeVec {
317 vec: Vec<rmpv::Value>,
318}
319
320impl serde::ser::SerializeSeq for SerializeVec {
321 type Ok = rmpv::Value;
322 type Error = rmpv::ext::Error;
323
324 fn serialize_element<T: ?Sized + serde::Serialize>(
325 &mut self,
326 value: &T,
327 ) -> std::result::Result<(), Self::Error> {
328 self.vec.push(value.serialize(NamedSerializer)?);
329 Ok(())
330 }
331
332 fn end(self) -> std::result::Result<rmpv::Value, Self::Error> {
333 Ok(rmpv::Value::Array(self.vec))
334 }
335}
336
337impl serde::ser::SerializeTuple for SerializeVec {
338 type Ok = rmpv::Value;
339 type Error = rmpv::ext::Error;
340
341 fn serialize_element<T: ?Sized + serde::Serialize>(
342 &mut self,
343 value: &T,
344 ) -> std::result::Result<(), Self::Error> {
345 serde::ser::SerializeSeq::serialize_element(self, value)
346 }
347
348 fn end(self) -> std::result::Result<rmpv::Value, Self::Error> {
349 serde::ser::SerializeSeq::end(self)
350 }
351}
352
353impl serde::ser::SerializeTupleStruct for SerializeVec {
354 type Ok = rmpv::Value;
355 type Error = rmpv::ext::Error;
356
357 fn serialize_field<T: ?Sized + serde::Serialize>(
358 &mut self,
359 value: &T,
360 ) -> std::result::Result<(), Self::Error> {
361 serde::ser::SerializeSeq::serialize_element(self, value)
362 }
363
364 fn end(self) -> std::result::Result<rmpv::Value, Self::Error> {
365 serde::ser::SerializeSeq::end(self)
366 }
367}
368
369struct SerializeTupleVariant {
370 idx: u32,
371 vec: Vec<rmpv::Value>,
372}
373
374impl serde::ser::SerializeTupleVariant for SerializeTupleVariant {
375 type Ok = rmpv::Value;
376 type Error = rmpv::ext::Error;
377
378 fn serialize_field<T: ?Sized + serde::Serialize>(
379 &mut self,
380 value: &T,
381 ) -> std::result::Result<(), Self::Error> {
382 self.vec.push(value.serialize(NamedSerializer)?);
383 Ok(())
384 }
385
386 fn end(self) -> std::result::Result<rmpv::Value, Self::Error> {
387 Ok(rmpv::Value::Map(vec![(
388 rmpv::Value::Integer(rmpv::Integer::from(self.idx)),
389 rmpv::Value::Array(self.vec),
390 )]))
391 }
392}
393
394struct SerializeMap {
395 entries: Vec<(rmpv::Value, rmpv::Value)>,
396 cur_key: Option<rmpv::Value>,
397}
398
399impl serde::ser::SerializeMap for SerializeMap {
400 type Ok = rmpv::Value;
401 type Error = rmpv::ext::Error;
402
403 fn serialize_key<T: ?Sized + serde::Serialize>(
404 &mut self,
405 key: &T,
406 ) -> std::result::Result<(), Self::Error> {
407 self.cur_key = Some(key.serialize(NamedSerializer)?);
408 Ok(())
409 }
410
411 fn serialize_value<T: ?Sized + serde::Serialize>(
412 &mut self,
413 value: &T,
414 ) -> std::result::Result<(), Self::Error> {
415 let key = self.cur_key.take().ok_or_else(|| {
416 <Self::Error as serde::ser::Error>::custom(
417 "serialize_value called before serialize_key",
418 )
419 })?;
420 self.entries.push((key, value.serialize(NamedSerializer)?));
421 Ok(())
422 }
423
424 fn end(self) -> std::result::Result<rmpv::Value, Self::Error> {
425 Ok(rmpv::Value::Map(self.entries))
426 }
427}
428
429impl serde::ser::SerializeStruct for SerializeMap {
430 type Ok = rmpv::Value;
431 type Error = rmpv::ext::Error;
432
433 fn serialize_field<T: ?Sized + serde::Serialize>(
434 &mut self,
435 key: &'static str,
436 value: &T,
437 ) -> std::result::Result<(), Self::Error> {
438 let k = rmpv::Value::String(rmpv::Utf8String::from(key));
439 let v = value.serialize(NamedSerializer)?;
440 self.entries.push((k, v));
441 Ok(())
442 }
443
444 fn end(self) -> std::result::Result<rmpv::Value, Self::Error> {
445 Ok(rmpv::Value::Map(self.entries))
446 }
447}
448
449struct SerializeStructVariant {
450 idx: u32,
451 entries: Vec<(rmpv::Value, rmpv::Value)>,
452}
453
454impl serde::ser::SerializeStructVariant for SerializeStructVariant {
455 type Ok = rmpv::Value;
456 type Error = rmpv::ext::Error;
457
458 fn serialize_field<T: ?Sized + serde::Serialize>(
459 &mut self,
460 key: &'static str,
461 value: &T,
462 ) -> std::result::Result<(), Self::Error> {
463 let k = rmpv::Value::String(rmpv::Utf8String::from(key));
464 let v = value.serialize(NamedSerializer)?;
465 self.entries.push((k, v));
466 Ok(())
467 }
468
469 fn end(self) -> std::result::Result<rmpv::Value, Self::Error> {
470 Ok(rmpv::Value::Map(vec![(
471 rmpv::Value::Integer(rmpv::Integer::from(self.idx)),
472 rmpv::Value::Map(self.entries),
473 )]))
474 }
475}
476
477#[derive(Debug, Clone, Serialize, Deserialize)]
479pub struct QuicRequest {
480 pub synapse_type: String,
482 pub data: HashMap<String, rmpv::Value>,
484}
485
486impl QuicRequest {
487 pub fn new(synapse_type: String, data: HashMap<String, rmpv::Value>) -> Self {
488 Self { synapse_type, data }
489 }
490
491 pub fn from_typed<T: serde::Serialize>(
493 synapse_type: impl Into<String>,
494 data: &T,
495 ) -> Result<Self> {
496 Ok(Self {
497 synapse_type: synapse_type.into(),
498 data: serialize_to_rmpv_map(data)?,
499 })
500 }
501}
502
503#[derive(Debug, Clone, Serialize, Deserialize)]
505pub struct QuicResponse {
506 pub success: bool,
508 pub data: HashMap<String, rmpv::Value>,
510 pub latency_ms: f64,
512 #[serde(default)]
514 pub error: Option<String>,
515}
516
517impl QuicResponse {
518 pub fn into_result(self) -> Result<Self> {
520 if self.success {
521 Ok(self)
522 } else {
523 Err(LightningError::Handler(
524 self.error.unwrap_or_else(|| "request failed".into()),
525 ))
526 }
527 }
528
529 pub fn deserialize_data<T: serde::de::DeserializeOwned>(&self) -> Result<T> {
531 let map_value = hashmap_to_rmpv_map(self.data.clone());
532 rmpv::ext::from_value(map_value).map_err(|e| LightningError::Serialization(e.to_string()))
533 }
534}
535
536#[derive(Debug, Clone, Serialize, Deserialize)]
538pub struct HandshakeRequest {
539 pub validator_hotkey: String,
541 pub timestamp: u64,
543 pub nonce: String,
545 pub signature: String,
547}
548
549#[derive(Debug, Clone, Serialize, Deserialize)]
551pub struct HandshakeResponse {
552 pub miner_hotkey: String,
554 pub timestamp: u64,
556 pub signature: String,
558 pub accepted: bool,
560 pub connection_id: String,
562 #[serde(default)]
564 pub cert_fingerprint: Option<String>,
565}
566
567#[derive(Debug, Clone, Serialize, Deserialize)]
569pub struct SynapsePacket {
570 pub synapse_type: String,
571 pub data: HashMap<String, rmpv::Value>,
572 pub timestamp: u64,
573}
574
575#[derive(Debug, Clone, Serialize, Deserialize)]
577pub struct SynapseResponse {
578 pub success: bool,
579 pub data: HashMap<String, rmpv::Value>,
580 pub timestamp: u64,
581 pub error: Option<String>,
582}
583
584#[derive(Debug, Clone, Serialize, Deserialize)]
585pub struct StreamChunk {
586 #[serde(with = "serde_bytes")]
587 pub data: Vec<u8>,
588}
589
590#[derive(Debug, Clone, Serialize, Deserialize)]
591pub struct StreamEnd {
592 pub success: bool,
593 pub error: Option<String>,
594}
595
596#[repr(u8)]
598#[derive(Debug, Clone, Copy, PartialEq)]
599pub enum MessageType {
600 HandshakeRequest = 0x01,
601 HandshakeResponse = 0x02,
602 SynapsePacket = 0x03,
603 SynapseResponse = 0x04,
604 StreamChunk = 0x05,
605 StreamEnd = 0x06,
606}
607
608impl TryFrom<u8> for MessageType {
609 type Error = LightningError;
610
611 fn try_from(value: u8) -> std::result::Result<Self, Self::Error> {
612 match value {
613 0x01 => Ok(MessageType::HandshakeRequest),
614 0x02 => Ok(MessageType::HandshakeResponse),
615 0x03 => Ok(MessageType::SynapsePacket),
616 0x04 => Ok(MessageType::SynapseResponse),
617 0x05 => Ok(MessageType::StreamChunk),
618 0x06 => Ok(MessageType::StreamEnd),
619 _ => Err(LightningError::Transport(format!(
620 "unknown message type: 0x{:02x}",
621 value
622 ))),
623 }
624 }
625}
626
627const FRAME_HEADER_SIZE: usize = 5;
628
629async fn read_exact_from_recv(recv: &mut RecvStream, buf: &mut [u8]) -> Result<()> {
630 let mut offset = 0;
631 while offset < buf.len() {
632 match recv.read(&mut buf[offset..]).await {
633 Ok(Some(n)) => offset += n,
634 Ok(None) => {
635 return Err(LightningError::Transport(format!(
636 "stream closed after {} of {} bytes",
637 offset,
638 buf.len()
639 )));
640 }
641 Err(e) => {
642 return Err(LightningError::Transport(format!("read error: {}", e)));
643 }
644 }
645 }
646 Ok(())
647}
648
649const INCREMENTAL_READ_THRESHOLD: usize = 1_048_576;
650const READ_CHUNK_SIZE: usize = 65_536;
651
652pub fn parse_frame_header(data: &[u8], max_payload: usize) -> Result<(MessageType, &[u8])> {
655 if data.len() < FRAME_HEADER_SIZE {
656 return Err(LightningError::Transport(
657 "insufficient data for frame header".to_string(),
658 ));
659 }
660 let msg_type = MessageType::try_from(data[0])?;
661 let payload_len = u32::from_be_bytes([data[1], data[2], data[3], data[4]]) as usize;
662 if payload_len > max_payload {
663 return Err(LightningError::Transport(format!(
664 "frame payload {} bytes exceeds maximum {}",
665 payload_len, max_payload
666 )));
667 }
668 if data.len() < FRAME_HEADER_SIZE + payload_len {
669 return Err(LightningError::Transport(format!(
670 "insufficient data for frame payload: have {}, need {}",
671 data.len() - FRAME_HEADER_SIZE,
672 payload_len
673 )));
674 }
675 Ok((
676 msg_type,
677 &data[FRAME_HEADER_SIZE..FRAME_HEADER_SIZE + payload_len],
678 ))
679}
680
681pub async fn read_frame(
682 recv: &mut RecvStream,
683 max_payload: usize,
684) -> Result<(MessageType, Vec<u8>)> {
685 let mut header = [0u8; FRAME_HEADER_SIZE];
686 read_exact_from_recv(recv, &mut header).await?;
687
688 let msg_type = MessageType::try_from(header[0])?;
689 let payload_len = u32::from_be_bytes([header[1], header[2], header[3], header[4]]) as usize;
690
691 if payload_len > max_payload {
692 return Err(LightningError::Transport(format!(
693 "frame payload {} bytes exceeds maximum {}",
694 payload_len, max_payload
695 )));
696 }
697
698 if payload_len == 0 {
699 return Ok((msg_type, Vec::new()));
700 }
701
702 if payload_len <= INCREMENTAL_READ_THRESHOLD {
706 let mut payload = vec![0u8; payload_len];
707 read_exact_from_recv(recv, &mut payload).await?;
708 return Ok((msg_type, payload));
709 }
710
711 let mut payload = Vec::with_capacity(INCREMENTAL_READ_THRESHOLD);
712 let mut remaining = payload_len;
713 while remaining > 0 {
714 let next_capacity = payload
715 .capacity()
716 .saturating_mul(2)
717 .max(INCREMENTAL_READ_THRESHOLD)
718 .min(payload_len)
719 .min(max_payload);
720 if payload.capacity() < next_capacity {
721 payload.reserve(next_capacity - payload.len());
722 }
723 let chunk_size = remaining.min(READ_CHUNK_SIZE);
724 let start = payload.len();
725 payload.resize(start + chunk_size, 0);
726 read_exact_from_recv(recv, &mut payload[start..]).await?;
727 remaining -= chunk_size;
728 }
729 Ok((msg_type, payload))
730}
731
732pub async fn write_frame(
733 send: &mut SendStream,
734 msg_type: MessageType,
735 payload: &[u8],
736) -> Result<()> {
737 let payload_len: u32 = payload.len().try_into().map_err(|_| {
738 LightningError::Transport(format!(
739 "frame payload {} bytes exceeds u32::MAX",
740 payload.len()
741 ))
742 })?;
743
744 let mut header = [0u8; FRAME_HEADER_SIZE];
745 header[0] = msg_type as u8;
746 header[1..5].copy_from_slice(&payload_len.to_be_bytes());
747
748 send.write_all(&header)
749 .await
750 .map_err(|e| LightningError::Transport(format!("failed to write frame header: {}", e)))?;
751 if !payload.is_empty() {
752 send.write_all(payload).await.map_err(|e| {
753 LightningError::Transport(format!("failed to write frame payload: {}", e))
754 })?;
755 }
756 Ok(())
757}
758
759pub async fn write_frame_and_finish(
760 send: &mut SendStream,
761 msg_type: MessageType,
762 payload: &[u8],
763) -> Result<()> {
764 write_frame(send, msg_type, payload).await?;
765 send.finish()
766 .map_err(|e| LightningError::Transport(format!("failed to finish stream: {}", e)))?;
767 Ok(())
768}
769
770#[cfg(test)]
771mod tests {
772 use super::*;
773
774 #[test]
775 fn quic_request_from_typed_serializes_struct() {
776 #[derive(serde::Serialize)]
777 struct MyReq {
778 name: String,
779 count: u32,
780 }
781
782 let req = QuicRequest::from_typed(
783 "test_synapse",
784 &MyReq {
785 name: "hello".into(),
786 count: 42,
787 },
788 )
789 .unwrap();
790
791 assert_eq!(req.synapse_type, "test_synapse");
792 assert_eq!(
793 req.data.get("name").unwrap(),
794 &rmpv::Value::String("hello".into())
795 );
796 assert_eq!(
797 req.data.get("count").unwrap(),
798 &rmpv::Value::Integer(42.into())
799 );
800 }
801
802 #[test]
803 fn quic_response_into_result_ok_on_success() {
804 let resp = QuicResponse {
805 success: true,
806 data: HashMap::new(),
807 latency_ms: 1.0,
808 error: None,
809 };
810 assert!(resp.into_result().is_ok());
811 }
812
813 #[test]
814 fn quic_response_into_result_err_on_failure() {
815 let resp = QuicResponse {
816 success: false,
817 data: HashMap::new(),
818 latency_ms: 1.0,
819 error: Some("bad request".into()),
820 };
821 let err = resp.into_result().unwrap_err();
822 assert!(err.to_string().contains("bad request"));
823 }
824
825 #[test]
826 fn quic_response_into_result_uses_default_message() {
827 let resp = QuicResponse {
828 success: false,
829 data: HashMap::new(),
830 latency_ms: 1.0,
831 error: None,
832 };
833 let err = resp.into_result().unwrap_err();
834 assert!(err.to_string().contains("request failed"));
835 }
836
837 #[test]
838 fn quic_response_deserialize_data_roundtrips() {
839 #[derive(serde::Serialize, serde::Deserialize, Debug, PartialEq)]
840 struct MyResp {
841 value: i32,
842 label: String,
843 }
844
845 let original = MyResp {
846 value: 99,
847 label: "test".into(),
848 };
849
850 let data = serialize_to_rmpv_map(&original).unwrap();
851
852 let resp = QuicResponse {
853 success: true,
854 data,
855 latency_ms: 1.0,
856 error: None,
857 };
858
859 let deserialized: MyResp = resp.deserialize_data().unwrap();
860 assert_eq!(deserialized, original);
861 }
862
863 #[test]
864 fn serialize_to_rmpv_map_handles_nested_structs() {
865 #[derive(serde::Serialize, serde::Deserialize, Debug, PartialEq)]
866 struct Inner {
867 x: i32,
868 y: String,
869 }
870
871 #[derive(serde::Serialize, serde::Deserialize, Debug, PartialEq)]
872 struct Outer {
873 name: String,
874 inner: Inner,
875 values: Vec<u32>,
876 }
877
878 let original = Outer {
879 name: "test".into(),
880 inner: Inner {
881 x: 42,
882 y: "nested".into(),
883 },
884 values: vec![1, 2, 3],
885 };
886
887 let map = serialize_to_rmpv_map(&original).unwrap();
888 assert_eq!(
889 map.get("name").unwrap(),
890 &rmpv::Value::String("test".into())
891 );
892 assert!(matches!(map.get("inner").unwrap(), rmpv::Value::Map(_)));
893 assert!(matches!(map.get("values").unwrap(), rmpv::Value::Array(_)));
894
895 let resp = QuicResponse {
896 success: true,
897 data: map,
898 latency_ms: 0.0,
899 error: None,
900 };
901 let deserialized: Outer = resp.deserialize_data().unwrap();
902 assert_eq!(deserialized, original);
903 }
904
905 #[test]
906 fn handshake_request_message_format() {
907 let msg = handshake_request_message("5GrwvaEF", 1234567890, "abc123", "fp_b64");
908 assert_eq!(msg, "handshake:5GrwvaEF:1234567890:abc123:fp_b64");
909 }
910
911 #[test]
912 fn handshake_response_message_format() {
913 let msg =
914 handshake_response_message("5GrwvaEF", "5FHneW46", 1234567890, "abc123", "fp_b64");
915 assert_eq!(
916 msg,
917 "handshake_response:5GrwvaEF:5FHneW46:1234567890:abc123:fp_b64"
918 );
919 }
920
921 #[test]
922 fn parse_frame_header_valid() {
923 let mut data = vec![0x01];
924 data.extend_from_slice(&5u32.to_be_bytes());
925 data.extend_from_slice(b"hello");
926 let (msg_type, payload) = parse_frame_header(&data, DEFAULT_MAX_FRAME_PAYLOAD).unwrap();
927 assert_eq!(msg_type, MessageType::HandshakeRequest);
928 assert_eq!(payload, b"hello");
929 }
930
931 #[test]
932 fn parse_frame_header_insufficient_header() {
933 assert!(parse_frame_header(&[0x01, 0x00], DEFAULT_MAX_FRAME_PAYLOAD).is_err());
934 }
935
936 #[test]
937 fn parse_frame_header_insufficient_payload() {
938 let mut data = vec![0x01];
939 data.extend_from_slice(&10u32.to_be_bytes());
940 data.extend_from_slice(b"short");
941 assert!(parse_frame_header(&data, DEFAULT_MAX_FRAME_PAYLOAD).is_err());
942 }
943
944 #[test]
945 fn parse_frame_header_oversized_payload() {
946 let mut data = vec![0x01];
947 data.extend_from_slice(&(DEFAULT_MAX_FRAME_PAYLOAD as u32 + 1).to_be_bytes());
948 assert!(parse_frame_header(&data, DEFAULT_MAX_FRAME_PAYLOAD).is_err());
949 }
950
951 #[test]
952 fn parse_frame_header_invalid_message_type() {
953 let mut data = vec![0xFF];
954 data.extend_from_slice(&0u32.to_be_bytes());
955 assert!(parse_frame_header(&data, DEFAULT_MAX_FRAME_PAYLOAD).is_err());
956 }
957
958 use proptest::prelude::*;
959
960 fn arb_rmpv_leaf() -> impl Strategy<Value = rmpv::Value> {
961 prop_oneof![
962 Just(rmpv::Value::Nil),
963 any::<bool>().prop_map(rmpv::Value::Boolean),
964 any::<i64>().prop_map(|v| rmpv::Value::Integer(rmpv::Integer::from(v))),
965 any::<u64>().prop_map(|v| rmpv::Value::Integer(rmpv::Integer::from(v))),
966 any::<f32>().prop_map(rmpv::Value::F32),
967 any::<f64>().prop_map(rmpv::Value::F64),
968 "[a-zA-Z0-9_ ]{0,32}"
969 .prop_map(|s| rmpv::Value::String(rmpv::Utf8String::from(s.as_str()))),
970 proptest::collection::vec(any::<u8>(), 0..64).prop_map(rmpv::Value::Binary),
971 ]
972 }
973
974 fn arb_rmpv_value() -> impl Strategy<Value = rmpv::Value> {
975 arb_rmpv_leaf().prop_recursive(3, 32, 8, |inner| {
976 prop_oneof![
977 proptest::collection::vec(inner.clone(), 0..8).prop_map(rmpv::Value::Array),
978 proptest::collection::vec((inner.clone(), inner), 0..8).prop_map(rmpv::Value::Map),
979 ]
980 })
981 }
982
983 proptest! {
984 #[test]
985 fn msgpack_encode_decode_roundtrip(value in arb_rmpv_value()) {
986 let bytes = rmp_serde::to_vec(&value).unwrap();
987 let decoded: rmpv::Value = rmp_serde::from_slice(&bytes).unwrap();
988 prop_assert_eq!(value, decoded);
989 }
990
991 #[test]
992 fn serialize_to_rmpv_map_roundtrip(
993 keys in proptest::collection::vec("[a-z]{1,8}", 1..8),
994 vals in proptest::collection::vec(arb_rmpv_leaf(), 1..8),
995 ) {
996 let mut map = HashMap::new();
997 for (k, v) in keys.into_iter().zip(vals.into_iter()) {
998 map.insert(k, v);
999 }
1000 let rmpv_map = hashmap_to_rmpv_map(map);
1001 let bytes = rmp_serde::to_vec(&rmpv_map).unwrap();
1002 let decoded: rmpv::Value = rmp_serde::from_slice(&bytes).unwrap();
1003 prop_assert_eq!(rmpv_map, decoded);
1004 }
1005
1006 #[test]
1007 fn from_typed_roundtrip(
1008 name in "[a-zA-Z]{1,16}",
1009 count in any::<u32>(),
1010 label in "[a-zA-Z0-9 ]{0,32}",
1011 ) {
1012 #[derive(serde::Serialize, serde::Deserialize, Debug, PartialEq)]
1013 struct TestStruct {
1014 name: String,
1015 count: u32,
1016 label: String,
1017 }
1018
1019 let original = TestStruct {
1020 name: name.clone(),
1021 count,
1022 label: label.clone(),
1023 };
1024
1025 let req = QuicRequest::from_typed("test", &original).unwrap();
1026 let resp = QuicResponse {
1027 success: true,
1028 data: req.data,
1029 latency_ms: 0.0,
1030 error: None,
1031 };
1032 let deserialized: TestStruct = resp.deserialize_data().unwrap();
1033 prop_assert_eq!(original, deserialized);
1034 }
1035
1036 #[test]
1037 fn parse_frame_header_never_panics(data in proptest::collection::vec(any::<u8>(), 0..256)) {
1038 let _ = parse_frame_header(&data, DEFAULT_MAX_FRAME_PAYLOAD);
1039 }
1040 }
1041}