1#![deny(rustdoc::broken_intra_doc_links)]
26#![deny(rustdoc::private_intra_doc_links)]
27#![deny(missing_docs)]
28#![cfg_attr(docsrs, feature(doc_auto_cfg))]
29
30use bitcoin::secp256k1::PublicKey;
31
32use tokio::net::TcpStream;
33use tokio::sync::mpsc;
34use tokio::time;
35
36use lightning::ln::msgs::SocketAddress;
37use lightning::ln::peer_handler;
38use lightning::ln::peer_handler::APeerManager;
39use lightning::ln::peer_handler::SocketDescriptor as LnSocketTrait;
40
41use std::future::Future;
42use std::hash::Hash;
43use std::net::SocketAddr;
44use std::net::TcpStream as StdTcpStream;
45use std::ops::Deref;
46use std::pin::Pin;
47use std::sync::atomic::{AtomicU64, Ordering};
48use std::sync::{Arc, Mutex};
49use std::task::{self, Poll};
50use std::time::Duration;
51
52static ID_COUNTER: AtomicU64 = AtomicU64::new(0);
53
54pub(crate) enum SelectorOutput {
59 A(Option<()>),
60 B(Option<()>),
61 C(tokio::io::Result<()>),
62}
63
64pub(crate) struct TwoSelector<
65 A: Future<Output = Option<()>> + Unpin,
66 B: Future<Output = Option<()>> + Unpin,
67> {
68 pub a: A,
69 pub b: B,
70}
71
72impl<A: Future<Output = Option<()>> + Unpin, B: Future<Output = Option<()>> + Unpin> Future
73 for TwoSelector<A, B>
74{
75 type Output = SelectorOutput;
76 fn poll(mut self: Pin<&mut Self>, ctx: &mut task::Context<'_>) -> Poll<SelectorOutput> {
77 match Pin::new(&mut self.a).poll(ctx) {
78 Poll::Ready(res) => {
79 return Poll::Ready(SelectorOutput::A(res));
80 },
81 Poll::Pending => {},
82 }
83 match Pin::new(&mut self.b).poll(ctx) {
84 Poll::Ready(res) => {
85 return Poll::Ready(SelectorOutput::B(res));
86 },
87 Poll::Pending => {},
88 }
89 Poll::Pending
90 }
91}
92
93pub(crate) struct ThreeSelector<
94 A: Future<Output = Option<()>> + Unpin,
95 B: Future<Output = Option<()>> + Unpin,
96 C: Future<Output = tokio::io::Result<()>> + Unpin,
97> {
98 pub a: A,
99 pub b: B,
100 pub c: C,
101}
102
103impl<
104 A: Future<Output = Option<()>> + Unpin,
105 B: Future<Output = Option<()>> + Unpin,
106 C: Future<Output = tokio::io::Result<()>> + Unpin,
107 > Future for ThreeSelector<A, B, C>
108{
109 type Output = SelectorOutput;
110 fn poll(mut self: Pin<&mut Self>, ctx: &mut task::Context<'_>) -> Poll<SelectorOutput> {
111 match Pin::new(&mut self.a).poll(ctx) {
112 Poll::Ready(res) => {
113 return Poll::Ready(SelectorOutput::A(res));
114 },
115 Poll::Pending => {},
116 }
117 match Pin::new(&mut self.b).poll(ctx) {
118 Poll::Ready(res) => {
119 return Poll::Ready(SelectorOutput::B(res));
120 },
121 Poll::Pending => {},
122 }
123 match Pin::new(&mut self.c).poll(ctx) {
124 Poll::Ready(res) => {
125 return Poll::Ready(SelectorOutput::C(res));
126 },
127 Poll::Pending => {},
128 }
129 Poll::Pending
130 }
131}
132
133struct Connection {
137 writer: Option<Arc<TcpStream>>,
138 write_avail: mpsc::Sender<()>,
149 read_waker: mpsc::Sender<()>,
154 read_paused: bool,
155 rl_requested_disconnect: bool,
156 id: u64,
157}
158impl Connection {
159 async fn poll_event_process<PM: Deref + 'static + Send + Sync>(
160 peer_manager: PM, mut event_receiver: mpsc::Receiver<()>,
161 ) where
162 PM::Target: APeerManager<Descriptor = SocketDescriptor>,
163 {
164 loop {
165 if event_receiver.recv().await.is_none() {
166 return;
167 }
168 peer_manager.as_ref().process_events();
169 }
170 }
171
172 async fn schedule_read<PM: Deref + 'static + Send + Sync + Clone>(
173 peer_manager: PM, us: Arc<Mutex<Self>>, reader: Arc<TcpStream>,
174 mut read_wake_receiver: mpsc::Receiver<()>, mut write_avail_receiver: mpsc::Receiver<()>,
175 ) where
176 PM::Target: APeerManager<Descriptor = SocketDescriptor>,
177 {
178 let (event_waker, event_receiver) = mpsc::channel(1);
180 tokio::spawn(Self::poll_event_process(peer_manager.clone(), event_receiver));
181
182 let mut buf = [0; 4096];
185
186 let mut our_descriptor = SocketDescriptor::new(us.clone());
187 enum Disconnect {
189 CloseConnection,
194 PeerDisconnected,
199 }
200 let disconnect_type = loop {
201 let read_paused = {
202 let us_lock = us.lock().unwrap();
203 if us_lock.rl_requested_disconnect {
204 break Disconnect::CloseConnection;
205 }
206 us_lock.read_paused
207 };
208 let select_result = if read_paused {
210 TwoSelector {
211 a: Box::pin(write_avail_receiver.recv()),
212 b: Box::pin(read_wake_receiver.recv()),
213 }
214 .await
215 } else {
216 ThreeSelector {
217 a: Box::pin(write_avail_receiver.recv()),
218 b: Box::pin(read_wake_receiver.recv()),
219 c: Box::pin(reader.readable()),
220 }
221 .await
222 };
223 match select_result {
224 SelectorOutput::A(v) => {
225 assert!(v.is_some()); if peer_manager.as_ref().write_buffer_space_avail(&mut our_descriptor).is_err()
227 {
228 break Disconnect::CloseConnection;
229 }
230 },
231 SelectorOutput::B(some) => {
232 debug_assert!(some.is_some());
236 },
237 SelectorOutput::C(res) => {
238 if res.is_err() {
239 break Disconnect::PeerDisconnected;
240 }
241 match reader.try_read(&mut buf) {
242 Ok(0) => break Disconnect::PeerDisconnected,
243 Ok(len) => {
244 let read_res =
245 peer_manager.as_ref().read_event(&mut our_descriptor, &buf[0..len]);
246 let mut us_lock = us.lock().unwrap();
247 match read_res {
248 Ok(pause_read) => {
249 if pause_read {
250 us_lock.read_paused = true;
251 }
252 },
253 Err(_) => break Disconnect::CloseConnection,
254 }
255 },
256 Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {
257 },
260 Err(_) => break Disconnect::PeerDisconnected,
261 }
262 },
263 }
264 let _ = event_waker.try_send(());
265
266 let _ = tokio::task::yield_now().await;
272 };
273 us.lock().unwrap().writer.take();
274 if let Disconnect::PeerDisconnected = disconnect_type {
275 peer_manager.as_ref().socket_disconnected(&our_descriptor);
276 peer_manager.as_ref().process_events();
277 }
278 }
279
280 fn new(
281 stream: StdTcpStream,
282 ) -> (Arc<TcpStream>, mpsc::Receiver<()>, mpsc::Receiver<()>, Arc<Mutex<Self>>) {
283 let (write_avail, write_receiver) = mpsc::channel(1);
289 let (read_waker, read_receiver) = mpsc::channel(1);
293 stream.set_nonblocking(true).unwrap();
294 let tokio_stream = Arc::new(TcpStream::from_std(stream).unwrap());
295
296 let id = ID_COUNTER.fetch_add(1, Ordering::AcqRel);
297 let writer = Some(Arc::clone(&tokio_stream));
298 let conn = Arc::new(Mutex::new(Self {
299 writer,
300 write_avail,
301 read_waker,
302 read_paused: false,
303 rl_requested_disconnect: false,
304 id,
305 }));
306 (tokio_stream, write_receiver, read_receiver, conn)
307 }
308}
309
310fn get_addr_from_stream(stream: &StdTcpStream) -> Option<SocketAddress> {
311 match stream.peer_addr() {
312 Ok(SocketAddr::V4(sockaddr)) => {
313 Some(SocketAddress::TcpIpV4 { addr: sockaddr.ip().octets(), port: sockaddr.port() })
314 },
315 Ok(SocketAddr::V6(sockaddr)) => {
316 Some(SocketAddress::TcpIpV6 { addr: sockaddr.ip().octets(), port: sockaddr.port() })
317 },
318 Err(_) => None,
319 }
320}
321
322pub fn setup_inbound<PM: Deref + 'static + Send + Sync + Clone>(
329 peer_manager: PM, stream: StdTcpStream,
330) -> impl std::future::Future<Output = ()>
331where
332 PM::Target: APeerManager<Descriptor = SocketDescriptor>,
333{
334 let remote_addr = get_addr_from_stream(&stream);
335 let (reader, write_receiver, read_receiver, us) = Connection::new(stream);
336 #[cfg(test)]
337 let last_us = Arc::clone(&us);
338
339 let handle_opt = if peer_manager
340 .as_ref()
341 .new_inbound_connection(SocketDescriptor::new(us.clone()), remote_addr)
342 .is_ok()
343 {
344 let handle = tokio::spawn(Connection::schedule_read(
345 peer_manager,
346 us,
347 reader,
348 read_receiver,
349 write_receiver,
350 ));
351 Some(handle)
352 } else {
353 None
356 };
357
358 async move {
359 if let Some(handle) = handle_opt {
360 if let Err(e) = handle.await {
361 assert!(e.is_cancelled());
362 } else {
363 #[cfg(test)]
369 debug_assert!(Arc::try_unwrap(last_us).is_ok());
370 }
371 }
372 }
373}
374
375pub fn setup_outbound<PM: Deref + 'static + Send + Sync + Clone>(
383 peer_manager: PM, their_node_id: PublicKey, stream: StdTcpStream,
384) -> impl std::future::Future<Output = ()>
385where
386 PM::Target: APeerManager<Descriptor = SocketDescriptor>,
387{
388 let remote_addr = get_addr_from_stream(&stream);
389 let (reader, mut write_receiver, read_receiver, us) = Connection::new(stream);
390 #[cfg(test)]
391 let last_us = Arc::clone(&us);
392 let handle_opt = if let Ok(initial_send) = peer_manager.as_ref().new_outbound_connection(
393 their_node_id,
394 SocketDescriptor::new(us.clone()),
395 remote_addr,
396 ) {
397 let handle = tokio::spawn(async move {
398 let send_fut = async {
404 loop {
405 match SocketDescriptor::new(us.clone()).send_data(&initial_send, true) {
406 v if v == initial_send.len() => break Ok(()),
407 0 => {
408 write_receiver.recv().await;
409 },
413 _ => {
414 eprintln!("Failed to write first full message to socket!");
415 peer_manager
416 .as_ref()
417 .socket_disconnected(&SocketDescriptor::new(Arc::clone(&us)));
418 break Err(());
419 },
420 }
421 }
422 };
423 let timeout_send_fut = tokio::time::timeout(Duration::from_millis(100), send_fut);
424 if let Ok(Ok(())) = timeout_send_fut.await {
425 Connection::schedule_read(peer_manager, us, reader, read_receiver, write_receiver)
426 .await;
427 }
428 });
429 Some(handle)
430 } else {
431 None
434 };
435
436 async move {
437 if let Some(handle) = handle_opt {
438 if let Err(e) = handle.await {
439 assert!(e.is_cancelled());
440 } else {
441 #[cfg(test)]
447 debug_assert!(Arc::try_unwrap(last_us).is_ok());
448 }
449 }
450 }
451}
452
453pub async fn connect_outbound<PM: Deref + 'static + Send + Sync + Clone>(
465 peer_manager: PM, their_node_id: PublicKey, addr: SocketAddr,
466) -> Option<impl std::future::Future<Output = ()>>
467where
468 PM::Target: APeerManager<Descriptor = SocketDescriptor>,
469{
470 let connect_fut = async { TcpStream::connect(&addr).await.map(|s| s.into_std().unwrap()) };
471 if let Ok(Ok(stream)) = time::timeout(Duration::from_secs(10), connect_fut).await {
472 Some(setup_outbound(peer_manager, their_node_id, stream))
473 } else {
474 None
475 }
476}
477
478const SOCK_WAKER_VTABLE: task::RawWakerVTable = task::RawWakerVTable::new(
479 clone_socket_waker,
480 wake_socket_waker,
481 wake_socket_waker_by_ref,
482 drop_socket_waker,
483);
484
485fn clone_socket_waker(orig_ptr: *const ()) -> task::RawWaker {
486 let new_waker = unsafe { Arc::from_raw(orig_ptr as *const mpsc::Sender<()>) };
487 let res = write_avail_to_waker(&new_waker);
488 let _ = Arc::into_raw(new_waker);
490 res
491}
492fn wake_socket_waker(orig_ptr: *const ()) {
497 let sender = unsafe { &mut *(orig_ptr as *mut mpsc::Sender<()>) };
498 let _ = sender.try_send(());
499 drop_socket_waker(orig_ptr);
500}
501fn wake_socket_waker_by_ref(orig_ptr: *const ()) {
502 let sender_ptr = orig_ptr as *const mpsc::Sender<()>;
503 let sender = unsafe { &*sender_ptr };
504 let _ = sender.try_send(());
505}
506fn drop_socket_waker(orig_ptr: *const ()) {
507 let _orig_arc = unsafe { Arc::from_raw(orig_ptr as *mut mpsc::Sender<()>) };
508 }
510fn write_avail_to_waker(sender: &Arc<mpsc::Sender<()>>) -> task::RawWaker {
511 let new_ptr = Arc::into_raw(Arc::clone(&sender));
512 task::RawWaker::new(new_ptr as *const (), &SOCK_WAKER_VTABLE)
513}
514
515pub struct SocketDescriptor {
518 conn: Arc<Mutex<Connection>>,
519 write_avail_sender: Arc<mpsc::Sender<()>>,
524 id: u64,
525}
526impl SocketDescriptor {
527 fn new(conn: Arc<Mutex<Connection>>) -> Self {
528 let (id, write_avail_sender) = {
529 let us = conn.lock().unwrap();
530 (us.id, Arc::new(us.write_avail.clone()))
531 };
532 Self { conn, id, write_avail_sender }
533 }
534}
535impl peer_handler::SocketDescriptor for SocketDescriptor {
536 fn send_data(&mut self, data: &[u8], resume_read: bool) -> usize {
537 let mut us = self.conn.lock().unwrap();
542 if us.writer.is_none() {
543 return 0;
545 }
546
547 if resume_read && us.read_paused {
548 us.read_paused = false;
552 let _ = us.read_waker.try_send(());
553 }
554 if data.is_empty() {
555 return 0;
556 }
557 let waker =
558 unsafe { task::Waker::from_raw(write_avail_to_waker(&self.write_avail_sender)) };
559 let mut ctx = task::Context::from_waker(&waker);
560 let mut written_len = 0;
561 loop {
562 match us.writer.as_ref().unwrap().poll_write_ready(&mut ctx) {
563 task::Poll::Ready(Ok(())) => {
564 match us.writer.as_ref().unwrap().try_write(&data[written_len..]) {
565 Ok(res) => {
566 debug_assert_ne!(res, 0);
567 written_len += res;
568 if written_len == data.len() {
569 return written_len;
570 }
571 },
572 Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
573 continue;
574 },
575 Err(_) => return written_len,
576 }
577 },
578 task::Poll::Ready(Err(_)) => return written_len,
579 task::Poll::Pending => {
580 us.read_paused = true;
584 let _ = us.read_waker.try_send(());
587 return written_len;
588 },
589 }
590 }
591 }
592
593 fn disconnect_socket(&mut self) {
594 let mut us = self.conn.lock().unwrap();
595 us.rl_requested_disconnect = true;
596 let _ = us.write_avail.try_send(());
598 }
599}
600impl Clone for SocketDescriptor {
601 fn clone(&self) -> Self {
602 Self {
603 conn: Arc::clone(&self.conn),
604 id: self.id,
605 write_avail_sender: Arc::clone(&self.write_avail_sender),
606 }
607 }
608}
609impl Eq for SocketDescriptor {}
610impl PartialEq for SocketDescriptor {
611 fn eq(&self, o: &Self) -> bool {
612 self.id == o.id
613 }
614}
615impl Hash for SocketDescriptor {
616 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
617 self.id.hash(state);
618 }
619}
620
621#[cfg(test)]
622mod tests {
623 use bitcoin::constants::ChainHash;
624 use bitcoin::secp256k1::{PublicKey, Secp256k1, SecretKey};
625 use bitcoin::Network;
626 use lightning::events::*;
627 use lightning::ln::msgs::*;
628 use lightning::ln::peer_handler::{IgnoringMessageHandler, MessageHandler, PeerManager};
629 use lightning::routing::gossip::NodeId;
630 use lightning::types::features::*;
631 use lightning::util::test_utils::TestNodeSigner;
632
633 use tokio::sync::mpsc;
634
635 use std::mem;
636 use std::sync::atomic::{AtomicBool, Ordering};
637 use std::sync::{Arc, Mutex};
638 use std::time::Duration;
639
640 pub struct TestLogger();
641 impl lightning::util::logger::Logger for TestLogger {
642 fn log(&self, record: lightning::util::logger::Record) {
643 println!(
644 "{:<5} [{} : {}, {}] {}",
645 record.level.to_string(),
646 record.module_path,
647 record.file,
648 record.line,
649 record.args
650 );
651 }
652 }
653
654 struct MsgHandler {
655 expected_pubkey: PublicKey,
656 pubkey_connected: mpsc::Sender<()>,
657 pubkey_disconnected: mpsc::Sender<()>,
658 disconnected_flag: AtomicBool,
659 msg_events: Mutex<Vec<MessageSendEvent>>,
660 }
661 impl RoutingMessageHandler for MsgHandler {
662 fn handle_node_announcement(
663 &self, _their_node_id: Option<PublicKey>, _msg: &NodeAnnouncement,
664 ) -> Result<bool, LightningError> {
665 Ok(false)
666 }
667 fn handle_channel_announcement(
668 &self, _their_node_id: Option<PublicKey>, _msg: &ChannelAnnouncement,
669 ) -> Result<bool, LightningError> {
670 Ok(false)
671 }
672 fn handle_channel_update(
673 &self, _their_node_id: Option<PublicKey>, _msg: &ChannelUpdate,
674 ) -> Result<bool, LightningError> {
675 Ok(false)
676 }
677 fn get_next_channel_announcement(
678 &self, _starting_point: u64,
679 ) -> Option<(ChannelAnnouncement, Option<ChannelUpdate>, Option<ChannelUpdate>)> {
680 None
681 }
682 fn get_next_node_announcement(
683 &self, _starting_point: Option<&NodeId>,
684 ) -> Option<NodeAnnouncement> {
685 None
686 }
687 fn peer_connected(
688 &self, _their_node_id: PublicKey, _init_msg: &Init, _inbound: bool,
689 ) -> Result<(), ()> {
690 Ok(())
691 }
692 fn handle_reply_channel_range(
693 &self, _their_node_id: PublicKey, _msg: ReplyChannelRange,
694 ) -> Result<(), LightningError> {
695 Ok(())
696 }
697 fn handle_reply_short_channel_ids_end(
698 &self, _their_node_id: PublicKey, _msg: ReplyShortChannelIdsEnd,
699 ) -> Result<(), LightningError> {
700 Ok(())
701 }
702 fn handle_query_channel_range(
703 &self, _their_node_id: PublicKey, _msg: QueryChannelRange,
704 ) -> Result<(), LightningError> {
705 Ok(())
706 }
707 fn handle_query_short_channel_ids(
708 &self, _their_node_id: PublicKey, _msg: QueryShortChannelIds,
709 ) -> Result<(), LightningError> {
710 Ok(())
711 }
712 fn provided_node_features(&self) -> NodeFeatures {
713 NodeFeatures::empty()
714 }
715 fn provided_init_features(&self, _their_node_id: PublicKey) -> InitFeatures {
716 InitFeatures::empty()
717 }
718 fn processing_queue_high(&self) -> bool {
719 false
720 }
721 }
722 impl ChannelMessageHandler for MsgHandler {
723 fn handle_open_channel(&self, _their_node_id: PublicKey, _msg: &OpenChannel) {}
724 fn handle_accept_channel(&self, _their_node_id: PublicKey, _msg: &AcceptChannel) {}
725 fn handle_funding_created(&self, _their_node_id: PublicKey, _msg: &FundingCreated) {}
726 fn handle_funding_signed(&self, _their_node_id: PublicKey, _msg: &FundingSigned) {}
727 fn handle_channel_ready(&self, _their_node_id: PublicKey, _msg: &ChannelReady) {}
728 fn handle_shutdown(&self, _their_node_id: PublicKey, _msg: &Shutdown) {}
729 fn handle_closing_signed(&self, _their_node_id: PublicKey, _msg: &ClosingSigned) {}
730 fn handle_update_add_htlc(&self, _their_node_id: PublicKey, _msg: &UpdateAddHTLC) {}
731 fn handle_update_fulfill_htlc(&self, _their_node_id: PublicKey, _msg: &UpdateFulfillHTLC) {}
732 fn handle_update_fail_htlc(&self, _their_node_id: PublicKey, _msg: &UpdateFailHTLC) {}
733 fn handle_update_fail_malformed_htlc(
734 &self, _their_node_id: PublicKey, _msg: &UpdateFailMalformedHTLC,
735 ) {
736 }
737 fn handle_commitment_signed(&self, _their_node_id: PublicKey, _msg: &CommitmentSigned) {}
738 fn handle_revoke_and_ack(&self, _their_node_id: PublicKey, _msg: &RevokeAndACK) {}
739 fn handle_update_fee(&self, _their_node_id: PublicKey, _msg: &UpdateFee) {}
740 fn handle_announcement_signatures(
741 &self, _their_node_id: PublicKey, _msg: &AnnouncementSignatures,
742 ) {
743 }
744 fn handle_channel_update(&self, _their_node_id: PublicKey, _msg: &ChannelUpdate) {}
745 fn handle_open_channel_v2(&self, _their_node_id: PublicKey, _msg: &OpenChannelV2) {}
746 fn handle_accept_channel_v2(&self, _their_node_id: PublicKey, _msg: &AcceptChannelV2) {}
747 fn handle_stfu(&self, _their_node_id: PublicKey, _msg: &Stfu) {}
748 #[cfg(splicing)]
749 fn handle_splice_init(&self, _their_node_id: PublicKey, _msg: &SpliceInit) {}
750 #[cfg(splicing)]
751 fn handle_splice_ack(&self, _their_node_id: PublicKey, _msg: &SpliceAck) {}
752 #[cfg(splicing)]
753 fn handle_splice_locked(&self, _their_node_id: PublicKey, _msg: &SpliceLocked) {}
754 fn handle_tx_add_input(&self, _their_node_id: PublicKey, _msg: &TxAddInput) {}
755 fn handle_tx_add_output(&self, _their_node_id: PublicKey, _msg: &TxAddOutput) {}
756 fn handle_tx_remove_input(&self, _their_node_id: PublicKey, _msg: &TxRemoveInput) {}
757 fn handle_tx_remove_output(&self, _their_node_id: PublicKey, _msg: &TxRemoveOutput) {}
758 fn handle_tx_complete(&self, _their_node_id: PublicKey, _msg: &TxComplete) {}
759 fn handle_tx_signatures(&self, _their_node_id: PublicKey, _msg: &TxSignatures) {}
760 fn handle_tx_init_rbf(&self, _their_node_id: PublicKey, _msg: &TxInitRbf) {}
761 fn handle_tx_ack_rbf(&self, _their_node_id: PublicKey, _msg: &TxAckRbf) {}
762 fn handle_tx_abort(&self, _their_node_id: PublicKey, _msg: &TxAbort) {}
763 fn peer_disconnected(&self, their_node_id: PublicKey) {
764 if their_node_id == self.expected_pubkey {
765 self.disconnected_flag.store(true, Ordering::SeqCst);
766 self.pubkey_disconnected.clone().try_send(()).unwrap();
767 }
768 }
769 fn peer_connected(
770 &self, their_node_id: PublicKey, _init_msg: &Init, _inbound: bool,
771 ) -> Result<(), ()> {
772 if their_node_id == self.expected_pubkey {
773 self.pubkey_connected.clone().try_send(()).unwrap();
774 }
775 Ok(())
776 }
777 fn handle_channel_reestablish(&self, _their_node_id: PublicKey, _msg: &ChannelReestablish) {
778 }
779 fn handle_error(&self, _their_node_id: PublicKey, _msg: &ErrorMessage) {}
780 fn provided_node_features(&self) -> NodeFeatures {
781 NodeFeatures::empty()
782 }
783 fn provided_init_features(&self, _their_node_id: PublicKey) -> InitFeatures {
784 InitFeatures::empty()
785 }
786 fn get_chain_hashes(&self) -> Option<Vec<ChainHash>> {
787 Some(vec![ChainHash::using_genesis_block(Network::Testnet)])
788 }
789 fn message_received(&self) {}
790 }
791 impl MessageSendEventsProvider for MsgHandler {
792 fn get_and_clear_pending_msg_events(&self) -> Vec<MessageSendEvent> {
793 let mut ret = Vec::new();
794 mem::swap(&mut *self.msg_events.lock().unwrap(), &mut ret);
795 ret
796 }
797 }
798
799 fn make_tcp_connection() -> (std::net::TcpStream, std::net::TcpStream) {
800 if let Ok(listener) = std::net::TcpListener::bind("127.0.0.1:9735") {
801 (std::net::TcpStream::connect("127.0.0.1:9735").unwrap(), listener.accept().unwrap().0)
802 } else if let Ok(listener) = std::net::TcpListener::bind("127.0.0.1:19735") {
803 (std::net::TcpStream::connect("127.0.0.1:19735").unwrap(), listener.accept().unwrap().0)
804 } else if let Ok(listener) = std::net::TcpListener::bind("127.0.0.1:9997") {
805 (std::net::TcpStream::connect("127.0.0.1:9997").unwrap(), listener.accept().unwrap().0)
806 } else if let Ok(listener) = std::net::TcpListener::bind("127.0.0.1:9998") {
807 (std::net::TcpStream::connect("127.0.0.1:9998").unwrap(), listener.accept().unwrap().0)
808 } else if let Ok(listener) = std::net::TcpListener::bind("127.0.0.1:9999") {
809 (std::net::TcpStream::connect("127.0.0.1:9999").unwrap(), listener.accept().unwrap().0)
810 } else if let Ok(listener) = std::net::TcpListener::bind("127.0.0.1:46926") {
811 (std::net::TcpStream::connect("127.0.0.1:46926").unwrap(), listener.accept().unwrap().0)
812 } else {
813 panic!("Failed to bind to v4 localhost on common ports");
814 }
815 }
816
817 async fn do_basic_connection_test() {
818 let secp_ctx = Secp256k1::new();
819 let a_key = SecretKey::from_slice(&[1; 32]).unwrap();
820 let b_key = SecretKey::from_slice(&[1; 32]).unwrap();
821 let a_pub = PublicKey::from_secret_key(&secp_ctx, &a_key);
822 let b_pub = PublicKey::from_secret_key(&secp_ctx, &b_key);
823
824 let (a_connected_sender, mut a_connected) = mpsc::channel(1);
825 let (a_disconnected_sender, mut a_disconnected) = mpsc::channel(1);
826 let a_handler = Arc::new(MsgHandler {
827 expected_pubkey: b_pub,
828 pubkey_connected: a_connected_sender,
829 pubkey_disconnected: a_disconnected_sender,
830 disconnected_flag: AtomicBool::new(false),
831 msg_events: Mutex::new(Vec::new()),
832 });
833 let a_msg_handler = MessageHandler {
834 chan_handler: Arc::clone(&a_handler),
835 route_handler: Arc::clone(&a_handler),
836 onion_message_handler: Arc::new(IgnoringMessageHandler {}),
837 custom_message_handler: Arc::new(IgnoringMessageHandler {}),
838 };
839 let a_manager = Arc::new(PeerManager::new(
840 a_msg_handler,
841 0,
842 &[1; 32],
843 Arc::new(TestLogger()),
844 Arc::new(TestNodeSigner::new(a_key)),
845 ));
846
847 let (b_connected_sender, mut b_connected) = mpsc::channel(1);
848 let (b_disconnected_sender, mut b_disconnected) = mpsc::channel(1);
849 let b_handler = Arc::new(MsgHandler {
850 expected_pubkey: a_pub,
851 pubkey_connected: b_connected_sender,
852 pubkey_disconnected: b_disconnected_sender,
853 disconnected_flag: AtomicBool::new(false),
854 msg_events: Mutex::new(Vec::new()),
855 });
856 let b_msg_handler = MessageHandler {
857 chan_handler: Arc::clone(&b_handler),
858 route_handler: Arc::clone(&b_handler),
859 onion_message_handler: Arc::new(IgnoringMessageHandler {}),
860 custom_message_handler: Arc::new(IgnoringMessageHandler {}),
861 };
862 let b_manager = Arc::new(PeerManager::new(
863 b_msg_handler,
864 0,
865 &[2; 32],
866 Arc::new(TestLogger()),
867 Arc::new(TestNodeSigner::new(b_key)),
868 ));
869
870 let (conn_a, conn_b) = make_tcp_connection();
875
876 let fut_a = super::setup_outbound(Arc::clone(&a_manager), b_pub, conn_a);
877 let fut_b = super::setup_inbound(b_manager, conn_b);
878
879 tokio::time::timeout(Duration::from_secs(10), a_connected.recv()).await.unwrap();
880 tokio::time::timeout(Duration::from_secs(1), b_connected.recv()).await.unwrap();
881
882 a_handler.msg_events.lock().unwrap().push(MessageSendEvent::HandleError {
883 node_id: b_pub,
884 action: ErrorAction::DisconnectPeer { msg: None },
885 });
886 assert!(!a_handler.disconnected_flag.load(Ordering::SeqCst));
887 assert!(!b_handler.disconnected_flag.load(Ordering::SeqCst));
888
889 a_manager.process_events();
890 tokio::time::timeout(Duration::from_secs(10), a_disconnected.recv()).await.unwrap();
891 tokio::time::timeout(Duration::from_secs(1), b_disconnected.recv()).await.unwrap();
892 assert!(a_handler.disconnected_flag.load(Ordering::SeqCst));
893 assert!(b_handler.disconnected_flag.load(Ordering::SeqCst));
894
895 fut_a.await;
896 fut_b.await;
897 }
898
899 #[tokio::test(flavor = "multi_thread")]
900 async fn basic_threaded_connection_test() {
901 do_basic_connection_test().await;
902 }
903
904 #[tokio::test]
905 async fn basic_unthreaded_connection_test() {
906 do_basic_connection_test().await;
907 }
908
909 async fn race_disconnect_accept() {
910 let secp_ctx = Secp256k1::new();
914 let a_key = SecretKey::from_slice(&[1; 32]).unwrap();
915 let b_key = SecretKey::from_slice(&[2; 32]).unwrap();
916 let b_pub = PublicKey::from_secret_key(&secp_ctx, &b_key);
917
918 let a_msg_handler = MessageHandler {
919 chan_handler: Arc::new(lightning::ln::peer_handler::ErroringMessageHandler::new()),
920 onion_message_handler: Arc::new(IgnoringMessageHandler {}),
921 route_handler: Arc::new(lightning::ln::peer_handler::IgnoringMessageHandler {}),
922 custom_message_handler: Arc::new(IgnoringMessageHandler {}),
923 };
924 let a_manager = Arc::new(PeerManager::new(
925 a_msg_handler,
926 0,
927 &[1; 32],
928 Arc::new(TestLogger()),
929 Arc::new(TestNodeSigner::new(a_key)),
930 ));
931
932 let conn_a = {
934 let (conn_a, _) = make_tcp_connection();
935 conn_a
936 };
937 let conn_b = {
938 let (_, conn_b) = make_tcp_connection();
939 conn_b
940 };
941
942 let manager_reference = Arc::clone(&a_manager);
944 tokio::spawn(async move { super::setup_inbound(manager_reference, conn_a).await });
945 tokio::spawn(async move { super::setup_outbound(a_manager, b_pub, conn_b).await });
946 }
947
948 #[tokio::test(flavor = "multi_thread")]
949 async fn threaded_race_disconnect_accept() {
950 race_disconnect_accept().await;
951 }
952
953 #[tokio::test]
954 async fn unthreaded_race_disconnect_accept() {
955 race_disconnect_accept().await;
956 }
957}