1use std::{
4 collections::{HashMap, VecDeque},
5 fmt::Display,
6 sync::Mutex,
7};
8
9use lightning::types::features::{InitFeatures, NodeFeatures};
10use lightning::{
11 io::Cursor,
12 ln::{
13 msgs::{DecodeError, LightningError},
14 peer_handler::CustomMessageHandler,
15 wire::{CustomMessageReader, Type},
16 },
17 util::ser::{Readable, Writeable, MAX_BUF_SIZE},
18};
19use secp256k1_zkp::PublicKey;
20
21use crate::{
22 segmentation::{get_segments, segment_reader::SegmentReader},
23 Message, WireMessage,
24};
25
26pub struct MessageHandler {
31 msg_events: Mutex<VecDeque<(PublicKey, WireMessage)>>,
32 msg_received: Mutex<Vec<(PublicKey, Message)>>,
33 segment_readers: Mutex<HashMap<PublicKey, SegmentReader>>,
34}
35
36impl Default for MessageHandler {
37 fn default() -> Self {
38 Self::new()
39 }
40}
41
42impl MessageHandler {
43 pub fn new() -> Self {
45 MessageHandler {
46 msg_events: Mutex::new(VecDeque::new()),
47 msg_received: Mutex::new(Vec::new()),
48 segment_readers: Mutex::new(HashMap::new()),
49 }
50 }
51
52 pub fn get_and_clear_received_messages(&self) -> Vec<(PublicKey, Message)> {
55 let mut ret = Vec::new();
56 std::mem::swap(&mut *self.msg_received.lock().unwrap(), &mut ret);
57 ret
58 }
59
60 pub fn send_message(&self, node_id: PublicKey, msg: Message) {
64 if msg.serialized_length() > MAX_BUF_SIZE {
65 let (seg_start, seg_chunks) = get_segments(msg.encode(), msg.type_id());
66 let mut msg_events = self.msg_events.lock().unwrap();
67 msg_events.push_back((node_id, WireMessage::SegmentStart(seg_start)));
68 for chunk in seg_chunks {
69 msg_events.push_back((node_id, WireMessage::SegmentChunk(chunk)));
70 }
71 } else {
72 self.msg_events
73 .lock()
74 .unwrap()
75 .push_back((node_id, WireMessage::Message(msg)));
76 }
77 }
78
79 pub fn has_pending_messages(&self) -> bool {
81 !self.msg_events.lock().unwrap().is_empty()
82 }
83}
84
85macro_rules! handle_read_dlc_messages {
86 ($msg_type:ident, $buffer:ident, $(($type_id:ident, $variant:ident)),*) => {{
87 let decoded = match $msg_type {
88 $(
89 $crate::$type_id => Message::$variant(Readable::read($buffer)?),
90 )*
91 _ => return Ok(None),
92 };
93 Ok(Some(WireMessage::Message(decoded)))
94 }};
95}
96
97pub fn read_dlc_message<R: ::lightning::io::Read>(
99 msg_type: u16,
100 buffer: &mut R,
101) -> Result<Option<WireMessage>, DecodeError> {
102 handle_read_dlc_messages!(
103 msg_type,
104 buffer,
105 (OFFER_TYPE, Offer),
106 (ACCEPT_TYPE, Accept),
107 (SIGN_TYPE, Sign),
108 (OFFER_CHANNEL_TYPE, OfferChannel),
109 (ACCEPT_CHANNEL_TYPE, AcceptChannel),
110 (SIGN_CHANNEL_TYPE, SignChannel),
111 (SETTLE_CHANNEL_OFFER_TYPE, SettleOffer),
112 (SETTLE_CHANNEL_ACCEPT_TYPE, SettleAccept),
113 (SETTLE_CHANNEL_CONFIRM_TYPE, SettleConfirm),
114 (SETTLE_CHANNEL_FINALIZE_TYPE, SettleFinalize),
115 (RENEW_CHANNEL_OFFER_TYPE, RenewOffer),
116 (RENEW_CHANNEL_ACCEPT_TYPE, RenewAccept),
117 (RENEW_CHANNEL_CONFIRM_TYPE, RenewConfirm),
118 (RENEW_CHANNEL_FINALIZE_TYPE, RenewFinalize),
119 (COLLABORATIVE_CLOSE_OFFER_TYPE, CollaborativeCloseOffer),
120 (REJECT, Reject)
121 )
122}
123
124impl CustomMessageReader for MessageHandler {
127 type CustomMessage = WireMessage;
128 fn read<R: ::lightning::io::Read>(
129 &self,
130 msg_type: u16,
131 buffer: &mut R,
132 ) -> Result<Option<WireMessage>, DecodeError> {
133 let decoded = match msg_type {
134 crate::segmentation::SEGMENT_START_TYPE => {
135 WireMessage::SegmentStart(Readable::read(buffer)?)
136 }
137 crate::segmentation::SEGMENT_CHUNK_TYPE => {
138 WireMessage::SegmentChunk(Readable::read(buffer)?)
139 }
140 _ => return read_dlc_message(msg_type, buffer),
141 };
142
143 Ok(Some(decoded))
144 }
145}
146
147impl CustomMessageHandler for MessageHandler {
150 fn peer_connected(
151 &self,
152 _their_node_id: PublicKey,
153 _msg: &lightning::ln::msgs::Init,
154 _inbound: bool,
155 ) -> Result<(), ()> {
156 Ok(())
157 }
158
159 fn peer_disconnected(&self, _their_node_id: PublicKey) {}
160
161 fn handle_custom_message(
162 &self,
163 msg: WireMessage,
164 org: PublicKey,
165 ) -> Result<(), LightningError> {
166 let mut segment_readers = self.segment_readers.lock().unwrap();
167 let segment_reader = segment_readers.entry(org).or_default();
168
169 if segment_reader.expecting_chunk() {
170 match msg {
171 WireMessage::SegmentChunk(s) => {
172 if let Some(msg) = segment_reader
173 .process_segment_chunk(s)
174 .map_err(|e| to_ln_error(e, "Error processing segment chunk"))?
175 {
176 let mut buf = Cursor::new(msg);
177 let message_type = <u16 as Readable>::read(&mut buf).map_err(|e| {
178 to_ln_error(e, "Could not reconstruct message from segments")
179 })?;
180 if let WireMessage::Message(m) = self
181 .read(message_type, &mut buf)
182 .map_err(|e| {
183 to_ln_error(e, "Could not reconstruct message from segments")
184 })?
185 .expect("to have a message")
186 {
187 self.msg_received.lock().unwrap().push((org, m));
188 } else {
189 return Err(to_ln_error(
190 "Unexpected message type",
191 &message_type.to_string(),
192 ));
193 }
194 }
195 return Ok(());
196 }
197 _ => {
198 segment_reader.reset();
201 }
202 }
203 }
204
205 match msg {
206 WireMessage::Message(m) => self.msg_received.lock().unwrap().push((org, m)),
207 WireMessage::SegmentStart(s) => segment_reader
208 .process_segment_start(s)
209 .map_err(|e| to_ln_error(e, "Error processing segment start"))?,
210 WireMessage::SegmentChunk(_) => {
211 return Err(LightningError {
212 err: "Received a SegmentChunk while not expecting one.".to_string(),
213 action: lightning::ln::msgs::ErrorAction::DisconnectPeer { msg: None },
214 });
215 }
216 };
217 Ok(())
218 }
219
220 fn get_and_clear_pending_msg(&self) -> Vec<(PublicKey, Self::CustomMessage)> {
221 self.msg_events.lock().unwrap().drain(..).collect()
222 }
223
224 fn provided_node_features(&self) -> NodeFeatures {
225 NodeFeatures::empty()
226 }
227
228 fn provided_init_features(&self, _their_node_id: PublicKey) -> InitFeatures {
229 InitFeatures::empty()
230 }
231}
232
233#[inline]
234fn to_ln_error<T: Display>(e: T, msg: &str) -> LightningError {
235 LightningError {
236 err: format!("{msg}: {e}"),
237 action: lightning::ln::msgs::ErrorAction::DisconnectPeer { msg: None },
238 }
239}
240
241#[cfg(test)]
242mod tests {
243 use secp256k1_zkp::{SecretKey, SECP256K1};
244
245 use crate::{
246 segmentation::{SegmentChunk, SegmentStart},
247 AcceptDlc, OfferDlc, SignDlc,
248 };
249
250 use super::*;
251
252 fn some_pk() -> PublicKey {
253 PublicKey::from_secret_key(
254 SECP256K1,
255 &SecretKey::from_slice(&secp256k1_zkp::constants::ONE).unwrap(),
256 )
257 }
258
259 macro_rules! read_test {
260 ($type: ty, $input: ident) => {
261 let msg: $type = serde_json::from_str(&$input).unwrap();
262 handler_read_test(msg);
263 };
264 }
265
266 fn handler_read_test<T: Writeable + Readable + PartialEq + Type + std::fmt::Debug>(msg: T) {
267 let mut buf = Vec::new();
268 msg.type_id()
269 .write(&mut buf)
270 .expect("Error writing type id");
271 msg.write(&mut buf).expect("Error writing message");
272 let handler = MessageHandler::new();
273 let mut reader = Cursor::new(&mut buf);
274 let message_type =
275 <u16 as Readable>::read(&mut reader).expect("to be able to read the type prefix.");
276 handler
277 .read(message_type, &mut reader)
278 .expect("to be able to read the message")
279 .expect("to have a message");
280 }
281
282 #[test]
283 fn read_offer_test() {
284 let input = include_str!("./test_inputs/offer_msg.json");
285 read_test!(OfferDlc, input);
286 }
287
288 #[test]
289 fn read_accept_test() {
290 let input = include_str!("./test_inputs/accept_msg.json");
291 read_test!(AcceptDlc, input);
292 }
293
294 #[test]
295 fn read_sign_test() {
296 let input = include_str!("./test_inputs/sign_msg.json");
297 read_test!(SignDlc, input);
298 }
299
300 #[test]
301 fn read_segment_start_test() {
302 let input = include_str!("./test_inputs/segment_start_msg.json");
303 read_test!(SegmentStart, input);
304 }
305
306 #[test]
307 fn read_segment_chunk_test() {
308 let input = include_str!("./test_inputs/segment_chunk_msg.json");
309 read_test!(SegmentChunk, input);
310 }
311
312 #[test]
313 fn read_unknown_message_returns_none() {
314 let handler = MessageHandler::new();
315 let mut buf = &[0u8; 10];
316 let mut reader = Cursor::new(&mut buf);
317 let message_type = 0;
318
319 assert!(handler
320 .read(message_type, &mut reader)
321 .expect("should not error on unknown messages")
322 .is_none());
323 }
324
325 #[test]
326 fn send_regular_message_test() {
327 let input = include_str!("./test_inputs/offer_msg.json");
328 let msg: OfferDlc = serde_json::from_str(input).unwrap();
329 let handler = MessageHandler::new();
330 handler.send_message(some_pk(), Message::Offer(msg));
331 assert_eq!(handler.msg_events.lock().unwrap().len(), 1);
332 }
333
334 #[test]
335 fn send_large_message_segmented_test() {
336 let input = include_str!("./test_inputs/accept_msg.json");
337 let msg: AcceptDlc = serde_json::from_str(input).unwrap();
338 let handler = MessageHandler::new();
339 handler.send_message(some_pk(), Message::Accept(msg));
340 assert!(handler.msg_events.lock().unwrap().len() > 1);
341 }
342
343 #[test]
344 fn is_empty_after_clearing_msg_events_test() {
345 let input = include_str!("./test_inputs/accept_msg.json");
346 let msg: AcceptDlc = serde_json::from_str(input).unwrap();
347 let handler = MessageHandler::new();
348 handler.send_message(some_pk(), Message::Accept(msg));
349 handler.get_and_clear_pending_msg();
350 assert!(!handler.has_pending_messages());
351 }
352
353 #[test]
354 fn send_message_with_dlc_input_test() {
355 let input = include_str!("./test_inputs/offer_msg_with_dlc_input.json");
356 let msg: OfferDlc = serde_json::from_str(input).unwrap();
357 let handler = MessageHandler::new();
358 handler.send_message(some_pk(), Message::Offer(msg));
359 handler.get_and_clear_pending_msg();
360 assert!(!handler.has_pending_messages());
361 }
362
363 #[test]
364 #[ignore = "Need to regenerate the segment start and chunk messages for an accept contract with optional funding input"]
365 fn rebuilds_segments_properly_test() {
366 let input1 = include_str!("./test_inputs/segment_start_msg.json");
367 let input2 = include_str!("./test_inputs/segment_chunk_msg.json");
368 let segment_start: SegmentStart = serde_json::from_str(input1).unwrap();
369 let segment_chunk: SegmentChunk = serde_json::from_str(input2).unwrap();
370
371 let handler = MessageHandler::new();
372 handler
373 .handle_custom_message(WireMessage::SegmentStart(segment_start), some_pk())
374 .expect("to be able to process segment start");
375 handler
376 .handle_custom_message(WireMessage::SegmentChunk(segment_chunk), some_pk())
377 .expect("to be able to process segment start");
378 let msg = handler.get_and_clear_received_messages();
379 assert_eq!(1, msg.len());
380 if let (_, Message::Accept(_)) = msg[0] {
381 } else {
382 panic!("Expected an accept message");
383 }
384 }
385}