1use anyhow::anyhow;
2use bytes::{Buf, BufMut};
3use cln_rpc::primitives::TlvEntry;
4use serde::{Deserialize, Deserializer};
5
6#[derive(Clone, Debug)]
10pub struct SerializedTlvStream {
11 entries: Vec<TlvEntry>,
12}
13
14impl SerializedTlvStream {
15 pub fn new() -> Self {
16 Self { entries: vec![] }
17 }
18
19 pub fn get(&self, typ: u64) -> Option<TlvEntry> {
20 self.entries.iter().filter(|e| e.typ == typ).next().cloned()
21 }
22
23 pub fn insert(&mut self, e: TlvEntry) -> Result<(), anyhow::Error> {
24 if let Some(old) = self.get(e.typ) {
25 return Err(anyhow!(
26 "TlvStream contains entry of type={}, old={:?}, new={:?}",
27 e.typ,
28 old,
29 e
30 ));
31 }
32
33 self.entries.push(e);
34 self.entries
35 .sort_by(|a, b| a.typ.partial_cmp(&b.typ).unwrap());
36
37 Ok(())
38 }
39
40 pub fn set_bytes<T>(&mut self, typ: u64, val: T)
41 where
42 T: AsRef<[u8]>,
43 {
44 let pos = self.entries.iter().position(|e| e.typ == typ);
45 match pos {
46 Some(i) => self.entries[i].value = val.as_ref().to_vec(),
47 None => self
48 .insert(TlvEntry {
49 typ,
50 value: val.as_ref().to_vec(),
51 })
52 .unwrap(),
53 }
54 }
55
56 pub fn set_tu64(&mut self, typ: u64, val: TU64) {
57 let mut b = bytes::BytesMut::new();
58 b.put_tu64(val);
59 self.set_bytes(typ, b)
60 }
61}
62
63pub trait FromBytes: Sized {
64 type Error;
65 fn from_bytes<T>(s: T) -> Result<Self, Self::Error>
66 where
67 T: AsRef<[u8]> + 'static;
68}
69
70impl FromBytes for SerializedTlvStream {
71 type Error = anyhow::Error;
72 fn from_bytes<T>(s: T) -> Result<Self, Self::Error>
73 where
74 T: AsRef<[u8]> + 'static,
75 {
76 let mut b = s.as_ref();
77 let mut entries: Vec<TlvEntry> = vec![];
79 while b.remaining() >= 2 {
80 let typ = b.get_compact_size() as u64;
81 let len = b.get_compact_size() as usize;
82 let value = b.copy_to_bytes(len).to_vec();
83 entries.push(TlvEntry { typ, value });
84 }
85
86 Ok(SerializedTlvStream { entries })
87 }
88}
89
90pub type CompactSize = u64;
91
92pub type TU64 = u64;
95
96pub trait ProtoBuf: Buf {
98 fn get_compact_size(&mut self) -> CompactSize {
99 match self.get_u8() {
100 253 => self.get_u16().into(),
101 254 => self.get_u32().into(),
102 255 => self.get_u64(),
103 v => v.into(),
104 }
105 .into()
106 }
107
108 fn get_tu64(&mut self) -> TU64 {
109 match self.remaining() {
110 1 => self.get_u8() as u64,
111 2 => self.get_u16() as u64,
112 4 => self.get_u32() as u64,
113 8 => self.get_u64() as u64,
114 l => panic!("Unexpect TU64 length: {}", l),
115 }
116 }
117}
118
119impl ProtoBuf for bytes::Bytes {}
120impl ProtoBuf for &[u8] {}
121impl ProtoBuf for bytes::buf::Take<bytes::Bytes> {}
122
123pub trait ProtoBufMut: bytes::BufMut {
124 fn put_compact_size(&mut self, cs: CompactSize) {
125 match cs as u64 {
126 0..=0xFC => self.put_u8(cs as u8),
127 0xFD..=0xFFFF => {
128 self.put_u8(253);
129 self.put_u16(cs as u16);
130 }
131 0x10000..=0xFFFFFFFF => {
132 self.put_u8(254);
133 self.put_u32(cs as u32);
134 }
135 v => {
136 self.put_u8(255);
137 self.put_u64(v);
138 }
139 }
140 }
141
142 fn put_tu64(&mut self, u: TU64) {
143 let b: Vec<u8> = u
147 .to_be_bytes()
148 .iter()
149 .map(|x| x.clone())
150 .skip_while(|&x| x == 0)
151 .collect();
152 self.put_slice(&b);
153 }
154}
155
156impl ProtoBufMut for bytes::BytesMut {}
157
158pub trait ToBytes: Sized {
159 fn to_bytes(s: Self) -> Vec<u8>;
160}
161
162impl ToBytes for SerializedTlvStream {
163 fn to_bytes(s: Self) -> Vec<u8> {
164 let mut b = bytes::BytesMut::new();
165
166 for e in s.entries.iter() {
167 b.put_compact_size(e.typ);
168 b.put_compact_size(e.value.len() as u64);
169 b.put(&e.value[..]);
170 }
171 b.to_vec()
172 }
173}
174
175impl<'de> Deserialize<'de> for SerializedTlvStream {
176 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
177 where
178 D: Deserializer<'de>,
179 {
180 let s: String = Deserialize::deserialize(deserializer)?;
182 let mut b: bytes::Bytes = hex::decode(s)
183 .map_err(|e| serde::de::Error::custom(e.to_string()))?
184 .into();
185
186 let l = b.get_compact_size();
188 let b = b.take(l as usize); Self::from_bytes(b.into_inner()).map_err(|e| serde::de::Error::custom(e.to_string()))
191 }
192}
193
194#[cfg(test)]
195mod tests {
196 use super::*;
197
198 #[test]
199 fn test_tlv_stream() {
200 let raw_hex = "fd80e9fd01046c6e62633130306e31706e6670757a677070356a73376e653571727465326d707874666d7a703638667838676a7376776b3034366c76357a76707a7832766d687478683763777364717163717a7a737871797a3576717370357a6d6b78726539686d6864617378336b75357070336a38366472337778656b78336437383363706c6161363068783870357564733971787071797367717132646c68656177796c677534346567393363766e78666e64747a646a6e647465666b726861727763746b3368783766656e67346179746e6a3277686d74716665636a7930776777396c6665727072386b686d64667771736e386d6d7a3776643565776a34756370656439787076fd80eb022742";
201 let raw = hex::decode(&raw_hex).unwrap();
202 let tlv_stream = SerializedTlvStream::from_bytes(raw).unwrap();
203 let invoice = tlv_stream.get(33001);
204 let amount_msat = tlv_stream.get(33003);
205 assert!(invoice.is_some());
206 assert!(amount_msat
207 .is_some_and(|v| u16::from_be_bytes(v.value[..].try_into().unwrap()) == 10050));
208 }
209}