1use std::{
7 collections::HashMap,
8 ops::{Index, IndexMut},
9 sync::{Arc, Mutex},
10};
11
12use tokio::{
13 pin,
14 sync::{mpsc, watch},
15};
16
17use crate::{
18 async_pipe::{socket_stream_split, AsyncPipe},
19 json_rpc::{new_json_rpc, start_json_rpc},
20 log,
21 singleton::SingletonServer,
22 util::{errors::CodeError, sync::Barrier},
23};
24
25use super::{
26 dev_tunnels::ActiveTunnel,
27 protocol::{
28 self,
29 forward_singleton::{PortList, SetPortsResponse},
30 PortPrivacy,
31 },
32 shutdown_signal::ShutdownSignal,
33};
34
35#[derive(Default, Clone)]
36struct PortCount {
37 public: u32,
38 private: u32,
39}
40
41impl Index<PortPrivacy> for PortCount {
42 type Output = u32;
43
44 fn index(&self, privacy: PortPrivacy) -> &Self::Output {
45 match privacy {
46 PortPrivacy::Public => &self.public,
47 PortPrivacy::Private => &self.private,
48 }
49 }
50}
51
52impl IndexMut<PortPrivacy> for PortCount {
53 fn index_mut(&mut self, privacy: PortPrivacy) -> &mut Self::Output {
54 match privacy {
55 PortPrivacy::Public => &mut self.public,
56 PortPrivacy::Private => &mut self.private,
57 }
58 }
59}
60
61impl PortCount {
62 fn is_empty(&self) -> bool {
63 self.public == 0 && self.private == 0
64 }
65
66 fn primary_privacy(&self) -> PortPrivacy {
67 if self.public > 0 {
68 PortPrivacy::Public
69 } else {
70 PortPrivacy::Private
71 }
72 }
73}
74
75type PortMap = HashMap<u16, PortCount>;
76
77struct PortForwardingSender {
80 current: Mutex<PortList>,
92 sender: Arc<Mutex<watch::Sender<PortMap>>>,
93}
94
95impl PortForwardingSender {
96 pub fn set_ports(&self, ports: PortList) {
97 let mut current = self.current.lock().unwrap();
98 self.sender.lock().unwrap().send_modify(|v| {
99 for p in current.iter() {
100 if !ports.contains(p) {
101 let n = v.get_mut(&p.number).expect("expected port in map");
102 n[p.privacy] -= 1;
103 if n.is_empty() {
104 v.remove(&p.number);
105 }
106 }
107 }
108
109 for p in ports.iter() {
110 if !current.contains(p) {
111 match v.get_mut(&p.number) {
112 Some(n) => {
113 n[p.privacy] += 1;
114 }
115 None => {
116 let mut pc = PortCount::default();
117 pc[p.privacy] += 1;
118 v.insert(p.number, pc);
119 }
120 };
121 }
122 }
123
124 current.splice(.., ports);
125 });
126 }
127}
128
129impl Clone for PortForwardingSender {
130 fn clone(&self) -> Self {
131 Self {
132 current: Mutex::new(vec![]),
133 sender: self.sender.clone(),
134 }
135 }
136}
137
138impl Drop for PortForwardingSender {
139 fn drop(&mut self) {
140 self.set_ports(vec![]);
141 }
142}
143
144struct PortForwardingReceiver {
145 receiver: watch::Receiver<PortMap>,
146}
147
148impl PortForwardingReceiver {
149 pub fn new() -> (PortForwardingSender, Self) {
150 let (sender, receiver) = watch::channel(HashMap::new());
151 let handle = PortForwardingSender {
152 current: Mutex::new(vec![]),
153 sender: Arc::new(Mutex::new(sender)),
154 };
155
156 let tracker = Self { receiver };
157
158 (handle, tracker)
159 }
160
161 pub async fn apply_to(&mut self, log: log::Logger, tunnel: Arc<ActiveTunnel>) {
163 let mut current: PortMap = HashMap::new();
164 while self.receiver.changed().await.is_ok() {
165 let next = self.receiver.borrow().clone();
166
167 for (port, count) in current.iter() {
168 let privacy = count.primary_privacy();
169 if !matches!(next.get(port), Some(n) if n.primary_privacy() == privacy) {
170 match tunnel.remove_port(*port).await {
171 Ok(_) => info!(log, "stopped forwarding port {} at {:?}", *port, privacy),
172 Err(e) => error!(log, "failed to stop forwarding port {}: {}", port, e),
173 }
174 }
175 }
176
177 for (port, count) in next.iter() {
178 let privacy = count.primary_privacy();
179 if !matches!(current.get(port), Some(n) if n.primary_privacy() == privacy) {
180 match tunnel.add_port_tcp(*port, privacy).await {
181 Ok(_) => info!(log, "forwarding port {} at {:?}", port, privacy),
182 Err(e) => error!(log, "failed to forward port {}: {}", port, e),
183 }
184 }
185 }
186
187 current = next;
188 }
189 }
190}
191
192pub struct SingletonClientArgs {
193 pub log: log::Logger,
194 pub stream: AsyncPipe,
195 pub shutdown: Barrier<ShutdownSignal>,
196 pub port_requests: watch::Receiver<PortList>,
197}
198
199#[derive(Clone)]
200struct SingletonServerContext {
201 log: log::Logger,
202 handle: PortForwardingSender,
203 tunnel: Arc<ActiveTunnel>,
204}
205
206pub async fn client(args: SingletonClientArgs) -> Result<(), std::io::Error> {
208 let mut rpc = new_json_rpc();
209 let (msg_tx, msg_rx) = mpsc::unbounded_channel();
210 let SingletonClientArgs {
211 log,
212 shutdown,
213 stream,
214 mut port_requests,
215 } = args;
216
217 debug!(
218 log,
219 "An existing port forwarding process is running on this machine, connecting to it..."
220 );
221
222 let caller = rpc.get_caller(msg_tx);
223 let rpc = rpc.methods(()).build(log.clone());
224 let (read, write) = socket_stream_split(stream);
225
226 let serve = start_json_rpc(rpc, read, write, msg_rx, shutdown);
227 let forward = async move {
228 while port_requests.changed().await.is_ok() {
229 let ports = port_requests.borrow().clone();
230 let r = caller
231 .call::<_, _, protocol::forward_singleton::SetPortsResponse>(
232 protocol::forward_singleton::METHOD_SET_PORTS,
233 protocol::forward_singleton::SetPortsParams { ports },
234 )
235 .await
236 .unwrap();
237
238 match r {
239 Err(e) => error!(log, "failed to set ports: {:?}", e),
240 Ok(r) => print_forwarding_addr(&r),
241 };
242 }
243 };
244
245 tokio::select! {
246 r = serve => r.map(|_| ()),
247 _ = forward => Ok(()),
248 }
249}
250
251pub async fn server(
253 log: log::Logger,
254 tunnel: ActiveTunnel,
255 server: SingletonServer,
256 mut port_requests: watch::Receiver<PortList>,
257 shutdown_rx: Barrier<ShutdownSignal>,
258) -> Result<(), CodeError> {
259 let tunnel = Arc::new(tunnel);
260 let (forward_tx, mut forward_rx) = PortForwardingReceiver::new();
261
262 let forward_own_tunnel = tunnel.clone();
263 let forward_own_tx = forward_tx.clone();
264 let forward_own = async move {
265 while port_requests.changed().await.is_ok() {
266 forward_own_tx.set_ports(port_requests.borrow().clone());
267 print_forwarding_addr(&SetPortsResponse {
268 port_format: forward_own_tunnel.get_port_format().ok(),
269 });
270 }
271 };
272
273 tokio::select! {
274 _ = forward_own => Ok(()),
275 _ = forward_rx.apply_to(log.clone(), tunnel.clone()) => Ok(()),
276 r = serve_singleton_rpc(server, log, tunnel, forward_tx, shutdown_rx) => r,
277 }
278}
279
280async fn serve_singleton_rpc(
281 mut server: SingletonServer,
282 log: log::Logger,
283 tunnel: Arc<ActiveTunnel>,
284 forward_tx: PortForwardingSender,
285 shutdown_rx: Barrier<ShutdownSignal>,
286) -> Result<(), CodeError> {
287 let mut own_shutdown = shutdown_rx.clone();
288 let shutdown_fut = own_shutdown.wait();
289 pin!(shutdown_fut);
290
291 loop {
292 let cnx = tokio::select! {
293 c = server.accept() => c?,
294 _ = &mut shutdown_fut => return Ok(()),
295 };
296
297 let (read, write) = socket_stream_split(cnx);
298 let shutdown_rx = shutdown_rx.clone();
299
300 let handle = forward_tx.clone();
301 let log = log.clone();
302 let tunnel = tunnel.clone();
303 tokio::spawn(async move {
304 let rpc = new_json_rpc();
307 let mut rpc = rpc.methods(SingletonServerContext {
308 log: log.clone(),
309 handle,
310 tunnel,
311 });
312
313 rpc.register_sync(
314 protocol::forward_singleton::METHOD_SET_PORTS,
315 |p: protocol::forward_singleton::SetPortsParams, ctx| {
316 info!(ctx.log, "client setting ports to {:?}", p.ports);
317 ctx.handle.set_ports(p.ports);
318 Ok(SetPortsResponse {
319 port_format: ctx.tunnel.get_port_format().ok(),
320 })
321 },
322 );
323
324 let _ = start_json_rpc(rpc.build(log), read, write, (), shutdown_rx).await;
325 });
326 }
327}
328
329fn print_forwarding_addr(r: &SetPortsResponse) {
330 eprintln!("{}\n", serde_json::to_string(r).unwrap());
331}