lightning_net_tokio/
lib.rs

1// This file is Copyright its original authors, visible in version control
2// history.
3//
4// This file is licensed under the Apache License, Version 2.0 <LICENSE-APACHE
5// or http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
6// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your option.
7// You may not use this file except in accordance with one or both of these
8// licenses.
9
10//! A socket handling library for those running in Tokio environments who wish to use
11//! rust-lightning with native [`TcpStream`]s.
12//!
13//! Designed to be as simple as possible, the high-level usage is almost as simple as "hand over a
14//! [`TcpStream`] and a reference to a [`PeerManager`] and the rest is handled".
15//!
16//! The [`PeerManager`], due to the fire-and-forget nature of this logic, must be a reference,
17//! (e.g. an [`Arc`]) and must use the [`SocketDescriptor`] provided here as the [`PeerManager`]'s
18//! `SocketDescriptor` implementation.
19//!
20//! Three methods are exposed to register a new connection for handling in [`tokio::spawn`] calls;
21//! see their individual docs for details.
22//!
23//! [`PeerManager`]: lightning::ln::peer_handler::PeerManager
24
25#![deny(rustdoc::broken_intra_doc_links)]
26#![deny(rustdoc::private_intra_doc_links)]
27#![deny(missing_docs)]
28#![cfg_attr(docsrs, feature(doc_cfg))]
29
30use bitcoin::secp256k1::PublicKey;
31
32use tokio::net::TcpStream;
33use tokio::sync::mpsc;
34use tokio::time;
35
36use lightning::ln::msgs::SocketAddress;
37use lightning::ln::peer_handler;
38use lightning::ln::peer_handler::APeerManager;
39use lightning::ln::peer_handler::SocketDescriptor as LnSocketTrait;
40
41use std::future::Future;
42use std::hash::Hash;
43use std::net::SocketAddr;
44use std::net::TcpStream as StdTcpStream;
45use std::ops::Deref;
46use std::pin::Pin;
47use std::sync::atomic::{AtomicU64, Ordering};
48use std::sync::{Arc, Mutex};
49use std::task::{self, Poll};
50use std::time::Duration;
51
52static ID_COUNTER: AtomicU64 = AtomicU64::new(0);
53
54// We only need to select over multiple futures in one place, and taking on the full `tokio/macros`
55// dependency tree in order to do so (which has broken our MSRV before) is excessive. Instead, we
56// define a trivial two- and three- select macro with the specific types we need and just use that.
57
58pub(crate) enum SelectorOutput {
59	A(Option<()>),
60	B(Option<()>),
61	C(tokio::io::Result<()>),
62}
63
64pub(crate) struct TwoSelector<
65	A: Future<Output = Option<()>> + Unpin,
66	B: Future<Output = Option<()>> + Unpin,
67> {
68	pub a: A,
69	pub b: B,
70}
71
72impl<A: Future<Output = Option<()>> + Unpin, B: Future<Output = Option<()>> + Unpin> Future
73	for TwoSelector<A, B>
74{
75	type Output = SelectorOutput;
76	fn poll(mut self: Pin<&mut Self>, ctx: &mut task::Context<'_>) -> Poll<SelectorOutput> {
77		match Pin::new(&mut self.a).poll(ctx) {
78			Poll::Ready(res) => {
79				return Poll::Ready(SelectorOutput::A(res));
80			},
81			Poll::Pending => {},
82		}
83		match Pin::new(&mut self.b).poll(ctx) {
84			Poll::Ready(res) => {
85				return Poll::Ready(SelectorOutput::B(res));
86			},
87			Poll::Pending => {},
88		}
89		Poll::Pending
90	}
91}
92
93pub(crate) struct ThreeSelector<
94	A: Future<Output = Option<()>> + Unpin,
95	B: Future<Output = Option<()>> + Unpin,
96	C: Future<Output = tokio::io::Result<()>> + Unpin,
97> {
98	pub a: A,
99	pub b: B,
100	pub c: C,
101}
102
103impl<
104		A: Future<Output = Option<()>> + Unpin,
105		B: Future<Output = Option<()>> + Unpin,
106		C: Future<Output = tokio::io::Result<()>> + Unpin,
107	> Future for ThreeSelector<A, B, C>
108{
109	type Output = SelectorOutput;
110	fn poll(mut self: Pin<&mut Self>, ctx: &mut task::Context<'_>) -> Poll<SelectorOutput> {
111		match Pin::new(&mut self.a).poll(ctx) {
112			Poll::Ready(res) => {
113				return Poll::Ready(SelectorOutput::A(res));
114			},
115			Poll::Pending => {},
116		}
117		match Pin::new(&mut self.b).poll(ctx) {
118			Poll::Ready(res) => {
119				return Poll::Ready(SelectorOutput::B(res));
120			},
121			Poll::Pending => {},
122		}
123		match Pin::new(&mut self.c).poll(ctx) {
124			Poll::Ready(res) => {
125				return Poll::Ready(SelectorOutput::C(res));
126			},
127			Poll::Pending => {},
128		}
129		Poll::Pending
130	}
131}
132
133/// Connection contains all our internal state for a connection - we hold a reference to the
134/// Connection object (in an Arc<Mutex<>>) in each SocketDescriptor we create as well as in the
135/// read future (which is returned by schedule_read).
136struct Connection {
137	writer: Option<Arc<TcpStream>>,
138	// Because our PeerManager is templated by user-provided types, and we can't (as far as I can
139	// tell) have a const RawWakerVTable built out of templated functions, we need some indirection
140	// between being woken up with write-ready and calling PeerManager::write_buffer_space_avail.
141	// This provides that indirection, with a Sender which gets handed to the PeerManager Arc on
142	// the schedule_read stack.
143	//
144	// An alternative (likely more effecient) approach would involve creating a RawWakerVTable at
145	// runtime with functions templated by the Arc<PeerManager> type, calling
146	// write_buffer_space_avail directly from tokio's write wake, however doing so would require
147	// more unsafe voodo than I really feel like writing.
148	write_avail: mpsc::Sender<()>,
149	// When we are told by rust-lightning to pause read (because we have writes backing up), we do
150	// so by setting read_paused. At that point, the read task will stop reading bytes from the
151	// socket. To wake it up (without otherwise changing its state, we can push a value into this
152	// Sender.
153	read_waker: mpsc::Sender<()>,
154	read_paused: bool,
155	rl_requested_disconnect: bool,
156	id: u64,
157}
158impl Connection {
159	async fn poll_event_process<PM: Deref + 'static + Send + Sync>(
160		peer_manager: PM, mut event_receiver: mpsc::Receiver<()>,
161	) where
162		PM::Target: APeerManager<Descriptor = SocketDescriptor>,
163	{
164		loop {
165			if event_receiver.recv().await.is_none() {
166				return;
167			}
168			peer_manager.as_ref().process_events();
169		}
170	}
171
172	async fn schedule_read<PM: Deref + 'static + Send + Sync + Clone>(
173		peer_manager: PM, us: Arc<Mutex<Self>>, reader: Arc<TcpStream>,
174		mut read_wake_receiver: mpsc::Receiver<()>, mut write_avail_receiver: mpsc::Receiver<()>,
175	) where
176		PM::Target: APeerManager<Descriptor = SocketDescriptor>,
177	{
178		// Create a waker to wake up poll_event_process, above
179		let (event_waker, event_receiver) = mpsc::channel(1);
180		tokio::spawn(Self::poll_event_process(peer_manager.clone(), event_receiver));
181
182		// 4KiB is nice and big without handling too many messages all at once, giving other peers
183		// a chance to do some work.
184		let mut buf = [0; 4096];
185
186		let mut our_descriptor = SocketDescriptor::new(Arc::clone(&us));
187		// An enum describing why we did/are disconnecting:
188		enum Disconnect {
189			// Rust-Lightning told us to disconnect, either by returning an Err or by calling
190			// SocketDescriptor::disconnect_socket.
191			// In this case, we do not call peer_manager.socket_disconnected() as Rust-Lightning
192			// already knows we're disconnected.
193			CloseConnection,
194			// The connection was disconnected for some other reason, ie because the socket was
195			// closed.
196			// In this case, we do need to call peer_manager.socket_disconnected() to inform
197			// Rust-Lightning that the socket is gone.
198			PeerDisconnected,
199		}
200		let disconnect_type = loop {
201			let read_paused = {
202				let us_lock = us.lock().unwrap();
203				if us_lock.rl_requested_disconnect {
204					break Disconnect::CloseConnection;
205				}
206				us_lock.read_paused
207			};
208			// TODO: Drop the Box'ing of the futures once Rust has pin-on-stack support.
209			let select_result = if read_paused {
210				TwoSelector {
211					a: Box::pin(write_avail_receiver.recv()),
212					b: Box::pin(read_wake_receiver.recv()),
213				}
214				.await
215			} else {
216				ThreeSelector {
217					a: Box::pin(write_avail_receiver.recv()),
218					b: Box::pin(read_wake_receiver.recv()),
219					c: Box::pin(reader.readable()),
220				}
221				.await
222			};
223			match select_result {
224				SelectorOutput::A(v) => {
225					assert!(v.is_some()); // We can't have dropped the sending end, its in the us Arc!
226					if peer_manager.as_ref().write_buffer_space_avail(&mut our_descriptor).is_err()
227					{
228						break Disconnect::CloseConnection;
229					}
230				},
231				SelectorOutput::B(some) => {
232					// The mpsc Receiver should only return `None` if the write side has been
233					// dropped, but that shouldn't be possible since its referenced by the Self in
234					// `us`.
235					debug_assert!(some.is_some());
236				},
237				SelectorOutput::C(res) => {
238					if res.is_err() {
239						break Disconnect::PeerDisconnected;
240					}
241					match reader.try_read(&mut buf) {
242						Ok(0) => break Disconnect::PeerDisconnected,
243						Ok(len) => {
244							let read_res =
245								peer_manager.as_ref().read_event(&mut our_descriptor, &buf[0..len]);
246							match read_res {
247								Ok(()) => {},
248								Err(_) => break Disconnect::CloseConnection,
249							}
250						},
251						Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {
252							// readable() is allowed to spuriously wake, so we have to handle
253							// WouldBlock here.
254						},
255						Err(_) => break Disconnect::PeerDisconnected,
256					}
257				},
258			}
259			let _ = event_waker.try_send(());
260
261			// At this point we've processed a message or two, and reset the ping timer for this
262			// peer, at least in the "are we still receiving messages" context, if we don't give up
263			// our timeslice to another task we may just spin on this peer, starving other peers
264			// and eventually disconnecting them for ping timeouts. Instead, we explicitly yield
265			// here.
266			let _ = tokio::task::yield_now().await;
267		};
268		us.lock().unwrap().writer.take();
269		if let Disconnect::PeerDisconnected = disconnect_type {
270			peer_manager.as_ref().socket_disconnected(&our_descriptor);
271			peer_manager.as_ref().process_events();
272		}
273	}
274
275	fn new(
276		stream: StdTcpStream,
277	) -> (Arc<TcpStream>, mpsc::Receiver<()>, mpsc::Receiver<()>, Arc<Mutex<Self>>) {
278		// We only ever need a channel of depth 1 here: if we returned a non-full write to the
279		// PeerManager, we will eventually get notified that there is room in the socket to write
280		// new bytes, which will generate an event. That event will be popped off the queue before
281		// we call write_buffer_space_avail, ensuring that we have room to push a new () if, during
282		// the write_buffer_space_avail() call, send_data() returns a non-full write.
283		let (write_avail, write_receiver) = mpsc::channel(1);
284		// Similarly here - our only goal is to make sure the reader wakes up at some point after
285		// we shove a value into the channel which comes after we've reset the read_paused bool to
286		// false.
287		let (read_waker, read_receiver) = mpsc::channel(1);
288		stream.set_nonblocking(true).unwrap();
289		let tokio_stream = Arc::new(TcpStream::from_std(stream).unwrap());
290
291		let id = ID_COUNTER.fetch_add(1, Ordering::AcqRel);
292		let writer = Some(Arc::clone(&tokio_stream));
293		let conn = Arc::new(Mutex::new(Self {
294			writer,
295			write_avail,
296			read_waker,
297			read_paused: false,
298			rl_requested_disconnect: false,
299			id,
300		}));
301		(tokio_stream, write_receiver, read_receiver, conn)
302	}
303}
304
305fn get_addr_from_stream(stream: &StdTcpStream) -> Option<SocketAddress> {
306	match stream.peer_addr() {
307		Ok(SocketAddr::V4(sockaddr)) => {
308			Some(SocketAddress::TcpIpV4 { addr: sockaddr.ip().octets(), port: sockaddr.port() })
309		},
310		Ok(SocketAddr::V6(sockaddr)) => {
311			Some(SocketAddress::TcpIpV6 { addr: sockaddr.ip().octets(), port: sockaddr.port() })
312		},
313		Err(_) => None,
314	}
315}
316
317/// Process incoming messages and feed outgoing messages on the provided socket generated by
318/// accepting an incoming connection.
319///
320/// The returned future will complete when the peer is disconnected and associated handling
321/// futures are freed, though, because all processing futures are spawned with tokio::spawn, you do
322/// not need to poll the provided future in order to make progress.
323pub fn setup_inbound<PM: Deref + 'static + Send + Sync + Clone>(
324	peer_manager: PM, stream: StdTcpStream,
325) -> impl std::future::Future<Output = ()>
326where
327	PM::Target: APeerManager<Descriptor = SocketDescriptor>,
328{
329	let remote_addr = get_addr_from_stream(&stream);
330	let (reader, write_receiver, read_receiver, us) = Connection::new(stream);
331	#[cfg(test)]
332	let last_us = Arc::clone(&us);
333
334	let handle_opt = if peer_manager
335		.as_ref()
336		.new_inbound_connection(SocketDescriptor::new(Arc::clone(&us)), remote_addr)
337		.is_ok()
338	{
339		let handle = tokio::spawn(Connection::schedule_read(
340			peer_manager,
341			us,
342			reader,
343			read_receiver,
344			write_receiver,
345		));
346		Some(handle)
347	} else {
348		// Note that we will skip socket_disconnected here, in accordance with the PeerManager
349		// requirements.
350		None
351	};
352
353	async move {
354		if let Some(handle) = handle_opt {
355			if let Err(e) = handle.await {
356				assert!(e.is_cancelled());
357			} else {
358				// This is certainly not guaranteed to always be true - the read loop may exit
359				// while there are still pending write wakers that need to be woken up after the
360				// socket shutdown(). Still, as a check during testing, to make sure tokio doesn't
361				// keep too many wakers around, this makes sense. The race should be rare (we do
362				// some work after shutdown()) and an error would be a major memory leak.
363				#[cfg(test)]
364				debug_assert!(Arc::try_unwrap(last_us).is_ok());
365			}
366		}
367	}
368}
369
370/// Process incoming messages and feed outgoing messages on the provided socket generated by
371/// making an outbound connection which is expected to be accepted by a peer with the given
372/// public key. The relevant processing is set to run free (via tokio::spawn).
373///
374/// The returned future will complete when the peer is disconnected and associated handling
375/// futures are freed, though, because all processing futures are spawned with tokio::spawn, you do
376/// not need to poll the provided future in order to make progress.
377pub fn setup_outbound<PM: Deref + 'static + Send + Sync + Clone>(
378	peer_manager: PM, their_node_id: PublicKey, stream: StdTcpStream,
379) -> impl std::future::Future<Output = ()>
380where
381	PM::Target: APeerManager<Descriptor = SocketDescriptor>,
382{
383	let remote_addr = get_addr_from_stream(&stream);
384	let (reader, mut write_receiver, read_receiver, us) = Connection::new(stream);
385	#[cfg(test)]
386	let last_us = Arc::clone(&us);
387	let handle_opt = if let Ok(initial_send) = peer_manager.as_ref().new_outbound_connection(
388		their_node_id,
389		SocketDescriptor::new(Arc::clone(&us)),
390		remote_addr,
391	) {
392		let handle = tokio::spawn(async move {
393			// We should essentially always have enough room in a TCP socket buffer to send the
394			// initial 10s of bytes. However, tokio running in single-threaded mode will always
395			// fail writes and wake us back up later to write. Thus, we handle a single
396			// std::task::Poll::Pending but still expect to write the full set of bytes at once
397			// and use a relatively tight timeout.
398			let send_fut = async {
399				loop {
400					match SocketDescriptor::new(Arc::clone(&us)).send_data(&initial_send, true) {
401						v if v == initial_send.len() => break Ok(()),
402						0 => {
403							write_receiver.recv().await;
404							// In theory we could check for if we've been instructed to disconnect
405							// the peer here, but its OK to just skip it - we'll check for it in
406							// schedule_read prior to any relevant calls into RL.
407						},
408						_ => {
409							eprintln!("Failed to write first full message to socket!");
410							peer_manager
411								.as_ref()
412								.socket_disconnected(&SocketDescriptor::new(Arc::clone(&us)));
413							break Err(());
414						},
415					}
416				}
417			};
418			let timeout_send_fut = tokio::time::timeout(Duration::from_millis(100), send_fut);
419			if let Ok(Ok(())) = timeout_send_fut.await {
420				Connection::schedule_read(peer_manager, us, reader, read_receiver, write_receiver)
421					.await;
422			}
423		});
424		Some(handle)
425	} else {
426		// Note that we will skip socket_disconnected here, in accordance with the PeerManager
427		// requirements.
428		None
429	};
430
431	async move {
432		if let Some(handle) = handle_opt {
433			if let Err(e) = handle.await {
434				assert!(e.is_cancelled());
435			} else {
436				// This is certainly not guaranteed to always be true - the read loop may exit
437				// while there are still pending write wakers that need to be woken up after the
438				// socket shutdown(). Still, as a check during testing, to make sure tokio doesn't
439				// keep too many wakers around, this makes sense. The race should be rare (we do
440				// some work after shutdown()) and an error would be a major memory leak.
441				#[cfg(test)]
442				debug_assert!(Arc::try_unwrap(last_us).is_ok());
443			}
444		}
445	}
446}
447
448/// Process incoming messages and feed outgoing messages on a new connection made to the given
449/// socket address which is expected to be accepted by a peer with the given public key (by
450/// scheduling futures with tokio::spawn).
451///
452/// Shorthand for TcpStream::connect(addr) with a timeout followed by setup_outbound().
453///
454/// Returns a future (as the fn is async) which needs to be polled to complete the connection and
455/// connection setup. That future then returns a future which will complete when the peer is
456/// disconnected and associated handling futures are freed, though, because all processing in said
457/// futures are spawned with tokio::spawn, you do not need to poll the second future in order to
458/// make progress.
459pub async fn connect_outbound<PM: Deref + 'static + Send + Sync + Clone>(
460	peer_manager: PM, their_node_id: PublicKey, addr: SocketAddr,
461) -> Option<impl std::future::Future<Output = ()>>
462where
463	PM::Target: APeerManager<Descriptor = SocketDescriptor>,
464{
465	let connect_fut = async { TcpStream::connect(&addr).await.map(|s| s.into_std().unwrap()) };
466	if let Ok(Ok(stream)) = time::timeout(Duration::from_secs(10), connect_fut).await {
467		Some(setup_outbound(peer_manager, their_node_id, stream))
468	} else {
469		None
470	}
471}
472
473const SOCK_WAKER_VTABLE: task::RawWakerVTable = task::RawWakerVTable::new(
474	clone_socket_waker,
475	wake_socket_waker,
476	wake_socket_waker_by_ref,
477	drop_socket_waker,
478);
479
480fn clone_socket_waker(orig_ptr: *const ()) -> task::RawWaker {
481	let new_waker = unsafe { Arc::from_raw(orig_ptr as *const mpsc::Sender<()>) };
482	let res = write_avail_to_waker(&new_waker);
483	// Don't decrement the refcount when dropping new_waker by turning it back `into_raw`.
484	let _ = Arc::into_raw(new_waker);
485	res
486}
487// When waking, an error should be fine. Most likely we got two send_datas in a row, both of which
488// failed to fully write, but we only need to call write_buffer_space_avail() once. Otherwise, the
489// sending thread may have already gone away due to a socket close, in which case there's nothing
490// to wake up anyway.
491fn wake_socket_waker(orig_ptr: *const ()) {
492	let sender = unsafe { &mut *(orig_ptr as *mut mpsc::Sender<()>) };
493	let _ = sender.try_send(());
494	drop_socket_waker(orig_ptr);
495}
496fn wake_socket_waker_by_ref(orig_ptr: *const ()) {
497	let sender_ptr = orig_ptr as *const mpsc::Sender<()>;
498	let sender = unsafe { &*sender_ptr };
499	let _ = sender.try_send(());
500}
501fn drop_socket_waker(orig_ptr: *const ()) {
502	let _orig_arc = unsafe { Arc::from_raw(orig_ptr as *mut mpsc::Sender<()>) };
503	// _orig_arc is now dropped
504}
505fn write_avail_to_waker(sender: &Arc<mpsc::Sender<()>>) -> task::RawWaker {
506	let new_ptr = Arc::into_raw(Arc::clone(&sender));
507	task::RawWaker::new(new_ptr as *const (), &SOCK_WAKER_VTABLE)
508}
509
510/// The SocketDescriptor used to refer to sockets by a PeerHandler. This is pub only as it is a
511/// type in the template of PeerHandler.
512pub struct SocketDescriptor {
513	conn: Arc<Mutex<Connection>>,
514	// We store a copy of the mpsc::Sender to wake the read task in an Arc here. While we can
515	// simply clone the sender and store a copy in each waker, that would require allocating for
516	// each waker. Instead, we can simply `Arc::clone`, creating a new reference and store the
517	// pointer in the waker.
518	write_avail_sender: Arc<mpsc::Sender<()>>,
519	id: u64,
520}
521impl SocketDescriptor {
522	fn new(conn: Arc<Mutex<Connection>>) -> Self {
523		let (id, write_avail_sender) = {
524			let us = conn.lock().unwrap();
525			(us.id, Arc::new(us.write_avail.clone()))
526		};
527		Self { conn, id, write_avail_sender }
528	}
529}
530impl peer_handler::SocketDescriptor for SocketDescriptor {
531	fn send_data(&mut self, data: &[u8], continue_read: bool) -> usize {
532		// To send data, we take a lock on our Connection to access the TcpStream, writing to it if
533		// there's room in the kernel buffer, or otherwise create a new Waker with a
534		// SocketDescriptor in it which can wake up the write_avail Sender, waking up the
535		// processing future which will call write_buffer_space_avail and we'll end up back here.
536		let mut us = self.conn.lock().unwrap();
537		if us.writer.is_none() {
538			// The writer gets take()n when it is time to shut down, so just fast-return 0 here.
539			return 0;
540		}
541
542		let read_was_paused = us.read_paused;
543		us.read_paused = !continue_read;
544
545		if continue_read && read_was_paused {
546			// The schedule_read future may go to lock up but end up getting woken up by there
547			// being more room in the write buffer, dropping the other end of this Sender
548			// before we get here, so we ignore any failures to wake it up.
549			let _ = us.read_waker.try_send(());
550		}
551
552		if data.is_empty() {
553			return 0;
554		}
555		let waker =
556			unsafe { task::Waker::from_raw(write_avail_to_waker(&self.write_avail_sender)) };
557		let mut ctx = task::Context::from_waker(&waker);
558		let mut written_len = 0;
559		loop {
560			match us.writer.as_ref().unwrap().poll_write_ready(&mut ctx) {
561				task::Poll::Ready(Ok(())) => {
562					match us.writer.as_ref().unwrap().try_write(&data[written_len..]) {
563						Ok(res) => {
564							debug_assert_ne!(res, 0);
565							written_len += res;
566							if written_len == data.len() {
567								return written_len;
568							}
569						},
570						Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
571							continue;
572						},
573						Err(_) => return written_len,
574					}
575				},
576				task::Poll::Ready(Err(_)) => return written_len,
577				task::Poll::Pending => return written_len,
578			}
579		}
580	}
581
582	fn disconnect_socket(&mut self) {
583		let mut us = self.conn.lock().unwrap();
584		us.rl_requested_disconnect = true;
585		// Wake up the sending thread, assuming it is still alive
586		let _ = us.write_avail.try_send(());
587	}
588}
589impl Clone for SocketDescriptor {
590	fn clone(&self) -> Self {
591		Self {
592			conn: Arc::clone(&self.conn),
593			id: self.id,
594			write_avail_sender: Arc::clone(&self.write_avail_sender),
595		}
596	}
597}
598impl Eq for SocketDescriptor {}
599impl PartialEq for SocketDescriptor {
600	fn eq(&self, o: &Self) -> bool {
601		self.id == o.id
602	}
603}
604impl Hash for SocketDescriptor {
605	fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
606		self.id.hash(state);
607	}
608}
609
610#[cfg(test)]
611mod tests {
612	use bitcoin::constants::ChainHash;
613	use bitcoin::secp256k1::{PublicKey, Secp256k1, SecretKey};
614	use bitcoin::Network;
615	use lightning::ln::msgs::*;
616	use lightning::ln::peer_handler::{IgnoringMessageHandler, MessageHandler, PeerManager};
617	use lightning::ln::types::ChannelId;
618	use lightning::routing::gossip::NodeId;
619	use lightning::types::features::*;
620	use lightning::util::test_utils::TestNodeSigner;
621
622	use tokio::sync::mpsc;
623
624	use std::mem;
625	use std::sync::atomic::{AtomicBool, Ordering};
626	use std::sync::{Arc, Mutex};
627	use std::time::Duration;
628
629	pub struct TestLogger();
630	impl lightning::util::logger::Logger for TestLogger {
631		fn log(&self, record: lightning::util::logger::Record) {
632			println!(
633				"{:<5} [{} : {}, {}] {}",
634				record.level.to_string(),
635				record.module_path,
636				record.file,
637				record.line,
638				record.args
639			);
640		}
641	}
642
643	struct MsgHandler {
644		expected_pubkey: PublicKey,
645		pubkey_connected: mpsc::Sender<()>,
646		pubkey_disconnected: mpsc::Sender<()>,
647		disconnected_flag: AtomicBool,
648		msg_events: Mutex<Vec<MessageSendEvent>>,
649	}
650	impl RoutingMessageHandler for MsgHandler {
651		fn handle_node_announcement(
652			&self, _their_node_id: Option<PublicKey>, _msg: &NodeAnnouncement,
653		) -> Result<bool, LightningError> {
654			Ok(false)
655		}
656		fn handle_channel_announcement(
657			&self, _their_node_id: Option<PublicKey>, _msg: &ChannelAnnouncement,
658		) -> Result<bool, LightningError> {
659			Ok(false)
660		}
661		fn handle_channel_update(
662			&self, _their_node_id: Option<PublicKey>, _msg: &ChannelUpdate,
663		) -> Result<bool, LightningError> {
664			Ok(false)
665		}
666		fn get_next_channel_announcement(
667			&self, _starting_point: u64,
668		) -> Option<(ChannelAnnouncement, Option<ChannelUpdate>, Option<ChannelUpdate>)> {
669			None
670		}
671		fn get_next_node_announcement(
672			&self, _starting_point: Option<&NodeId>,
673		) -> Option<NodeAnnouncement> {
674			None
675		}
676		fn handle_reply_channel_range(
677			&self, _their_node_id: PublicKey, _msg: ReplyChannelRange,
678		) -> Result<(), LightningError> {
679			Ok(())
680		}
681		fn handle_reply_short_channel_ids_end(
682			&self, _their_node_id: PublicKey, _msg: ReplyShortChannelIdsEnd,
683		) -> Result<(), LightningError> {
684			Ok(())
685		}
686		fn handle_query_channel_range(
687			&self, _their_node_id: PublicKey, _msg: QueryChannelRange,
688		) -> Result<(), LightningError> {
689			Ok(())
690		}
691		fn handle_query_short_channel_ids(
692			&self, _their_node_id: PublicKey, _msg: QueryShortChannelIds,
693		) -> Result<(), LightningError> {
694			Ok(())
695		}
696		fn processing_queue_high(&self) -> bool {
697			false
698		}
699	}
700	impl ChannelMessageHandler for MsgHandler {
701		fn handle_open_channel(&self, _their_node_id: PublicKey, _msg: &OpenChannel) {}
702		fn handle_accept_channel(&self, _their_node_id: PublicKey, _msg: &AcceptChannel) {}
703		fn handle_funding_created(&self, _their_node_id: PublicKey, _msg: &FundingCreated) {}
704		fn handle_funding_signed(&self, _their_node_id: PublicKey, _msg: &FundingSigned) {}
705		fn handle_channel_ready(&self, _their_node_id: PublicKey, _msg: &ChannelReady) {}
706		fn handle_shutdown(&self, _their_node_id: PublicKey, _msg: &Shutdown) {}
707		fn handle_closing_signed(&self, _their_node_id: PublicKey, _msg: &ClosingSigned) {}
708		#[cfg(simple_close)]
709		fn handle_closing_complete(&self, _their_node_id: PublicKey, _msg: ClosingComplete) {}
710		#[cfg(simple_close)]
711		fn handle_closing_sig(&self, _their_node_id: PublicKey, _msg: ClosingSig) {}
712		fn handle_update_add_htlc(&self, _their_node_id: PublicKey, _msg: &UpdateAddHTLC) {}
713		fn handle_update_fulfill_htlc(&self, _their_node_id: PublicKey, _msg: UpdateFulfillHTLC) {}
714		fn handle_update_fail_htlc(&self, _their_node_id: PublicKey, _msg: &UpdateFailHTLC) {}
715		fn handle_update_fail_malformed_htlc(
716			&self, _their_node_id: PublicKey, _msg: &UpdateFailMalformedHTLC,
717		) {
718		}
719		fn handle_commitment_signed(&self, _their_node_id: PublicKey, _msg: &CommitmentSigned) {}
720		fn handle_commitment_signed_batch(
721			&self, _their_node_id: PublicKey, _channel_id: ChannelId, _batch: Vec<CommitmentSigned>,
722		) {
723		}
724		fn handle_revoke_and_ack(&self, _their_node_id: PublicKey, _msg: &RevokeAndACK) {}
725		fn handle_update_fee(&self, _their_node_id: PublicKey, _msg: &UpdateFee) {}
726		fn handle_announcement_signatures(
727			&self, _their_node_id: PublicKey, _msg: &AnnouncementSignatures,
728		) {
729		}
730		fn handle_channel_update(&self, _their_node_id: PublicKey, _msg: &ChannelUpdate) {}
731		fn handle_open_channel_v2(&self, _their_node_id: PublicKey, _msg: &OpenChannelV2) {}
732		fn handle_accept_channel_v2(&self, _their_node_id: PublicKey, _msg: &AcceptChannelV2) {}
733		fn handle_stfu(&self, _their_node_id: PublicKey, _msg: &Stfu) {}
734		fn handle_splice_init(&self, _their_node_id: PublicKey, _msg: &SpliceInit) {}
735		fn handle_splice_ack(&self, _their_node_id: PublicKey, _msg: &SpliceAck) {}
736		fn handle_splice_locked(&self, _their_node_id: PublicKey, _msg: &SpliceLocked) {}
737		fn handle_tx_add_input(&self, _their_node_id: PublicKey, _msg: &TxAddInput) {}
738		fn handle_tx_add_output(&self, _their_node_id: PublicKey, _msg: &TxAddOutput) {}
739		fn handle_tx_remove_input(&self, _their_node_id: PublicKey, _msg: &TxRemoveInput) {}
740		fn handle_tx_remove_output(&self, _their_node_id: PublicKey, _msg: &TxRemoveOutput) {}
741		fn handle_tx_complete(&self, _their_node_id: PublicKey, _msg: &TxComplete) {}
742		fn handle_tx_signatures(&self, _their_node_id: PublicKey, _msg: &TxSignatures) {}
743		fn handle_tx_init_rbf(&self, _their_node_id: PublicKey, _msg: &TxInitRbf) {}
744		fn handle_tx_ack_rbf(&self, _their_node_id: PublicKey, _msg: &TxAckRbf) {}
745		fn handle_tx_abort(&self, _their_node_id: PublicKey, _msg: &TxAbort) {}
746		fn handle_peer_storage(&self, _their_node_id: PublicKey, _msg: PeerStorage) {}
747		fn handle_peer_storage_retrieval(
748			&self, _their_node_id: PublicKey, _msg: PeerStorageRetrieval,
749		) {
750		}
751		fn handle_channel_reestablish(&self, _their_node_id: PublicKey, _msg: &ChannelReestablish) {
752		}
753		fn handle_error(&self, _their_node_id: PublicKey, _msg: &ErrorMessage) {}
754		fn get_chain_hashes(&self) -> Option<Vec<ChainHash>> {
755			Some(vec![ChainHash::using_genesis_block(Network::Testnet)])
756		}
757		fn message_received(&self) {}
758	}
759	impl BaseMessageHandler for MsgHandler {
760		fn peer_disconnected(&self, their_node_id: PublicKey) {
761			if their_node_id == self.expected_pubkey {
762				self.disconnected_flag.store(true, Ordering::SeqCst);
763				// This method is called twice as we're two message handlers. `try_send` will fail
764				// the second time.
765				let _ = self.pubkey_disconnected.clone().try_send(());
766			}
767		}
768		fn peer_connected(
769			&self, their_node_id: PublicKey, _init_msg: &Init, _inbound: bool,
770		) -> Result<(), ()> {
771			if their_node_id == self.expected_pubkey {
772				// This method is called twice as we're two message handlers. `try_send` will fail
773				// the second time.
774				let _ = self.pubkey_connected.clone().try_send(());
775			}
776			Ok(())
777		}
778		fn provided_node_features(&self) -> NodeFeatures {
779			NodeFeatures::empty()
780		}
781		fn provided_init_features(&self, _their_node_id: PublicKey) -> InitFeatures {
782			InitFeatures::empty()
783		}
784		fn get_and_clear_pending_msg_events(&self) -> Vec<MessageSendEvent> {
785			let mut ret = Vec::new();
786			mem::swap(&mut *self.msg_events.lock().unwrap(), &mut ret);
787			ret
788		}
789	}
790
791	fn make_tcp_connection() -> (std::net::TcpStream, std::net::TcpStream) {
792		if let Ok(listener) = std::net::TcpListener::bind("127.0.0.1:9735") {
793			(std::net::TcpStream::connect("127.0.0.1:9735").unwrap(), listener.accept().unwrap().0)
794		} else if let Ok(listener) = std::net::TcpListener::bind("127.0.0.1:19735") {
795			(std::net::TcpStream::connect("127.0.0.1:19735").unwrap(), listener.accept().unwrap().0)
796		} else if let Ok(listener) = std::net::TcpListener::bind("127.0.0.1:9997") {
797			(std::net::TcpStream::connect("127.0.0.1:9997").unwrap(), listener.accept().unwrap().0)
798		} else if let Ok(listener) = std::net::TcpListener::bind("127.0.0.1:9998") {
799			(std::net::TcpStream::connect("127.0.0.1:9998").unwrap(), listener.accept().unwrap().0)
800		} else if let Ok(listener) = std::net::TcpListener::bind("127.0.0.1:9999") {
801			(std::net::TcpStream::connect("127.0.0.1:9999").unwrap(), listener.accept().unwrap().0)
802		} else if let Ok(listener) = std::net::TcpListener::bind("127.0.0.1:46926") {
803			(std::net::TcpStream::connect("127.0.0.1:46926").unwrap(), listener.accept().unwrap().0)
804		} else {
805			panic!("Failed to bind to v4 localhost on common ports");
806		}
807	}
808
809	async fn do_basic_connection_test() {
810		let secp_ctx = Secp256k1::new();
811		let a_key = SecretKey::from_slice(&[1; 32]).unwrap();
812		let b_key = SecretKey::from_slice(&[1; 32]).unwrap();
813		let a_pub = PublicKey::from_secret_key(&secp_ctx, &a_key);
814		let b_pub = PublicKey::from_secret_key(&secp_ctx, &b_key);
815
816		let (a_connected_sender, mut a_connected) = mpsc::channel(1);
817		let (a_disconnected_sender, mut a_disconnected) = mpsc::channel(1);
818		let a_handler = Arc::new(MsgHandler {
819			expected_pubkey: b_pub,
820			pubkey_connected: a_connected_sender,
821			pubkey_disconnected: a_disconnected_sender,
822			disconnected_flag: AtomicBool::new(false),
823			msg_events: Mutex::new(Vec::new()),
824		});
825		let a_msg_handler = MessageHandler {
826			chan_handler: Arc::clone(&a_handler),
827			route_handler: Arc::clone(&a_handler),
828			onion_message_handler: Arc::new(IgnoringMessageHandler {}),
829			custom_message_handler: Arc::new(IgnoringMessageHandler {}),
830			send_only_message_handler: Arc::new(IgnoringMessageHandler {}),
831		};
832		let a_manager = Arc::new(PeerManager::new(
833			a_msg_handler,
834			0,
835			&[1; 32],
836			Arc::new(TestLogger()),
837			Arc::new(TestNodeSigner::new(a_key)),
838		));
839
840		let (b_connected_sender, mut b_connected) = mpsc::channel(1);
841		let (b_disconnected_sender, mut b_disconnected) = mpsc::channel(1);
842		let b_handler = Arc::new(MsgHandler {
843			expected_pubkey: a_pub,
844			pubkey_connected: b_connected_sender,
845			pubkey_disconnected: b_disconnected_sender,
846			disconnected_flag: AtomicBool::new(false),
847			msg_events: Mutex::new(Vec::new()),
848		});
849		let b_msg_handler = MessageHandler {
850			chan_handler: Arc::clone(&b_handler),
851			route_handler: Arc::clone(&b_handler),
852			onion_message_handler: Arc::new(IgnoringMessageHandler {}),
853			custom_message_handler: Arc::new(IgnoringMessageHandler {}),
854			send_only_message_handler: Arc::new(IgnoringMessageHandler {}),
855		};
856		let b_manager = Arc::new(PeerManager::new(
857			b_msg_handler,
858			0,
859			&[2; 32],
860			Arc::new(TestLogger()),
861			Arc::new(TestNodeSigner::new(b_key)),
862		));
863
864		// We bind on localhost, hoping the environment is properly configured with a local
865		// address. This may not always be the case in containers and the like, so if this test is
866		// failing for you check that you have a loopback interface and it is configured with
867		// 127.0.0.1.
868		let (conn_a, conn_b) = make_tcp_connection();
869
870		let fut_a = super::setup_outbound(Arc::clone(&a_manager), b_pub, conn_a);
871		let fut_b = super::setup_inbound(b_manager, conn_b);
872
873		tokio::time::timeout(Duration::from_secs(10), a_connected.recv()).await.unwrap();
874		tokio::time::timeout(Duration::from_secs(1), b_connected.recv()).await.unwrap();
875
876		a_handler.msg_events.lock().unwrap().push(MessageSendEvent::HandleError {
877			node_id: b_pub,
878			action: ErrorAction::DisconnectPeer { msg: None },
879		});
880		assert!(!a_handler.disconnected_flag.load(Ordering::SeqCst));
881		assert!(!b_handler.disconnected_flag.load(Ordering::SeqCst));
882
883		a_manager.process_events();
884		tokio::time::timeout(Duration::from_secs(10), a_disconnected.recv()).await.unwrap();
885		tokio::time::timeout(Duration::from_secs(1), b_disconnected.recv()).await.unwrap();
886		assert!(a_handler.disconnected_flag.load(Ordering::SeqCst));
887		assert!(b_handler.disconnected_flag.load(Ordering::SeqCst));
888
889		fut_a.await;
890		fut_b.await;
891	}
892
893	#[tokio::test(flavor = "multi_thread")]
894	async fn basic_threaded_connection_test() {
895		do_basic_connection_test().await;
896	}
897
898	#[tokio::test]
899	async fn basic_unthreaded_connection_test() {
900		do_basic_connection_test().await;
901	}
902
903	async fn race_disconnect_accept() {
904		// Previously, if we handed an already-disconnected socket to `setup_inbound` we'd panic.
905		// This attempts to find other similar races by opening connections and shutting them down
906		// while connecting. Sadly in testing this did *not* reproduce the previous issue.
907		let secp_ctx = Secp256k1::new();
908		let a_key = SecretKey::from_slice(&[1; 32]).unwrap();
909		let b_key = SecretKey::from_slice(&[2; 32]).unwrap();
910		let b_pub = PublicKey::from_secret_key(&secp_ctx, &b_key);
911
912		let a_msg_handler = MessageHandler {
913			chan_handler: Arc::new(lightning::ln::peer_handler::ErroringMessageHandler::new()),
914			onion_message_handler: Arc::new(IgnoringMessageHandler {}),
915			route_handler: Arc::new(lightning::ln::peer_handler::IgnoringMessageHandler {}),
916			custom_message_handler: Arc::new(IgnoringMessageHandler {}),
917			send_only_message_handler: Arc::new(IgnoringMessageHandler {}),
918		};
919		let a_manager = Arc::new(PeerManager::new(
920			a_msg_handler,
921			0,
922			&[1; 32],
923			Arc::new(TestLogger()),
924			Arc::new(TestNodeSigner::new(a_key)),
925		));
926
927		// Make two connections, one for an inbound and one for an outbound connection
928		let conn_a = {
929			let (conn_a, _) = make_tcp_connection();
930			conn_a
931		};
932		let conn_b = {
933			let (_, conn_b) = make_tcp_connection();
934			conn_b
935		};
936
937		// Call connection setup inside new tokio tasks.
938		let manager_reference = Arc::clone(&a_manager);
939		tokio::spawn(async move { super::setup_inbound(manager_reference, conn_a).await });
940		tokio::spawn(async move { super::setup_outbound(a_manager, b_pub, conn_b).await });
941	}
942
943	#[tokio::test(flavor = "multi_thread")]
944	async fn threaded_race_disconnect_accept() {
945		race_disconnect_accept().await;
946	}
947
948	#[tokio::test]
949	async fn unthreaded_race_disconnect_accept() {
950		race_disconnect_accept().await;
951	}
952}