cli/tunnels/
local_forwarding.rs

1/*---------------------------------------------------------------------------------------------
2 *  Copyright (c) Microsoft Corporation. All rights reserved.
3 *  Licensed under the MIT License. See License.txt in the project root for license information.
4 *--------------------------------------------------------------------------------------------*/
5
6use std::{
7	collections::HashMap,
8	ops::{Index, IndexMut},
9	sync::{Arc, Mutex},
10};
11
12use tokio::{
13	pin,
14	sync::{mpsc, watch},
15};
16
17use crate::{
18	async_pipe::{socket_stream_split, AsyncPipe},
19	json_rpc::{new_json_rpc, start_json_rpc},
20	log,
21	singleton::SingletonServer,
22	util::{errors::CodeError, sync::Barrier},
23};
24
25use super::{
26	dev_tunnels::ActiveTunnel,
27	protocol::{
28		self,
29		forward_singleton::{PortList, SetPortsResponse},
30		PortPrivacy,
31	},
32	shutdown_signal::ShutdownSignal,
33};
34
35#[derive(Default, Clone)]
36struct PortCount {
37	public: u32,
38	private: u32,
39}
40
41impl Index<PortPrivacy> for PortCount {
42	type Output = u32;
43
44	fn index(&self, privacy: PortPrivacy) -> &Self::Output {
45		match privacy {
46			PortPrivacy::Public => &self.public,
47			PortPrivacy::Private => &self.private,
48		}
49	}
50}
51
52impl IndexMut<PortPrivacy> for PortCount {
53	fn index_mut(&mut self, privacy: PortPrivacy) -> &mut Self::Output {
54		match privacy {
55			PortPrivacy::Public => &mut self.public,
56			PortPrivacy::Private => &mut self.private,
57		}
58	}
59}
60
61impl PortCount {
62	fn is_empty(&self) -> bool {
63		self.public == 0 && self.private == 0
64	}
65
66	fn primary_privacy(&self) -> PortPrivacy {
67		if self.public > 0 {
68			PortPrivacy::Public
69		} else {
70			PortPrivacy::Private
71		}
72	}
73}
74
75type PortMap = HashMap<u16, PortCount>;
76
77/// The PortForwardingHandle is given out to multiple consumers to allow
78/// them to set_ports that they want to be forwarded.
79struct PortForwardingSender {
80	/// Todo: when `SyncUnsafeCell` is no longer nightly, we can use it here with
81	/// the following comment:
82	///
83	/// SyncUnsafeCell is used and safe here because PortForwardingSender is used
84	/// exclusively in synchronous dispatch *and* we create a new sender in the
85	/// context for each connection, in `serve_singleton_rpc`.
86	///
87	/// If PortForwardingSender is ever used in a different context, this should
88	/// be refactored, e.g. to use locks or `&mut self` in set_ports`
89	///
90	/// see https://doc.rust-lang.org/stable/std/cell/struct.SyncUnsafeCell.html
91	current: Mutex<PortList>,
92	sender: Arc<Mutex<watch::Sender<PortMap>>>,
93}
94
95impl PortForwardingSender {
96	pub fn set_ports(&self, ports: PortList) {
97		let mut current = self.current.lock().unwrap();
98		self.sender.lock().unwrap().send_modify(|v| {
99			for p in current.iter() {
100				if !ports.contains(p) {
101					let n = v.get_mut(&p.number).expect("expected port in map");
102					n[p.privacy] -= 1;
103					if n.is_empty() {
104						v.remove(&p.number);
105					}
106				}
107			}
108
109			for p in ports.iter() {
110				if !current.contains(p) {
111					match v.get_mut(&p.number) {
112						Some(n) => {
113							n[p.privacy] += 1;
114						}
115						None => {
116							let mut pc = PortCount::default();
117							pc[p.privacy] += 1;
118							v.insert(p.number, pc);
119						}
120					};
121				}
122			}
123
124			current.splice(.., ports);
125		});
126	}
127}
128
129impl Clone for PortForwardingSender {
130	fn clone(&self) -> Self {
131		Self {
132			current: Mutex::new(vec![]),
133			sender: self.sender.clone(),
134		}
135	}
136}
137
138impl Drop for PortForwardingSender {
139	fn drop(&mut self) {
140		self.set_ports(vec![]);
141	}
142}
143
144struct PortForwardingReceiver {
145	receiver: watch::Receiver<PortMap>,
146}
147
148impl PortForwardingReceiver {
149	pub fn new() -> (PortForwardingSender, Self) {
150		let (sender, receiver) = watch::channel(HashMap::new());
151		let handle = PortForwardingSender {
152			current: Mutex::new(vec![]),
153			sender: Arc::new(Mutex::new(sender)),
154		};
155
156		let tracker = Self { receiver };
157
158		(handle, tracker)
159	}
160
161	/// Applies all changes from PortForwardingHandles to the tunnel.
162	pub async fn apply_to(&mut self, log: log::Logger, tunnel: Arc<ActiveTunnel>) {
163		let mut current: PortMap = HashMap::new();
164		while self.receiver.changed().await.is_ok() {
165			let next = self.receiver.borrow().clone();
166
167			for (port, count) in current.iter() {
168				let privacy = count.primary_privacy();
169				if !matches!(next.get(port), Some(n) if n.primary_privacy() == privacy) {
170					match tunnel.remove_port(*port).await {
171						Ok(_) => info!(log, "stopped forwarding port {} at {:?}", *port, privacy),
172						Err(e) => error!(log, "failed to stop forwarding port {}: {}", port, e),
173					}
174				}
175			}
176
177			for (port, count) in next.iter() {
178				let privacy = count.primary_privacy();
179				if !matches!(current.get(port), Some(n) if n.primary_privacy() == privacy) {
180					match tunnel.add_port_tcp(*port, privacy).await {
181						Ok(_) => info!(log, "forwarding port {} at {:?}", port, privacy),
182						Err(e) => error!(log, "failed to forward port {}: {}", port, e),
183					}
184				}
185			}
186
187			current = next;
188		}
189	}
190}
191
192pub struct SingletonClientArgs {
193	pub log: log::Logger,
194	pub stream: AsyncPipe,
195	pub shutdown: Barrier<ShutdownSignal>,
196	pub port_requests: watch::Receiver<PortList>,
197}
198
199#[derive(Clone)]
200struct SingletonServerContext {
201	log: log::Logger,
202	handle: PortForwardingSender,
203	tunnel: Arc<ActiveTunnel>,
204}
205
206/// Serves a client singleton for port forwarding.
207pub async fn client(args: SingletonClientArgs) -> Result<(), std::io::Error> {
208	let mut rpc = new_json_rpc();
209	let (msg_tx, msg_rx) = mpsc::unbounded_channel();
210	let SingletonClientArgs {
211		log,
212		shutdown,
213		stream,
214		mut port_requests,
215	} = args;
216
217	debug!(
218		log,
219		"An existing port forwarding process is running on this machine, connecting to it..."
220	);
221
222	let caller = rpc.get_caller(msg_tx);
223	let rpc = rpc.methods(()).build(log.clone());
224	let (read, write) = socket_stream_split(stream);
225
226	let serve = start_json_rpc(rpc, read, write, msg_rx, shutdown);
227	let forward = async move {
228		while port_requests.changed().await.is_ok() {
229			let ports = port_requests.borrow().clone();
230			let r = caller
231				.call::<_, _, protocol::forward_singleton::SetPortsResponse>(
232					protocol::forward_singleton::METHOD_SET_PORTS,
233					protocol::forward_singleton::SetPortsParams { ports },
234				)
235				.await
236				.unwrap();
237
238			match r {
239				Err(e) => error!(log, "failed to set ports: {:?}", e),
240				Ok(r) => print_forwarding_addr(&r),
241			};
242		}
243	};
244
245	tokio::select! {
246		r = serve => r.map(|_| ()),
247		_ = forward => Ok(()),
248	}
249}
250
251/// Serves a port-forwarding singleton.
252pub async fn server(
253	log: log::Logger,
254	tunnel: ActiveTunnel,
255	server: SingletonServer,
256	mut port_requests: watch::Receiver<PortList>,
257	shutdown_rx: Barrier<ShutdownSignal>,
258) -> Result<(), CodeError> {
259	let tunnel = Arc::new(tunnel);
260	let (forward_tx, mut forward_rx) = PortForwardingReceiver::new();
261
262	let forward_own_tunnel = tunnel.clone();
263	let forward_own_tx = forward_tx.clone();
264	let forward_own = async move {
265		while port_requests.changed().await.is_ok() {
266			forward_own_tx.set_ports(port_requests.borrow().clone());
267			print_forwarding_addr(&SetPortsResponse {
268				port_format: forward_own_tunnel.get_port_format().ok(),
269			});
270		}
271	};
272
273	tokio::select! {
274		_ = forward_own => Ok(()),
275		_ = forward_rx.apply_to(log.clone(), tunnel.clone()) => Ok(()),
276		r = serve_singleton_rpc(server, log, tunnel, forward_tx, shutdown_rx) => r,
277	}
278}
279
280async fn serve_singleton_rpc(
281	mut server: SingletonServer,
282	log: log::Logger,
283	tunnel: Arc<ActiveTunnel>,
284	forward_tx: PortForwardingSender,
285	shutdown_rx: Barrier<ShutdownSignal>,
286) -> Result<(), CodeError> {
287	let mut own_shutdown = shutdown_rx.clone();
288	let shutdown_fut = own_shutdown.wait();
289	pin!(shutdown_fut);
290
291	loop {
292		let cnx = tokio::select! {
293			c = server.accept() => c?,
294			_ = &mut shutdown_fut => return Ok(()),
295		};
296
297		let (read, write) = socket_stream_split(cnx);
298		let shutdown_rx = shutdown_rx.clone();
299
300		let handle = forward_tx.clone();
301		let log = log.clone();
302		let tunnel = tunnel.clone();
303		tokio::spawn(async move {
304			// we make an rpc for the connection instead of re-using a dispatcher
305			// so that we can have the "handle" drop when the connection drops.
306			let rpc = new_json_rpc();
307			let mut rpc = rpc.methods(SingletonServerContext {
308				log: log.clone(),
309				handle,
310				tunnel,
311			});
312
313			rpc.register_sync(
314				protocol::forward_singleton::METHOD_SET_PORTS,
315				|p: protocol::forward_singleton::SetPortsParams, ctx| {
316					info!(ctx.log, "client setting ports to {:?}", p.ports);
317					ctx.handle.set_ports(p.ports);
318					Ok(SetPortsResponse {
319						port_format: ctx.tunnel.get_port_format().ok(),
320					})
321				},
322			);
323
324			let _ = start_json_rpc(rpc.build(log), read, write, (), shutdown_rx).await;
325		});
326	}
327}
328
329fn print_forwarding_addr(r: &SetPortsResponse) {
330	eprintln!("{}\n", serde_json::to_string(r).unwrap());
331}