ddk_messages/
message_handler.rs

1//! Struct used to help send and receive DLC related messages.
2
3use std::{
4    collections::{HashMap, VecDeque},
5    fmt::Display,
6    sync::Mutex,
7};
8
9use lightning::types::features::{InitFeatures, NodeFeatures};
10use lightning::{
11    io::Cursor,
12    ln::{
13        msgs::{DecodeError, LightningError},
14        peer_handler::CustomMessageHandler,
15        wire::{CustomMessageReader, Type},
16    },
17    util::ser::{Readable, Writeable, MAX_BUF_SIZE},
18};
19use secp256k1_zkp::PublicKey;
20
21use crate::{
22    segmentation::{get_segments, segment_reader::SegmentReader},
23    Message, WireMessage,
24};
25
26/// MessageHandler is used to send and receive messages through the custom
27/// message handling mechanism of the LDK. It also handles message segmentation
28/// by splitting large messages when sending and re-constructing them when
29/// receiving.
30pub struct MessageHandler {
31    msg_events: Mutex<VecDeque<(PublicKey, WireMessage)>>,
32    msg_received: Mutex<Vec<(PublicKey, Message)>>,
33    segment_readers: Mutex<HashMap<PublicKey, SegmentReader>>,
34}
35
36impl Default for MessageHandler {
37    fn default() -> Self {
38        Self::new()
39    }
40}
41
42impl MessageHandler {
43    /// Creates a new instance of a [`MessageHandler`]
44    pub fn new() -> Self {
45        MessageHandler {
46            msg_events: Mutex::new(VecDeque::new()),
47            msg_received: Mutex::new(Vec::new()),
48            segment_readers: Mutex::new(HashMap::new()),
49        }
50    }
51
52    /// Returns the messages received by the message handler and empty the
53    /// receiving buffer.
54    pub fn get_and_clear_received_messages(&self) -> Vec<(PublicKey, Message)> {
55        let mut ret = Vec::new();
56        std::mem::swap(&mut *self.msg_received.lock().unwrap(), &mut ret);
57        ret
58    }
59
60    /// Send a message to the peer with given node id. Not that the message is not
61    /// sent right away, but only when the LDK
62    /// [`lightning::ln::peer_handler::PeerManager::process_events`] is next called.
63    pub fn send_message(&self, node_id: PublicKey, msg: Message) {
64        if msg.serialized_length() > MAX_BUF_SIZE {
65            let (seg_start, seg_chunks) = get_segments(msg.encode(), msg.type_id());
66            let mut msg_events = self.msg_events.lock().unwrap();
67            msg_events.push_back((node_id, WireMessage::SegmentStart(seg_start)));
68            for chunk in seg_chunks {
69                msg_events.push_back((node_id, WireMessage::SegmentChunk(chunk)));
70            }
71        } else {
72            self.msg_events
73                .lock()
74                .unwrap()
75                .push_back((node_id, WireMessage::Message(msg)));
76        }
77    }
78
79    /// Returns whether the message handler has any message to be sent.
80    pub fn has_pending_messages(&self) -> bool {
81        !self.msg_events.lock().unwrap().is_empty()
82    }
83}
84
85macro_rules! handle_read_dlc_messages {
86    ($msg_type:ident, $buffer:ident, $(($type_id:ident, $variant:ident)),*) => {{
87        let decoded = match $msg_type {
88            $(
89                $crate::$type_id => Message::$variant(Readable::read($buffer)?),
90            )*
91            _ => return Ok(None),
92        };
93        Ok(Some(WireMessage::Message(decoded)))
94    }};
95}
96
97/// Parses a DLC message from a buffer.
98pub fn read_dlc_message<R: ::lightning::io::Read>(
99    msg_type: u16,
100    buffer: &mut R,
101) -> Result<Option<WireMessage>, DecodeError> {
102    handle_read_dlc_messages!(
103        msg_type,
104        buffer,
105        (OFFER_TYPE, Offer),
106        (ACCEPT_TYPE, Accept),
107        (SIGN_TYPE, Sign),
108        (OFFER_CHANNEL_TYPE, OfferChannel),
109        (ACCEPT_CHANNEL_TYPE, AcceptChannel),
110        (SIGN_CHANNEL_TYPE, SignChannel),
111        (SETTLE_CHANNEL_OFFER_TYPE, SettleOffer),
112        (SETTLE_CHANNEL_ACCEPT_TYPE, SettleAccept),
113        (SETTLE_CHANNEL_CONFIRM_TYPE, SettleConfirm),
114        (SETTLE_CHANNEL_FINALIZE_TYPE, SettleFinalize),
115        (RENEW_CHANNEL_OFFER_TYPE, RenewOffer),
116        (RENEW_CHANNEL_ACCEPT_TYPE, RenewAccept),
117        (RENEW_CHANNEL_CONFIRM_TYPE, RenewConfirm),
118        (RENEW_CHANNEL_FINALIZE_TYPE, RenewFinalize),
119        (COLLABORATIVE_CLOSE_OFFER_TYPE, CollaborativeCloseOffer),
120        (REJECT, Reject)
121    )
122}
123
124/// Implementation of the `CustomMessageReader` trait is required to decode
125/// custom messages in the LDK.
126impl CustomMessageReader for MessageHandler {
127    type CustomMessage = WireMessage;
128    fn read<R: ::lightning::io::Read>(
129        &self,
130        msg_type: u16,
131        buffer: &mut R,
132    ) -> Result<Option<WireMessage>, DecodeError> {
133        let decoded = match msg_type {
134            crate::segmentation::SEGMENT_START_TYPE => {
135                WireMessage::SegmentStart(Readable::read(buffer)?)
136            }
137            crate::segmentation::SEGMENT_CHUNK_TYPE => {
138                WireMessage::SegmentChunk(Readable::read(buffer)?)
139            }
140            _ => return read_dlc_message(msg_type, buffer),
141        };
142
143        Ok(Some(decoded))
144    }
145}
146
147/// Implementation of the `CustomMessageHandler` trait is required to handle
148/// custom messages in the LDK.
149impl CustomMessageHandler for MessageHandler {
150    fn peer_connected(
151        &self,
152        _their_node_id: PublicKey,
153        _msg: &lightning::ln::msgs::Init,
154        _inbound: bool,
155    ) -> Result<(), ()> {
156        Ok(())
157    }
158
159    fn peer_disconnected(&self, _their_node_id: PublicKey) {}
160
161    fn handle_custom_message(
162        &self,
163        msg: WireMessage,
164        org: PublicKey,
165    ) -> Result<(), LightningError> {
166        let mut segment_readers = self.segment_readers.lock().unwrap();
167        let segment_reader = segment_readers.entry(org).or_default();
168
169        if segment_reader.expecting_chunk() {
170            match msg {
171                WireMessage::SegmentChunk(s) => {
172                    if let Some(msg) = segment_reader
173                        .process_segment_chunk(s)
174                        .map_err(|e| to_ln_error(e, "Error processing segment chunk"))?
175                    {
176                        let mut buf = Cursor::new(msg);
177                        let message_type = <u16 as Readable>::read(&mut buf).map_err(|e| {
178                            to_ln_error(e, "Could not reconstruct message from segments")
179                        })?;
180                        if let WireMessage::Message(m) = self
181                            .read(message_type, &mut buf)
182                            .map_err(|e| {
183                                to_ln_error(e, "Could not reconstruct message from segments")
184                            })?
185                            .expect("to have a message")
186                        {
187                            self.msg_received.lock().unwrap().push((org, m));
188                        } else {
189                            return Err(to_ln_error(
190                                "Unexpected message type",
191                                &message_type.to_string(),
192                            ));
193                        }
194                    }
195                    return Ok(());
196                }
197                _ => {
198                    // We were expecting a segment chunk but received something
199                    // else, we reset the state.
200                    segment_reader.reset();
201                }
202            }
203        }
204
205        match msg {
206            WireMessage::Message(m) => self.msg_received.lock().unwrap().push((org, m)),
207            WireMessage::SegmentStart(s) => segment_reader
208                .process_segment_start(s)
209                .map_err(|e| to_ln_error(e, "Error processing segment start"))?,
210            WireMessage::SegmentChunk(_) => {
211                return Err(LightningError {
212                    err: "Received a SegmentChunk while not expecting one.".to_string(),
213                    action: lightning::ln::msgs::ErrorAction::DisconnectPeer { msg: None },
214                });
215            }
216        };
217        Ok(())
218    }
219
220    fn get_and_clear_pending_msg(&self) -> Vec<(PublicKey, Self::CustomMessage)> {
221        self.msg_events.lock().unwrap().drain(..).collect()
222    }
223
224    fn provided_node_features(&self) -> NodeFeatures {
225        NodeFeatures::empty()
226    }
227
228    fn provided_init_features(&self, _their_node_id: PublicKey) -> InitFeatures {
229        InitFeatures::empty()
230    }
231}
232
233#[inline]
234fn to_ln_error<T: Display>(e: T, msg: &str) -> LightningError {
235    LightningError {
236        err: format!("{msg}: {e}"),
237        action: lightning::ln::msgs::ErrorAction::DisconnectPeer { msg: None },
238    }
239}
240
241#[cfg(test)]
242mod tests {
243    use secp256k1_zkp::{SecretKey, SECP256K1};
244
245    use crate::{
246        segmentation::{SegmentChunk, SegmentStart},
247        AcceptDlc, OfferDlc, SignDlc,
248    };
249
250    use super::*;
251
252    fn some_pk() -> PublicKey {
253        PublicKey::from_secret_key(
254            SECP256K1,
255            &SecretKey::from_slice(&secp256k1_zkp::constants::ONE).unwrap(),
256        )
257    }
258
259    macro_rules! read_test {
260        ($type: ty, $input: ident) => {
261            let msg: $type = serde_json::from_str(&$input).unwrap();
262            handler_read_test(msg);
263        };
264    }
265
266    fn handler_read_test<T: Writeable + Readable + PartialEq + Type + std::fmt::Debug>(msg: T) {
267        let mut buf = Vec::new();
268        msg.type_id()
269            .write(&mut buf)
270            .expect("Error writing type id");
271        msg.write(&mut buf).expect("Error writing message");
272        let handler = MessageHandler::new();
273        let mut reader = Cursor::new(&mut buf);
274        let message_type =
275            <u16 as Readable>::read(&mut reader).expect("to be able to read the type prefix.");
276        handler
277            .read(message_type, &mut reader)
278            .expect("to be able to read the message")
279            .expect("to have a message");
280    }
281
282    #[test]
283    fn read_offer_test() {
284        let input = include_str!("./test_inputs/offer_msg.json");
285        read_test!(OfferDlc, input);
286    }
287
288    #[test]
289    fn read_accept_test() {
290        let input = include_str!("./test_inputs/accept_msg.json");
291        read_test!(AcceptDlc, input);
292    }
293
294    #[test]
295    fn read_sign_test() {
296        let input = include_str!("./test_inputs/sign_msg.json");
297        read_test!(SignDlc, input);
298    }
299
300    #[test]
301    fn read_segment_start_test() {
302        let input = include_str!("./test_inputs/segment_start_msg.json");
303        read_test!(SegmentStart, input);
304    }
305
306    #[test]
307    fn read_segment_chunk_test() {
308        let input = include_str!("./test_inputs/segment_chunk_msg.json");
309        read_test!(SegmentChunk, input);
310    }
311
312    #[test]
313    fn read_unknown_message_returns_none() {
314        let handler = MessageHandler::new();
315        let mut buf = &[0u8; 10];
316        let mut reader = Cursor::new(&mut buf);
317        let message_type = 0;
318
319        assert!(handler
320            .read(message_type, &mut reader)
321            .expect("should not error on unknown messages")
322            .is_none());
323    }
324
325    #[test]
326    fn send_regular_message_test() {
327        let input = include_str!("./test_inputs/offer_msg.json");
328        let msg: OfferDlc = serde_json::from_str(input).unwrap();
329        let handler = MessageHandler::new();
330        handler.send_message(some_pk(), Message::Offer(msg));
331        assert_eq!(handler.msg_events.lock().unwrap().len(), 1);
332    }
333
334    #[test]
335    fn send_large_message_segmented_test() {
336        let input = include_str!("./test_inputs/accept_msg.json");
337        let msg: AcceptDlc = serde_json::from_str(input).unwrap();
338        let handler = MessageHandler::new();
339        handler.send_message(some_pk(), Message::Accept(msg));
340        assert!(handler.msg_events.lock().unwrap().len() > 1);
341    }
342
343    #[test]
344    fn is_empty_after_clearing_msg_events_test() {
345        let input = include_str!("./test_inputs/accept_msg.json");
346        let msg: AcceptDlc = serde_json::from_str(input).unwrap();
347        let handler = MessageHandler::new();
348        handler.send_message(some_pk(), Message::Accept(msg));
349        handler.get_and_clear_pending_msg();
350        assert!(!handler.has_pending_messages());
351    }
352
353    #[test]
354    fn send_message_with_dlc_input_test() {
355        let input = include_str!("./test_inputs/offer_msg_with_dlc_input.json");
356        let msg: OfferDlc = serde_json::from_str(input).unwrap();
357        let handler = MessageHandler::new();
358        handler.send_message(some_pk(), Message::Offer(msg));
359        handler.get_and_clear_pending_msg();
360        assert!(!handler.has_pending_messages());
361    }
362
363    #[test]
364    #[ignore = "Need to regenerate the segment start and chunk messages for an accept contract with optional funding input"]
365    fn rebuilds_segments_properly_test() {
366        let input1 = include_str!("./test_inputs/segment_start_msg.json");
367        let input2 = include_str!("./test_inputs/segment_chunk_msg.json");
368        let segment_start: SegmentStart = serde_json::from_str(input1).unwrap();
369        let segment_chunk: SegmentChunk = serde_json::from_str(input2).unwrap();
370
371        let handler = MessageHandler::new();
372        handler
373            .handle_custom_message(WireMessage::SegmentStart(segment_start), some_pk())
374            .expect("to be able to process segment start");
375        handler
376            .handle_custom_message(WireMessage::SegmentChunk(segment_chunk), some_pk())
377            .expect("to be able to process segment start");
378        let msg = handler.get_and_clear_received_messages();
379        assert_eq!(1, msg.len());
380        if let (_, Message::Accept(_)) = msg[0] {
381        } else {
382            panic!("Expected an accept message");
383        }
384    }
385}