1#![deny(rustdoc::broken_intra_doc_links)]
26#![deny(rustdoc::private_intra_doc_links)]
27#![deny(missing_docs)]
28#![cfg_attr(docsrs, feature(doc_cfg))]
29
30use bitcoin::secp256k1::PublicKey;
31
32use tokio::net::TcpStream;
33use tokio::sync::mpsc;
34use tokio::time;
35
36use lightning::ln::msgs::SocketAddress;
37use lightning::ln::peer_handler;
38use lightning::ln::peer_handler::APeerManager;
39use lightning::ln::peer_handler::SocketDescriptor as LnSocketTrait;
40
41use std::future::Future;
42use std::hash::Hash;
43use std::net::SocketAddr;
44use std::net::TcpStream as StdTcpStream;
45use std::ops::Deref;
46use std::pin::Pin;
47use std::sync::atomic::{AtomicU64, Ordering};
48use std::sync::{Arc, Mutex};
49use std::task::{self, Poll};
50use std::time::Duration;
51
52static ID_COUNTER: AtomicU64 = AtomicU64::new(0);
53
54pub(crate) enum SelectorOutput {
59 A(Option<()>),
60 B(Option<()>),
61 C(tokio::io::Result<()>),
62}
63
64pub(crate) struct TwoSelector<
65 A: Future<Output = Option<()>> + Unpin,
66 B: Future<Output = Option<()>> + Unpin,
67> {
68 pub a: A,
69 pub b: B,
70}
71
72impl<A: Future<Output = Option<()>> + Unpin, B: Future<Output = Option<()>> + Unpin> Future
73 for TwoSelector<A, B>
74{
75 type Output = SelectorOutput;
76 fn poll(mut self: Pin<&mut Self>, ctx: &mut task::Context<'_>) -> Poll<SelectorOutput> {
77 match Pin::new(&mut self.a).poll(ctx) {
78 Poll::Ready(res) => {
79 return Poll::Ready(SelectorOutput::A(res));
80 },
81 Poll::Pending => {},
82 }
83 match Pin::new(&mut self.b).poll(ctx) {
84 Poll::Ready(res) => {
85 return Poll::Ready(SelectorOutput::B(res));
86 },
87 Poll::Pending => {},
88 }
89 Poll::Pending
90 }
91}
92
93pub(crate) struct ThreeSelector<
94 A: Future<Output = Option<()>> + Unpin,
95 B: Future<Output = Option<()>> + Unpin,
96 C: Future<Output = tokio::io::Result<()>> + Unpin,
97> {
98 pub a: A,
99 pub b: B,
100 pub c: C,
101}
102
103impl<
104 A: Future<Output = Option<()>> + Unpin,
105 B: Future<Output = Option<()>> + Unpin,
106 C: Future<Output = tokio::io::Result<()>> + Unpin,
107 > Future for ThreeSelector<A, B, C>
108{
109 type Output = SelectorOutput;
110 fn poll(mut self: Pin<&mut Self>, ctx: &mut task::Context<'_>) -> Poll<SelectorOutput> {
111 match Pin::new(&mut self.a).poll(ctx) {
112 Poll::Ready(res) => {
113 return Poll::Ready(SelectorOutput::A(res));
114 },
115 Poll::Pending => {},
116 }
117 match Pin::new(&mut self.b).poll(ctx) {
118 Poll::Ready(res) => {
119 return Poll::Ready(SelectorOutput::B(res));
120 },
121 Poll::Pending => {},
122 }
123 match Pin::new(&mut self.c).poll(ctx) {
124 Poll::Ready(res) => {
125 return Poll::Ready(SelectorOutput::C(res));
126 },
127 Poll::Pending => {},
128 }
129 Poll::Pending
130 }
131}
132
133struct Connection {
137 writer: Option<Arc<TcpStream>>,
138 write_avail: mpsc::Sender<()>,
149 read_waker: mpsc::Sender<()>,
154 read_paused: bool,
155 rl_requested_disconnect: bool,
156 id: u64,
157}
158impl Connection {
159 async fn poll_event_process<PM: Deref + 'static + Send + Sync>(
160 peer_manager: PM, mut event_receiver: mpsc::Receiver<()>,
161 ) where
162 PM::Target: APeerManager<Descriptor = SocketDescriptor>,
163 {
164 loop {
165 if event_receiver.recv().await.is_none() {
166 return;
167 }
168 peer_manager.as_ref().process_events();
169 }
170 }
171
172 async fn schedule_read<PM: Deref + 'static + Send + Sync + Clone>(
173 peer_manager: PM, us: Arc<Mutex<Self>>, reader: Arc<TcpStream>,
174 mut read_wake_receiver: mpsc::Receiver<()>, mut write_avail_receiver: mpsc::Receiver<()>,
175 ) where
176 PM::Target: APeerManager<Descriptor = SocketDescriptor>,
177 {
178 let (event_waker, event_receiver) = mpsc::channel(1);
180 tokio::spawn(Self::poll_event_process(peer_manager.clone(), event_receiver));
181
182 let mut buf = [0; 4096];
185
186 let mut our_descriptor = SocketDescriptor::new(Arc::clone(&us));
187 enum Disconnect {
189 CloseConnection,
194 PeerDisconnected,
199 }
200 let disconnect_type = loop {
201 let read_paused = {
202 let us_lock = us.lock().unwrap();
203 if us_lock.rl_requested_disconnect {
204 break Disconnect::CloseConnection;
205 }
206 us_lock.read_paused
207 };
208 let select_result = if read_paused {
210 TwoSelector {
211 a: Box::pin(write_avail_receiver.recv()),
212 b: Box::pin(read_wake_receiver.recv()),
213 }
214 .await
215 } else {
216 ThreeSelector {
217 a: Box::pin(write_avail_receiver.recv()),
218 b: Box::pin(read_wake_receiver.recv()),
219 c: Box::pin(reader.readable()),
220 }
221 .await
222 };
223 match select_result {
224 SelectorOutput::A(v) => {
225 assert!(v.is_some()); if peer_manager.as_ref().write_buffer_space_avail(&mut our_descriptor).is_err()
227 {
228 break Disconnect::CloseConnection;
229 }
230 },
231 SelectorOutput::B(some) => {
232 debug_assert!(some.is_some());
236 },
237 SelectorOutput::C(res) => {
238 if res.is_err() {
239 break Disconnect::PeerDisconnected;
240 }
241 match reader.try_read(&mut buf) {
242 Ok(0) => break Disconnect::PeerDisconnected,
243 Ok(len) => {
244 let read_res =
245 peer_manager.as_ref().read_event(&mut our_descriptor, &buf[0..len]);
246 match read_res {
247 Ok(()) => {},
248 Err(_) => break Disconnect::CloseConnection,
249 }
250 },
251 Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {
252 },
255 Err(_) => break Disconnect::PeerDisconnected,
256 }
257 },
258 }
259 let _ = event_waker.try_send(());
260
261 let _ = tokio::task::yield_now().await;
267 };
268 us.lock().unwrap().writer.take();
269 if let Disconnect::PeerDisconnected = disconnect_type {
270 peer_manager.as_ref().socket_disconnected(&our_descriptor);
271 peer_manager.as_ref().process_events();
272 }
273 }
274
275 fn new(
276 stream: StdTcpStream,
277 ) -> (Arc<TcpStream>, mpsc::Receiver<()>, mpsc::Receiver<()>, Arc<Mutex<Self>>) {
278 let (write_avail, write_receiver) = mpsc::channel(1);
284 let (read_waker, read_receiver) = mpsc::channel(1);
288 stream.set_nonblocking(true).unwrap();
289 let tokio_stream = Arc::new(TcpStream::from_std(stream).unwrap());
290
291 let id = ID_COUNTER.fetch_add(1, Ordering::AcqRel);
292 let writer = Some(Arc::clone(&tokio_stream));
293 let conn = Arc::new(Mutex::new(Self {
294 writer,
295 write_avail,
296 read_waker,
297 read_paused: false,
298 rl_requested_disconnect: false,
299 id,
300 }));
301 (tokio_stream, write_receiver, read_receiver, conn)
302 }
303}
304
305fn get_addr_from_stream(stream: &StdTcpStream) -> Option<SocketAddress> {
306 match stream.peer_addr() {
307 Ok(SocketAddr::V4(sockaddr)) => {
308 Some(SocketAddress::TcpIpV4 { addr: sockaddr.ip().octets(), port: sockaddr.port() })
309 },
310 Ok(SocketAddr::V6(sockaddr)) => {
311 Some(SocketAddress::TcpIpV6 { addr: sockaddr.ip().octets(), port: sockaddr.port() })
312 },
313 Err(_) => None,
314 }
315}
316
317pub fn setup_inbound<PM: Deref + 'static + Send + Sync + Clone>(
324 peer_manager: PM, stream: StdTcpStream,
325) -> impl std::future::Future<Output = ()>
326where
327 PM::Target: APeerManager<Descriptor = SocketDescriptor>,
328{
329 let remote_addr = get_addr_from_stream(&stream);
330 let (reader, write_receiver, read_receiver, us) = Connection::new(stream);
331 #[cfg(test)]
332 let last_us = Arc::clone(&us);
333
334 let handle_opt = if peer_manager
335 .as_ref()
336 .new_inbound_connection(SocketDescriptor::new(Arc::clone(&us)), remote_addr)
337 .is_ok()
338 {
339 let handle = tokio::spawn(Connection::schedule_read(
340 peer_manager,
341 us,
342 reader,
343 read_receiver,
344 write_receiver,
345 ));
346 Some(handle)
347 } else {
348 None
351 };
352
353 async move {
354 if let Some(handle) = handle_opt {
355 if let Err(e) = handle.await {
356 assert!(e.is_cancelled());
357 } else {
358 #[cfg(test)]
364 debug_assert!(Arc::try_unwrap(last_us).is_ok());
365 }
366 }
367 }
368}
369
370pub fn setup_outbound<PM: Deref + 'static + Send + Sync + Clone>(
378 peer_manager: PM, their_node_id: PublicKey, stream: StdTcpStream,
379) -> impl std::future::Future<Output = ()>
380where
381 PM::Target: APeerManager<Descriptor = SocketDescriptor>,
382{
383 let remote_addr = get_addr_from_stream(&stream);
384 let (reader, mut write_receiver, read_receiver, us) = Connection::new(stream);
385 #[cfg(test)]
386 let last_us = Arc::clone(&us);
387 let handle_opt = if let Ok(initial_send) = peer_manager.as_ref().new_outbound_connection(
388 their_node_id,
389 SocketDescriptor::new(Arc::clone(&us)),
390 remote_addr,
391 ) {
392 let handle = tokio::spawn(async move {
393 let send_fut = async {
399 loop {
400 match SocketDescriptor::new(Arc::clone(&us)).send_data(&initial_send, true) {
401 v if v == initial_send.len() => break Ok(()),
402 0 => {
403 write_receiver.recv().await;
404 },
408 _ => {
409 eprintln!("Failed to write first full message to socket!");
410 peer_manager
411 .as_ref()
412 .socket_disconnected(&SocketDescriptor::new(Arc::clone(&us)));
413 break Err(());
414 },
415 }
416 }
417 };
418 let timeout_send_fut = tokio::time::timeout(Duration::from_millis(100), send_fut);
419 if let Ok(Ok(())) = timeout_send_fut.await {
420 Connection::schedule_read(peer_manager, us, reader, read_receiver, write_receiver)
421 .await;
422 }
423 });
424 Some(handle)
425 } else {
426 None
429 };
430
431 async move {
432 if let Some(handle) = handle_opt {
433 if let Err(e) = handle.await {
434 assert!(e.is_cancelled());
435 } else {
436 #[cfg(test)]
442 debug_assert!(Arc::try_unwrap(last_us).is_ok());
443 }
444 }
445 }
446}
447
448pub async fn connect_outbound<PM: Deref + 'static + Send + Sync + Clone>(
460 peer_manager: PM, their_node_id: PublicKey, addr: SocketAddr,
461) -> Option<impl std::future::Future<Output = ()>>
462where
463 PM::Target: APeerManager<Descriptor = SocketDescriptor>,
464{
465 let connect_fut = async { TcpStream::connect(&addr).await.map(|s| s.into_std().unwrap()) };
466 if let Ok(Ok(stream)) = time::timeout(Duration::from_secs(10), connect_fut).await {
467 Some(setup_outbound(peer_manager, their_node_id, stream))
468 } else {
469 None
470 }
471}
472
473const SOCK_WAKER_VTABLE: task::RawWakerVTable = task::RawWakerVTable::new(
474 clone_socket_waker,
475 wake_socket_waker,
476 wake_socket_waker_by_ref,
477 drop_socket_waker,
478);
479
480fn clone_socket_waker(orig_ptr: *const ()) -> task::RawWaker {
481 let new_waker = unsafe { Arc::from_raw(orig_ptr as *const mpsc::Sender<()>) };
482 let res = write_avail_to_waker(&new_waker);
483 let _ = Arc::into_raw(new_waker);
485 res
486}
487fn wake_socket_waker(orig_ptr: *const ()) {
492 let sender = unsafe { &mut *(orig_ptr as *mut mpsc::Sender<()>) };
493 let _ = sender.try_send(());
494 drop_socket_waker(orig_ptr);
495}
496fn wake_socket_waker_by_ref(orig_ptr: *const ()) {
497 let sender_ptr = orig_ptr as *const mpsc::Sender<()>;
498 let sender = unsafe { &*sender_ptr };
499 let _ = sender.try_send(());
500}
501fn drop_socket_waker(orig_ptr: *const ()) {
502 let _orig_arc = unsafe { Arc::from_raw(orig_ptr as *mut mpsc::Sender<()>) };
503 }
505fn write_avail_to_waker(sender: &Arc<mpsc::Sender<()>>) -> task::RawWaker {
506 let new_ptr = Arc::into_raw(Arc::clone(&sender));
507 task::RawWaker::new(new_ptr as *const (), &SOCK_WAKER_VTABLE)
508}
509
510pub struct SocketDescriptor {
513 conn: Arc<Mutex<Connection>>,
514 write_avail_sender: Arc<mpsc::Sender<()>>,
519 id: u64,
520}
521impl SocketDescriptor {
522 fn new(conn: Arc<Mutex<Connection>>) -> Self {
523 let (id, write_avail_sender) = {
524 let us = conn.lock().unwrap();
525 (us.id, Arc::new(us.write_avail.clone()))
526 };
527 Self { conn, id, write_avail_sender }
528 }
529}
530impl peer_handler::SocketDescriptor for SocketDescriptor {
531 fn send_data(&mut self, data: &[u8], continue_read: bool) -> usize {
532 let mut us = self.conn.lock().unwrap();
537 if us.writer.is_none() {
538 return 0;
540 }
541
542 let read_was_paused = us.read_paused;
543 us.read_paused = !continue_read;
544
545 if continue_read && read_was_paused {
546 let _ = us.read_waker.try_send(());
550 }
551
552 if data.is_empty() {
553 return 0;
554 }
555 let waker =
556 unsafe { task::Waker::from_raw(write_avail_to_waker(&self.write_avail_sender)) };
557 let mut ctx = task::Context::from_waker(&waker);
558 let mut written_len = 0;
559 loop {
560 match us.writer.as_ref().unwrap().poll_write_ready(&mut ctx) {
561 task::Poll::Ready(Ok(())) => {
562 match us.writer.as_ref().unwrap().try_write(&data[written_len..]) {
563 Ok(res) => {
564 debug_assert_ne!(res, 0);
565 written_len += res;
566 if written_len == data.len() {
567 return written_len;
568 }
569 },
570 Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
571 continue;
572 },
573 Err(_) => return written_len,
574 }
575 },
576 task::Poll::Ready(Err(_)) => return written_len,
577 task::Poll::Pending => return written_len,
578 }
579 }
580 }
581
582 fn disconnect_socket(&mut self) {
583 let mut us = self.conn.lock().unwrap();
584 us.rl_requested_disconnect = true;
585 let _ = us.write_avail.try_send(());
587 }
588}
589impl Clone for SocketDescriptor {
590 fn clone(&self) -> Self {
591 Self {
592 conn: Arc::clone(&self.conn),
593 id: self.id,
594 write_avail_sender: Arc::clone(&self.write_avail_sender),
595 }
596 }
597}
598impl Eq for SocketDescriptor {}
599impl PartialEq for SocketDescriptor {
600 fn eq(&self, o: &Self) -> bool {
601 self.id == o.id
602 }
603}
604impl Hash for SocketDescriptor {
605 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
606 self.id.hash(state);
607 }
608}
609
610#[cfg(test)]
611mod tests {
612 use bitcoin::constants::ChainHash;
613 use bitcoin::secp256k1::{PublicKey, Secp256k1, SecretKey};
614 use bitcoin::Network;
615 use lightning::ln::msgs::*;
616 use lightning::ln::peer_handler::{IgnoringMessageHandler, MessageHandler, PeerManager};
617 use lightning::ln::types::ChannelId;
618 use lightning::routing::gossip::NodeId;
619 use lightning::types::features::*;
620 use lightning::util::test_utils::TestNodeSigner;
621
622 use tokio::sync::mpsc;
623
624 use std::mem;
625 use std::sync::atomic::{AtomicBool, Ordering};
626 use std::sync::{Arc, Mutex};
627 use std::time::Duration;
628
629 pub struct TestLogger();
630 impl lightning::util::logger::Logger for TestLogger {
631 fn log(&self, record: lightning::util::logger::Record) {
632 println!(
633 "{:<5} [{} : {}, {}] {}",
634 record.level.to_string(),
635 record.module_path,
636 record.file,
637 record.line,
638 record.args
639 );
640 }
641 }
642
643 struct MsgHandler {
644 expected_pubkey: PublicKey,
645 pubkey_connected: mpsc::Sender<()>,
646 pubkey_disconnected: mpsc::Sender<()>,
647 disconnected_flag: AtomicBool,
648 msg_events: Mutex<Vec<MessageSendEvent>>,
649 }
650 impl RoutingMessageHandler for MsgHandler {
651 fn handle_node_announcement(
652 &self, _their_node_id: Option<PublicKey>, _msg: &NodeAnnouncement,
653 ) -> Result<bool, LightningError> {
654 Ok(false)
655 }
656 fn handle_channel_announcement(
657 &self, _their_node_id: Option<PublicKey>, _msg: &ChannelAnnouncement,
658 ) -> Result<bool, LightningError> {
659 Ok(false)
660 }
661 fn handle_channel_update(
662 &self, _their_node_id: Option<PublicKey>, _msg: &ChannelUpdate,
663 ) -> Result<bool, LightningError> {
664 Ok(false)
665 }
666 fn get_next_channel_announcement(
667 &self, _starting_point: u64,
668 ) -> Option<(ChannelAnnouncement, Option<ChannelUpdate>, Option<ChannelUpdate>)> {
669 None
670 }
671 fn get_next_node_announcement(
672 &self, _starting_point: Option<&NodeId>,
673 ) -> Option<NodeAnnouncement> {
674 None
675 }
676 fn handle_reply_channel_range(
677 &self, _their_node_id: PublicKey, _msg: ReplyChannelRange,
678 ) -> Result<(), LightningError> {
679 Ok(())
680 }
681 fn handle_reply_short_channel_ids_end(
682 &self, _their_node_id: PublicKey, _msg: ReplyShortChannelIdsEnd,
683 ) -> Result<(), LightningError> {
684 Ok(())
685 }
686 fn handle_query_channel_range(
687 &self, _their_node_id: PublicKey, _msg: QueryChannelRange,
688 ) -> Result<(), LightningError> {
689 Ok(())
690 }
691 fn handle_query_short_channel_ids(
692 &self, _their_node_id: PublicKey, _msg: QueryShortChannelIds,
693 ) -> Result<(), LightningError> {
694 Ok(())
695 }
696 fn processing_queue_high(&self) -> bool {
697 false
698 }
699 }
700 impl ChannelMessageHandler for MsgHandler {
701 fn handle_open_channel(&self, _their_node_id: PublicKey, _msg: &OpenChannel) {}
702 fn handle_accept_channel(&self, _their_node_id: PublicKey, _msg: &AcceptChannel) {}
703 fn handle_funding_created(&self, _their_node_id: PublicKey, _msg: &FundingCreated) {}
704 fn handle_funding_signed(&self, _their_node_id: PublicKey, _msg: &FundingSigned) {}
705 fn handle_channel_ready(&self, _their_node_id: PublicKey, _msg: &ChannelReady) {}
706 fn handle_shutdown(&self, _their_node_id: PublicKey, _msg: &Shutdown) {}
707 fn handle_closing_signed(&self, _their_node_id: PublicKey, _msg: &ClosingSigned) {}
708 #[cfg(simple_close)]
709 fn handle_closing_complete(&self, _their_node_id: PublicKey, _msg: ClosingComplete) {}
710 #[cfg(simple_close)]
711 fn handle_closing_sig(&self, _their_node_id: PublicKey, _msg: ClosingSig) {}
712 fn handle_update_add_htlc(&self, _their_node_id: PublicKey, _msg: &UpdateAddHTLC) {}
713 fn handle_update_fulfill_htlc(&self, _their_node_id: PublicKey, _msg: UpdateFulfillHTLC) {}
714 fn handle_update_fail_htlc(&self, _their_node_id: PublicKey, _msg: &UpdateFailHTLC) {}
715 fn handle_update_fail_malformed_htlc(
716 &self, _their_node_id: PublicKey, _msg: &UpdateFailMalformedHTLC,
717 ) {
718 }
719 fn handle_commitment_signed(&self, _their_node_id: PublicKey, _msg: &CommitmentSigned) {}
720 fn handle_commitment_signed_batch(
721 &self, _their_node_id: PublicKey, _channel_id: ChannelId, _batch: Vec<CommitmentSigned>,
722 ) {
723 }
724 fn handle_revoke_and_ack(&self, _their_node_id: PublicKey, _msg: &RevokeAndACK) {}
725 fn handle_update_fee(&self, _their_node_id: PublicKey, _msg: &UpdateFee) {}
726 fn handle_announcement_signatures(
727 &self, _their_node_id: PublicKey, _msg: &AnnouncementSignatures,
728 ) {
729 }
730 fn handle_channel_update(&self, _their_node_id: PublicKey, _msg: &ChannelUpdate) {}
731 fn handle_open_channel_v2(&self, _their_node_id: PublicKey, _msg: &OpenChannelV2) {}
732 fn handle_accept_channel_v2(&self, _their_node_id: PublicKey, _msg: &AcceptChannelV2) {}
733 fn handle_stfu(&self, _their_node_id: PublicKey, _msg: &Stfu) {}
734 fn handle_splice_init(&self, _their_node_id: PublicKey, _msg: &SpliceInit) {}
735 fn handle_splice_ack(&self, _their_node_id: PublicKey, _msg: &SpliceAck) {}
736 fn handle_splice_locked(&self, _their_node_id: PublicKey, _msg: &SpliceLocked) {}
737 fn handle_tx_add_input(&self, _their_node_id: PublicKey, _msg: &TxAddInput) {}
738 fn handle_tx_add_output(&self, _their_node_id: PublicKey, _msg: &TxAddOutput) {}
739 fn handle_tx_remove_input(&self, _their_node_id: PublicKey, _msg: &TxRemoveInput) {}
740 fn handle_tx_remove_output(&self, _their_node_id: PublicKey, _msg: &TxRemoveOutput) {}
741 fn handle_tx_complete(&self, _their_node_id: PublicKey, _msg: &TxComplete) {}
742 fn handle_tx_signatures(&self, _their_node_id: PublicKey, _msg: &TxSignatures) {}
743 fn handle_tx_init_rbf(&self, _their_node_id: PublicKey, _msg: &TxInitRbf) {}
744 fn handle_tx_ack_rbf(&self, _their_node_id: PublicKey, _msg: &TxAckRbf) {}
745 fn handle_tx_abort(&self, _their_node_id: PublicKey, _msg: &TxAbort) {}
746 fn handle_peer_storage(&self, _their_node_id: PublicKey, _msg: PeerStorage) {}
747 fn handle_peer_storage_retrieval(
748 &self, _their_node_id: PublicKey, _msg: PeerStorageRetrieval,
749 ) {
750 }
751 fn handle_channel_reestablish(&self, _their_node_id: PublicKey, _msg: &ChannelReestablish) {
752 }
753 fn handle_error(&self, _their_node_id: PublicKey, _msg: &ErrorMessage) {}
754 fn get_chain_hashes(&self) -> Option<Vec<ChainHash>> {
755 Some(vec![ChainHash::using_genesis_block(Network::Testnet)])
756 }
757 fn message_received(&self) {}
758 }
759 impl BaseMessageHandler for MsgHandler {
760 fn peer_disconnected(&self, their_node_id: PublicKey) {
761 if their_node_id == self.expected_pubkey {
762 self.disconnected_flag.store(true, Ordering::SeqCst);
763 let _ = self.pubkey_disconnected.clone().try_send(());
766 }
767 }
768 fn peer_connected(
769 &self, their_node_id: PublicKey, _init_msg: &Init, _inbound: bool,
770 ) -> Result<(), ()> {
771 if their_node_id == self.expected_pubkey {
772 let _ = self.pubkey_connected.clone().try_send(());
775 }
776 Ok(())
777 }
778 fn provided_node_features(&self) -> NodeFeatures {
779 NodeFeatures::empty()
780 }
781 fn provided_init_features(&self, _their_node_id: PublicKey) -> InitFeatures {
782 InitFeatures::empty()
783 }
784 fn get_and_clear_pending_msg_events(&self) -> Vec<MessageSendEvent> {
785 let mut ret = Vec::new();
786 mem::swap(&mut *self.msg_events.lock().unwrap(), &mut ret);
787 ret
788 }
789 }
790
791 fn make_tcp_connection() -> (std::net::TcpStream, std::net::TcpStream) {
792 if let Ok(listener) = std::net::TcpListener::bind("127.0.0.1:9735") {
793 (std::net::TcpStream::connect("127.0.0.1:9735").unwrap(), listener.accept().unwrap().0)
794 } else if let Ok(listener) = std::net::TcpListener::bind("127.0.0.1:19735") {
795 (std::net::TcpStream::connect("127.0.0.1:19735").unwrap(), listener.accept().unwrap().0)
796 } else if let Ok(listener) = std::net::TcpListener::bind("127.0.0.1:9997") {
797 (std::net::TcpStream::connect("127.0.0.1:9997").unwrap(), listener.accept().unwrap().0)
798 } else if let Ok(listener) = std::net::TcpListener::bind("127.0.0.1:9998") {
799 (std::net::TcpStream::connect("127.0.0.1:9998").unwrap(), listener.accept().unwrap().0)
800 } else if let Ok(listener) = std::net::TcpListener::bind("127.0.0.1:9999") {
801 (std::net::TcpStream::connect("127.0.0.1:9999").unwrap(), listener.accept().unwrap().0)
802 } else if let Ok(listener) = std::net::TcpListener::bind("127.0.0.1:46926") {
803 (std::net::TcpStream::connect("127.0.0.1:46926").unwrap(), listener.accept().unwrap().0)
804 } else {
805 panic!("Failed to bind to v4 localhost on common ports");
806 }
807 }
808
809 async fn do_basic_connection_test() {
810 let secp_ctx = Secp256k1::new();
811 let a_key = SecretKey::from_slice(&[1; 32]).unwrap();
812 let b_key = SecretKey::from_slice(&[1; 32]).unwrap();
813 let a_pub = PublicKey::from_secret_key(&secp_ctx, &a_key);
814 let b_pub = PublicKey::from_secret_key(&secp_ctx, &b_key);
815
816 let (a_connected_sender, mut a_connected) = mpsc::channel(1);
817 let (a_disconnected_sender, mut a_disconnected) = mpsc::channel(1);
818 let a_handler = Arc::new(MsgHandler {
819 expected_pubkey: b_pub,
820 pubkey_connected: a_connected_sender,
821 pubkey_disconnected: a_disconnected_sender,
822 disconnected_flag: AtomicBool::new(false),
823 msg_events: Mutex::new(Vec::new()),
824 });
825 let a_msg_handler = MessageHandler {
826 chan_handler: Arc::clone(&a_handler),
827 route_handler: Arc::clone(&a_handler),
828 onion_message_handler: Arc::new(IgnoringMessageHandler {}),
829 custom_message_handler: Arc::new(IgnoringMessageHandler {}),
830 send_only_message_handler: Arc::new(IgnoringMessageHandler {}),
831 };
832 let a_manager = Arc::new(PeerManager::new(
833 a_msg_handler,
834 0,
835 &[1; 32],
836 Arc::new(TestLogger()),
837 Arc::new(TestNodeSigner::new(a_key)),
838 ));
839
840 let (b_connected_sender, mut b_connected) = mpsc::channel(1);
841 let (b_disconnected_sender, mut b_disconnected) = mpsc::channel(1);
842 let b_handler = Arc::new(MsgHandler {
843 expected_pubkey: a_pub,
844 pubkey_connected: b_connected_sender,
845 pubkey_disconnected: b_disconnected_sender,
846 disconnected_flag: AtomicBool::new(false),
847 msg_events: Mutex::new(Vec::new()),
848 });
849 let b_msg_handler = MessageHandler {
850 chan_handler: Arc::clone(&b_handler),
851 route_handler: Arc::clone(&b_handler),
852 onion_message_handler: Arc::new(IgnoringMessageHandler {}),
853 custom_message_handler: Arc::new(IgnoringMessageHandler {}),
854 send_only_message_handler: Arc::new(IgnoringMessageHandler {}),
855 };
856 let b_manager = Arc::new(PeerManager::new(
857 b_msg_handler,
858 0,
859 &[2; 32],
860 Arc::new(TestLogger()),
861 Arc::new(TestNodeSigner::new(b_key)),
862 ));
863
864 let (conn_a, conn_b) = make_tcp_connection();
869
870 let fut_a = super::setup_outbound(Arc::clone(&a_manager), b_pub, conn_a);
871 let fut_b = super::setup_inbound(b_manager, conn_b);
872
873 tokio::time::timeout(Duration::from_secs(10), a_connected.recv()).await.unwrap();
874 tokio::time::timeout(Duration::from_secs(1), b_connected.recv()).await.unwrap();
875
876 a_handler.msg_events.lock().unwrap().push(MessageSendEvent::HandleError {
877 node_id: b_pub,
878 action: ErrorAction::DisconnectPeer { msg: None },
879 });
880 assert!(!a_handler.disconnected_flag.load(Ordering::SeqCst));
881 assert!(!b_handler.disconnected_flag.load(Ordering::SeqCst));
882
883 a_manager.process_events();
884 tokio::time::timeout(Duration::from_secs(10), a_disconnected.recv()).await.unwrap();
885 tokio::time::timeout(Duration::from_secs(1), b_disconnected.recv()).await.unwrap();
886 assert!(a_handler.disconnected_flag.load(Ordering::SeqCst));
887 assert!(b_handler.disconnected_flag.load(Ordering::SeqCst));
888
889 fut_a.await;
890 fut_b.await;
891 }
892
893 #[tokio::test(flavor = "multi_thread")]
894 async fn basic_threaded_connection_test() {
895 do_basic_connection_test().await;
896 }
897
898 #[tokio::test]
899 async fn basic_unthreaded_connection_test() {
900 do_basic_connection_test().await;
901 }
902
903 async fn race_disconnect_accept() {
904 let secp_ctx = Secp256k1::new();
908 let a_key = SecretKey::from_slice(&[1; 32]).unwrap();
909 let b_key = SecretKey::from_slice(&[2; 32]).unwrap();
910 let b_pub = PublicKey::from_secret_key(&secp_ctx, &b_key);
911
912 let a_msg_handler = MessageHandler {
913 chan_handler: Arc::new(lightning::ln::peer_handler::ErroringMessageHandler::new()),
914 onion_message_handler: Arc::new(IgnoringMessageHandler {}),
915 route_handler: Arc::new(lightning::ln::peer_handler::IgnoringMessageHandler {}),
916 custom_message_handler: Arc::new(IgnoringMessageHandler {}),
917 send_only_message_handler: Arc::new(IgnoringMessageHandler {}),
918 };
919 let a_manager = Arc::new(PeerManager::new(
920 a_msg_handler,
921 0,
922 &[1; 32],
923 Arc::new(TestLogger()),
924 Arc::new(TestNodeSigner::new(a_key)),
925 ));
926
927 let conn_a = {
929 let (conn_a, _) = make_tcp_connection();
930 conn_a
931 };
932 let conn_b = {
933 let (_, conn_b) = make_tcp_connection();
934 conn_b
935 };
936
937 let manager_reference = Arc::clone(&a_manager);
939 tokio::spawn(async move { super::setup_inbound(manager_reference, conn_a).await });
940 tokio::spawn(async move { super::setup_outbound(a_manager, b_pub, conn_b).await });
941 }
942
943 #[tokio::test(flavor = "multi_thread")]
944 async fn threaded_race_disconnect_accept() {
945 race_disconnect_accept().await;
946 }
947
948 #[tokio::test]
949 async fn unthreaded_race_disconnect_accept() {
950 race_disconnect_accept().await;
951 }
952}