Skip to main content

btlv/
lib.rs

1#[cfg(feature = "bigsize")]
2pub mod bigsize;
3#[cfg(not(feature = "bigsize"))]
4pub(crate) mod bigsize;
5pub(crate) mod encoding;
6pub mod error;
7pub mod stream;
8
9#[cfg(feature = "serde")]
10mod serde;
11
12pub use error::{Result, TlvError};
13pub use stream::{TlvRecord, TlvStream};
14
15// Helper traits used by the tlv_struct! macro. Not part of the public API.
16#[doc(hidden)]
17pub mod _macro_support {
18    use crate::error::{Result, TlvError};
19
20    // Re-export for integration tests; not part of the public API.
21    pub mod bigsize {
22        pub use crate::bigsize::{decode, encode};
23    }
24    pub mod encoding {
25        pub use crate::encoding::{decode_tu64, encode_tu64};
26    }
27
28    pub trait TlvTu64Encode {
29        fn to_tu64_value(&self) -> u64;
30    }
31    impl TlvTu64Encode for u64 {
32        fn to_tu64_value(&self) -> u64 {
33            *self
34        }
35    }
36    impl TlvTu64Encode for u32 {
37        fn to_tu64_value(&self) -> u64 {
38            *self as u64
39        }
40    }
41
42    pub trait TlvTu64Decode: Sized {
43        fn from_tu64_value(v: u64) -> Result<Self>;
44    }
45    impl TlvTu64Decode for u64 {
46        fn from_tu64_value(v: u64) -> Result<Self> {
47            Ok(v)
48        }
49    }
50    impl TlvTu64Decode for u32 {
51        fn from_tu64_value(v: u64) -> Result<Self> {
52            u32::try_from(v).map_err(|_| TlvError::Overflow)
53        }
54    }
55
56    pub trait TlvBytesEncode {
57        fn to_tlv_vec(&self) -> Vec<u8>;
58    }
59    impl TlvBytesEncode for Vec<u8> {
60        fn to_tlv_vec(&self) -> Vec<u8> {
61            self.clone()
62        }
63    }
64    impl<const N: usize> TlvBytesEncode for [u8; N] {
65        fn to_tlv_vec(&self) -> Vec<u8> {
66            self.to_vec()
67        }
68    }
69
70    pub trait TlvBytesDecode: Sized {
71        fn from_tlv_raw(raw: &[u8], type_num: u64) -> Result<Self>;
72    }
73    impl TlvBytesDecode for Vec<u8> {
74        fn from_tlv_raw(raw: &[u8], _type_num: u64) -> Result<Self> {
75            Ok(raw.to_vec())
76        }
77    }
78    impl<const N: usize> TlvBytesDecode for [u8; N] {
79        fn from_tlv_raw(raw: &[u8], type_num: u64) -> Result<Self> {
80            raw.try_into().map_err(|_| TlvError::InvalidLength {
81                type_: type_num,
82                expected: N,
83                actual: raw.len(),
84            })
85        }
86    }
87}
88
89/// Declare a Rust struct that maps to/from a TLV stream.
90///
91/// # Encoding tags
92/// - `tu64` — variable-length minimal big-endian integer (field type: `u64` or `u32`)
93/// - `u64`  — fixed 8-byte big-endian (field type: `u64`)
94/// - `bytes` — raw bytes (field type: `Vec<u8>`, `[u8; N]`, or `Option` variants)
95///
96/// Fields typed as `Option<T>` are automatically optional: omitted when `None`,
97/// decoded as `None` when absent from the stream.
98///
99/// # Example
100/// ```
101/// btlv::tlv_struct! {
102///     pub struct OnionPayload {
103///         #[tlv(2, tu64)]
104///         pub amt_to_forward: u64,
105///         #[tlv(4, tu64)]
106///         pub outgoing_cltv_value: u32,
107///         #[tlv(6, bytes)]
108///         pub short_channel_id: Option<[u8; 8]>,
109///     }
110/// }
111///
112/// let payload = OnionPayload {
113///     amt_to_forward: 1000,
114///     outgoing_cltv_value: 800000,
115///     short_channel_id: None,
116/// };
117/// let bytes = payload.to_tlv_bytes().unwrap();
118/// let decoded = OnionPayload::from_tlv_bytes(&bytes).unwrap();
119/// assert_eq!(decoded.amt_to_forward, 1000);
120/// assert_eq!(decoded.outgoing_cltv_value, 800000);
121/// assert_eq!(decoded.short_channel_id, None);
122/// ```
123///
124/// # Serde support
125///
126/// With the `serde` feature enabled (on by default), macro-generated structs
127/// implement `Serialize` and `Deserialize`. The wire-format bytes are encoded
128/// as a hex string, matching `TlvStream`'s serde representation.
129///
130/// ```
131/// # #[cfg(feature = "serde")] {
132/// btlv::tlv_struct! {
133///     pub struct Invoice {
134///         #[tlv(2, tu64)]
135///         pub amount_msat: u64,
136///         #[tlv(4, tu64)]
137///         pub expiry: u32,
138///     }
139/// }
140///
141/// let inv = Invoice { amount_msat: 50_000, expiry: 3600 };
142///
143/// // Serialize to JSON (the TLV bytes become a hex string)
144/// let json = serde_json::to_string(&inv).unwrap();
145///
146/// // Deserialize back
147/// let decoded: Invoice = serde_json::from_str(&json).unwrap();
148/// assert_eq!(decoded, inv);
149/// # }
150/// ```
151#[macro_export]
152macro_rules! tlv_struct {
153    // Top-level entry point: start the tt-muncher to classify fields
154    (
155        $(#[$struct_meta:meta])*
156        $vis:vis struct $name:ident {
157            $($rest:tt)*
158        }
159    ) => {
160        $crate::tlv_struct!(@munch
161            [$(#[$struct_meta])* $vis struct $name]
162            []
163            $($rest)*
164        );
165    };
166
167    // tt-muncher: optional field (Option<T>)
168    (@munch
169        [$($header:tt)*]
170        [$($acc:tt)*]
171        $(#[doc = $doc:literal])*
172        #[tlv($type_num:expr, $enc:ident)]
173        $field_vis:vis $field:ident : Option<$inner_ty:ty>,
174        $($rest:tt)*
175    ) => {
176        $crate::tlv_struct!(@munch
177            [$($header)*]
178            [$($acc)*
179                $(#[doc = $doc])*
180                ($type_num, $enc, optional)
181                $field_vis $field : Option<$inner_ty>,
182            ]
183            $($rest)*
184        );
185    };
186
187    // tt-muncher: required field (non-Option)
188    (@munch
189        [$($header:tt)*]
190        [$($acc:tt)*]
191        $(#[doc = $doc:literal])*
192        #[tlv($type_num:expr, $enc:ident)]
193        $field_vis:vis $field:ident : $field_ty:ty,
194        $($rest:tt)*
195    ) => {
196        $crate::tlv_struct!(@munch
197            [$($header)*]
198            [$($acc)*
199                $(#[doc = $doc])*
200                ($type_num, $enc, required)
201                $field_vis $field : $field_ty,
202            ]
203            $($rest)*
204        );
205    };
206
207    // tt-muncher: done — emit @impl_struct
208    (@munch
209        [$(#[$struct_meta:meta])* $vis:vis struct $name:ident]
210        [$($acc:tt)*]
211    ) => {
212        $crate::tlv_struct!(@impl_struct
213            $(#[$struct_meta])*
214            $vis struct $name {
215                $($acc)*
216            }
217        );
218    };
219
220    // Internal: struct definition + impls
221    (@impl_struct
222        $(#[$struct_meta:meta])*
223        $vis:vis struct $name:ident {
224            $(
225                $(#[doc = $doc:literal])*
226                ($type_num:expr, $enc:ident, $optionality:ident)
227                $field_vis:vis $field:ident : $field_ty:ty,
228            )*
229        }
230    ) => {
231        $(#[$struct_meta])*
232        #[derive(Debug, Clone, PartialEq, Eq)]
233        $vis struct $name {
234            $(
235                $(#[doc = $doc])*
236                $field_vis $field : $field_ty,
237            )*
238        }
239
240        impl $name {
241            /// Serialize this struct to TLV wire-format bytes.
242            pub fn to_tlv_bytes(&self) -> $crate::Result<Vec<u8>> {
243                let stream: $crate::TlvStream = self.into();
244                stream.to_bytes()
245            }
246
247            /// Deserialize from TLV wire-format bytes.
248            pub fn from_tlv_bytes(bytes: &[u8]) -> $crate::Result<Self> {
249                let stream = $crate::TlvStream::from_bytes(bytes)?;
250                Self::try_from(&stream)
251            }
252        }
253
254        impl From<&$name> for $crate::TlvStream {
255            fn from(val: &$name) -> Self {
256                let mut stream = $crate::TlvStream::default();
257                $(
258                    $crate::tlv_struct!(@encode_field stream, val, $field, $type_num, $enc, $optionality);
259                )*
260                stream
261            }
262        }
263
264        impl TryFrom<&$crate::TlvStream> for $name {
265            type Error = $crate::TlvError;
266
267            fn try_from(stream: &$crate::TlvStream) -> std::result::Result<Self, Self::Error> {
268                Ok($name {
269                    $(
270                        $field: $crate::tlv_struct!(@decode_field stream, $type_num, $enc, $optionality),
271                    )*
272                })
273            }
274        }
275
276        #[cfg(feature = "serde")]
277        impl ::serde::Serialize for $name {
278            fn serialize<S: ::serde::Serializer>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error> {
279                let stream: $crate::TlvStream = self.into();
280                ::serde::Serialize::serialize(&stream, serializer)
281            }
282        }
283
284        #[cfg(feature = "serde")]
285        impl<'de> ::serde::Deserialize<'de> for $name {
286            fn deserialize<D: ::serde::Deserializer<'de>>(deserializer: D) -> std::result::Result<Self, D::Error> {
287                let stream = <$crate::TlvStream as ::serde::Deserialize>::deserialize(deserializer)?;
288                Self::try_from(&stream).map_err(::serde::de::Error::custom)
289            }
290        }
291    };
292
293    // === Encode: tu64 required ===
294    (@encode_field $stream:ident, $val:ident, $field:ident, $type_num:expr, tu64, required) => {
295        $stream.set_tu64($type_num, $crate::_macro_support::TlvTu64Encode::to_tu64_value(&$val.$field));
296    };
297    // === Encode: tu64 optional ===
298    (@encode_field $stream:ident, $val:ident, $field:ident, $type_num:expr, tu64, optional) => {
299        if let Some(ref v) = $val.$field {
300            $stream.set_tu64($type_num, $crate::_macro_support::TlvTu64Encode::to_tu64_value(v));
301        }
302    };
303    // === Encode: u64 required ===
304    (@encode_field $stream:ident, $val:ident, $field:ident, $type_num:expr, u64, required) => {
305        $stream.set_u64($type_num, $val.$field);
306    };
307    // === Encode: u64 optional ===
308    (@encode_field $stream:ident, $val:ident, $field:ident, $type_num:expr, u64, optional) => {
309        if let Some(v) = $val.$field {
310            $stream.set_u64($type_num, v);
311        }
312    };
313    // === Encode: bytes required ===
314    (@encode_field $stream:ident, $val:ident, $field:ident, $type_num:expr, bytes, required) => {
315        $stream.insert($type_num, $crate::_macro_support::TlvBytesEncode::to_tlv_vec(&$val.$field));
316    };
317    // === Encode: bytes optional ===
318    (@encode_field $stream:ident, $val:ident, $field:ident, $type_num:expr, bytes, optional) => {
319        if let Some(ref v) = $val.$field {
320            $stream.insert($type_num, $crate::_macro_support::TlvBytesEncode::to_tlv_vec(v));
321        }
322    };
323
324    // === Decode: tu64 required ===
325    (@decode_field $stream:ident, $type_num:expr, tu64, required) => {{
326        let v = $stream.get_tu64($type_num)?
327            .ok_or($crate::TlvError::MissingRequired($type_num))?;
328        $crate::_macro_support::TlvTu64Decode::from_tu64_value(v)?
329    }};
330    // === Decode: tu64 optional ===
331    (@decode_field $stream:ident, $type_num:expr, tu64, optional) => {{
332        match $stream.get_tu64($type_num)? {
333            Some(v) => Some($crate::_macro_support::TlvTu64Decode::from_tu64_value(v)?),
334            None => None,
335        }
336    }};
337    // === Decode: u64 required ===
338    (@decode_field $stream:ident, $type_num:expr, u64, required) => {{
339        $stream.get_u64($type_num)?
340            .ok_or($crate::TlvError::MissingRequired($type_num))?
341    }};
342    // === Decode: u64 optional ===
343    (@decode_field $stream:ident, $type_num:expr, u64, optional) => {{
344        $stream.get_u64($type_num)?
345    }};
346    // === Decode: bytes required ===
347    (@decode_field $stream:ident, $type_num:expr, bytes, required) => {{
348        let raw = $stream.get($type_num)
349            .ok_or($crate::TlvError::MissingRequired($type_num))?;
350        $crate::_macro_support::TlvBytesDecode::from_tlv_raw(raw, $type_num)?
351    }};
352    // === Decode: bytes optional ===
353    (@decode_field $stream:ident, $type_num:expr, bytes, optional) => {{
354        match $stream.get($type_num) {
355            Some(raw) => Some($crate::_macro_support::TlvBytesDecode::from_tlv_raw(raw, $type_num)?),
356            None => None,
357        }
358    }};
359}
360
361#[cfg(test)]
362mod tests {
363    use super::*;
364
365    // Mixed required and optional fields
366    tlv_struct! {
367        /// An onion payload for testing.
368        pub struct OnionPayload {
369            /// Amount to forward in msat
370            #[tlv(2, tu64)]
371            pub amt_to_forward: u64,
372            /// Outgoing CLTV value
373            #[tlv(4, tu64)]
374            pub outgoing_cltv_value: u32,
375            /// Short channel ID
376            #[tlv(6, bytes)]
377            pub short_channel_id: Option<[u8; 8]>,
378            /// Payment secret
379            #[tlv(8, bytes)]
380            pub payment_secret: Option<[u8; 32]>,
381        }
382    }
383
384    #[test]
385    fn onion_payload_roundtrip_all_fields() {
386        let scid = [0x00, 0x73, 0x00, 0x0f, 0x2c, 0x00, 0x07, 0x00];
387        let secret = [0xab; 32];
388        let payload = OnionPayload {
389            amt_to_forward: 1000,
390            outgoing_cltv_value: 800000,
391            short_channel_id: Some(scid),
392            payment_secret: Some(secret),
393        };
394
395        let bytes = payload.to_tlv_bytes().unwrap();
396        let decoded = OnionPayload::from_tlv_bytes(&bytes).unwrap();
397        assert_eq!(decoded, payload);
398    }
399
400    #[test]
401    fn onion_payload_roundtrip_optional_none() {
402        let payload = OnionPayload {
403            amt_to_forward: 500,
404            outgoing_cltv_value: 144,
405            short_channel_id: None,
406            payment_secret: None,
407        };
408
409        let bytes = payload.to_tlv_bytes().unwrap();
410        let decoded = OnionPayload::from_tlv_bytes(&bytes).unwrap();
411        assert_eq!(decoded, payload);
412    }
413
414    #[test]
415    fn onion_payload_missing_required_field() {
416        let mut stream = TlvStream::default();
417        stream.set_tu64(2, 1000);
418        let bytes = stream.to_bytes().unwrap();
419
420        let err = OnionPayload::from_tlv_bytes(&bytes).unwrap_err();
421        assert!(matches!(err, TlvError::MissingRequired(4)));
422    }
423
424    #[test]
425    fn onion_payload_wrong_length_bytes() {
426        let mut stream = TlvStream::default();
427        stream.set_tu64(2, 1000);
428        stream.set_tu64(4, 144);
429        stream.insert(6, vec![0x00; 5]);
430        let bytes = stream.to_bytes().unwrap();
431
432        let err = OnionPayload::from_tlv_bytes(&bytes).unwrap_err();
433        assert!(matches!(
434            err,
435            TlvError::InvalidLength {
436                type_: 6,
437                expected: 8,
438                actual: 5,
439            }
440        ));
441    }
442
443    #[test]
444    fn onion_payload_to_stream_and_back() {
445        let payload = OnionPayload {
446            amt_to_forward: 42,
447            outgoing_cltv_value: 100,
448            short_channel_id: None,
449            payment_secret: None,
450        };
451
452        let stream: TlvStream = (&payload).into();
453        let back = OnionPayload::try_from(&stream).unwrap();
454        assert_eq!(back, payload);
455    }
456
457    // All-required struct
458    tlv_struct! {
459        pub struct SimplePayload {
460            #[tlv(2, tu64)]
461            pub amount: u64,
462            #[tlv(4, tu64)]
463            pub cltv: u32,
464        }
465    }
466
467    #[test]
468    fn simple_payload_roundtrip() {
469        let p = SimplePayload {
470            amount: 999,
471            cltv: 800000,
472        };
473        let bytes = p.to_tlv_bytes().unwrap();
474        let d = SimplePayload::from_tlv_bytes(&bytes).unwrap();
475        assert_eq!(d, p);
476    }
477
478    // All-optional struct
479    tlv_struct! {
480        pub struct OptionalOnly {
481            #[tlv(1, tu64)]
482            pub a: Option<u64>,
483            #[tlv(3, bytes)]
484            pub b: Option<Vec<u8>>,
485        }
486    }
487
488    #[test]
489    fn optional_only_empty() {
490        let p = OptionalOnly { a: None, b: None };
491        let bytes = p.to_tlv_bytes().unwrap();
492        assert!(bytes.is_empty());
493        let d = OptionalOnly::from_tlv_bytes(&bytes).unwrap();
494        assert_eq!(d, p);
495    }
496
497    #[test]
498    fn optional_only_with_values() {
499        let p = OptionalOnly {
500            a: Some(42),
501            b: Some(vec![0xde, 0xad]),
502        };
503        let bytes = p.to_tlv_bytes().unwrap();
504        let d = OptionalOnly::from_tlv_bytes(&bytes).unwrap();
505        assert_eq!(d, p);
506    }
507
508    // Struct with required Vec<u8> bytes
509    tlv_struct! {
510        pub struct WithRequiredBytes {
511            #[tlv(1, bytes)]
512            pub data: Vec<u8>,
513            #[tlv(3, tu64)]
514            pub count: u64,
515        }
516    }
517
518    #[test]
519    fn required_bytes_roundtrip() {
520        let p = WithRequiredBytes {
521            data: vec![0x01, 0x02, 0x03],
522            count: 7,
523        };
524        let bytes = p.to_tlv_bytes().unwrap();
525        let d = WithRequiredBytes::from_tlv_bytes(&bytes).unwrap();
526        assert_eq!(d, p);
527    }
528
529    // Fixed u64 encoding
530    tlv_struct! {
531        pub struct FixedU64Struct {
532            #[tlv(65537, u64)]
533            pub extra_fee: u64,
534            #[tlv(65539, u64)]
535            pub optional_fee: Option<u64>,
536        }
537    }
538
539    #[cfg(feature = "serde")]
540    #[test]
541    fn simple_payload_serde_json_roundtrip() {
542        let p = SimplePayload {
543            amount: 999,
544            cltv: 800000,
545        };
546        let json = serde_json::to_string(&p).unwrap();
547        let d: SimplePayload = serde_json::from_str(&json).unwrap();
548        assert_eq!(d, p);
549    }
550
551    #[test]
552    fn fixed_u64_roundtrip() {
553        let p = FixedU64Struct {
554            extra_fee: 1000,
555            optional_fee: Some(500),
556        };
557        let bytes = p.to_tlv_bytes().unwrap();
558        let d = FixedU64Struct::from_tlv_bytes(&bytes).unwrap();
559        assert_eq!(d, p);
560
561        let p2 = FixedU64Struct {
562            extra_fee: 42,
563            optional_fee: None,
564        };
565        let bytes2 = p2.to_tlv_bytes().unwrap();
566        let d2 = FixedU64Struct::from_tlv_bytes(&bytes2).unwrap();
567        assert_eq!(d2, p2);
568    }
569}