Skip to main content

btlightning/
types.rs

1use crate::error::{LightningError, Result};
2use quinn::{RecvStream, SendStream};
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5
6/// Default maximum frame payload size (64 MiB).
7pub const DEFAULT_MAX_FRAME_PAYLOAD: usize = 64 * 1024 * 1024;
8const _: () = assert!(DEFAULT_MAX_FRAME_PAYLOAD >= 1_048_576);
9const _: () = assert!(DEFAULT_MAX_FRAME_PAYLOAD <= u32::MAX as usize);
10
11/// Network address and identity of a Bittensor miner's QUIC axon endpoint.
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct QuicAxonInfo {
14    /// SS58-encoded sr25519 public key of the miner.
15    pub hotkey: String,
16    /// IPv4 address the miner advertises on-chain.
17    pub ip: String,
18    /// QUIC port the miner listens on.
19    pub port: u16,
20    /// Axon protocol identifier (4 = QUIC).
21    pub protocol: u8,
22}
23
24impl QuicAxonInfo {
25    pub fn new(hotkey: String, ip: String, port: u16, protocol: u8) -> Self {
26        Self {
27            hotkey,
28            ip,
29            port,
30            protocol,
31        }
32    }
33
34    pub fn addr_key(&self) -> PeerAddr {
35        PeerAddr::new(&self.ip, self.port)
36    }
37}
38
39#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
40pub struct PeerAddr(String);
41
42impl PeerAddr {
43    pub fn new(ip: &str, port: u16) -> Self {
44        if ip.contains(':') {
45            Self(format!("[{}]:{}", ip, port))
46        } else {
47            Self(format!("{}:{}", ip, port))
48        }
49    }
50}
51
52impl std::fmt::Display for PeerAddr {
53    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54        f.write_str(&self.0)
55    }
56}
57
58impl AsRef<str> for PeerAddr {
59    fn as_ref(&self) -> &str {
60        &self.0
61    }
62}
63
64pub(crate) fn hashmap_to_rmpv_map(data: HashMap<String, rmpv::Value>) -> rmpv::Value {
65    rmpv::Value::Map(
66        data.into_iter()
67            .map(|(k, v)| (rmpv::Value::String(k.into()), v))
68            .collect(),
69    )
70}
71
72pub(crate) fn serialize_to_rmpv_map<T: serde::Serialize>(
73    val: &T,
74) -> Result<HashMap<String, rmpv::Value>> {
75    let rmpv_val = val
76        .serialize(NamedSerializer)
77        .map_err(|e| LightningError::Serialization(e.to_string()))?;
78    match rmpv_val {
79        rmpv::Value::Map(entries) => entries
80            .into_iter()
81            .map(|(k, v)| {
82                let key = match k {
83                    rmpv::Value::String(s) => s
84                        .into_str()
85                        .ok_or_else(|| LightningError::Serialization("non-UTF8 map key".into())),
86                    other => Ok(other.to_string()),
87                };
88                key.map(|k| (k, v))
89            })
90            .collect(),
91        _ => Err(LightningError::Serialization(
92            "expected map from serialized struct".into(),
93        )),
94    }
95}
96
97// ":" is safe as delimiter: SS58 hotkeys are base58, nonces are hex, cert fingerprints
98// are base64 (standard alphabet uses +/= but not ":"), and timestamps are numeric.
99// If any field format changes to permit ":", switch to a structured encoding.
100pub(crate) fn handshake_request_message(
101    validator_hotkey: &str,
102    timestamp: u64,
103    nonce: &str,
104    cert_fp_b64: &str,
105) -> String {
106    format!(
107        "handshake:{}:{}:{}:{}",
108        validator_hotkey, timestamp, nonce, cert_fp_b64
109    )
110}
111
112pub(crate) fn handshake_response_message(
113    validator_hotkey: &str,
114    miner_hotkey: &str,
115    timestamp: u64,
116    nonce: &str,
117    cert_fp_b64: &str,
118) -> String {
119    format!(
120        "handshake_response:{}:{}:{}:{}:{}",
121        validator_hotkey, miner_hotkey, timestamp, nonce, cert_fp_b64
122    )
123}
124
125struct NamedSerializer;
126
127impl serde::Serializer for NamedSerializer {
128    type Ok = rmpv::Value;
129    type Error = rmpv::ext::Error;
130
131    type SerializeSeq = SerializeVec;
132    type SerializeTuple = SerializeVec;
133    type SerializeTupleStruct = SerializeVec;
134    type SerializeTupleVariant = SerializeTupleVariant;
135    type SerializeMap = SerializeMap;
136    type SerializeStruct = SerializeMap;
137    type SerializeStructVariant = SerializeStructVariant;
138
139    fn serialize_bool(self, v: bool) -> std::result::Result<rmpv::Value, Self::Error> {
140        Ok(rmpv::Value::Boolean(v))
141    }
142
143    fn serialize_i8(self, v: i8) -> std::result::Result<rmpv::Value, Self::Error> {
144        self.serialize_i64(v as i64)
145    }
146
147    fn serialize_i16(self, v: i16) -> std::result::Result<rmpv::Value, Self::Error> {
148        self.serialize_i64(v as i64)
149    }
150
151    fn serialize_i32(self, v: i32) -> std::result::Result<rmpv::Value, Self::Error> {
152        self.serialize_i64(v as i64)
153    }
154
155    fn serialize_i64(self, v: i64) -> std::result::Result<rmpv::Value, Self::Error> {
156        Ok(rmpv::Value::Integer(rmpv::Integer::from(v)))
157    }
158
159    fn serialize_u8(self, v: u8) -> std::result::Result<rmpv::Value, Self::Error> {
160        self.serialize_u64(v as u64)
161    }
162
163    fn serialize_u16(self, v: u16) -> std::result::Result<rmpv::Value, Self::Error> {
164        self.serialize_u64(v as u64)
165    }
166
167    fn serialize_u32(self, v: u32) -> std::result::Result<rmpv::Value, Self::Error> {
168        self.serialize_u64(v as u64)
169    }
170
171    fn serialize_u64(self, v: u64) -> std::result::Result<rmpv::Value, Self::Error> {
172        Ok(rmpv::Value::Integer(rmpv::Integer::from(v)))
173    }
174
175    fn serialize_f32(self, v: f32) -> std::result::Result<rmpv::Value, Self::Error> {
176        Ok(rmpv::Value::F32(v))
177    }
178
179    fn serialize_f64(self, v: f64) -> std::result::Result<rmpv::Value, Self::Error> {
180        Ok(rmpv::Value::F64(v))
181    }
182
183    fn serialize_char(self, v: char) -> std::result::Result<rmpv::Value, Self::Error> {
184        let mut s = String::new();
185        s.push(v);
186        self.serialize_str(&s)
187    }
188
189    fn serialize_str(self, v: &str) -> std::result::Result<rmpv::Value, Self::Error> {
190        Ok(rmpv::Value::String(rmpv::Utf8String::from(v)))
191    }
192
193    fn serialize_bytes(self, v: &[u8]) -> std::result::Result<rmpv::Value, Self::Error> {
194        Ok(rmpv::Value::Binary(v.to_vec()))
195    }
196
197    fn serialize_none(self) -> std::result::Result<rmpv::Value, Self::Error> {
198        Ok(rmpv::Value::Nil)
199    }
200
201    fn serialize_some<T: ?Sized + serde::Serialize>(
202        self,
203        value: &T,
204    ) -> std::result::Result<rmpv::Value, Self::Error> {
205        value.serialize(self)
206    }
207
208    fn serialize_unit(self) -> std::result::Result<rmpv::Value, Self::Error> {
209        Ok(rmpv::Value::Nil)
210    }
211
212    fn serialize_unit_struct(
213        self,
214        _name: &'static str,
215    ) -> std::result::Result<rmpv::Value, Self::Error> {
216        Ok(rmpv::Value::Nil)
217    }
218
219    fn serialize_unit_variant(
220        self,
221        _name: &'static str,
222        idx: u32,
223        _variant: &'static str,
224    ) -> std::result::Result<rmpv::Value, Self::Error> {
225        Ok(rmpv::Value::Integer(rmpv::Integer::from(idx)))
226    }
227
228    fn serialize_newtype_struct<T: ?Sized + serde::Serialize>(
229        self,
230        _name: &'static str,
231        value: &T,
232    ) -> std::result::Result<rmpv::Value, Self::Error> {
233        value.serialize(self)
234    }
235
236    fn serialize_newtype_variant<T: ?Sized + serde::Serialize>(
237        self,
238        _name: &'static str,
239        idx: u32,
240        _variant: &'static str,
241        value: &T,
242    ) -> std::result::Result<rmpv::Value, Self::Error> {
243        let inner = value.serialize(NamedSerializer)?;
244        Ok(rmpv::Value::Map(vec![(
245            rmpv::Value::Integer(rmpv::Integer::from(idx)),
246            inner,
247        )]))
248    }
249
250    fn serialize_seq(
251        self,
252        len: Option<usize>,
253    ) -> std::result::Result<Self::SerializeSeq, Self::Error> {
254        Ok(SerializeVec {
255            vec: Vec::with_capacity(len.unwrap_or(0)),
256        })
257    }
258
259    fn serialize_tuple(self, len: usize) -> std::result::Result<Self::SerializeTuple, Self::Error> {
260        self.serialize_seq(Some(len))
261    }
262
263    fn serialize_tuple_struct(
264        self,
265        _name: &'static str,
266        len: usize,
267    ) -> std::result::Result<Self::SerializeTupleStruct, Self::Error> {
268        self.serialize_seq(Some(len))
269    }
270
271    fn serialize_tuple_variant(
272        self,
273        _name: &'static str,
274        idx: u32,
275        _variant: &'static str,
276        len: usize,
277    ) -> std::result::Result<Self::SerializeTupleVariant, Self::Error> {
278        Ok(SerializeTupleVariant {
279            idx,
280            vec: Vec::with_capacity(len),
281        })
282    }
283
284    fn serialize_map(
285        self,
286        len: Option<usize>,
287    ) -> std::result::Result<Self::SerializeMap, Self::Error> {
288        Ok(SerializeMap {
289            entries: Vec::with_capacity(len.unwrap_or(0)),
290            cur_key: None,
291        })
292    }
293
294    fn serialize_struct(
295        self,
296        _name: &'static str,
297        len: usize,
298    ) -> std::result::Result<Self::SerializeStruct, Self::Error> {
299        self.serialize_map(Some(len))
300    }
301
302    fn serialize_struct_variant(
303        self,
304        _name: &'static str,
305        idx: u32,
306        _variant: &'static str,
307        len: usize,
308    ) -> std::result::Result<Self::SerializeStructVariant, Self::Error> {
309        Ok(SerializeStructVariant {
310            idx,
311            entries: Vec::with_capacity(len),
312        })
313    }
314}
315
316struct SerializeVec {
317    vec: Vec<rmpv::Value>,
318}
319
320impl serde::ser::SerializeSeq for SerializeVec {
321    type Ok = rmpv::Value;
322    type Error = rmpv::ext::Error;
323
324    fn serialize_element<T: ?Sized + serde::Serialize>(
325        &mut self,
326        value: &T,
327    ) -> std::result::Result<(), Self::Error> {
328        self.vec.push(value.serialize(NamedSerializer)?);
329        Ok(())
330    }
331
332    fn end(self) -> std::result::Result<rmpv::Value, Self::Error> {
333        Ok(rmpv::Value::Array(self.vec))
334    }
335}
336
337impl serde::ser::SerializeTuple for SerializeVec {
338    type Ok = rmpv::Value;
339    type Error = rmpv::ext::Error;
340
341    fn serialize_element<T: ?Sized + serde::Serialize>(
342        &mut self,
343        value: &T,
344    ) -> std::result::Result<(), Self::Error> {
345        serde::ser::SerializeSeq::serialize_element(self, value)
346    }
347
348    fn end(self) -> std::result::Result<rmpv::Value, Self::Error> {
349        serde::ser::SerializeSeq::end(self)
350    }
351}
352
353impl serde::ser::SerializeTupleStruct for SerializeVec {
354    type Ok = rmpv::Value;
355    type Error = rmpv::ext::Error;
356
357    fn serialize_field<T: ?Sized + serde::Serialize>(
358        &mut self,
359        value: &T,
360    ) -> std::result::Result<(), Self::Error> {
361        serde::ser::SerializeSeq::serialize_element(self, value)
362    }
363
364    fn end(self) -> std::result::Result<rmpv::Value, Self::Error> {
365        serde::ser::SerializeSeq::end(self)
366    }
367}
368
369struct SerializeTupleVariant {
370    idx: u32,
371    vec: Vec<rmpv::Value>,
372}
373
374impl serde::ser::SerializeTupleVariant for SerializeTupleVariant {
375    type Ok = rmpv::Value;
376    type Error = rmpv::ext::Error;
377
378    fn serialize_field<T: ?Sized + serde::Serialize>(
379        &mut self,
380        value: &T,
381    ) -> std::result::Result<(), Self::Error> {
382        self.vec.push(value.serialize(NamedSerializer)?);
383        Ok(())
384    }
385
386    fn end(self) -> std::result::Result<rmpv::Value, Self::Error> {
387        Ok(rmpv::Value::Map(vec![(
388            rmpv::Value::Integer(rmpv::Integer::from(self.idx)),
389            rmpv::Value::Array(self.vec),
390        )]))
391    }
392}
393
394struct SerializeMap {
395    entries: Vec<(rmpv::Value, rmpv::Value)>,
396    cur_key: Option<rmpv::Value>,
397}
398
399impl serde::ser::SerializeMap for SerializeMap {
400    type Ok = rmpv::Value;
401    type Error = rmpv::ext::Error;
402
403    fn serialize_key<T: ?Sized + serde::Serialize>(
404        &mut self,
405        key: &T,
406    ) -> std::result::Result<(), Self::Error> {
407        self.cur_key = Some(key.serialize(NamedSerializer)?);
408        Ok(())
409    }
410
411    fn serialize_value<T: ?Sized + serde::Serialize>(
412        &mut self,
413        value: &T,
414    ) -> std::result::Result<(), Self::Error> {
415        let key = self.cur_key.take().ok_or_else(|| {
416            <Self::Error as serde::ser::Error>::custom(
417                "serialize_value called before serialize_key",
418            )
419        })?;
420        self.entries.push((key, value.serialize(NamedSerializer)?));
421        Ok(())
422    }
423
424    fn end(self) -> std::result::Result<rmpv::Value, Self::Error> {
425        Ok(rmpv::Value::Map(self.entries))
426    }
427}
428
429impl serde::ser::SerializeStruct for SerializeMap {
430    type Ok = rmpv::Value;
431    type Error = rmpv::ext::Error;
432
433    fn serialize_field<T: ?Sized + serde::Serialize>(
434        &mut self,
435        key: &'static str,
436        value: &T,
437    ) -> std::result::Result<(), Self::Error> {
438        let k = rmpv::Value::String(rmpv::Utf8String::from(key));
439        let v = value.serialize(NamedSerializer)?;
440        self.entries.push((k, v));
441        Ok(())
442    }
443
444    fn end(self) -> std::result::Result<rmpv::Value, Self::Error> {
445        Ok(rmpv::Value::Map(self.entries))
446    }
447}
448
449struct SerializeStructVariant {
450    idx: u32,
451    entries: Vec<(rmpv::Value, rmpv::Value)>,
452}
453
454impl serde::ser::SerializeStructVariant for SerializeStructVariant {
455    type Ok = rmpv::Value;
456    type Error = rmpv::ext::Error;
457
458    fn serialize_field<T: ?Sized + serde::Serialize>(
459        &mut self,
460        key: &'static str,
461        value: &T,
462    ) -> std::result::Result<(), Self::Error> {
463        let k = rmpv::Value::String(rmpv::Utf8String::from(key));
464        let v = value.serialize(NamedSerializer)?;
465        self.entries.push((k, v));
466        Ok(())
467    }
468
469    fn end(self) -> std::result::Result<rmpv::Value, Self::Error> {
470        Ok(rmpv::Value::Map(vec![(
471            rmpv::Value::Integer(rmpv::Integer::from(self.idx)),
472            rmpv::Value::Map(self.entries),
473        )]))
474    }
475}
476
477/// Client-side request sent to a miner via [`LightningClient::query_axon`](crate::LightningClient::query_axon).
478#[derive(Debug, Clone, Serialize, Deserialize)]
479pub struct QuicRequest {
480    /// Handler name the server dispatches on (e.g. `"MyQuery"`).
481    pub synapse_type: String,
482    /// Arbitrary MessagePack key-value payload deserialized by the handler.
483    pub data: HashMap<String, rmpv::Value>,
484}
485
486impl QuicRequest {
487    pub fn new(synapse_type: String, data: HashMap<String, rmpv::Value>) -> Self {
488        Self { synapse_type, data }
489    }
490
491    /// Constructs a request by serializing a typed struct into the data map.
492    pub fn from_typed<T: serde::Serialize>(
493        synapse_type: impl Into<String>,
494        data: &T,
495    ) -> Result<Self> {
496        Ok(Self {
497            synapse_type: synapse_type.into(),
498            data: serialize_to_rmpv_map(data)?,
499        })
500    }
501}
502
503/// Response returned from a miner after processing a synapse request.
504#[derive(Debug, Clone, Serialize, Deserialize)]
505pub struct QuicResponse {
506    /// Whether the handler completed without error.
507    pub success: bool,
508    /// Handler return payload as MessagePack key-value pairs.
509    pub data: HashMap<String, rmpv::Value>,
510    /// Client-measured round-trip latency in milliseconds.
511    pub latency_ms: f64,
512    /// Error message when `success` is false.
513    #[serde(default)]
514    pub error: Option<String>,
515}
516
517impl QuicResponse {
518    /// Converts to `Result`, returning `Err` when `success` is false.
519    pub fn into_result(self) -> Result<Self> {
520        if self.success {
521            Ok(self)
522        } else {
523            Err(LightningError::Handler(
524                self.error.unwrap_or_else(|| "request failed".into()),
525            ))
526        }
527    }
528
529    /// Deserializes the `data` map into a typed struct `T`.
530    pub fn deserialize_data<T: serde::de::DeserializeOwned>(&self) -> Result<T> {
531        let map_value = hashmap_to_rmpv_map(self.data.clone());
532        rmpv::ext::from_value(map_value).map_err(|e| LightningError::Serialization(e.to_string()))
533    }
534}
535
536/// Validator-to-miner handshake initiation sent over a new QUIC stream.
537#[derive(Debug, Clone, Serialize, Deserialize)]
538pub struct HandshakeRequest {
539    /// SS58 hotkey of the connecting validator.
540    pub validator_hotkey: String,
541    /// UNIX epoch seconds at signing time.
542    pub timestamp: u64,
543    /// 128-bit hex-encoded cryptographic nonce (replay protection).
544    pub nonce: String,
545    /// Base64-encoded sr25519 signature over `handshake:<hotkey>:<ts>:<nonce>:<cert_fp>`.
546    pub signature: String,
547}
548
549/// Miner-to-validator handshake reply confirming or rejecting authentication.
550#[derive(Debug, Clone, Serialize, Deserialize)]
551pub struct HandshakeResponse {
552    /// SS58 hotkey of the miner.
553    pub miner_hotkey: String,
554    /// UNIX epoch seconds at signing time.
555    pub timestamp: u64,
556    /// Base64-encoded sr25519 signature over the response message.
557    pub signature: String,
558    /// Whether the miner accepted the validator's handshake.
559    pub accepted: bool,
560    /// Opaque identifier for this authenticated connection.
561    pub connection_id: String,
562    /// Base64-encoded BLAKE2-256 hash of the TLS certificate DER bytes.
563    #[serde(default)]
564    pub cert_fingerprint: Option<String>,
565}
566
567/// Wire-level synapse request frame sent after handshake authentication.
568#[derive(Debug, Clone, Serialize, Deserialize)]
569pub struct SynapsePacket {
570    pub synapse_type: String,
571    pub data: HashMap<String, rmpv::Value>,
572    pub timestamp: u64,
573}
574
575/// Wire-level synapse response frame returned by the miner handler.
576#[derive(Debug, Clone, Serialize, Deserialize)]
577pub struct SynapseResponse {
578    pub success: bool,
579    pub data: HashMap<String, rmpv::Value>,
580    pub timestamp: u64,
581    pub error: Option<String>,
582}
583
584#[derive(Debug, Clone, Serialize, Deserialize)]
585pub struct StreamChunk {
586    #[serde(with = "serde_bytes")]
587    pub data: Vec<u8>,
588}
589
590#[derive(Debug, Clone, Serialize, Deserialize)]
591pub struct StreamEnd {
592    pub success: bool,
593    pub error: Option<String>,
594}
595
596/// Single-byte discriminant in the 5-byte frame header identifying the payload kind.
597#[repr(u8)]
598#[derive(Debug, Clone, Copy, PartialEq)]
599pub enum MessageType {
600    HandshakeRequest = 0x01,
601    HandshakeResponse = 0x02,
602    SynapsePacket = 0x03,
603    SynapseResponse = 0x04,
604    StreamChunk = 0x05,
605    StreamEnd = 0x06,
606}
607
608impl TryFrom<u8> for MessageType {
609    type Error = LightningError;
610
611    fn try_from(value: u8) -> std::result::Result<Self, Self::Error> {
612        match value {
613            0x01 => Ok(MessageType::HandshakeRequest),
614            0x02 => Ok(MessageType::HandshakeResponse),
615            0x03 => Ok(MessageType::SynapsePacket),
616            0x04 => Ok(MessageType::SynapseResponse),
617            0x05 => Ok(MessageType::StreamChunk),
618            0x06 => Ok(MessageType::StreamEnd),
619            _ => Err(LightningError::Transport(format!(
620                "unknown message type: 0x{:02x}",
621                value
622            ))),
623        }
624    }
625}
626
627const FRAME_HEADER_SIZE: usize = 5;
628
629async fn read_exact_from_recv(recv: &mut RecvStream, buf: &mut [u8]) -> Result<()> {
630    let mut offset = 0;
631    while offset < buf.len() {
632        match recv.read(&mut buf[offset..]).await {
633            Ok(Some(n)) => offset += n,
634            Ok(None) => {
635                return Err(LightningError::Transport(format!(
636                    "stream closed after {} of {} bytes",
637                    offset,
638                    buf.len()
639                )));
640            }
641            Err(e) => {
642                return Err(LightningError::Transport(format!("read error: {}", e)));
643            }
644        }
645    }
646    Ok(())
647}
648
649const INCREMENTAL_READ_THRESHOLD: usize = 1_048_576;
650const READ_CHUNK_SIZE: usize = 65_536;
651
652/// Parses the 5-byte frame header (`[msg_type, payload_len_be32]`) from a byte slice
653/// and returns the message type and payload sub-slice.
654pub fn parse_frame_header(data: &[u8], max_payload: usize) -> Result<(MessageType, &[u8])> {
655    if data.len() < FRAME_HEADER_SIZE {
656        return Err(LightningError::Transport(
657            "insufficient data for frame header".to_string(),
658        ));
659    }
660    let msg_type = MessageType::try_from(data[0])?;
661    let payload_len = u32::from_be_bytes([data[1], data[2], data[3], data[4]]) as usize;
662    if payload_len > max_payload {
663        return Err(LightningError::Transport(format!(
664            "frame payload {} bytes exceeds maximum {}",
665            payload_len, max_payload
666        )));
667    }
668    if data.len() < FRAME_HEADER_SIZE + payload_len {
669        return Err(LightningError::Transport(format!(
670            "insufficient data for frame payload: have {}, need {}",
671            data.len() - FRAME_HEADER_SIZE,
672            payload_len
673        )));
674    }
675    Ok((
676        msg_type,
677        &data[FRAME_HEADER_SIZE..FRAME_HEADER_SIZE + payload_len],
678    ))
679}
680
681pub async fn read_frame(
682    recv: &mut RecvStream,
683    max_payload: usize,
684) -> Result<(MessageType, Vec<u8>)> {
685    let mut header = [0u8; FRAME_HEADER_SIZE];
686    read_exact_from_recv(recv, &mut header).await?;
687
688    let msg_type = MessageType::try_from(header[0])?;
689    let payload_len = u32::from_be_bytes([header[1], header[2], header[3], header[4]]) as usize;
690
691    if payload_len > max_payload {
692        return Err(LightningError::Transport(format!(
693            "frame payload {} bytes exceeds maximum {}",
694            payload_len, max_payload
695        )));
696    }
697
698    if payload_len == 0 {
699        return Ok((msg_type, Vec::new()));
700    }
701
702    // Payloads up to 1MB are allocated in one shot. Larger payloads are read
703    // incrementally in 64KB chunks to bound peak memory when the declared
704    // payload_len has not yet been validated against actual stream data.
705    if payload_len <= INCREMENTAL_READ_THRESHOLD {
706        let mut payload = vec![0u8; payload_len];
707        read_exact_from_recv(recv, &mut payload).await?;
708        return Ok((msg_type, payload));
709    }
710
711    let mut payload = Vec::with_capacity(INCREMENTAL_READ_THRESHOLD);
712    let mut remaining = payload_len;
713    while remaining > 0 {
714        let next_capacity = payload
715            .capacity()
716            .saturating_mul(2)
717            .max(INCREMENTAL_READ_THRESHOLD)
718            .min(payload_len)
719            .min(max_payload);
720        if payload.capacity() < next_capacity {
721            payload.reserve(next_capacity - payload.len());
722        }
723        let chunk_size = remaining.min(READ_CHUNK_SIZE);
724        let start = payload.len();
725        payload.resize(start + chunk_size, 0);
726        read_exact_from_recv(recv, &mut payload[start..]).await?;
727        remaining -= chunk_size;
728    }
729    Ok((msg_type, payload))
730}
731
732pub async fn write_frame(
733    send: &mut SendStream,
734    msg_type: MessageType,
735    payload: &[u8],
736) -> Result<()> {
737    let payload_len: u32 = payload.len().try_into().map_err(|_| {
738        LightningError::Transport(format!(
739            "frame payload {} bytes exceeds u32::MAX",
740            payload.len()
741        ))
742    })?;
743
744    let mut header = [0u8; FRAME_HEADER_SIZE];
745    header[0] = msg_type as u8;
746    header[1..5].copy_from_slice(&payload_len.to_be_bytes());
747
748    send.write_all(&header)
749        .await
750        .map_err(|e| LightningError::Transport(format!("failed to write frame header: {}", e)))?;
751    if !payload.is_empty() {
752        send.write_all(payload).await.map_err(|e| {
753            LightningError::Transport(format!("failed to write frame payload: {}", e))
754        })?;
755    }
756    Ok(())
757}
758
759pub async fn write_frame_and_finish(
760    send: &mut SendStream,
761    msg_type: MessageType,
762    payload: &[u8],
763) -> Result<()> {
764    write_frame(send, msg_type, payload).await?;
765    send.finish()
766        .map_err(|e| LightningError::Transport(format!("failed to finish stream: {}", e)))?;
767    Ok(())
768}
769
770#[cfg(test)]
771mod tests {
772    use super::*;
773
774    #[test]
775    fn quic_request_from_typed_serializes_struct() {
776        #[derive(serde::Serialize)]
777        struct MyReq {
778            name: String,
779            count: u32,
780        }
781
782        let req = QuicRequest::from_typed(
783            "test_synapse",
784            &MyReq {
785                name: "hello".into(),
786                count: 42,
787            },
788        )
789        .unwrap();
790
791        assert_eq!(req.synapse_type, "test_synapse");
792        assert_eq!(
793            req.data.get("name").unwrap(),
794            &rmpv::Value::String("hello".into())
795        );
796        assert_eq!(
797            req.data.get("count").unwrap(),
798            &rmpv::Value::Integer(42.into())
799        );
800    }
801
802    #[test]
803    fn quic_response_into_result_ok_on_success() {
804        let resp = QuicResponse {
805            success: true,
806            data: HashMap::new(),
807            latency_ms: 1.0,
808            error: None,
809        };
810        assert!(resp.into_result().is_ok());
811    }
812
813    #[test]
814    fn quic_response_into_result_err_on_failure() {
815        let resp = QuicResponse {
816            success: false,
817            data: HashMap::new(),
818            latency_ms: 1.0,
819            error: Some("bad request".into()),
820        };
821        let err = resp.into_result().unwrap_err();
822        assert!(err.to_string().contains("bad request"));
823    }
824
825    #[test]
826    fn quic_response_into_result_uses_default_message() {
827        let resp = QuicResponse {
828            success: false,
829            data: HashMap::new(),
830            latency_ms: 1.0,
831            error: None,
832        };
833        let err = resp.into_result().unwrap_err();
834        assert!(err.to_string().contains("request failed"));
835    }
836
837    #[test]
838    fn quic_response_deserialize_data_roundtrips() {
839        #[derive(serde::Serialize, serde::Deserialize, Debug, PartialEq)]
840        struct MyResp {
841            value: i32,
842            label: String,
843        }
844
845        let original = MyResp {
846            value: 99,
847            label: "test".into(),
848        };
849
850        let data = serialize_to_rmpv_map(&original).unwrap();
851
852        let resp = QuicResponse {
853            success: true,
854            data,
855            latency_ms: 1.0,
856            error: None,
857        };
858
859        let deserialized: MyResp = resp.deserialize_data().unwrap();
860        assert_eq!(deserialized, original);
861    }
862
863    #[test]
864    fn serialize_to_rmpv_map_handles_nested_structs() {
865        #[derive(serde::Serialize, serde::Deserialize, Debug, PartialEq)]
866        struct Inner {
867            x: i32,
868            y: String,
869        }
870
871        #[derive(serde::Serialize, serde::Deserialize, Debug, PartialEq)]
872        struct Outer {
873            name: String,
874            inner: Inner,
875            values: Vec<u32>,
876        }
877
878        let original = Outer {
879            name: "test".into(),
880            inner: Inner {
881                x: 42,
882                y: "nested".into(),
883            },
884            values: vec![1, 2, 3],
885        };
886
887        let map = serialize_to_rmpv_map(&original).unwrap();
888        assert_eq!(
889            map.get("name").unwrap(),
890            &rmpv::Value::String("test".into())
891        );
892        assert!(matches!(map.get("inner").unwrap(), rmpv::Value::Map(_)));
893        assert!(matches!(map.get("values").unwrap(), rmpv::Value::Array(_)));
894
895        let resp = QuicResponse {
896            success: true,
897            data: map,
898            latency_ms: 0.0,
899            error: None,
900        };
901        let deserialized: Outer = resp.deserialize_data().unwrap();
902        assert_eq!(deserialized, original);
903    }
904
905    #[test]
906    fn handshake_request_message_format() {
907        let msg = handshake_request_message("5GrwvaEF", 1234567890, "abc123", "fp_b64");
908        assert_eq!(msg, "handshake:5GrwvaEF:1234567890:abc123:fp_b64");
909    }
910
911    #[test]
912    fn handshake_response_message_format() {
913        let msg =
914            handshake_response_message("5GrwvaEF", "5FHneW46", 1234567890, "abc123", "fp_b64");
915        assert_eq!(
916            msg,
917            "handshake_response:5GrwvaEF:5FHneW46:1234567890:abc123:fp_b64"
918        );
919    }
920
921    #[test]
922    fn parse_frame_header_valid() {
923        let mut data = vec![0x01];
924        data.extend_from_slice(&5u32.to_be_bytes());
925        data.extend_from_slice(b"hello");
926        let (msg_type, payload) = parse_frame_header(&data, DEFAULT_MAX_FRAME_PAYLOAD).unwrap();
927        assert_eq!(msg_type, MessageType::HandshakeRequest);
928        assert_eq!(payload, b"hello");
929    }
930
931    #[test]
932    fn parse_frame_header_insufficient_header() {
933        assert!(parse_frame_header(&[0x01, 0x00], DEFAULT_MAX_FRAME_PAYLOAD).is_err());
934    }
935
936    #[test]
937    fn parse_frame_header_insufficient_payload() {
938        let mut data = vec![0x01];
939        data.extend_from_slice(&10u32.to_be_bytes());
940        data.extend_from_slice(b"short");
941        assert!(parse_frame_header(&data, DEFAULT_MAX_FRAME_PAYLOAD).is_err());
942    }
943
944    #[test]
945    fn parse_frame_header_oversized_payload() {
946        let mut data = vec![0x01];
947        data.extend_from_slice(&(DEFAULT_MAX_FRAME_PAYLOAD as u32 + 1).to_be_bytes());
948        assert!(parse_frame_header(&data, DEFAULT_MAX_FRAME_PAYLOAD).is_err());
949    }
950
951    #[test]
952    fn parse_frame_header_invalid_message_type() {
953        let mut data = vec![0xFF];
954        data.extend_from_slice(&0u32.to_be_bytes());
955        assert!(parse_frame_header(&data, DEFAULT_MAX_FRAME_PAYLOAD).is_err());
956    }
957
958    use proptest::prelude::*;
959
960    fn arb_rmpv_leaf() -> impl Strategy<Value = rmpv::Value> {
961        prop_oneof![
962            Just(rmpv::Value::Nil),
963            any::<bool>().prop_map(rmpv::Value::Boolean),
964            any::<i64>().prop_map(|v| rmpv::Value::Integer(rmpv::Integer::from(v))),
965            any::<u64>().prop_map(|v| rmpv::Value::Integer(rmpv::Integer::from(v))),
966            any::<f32>().prop_map(rmpv::Value::F32),
967            any::<f64>().prop_map(rmpv::Value::F64),
968            "[a-zA-Z0-9_ ]{0,32}"
969                .prop_map(|s| rmpv::Value::String(rmpv::Utf8String::from(s.as_str()))),
970            proptest::collection::vec(any::<u8>(), 0..64).prop_map(rmpv::Value::Binary),
971        ]
972    }
973
974    fn arb_rmpv_value() -> impl Strategy<Value = rmpv::Value> {
975        arb_rmpv_leaf().prop_recursive(3, 32, 8, |inner| {
976            prop_oneof![
977                proptest::collection::vec(inner.clone(), 0..8).prop_map(rmpv::Value::Array),
978                proptest::collection::vec((inner.clone(), inner), 0..8).prop_map(rmpv::Value::Map),
979            ]
980        })
981    }
982
983    proptest! {
984        #[test]
985        fn msgpack_encode_decode_roundtrip(value in arb_rmpv_value()) {
986            let bytes = rmp_serde::to_vec(&value).unwrap();
987            let decoded: rmpv::Value = rmp_serde::from_slice(&bytes).unwrap();
988            prop_assert_eq!(value, decoded);
989        }
990
991        #[test]
992        fn serialize_to_rmpv_map_roundtrip(
993            keys in proptest::collection::vec("[a-z]{1,8}", 1..8),
994            vals in proptest::collection::vec(arb_rmpv_leaf(), 1..8),
995        ) {
996            let mut map = HashMap::new();
997            for (k, v) in keys.into_iter().zip(vals.into_iter()) {
998                map.insert(k, v);
999            }
1000            let rmpv_map = hashmap_to_rmpv_map(map);
1001            let bytes = rmp_serde::to_vec(&rmpv_map).unwrap();
1002            let decoded: rmpv::Value = rmp_serde::from_slice(&bytes).unwrap();
1003            prop_assert_eq!(rmpv_map, decoded);
1004        }
1005
1006        #[test]
1007        fn from_typed_roundtrip(
1008            name in "[a-zA-Z]{1,16}",
1009            count in any::<u32>(),
1010            label in "[a-zA-Z0-9 ]{0,32}",
1011        ) {
1012            #[derive(serde::Serialize, serde::Deserialize, Debug, PartialEq)]
1013            struct TestStruct {
1014                name: String,
1015                count: u32,
1016                label: String,
1017            }
1018
1019            let original = TestStruct {
1020                name: name.clone(),
1021                count,
1022                label: label.clone(),
1023            };
1024
1025            let req = QuicRequest::from_typed("test", &original).unwrap();
1026            let resp = QuicResponse {
1027                success: true,
1028                data: req.data,
1029                latency_ms: 0.0,
1030                error: None,
1031            };
1032            let deserialized: TestStruct = resp.deserialize_data().unwrap();
1033            prop_assert_eq!(original, deserialized);
1034        }
1035
1036        #[test]
1037        fn parse_frame_header_never_panics(data in proptest::collection::vec(any::<u8>(), 0..256)) {
1038            let _ = parse_frame_header(&data, DEFAULT_MAX_FRAME_PAYLOAD);
1039        }
1040    }
1041}