1use std::{
2 collections::{LinkedList, HashMap},
3 net::{UdpSocket, SocketAddr},
4 cell::RefCell,
5 thread,
6 time::Duration,
7};
8use cyfs_debug::Mutex;
9use async_std::{
10 sync::{Arc},
11 task,
12 future
13};
14use cyfs_base::*;
15use crate::{
16 types::*,
17 interface::udp::MTU_LARGE
18};
19use std::time::{UNIX_EPOCH, SystemTime};
20
21#[derive(Clone)]
22pub struct Config {
23 pub keepalive: Duration
24}
25
26#[derive(Clone, Debug)]
27pub struct ProxyDeviceStub {
28 pub id: DeviceId,
29 pub timestamp: Timestamp,
30}
31
32#[derive(Clone, Debug)]
33pub struct ProxyEndpointStub {
34 endpoint: SocketAddr,
35 last_active: Timestamp
36}
37
38#[derive(Clone)]
39struct ProxyTunnel {
40 device_pair: (ProxyDeviceStub, ProxyDeviceStub),
41 endpoint_pair: (Option<ProxyEndpointStub>, Option<ProxyEndpointStub>),
42 last_active: Timestamp
43}
44
45impl std::fmt::Display for ProxyTunnel {
46 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
47 write!(f, "ProxyTunnel")
48 }
49}
50
51impl ProxyTunnel {
52 fn new(device_pair: (ProxyDeviceStub, ProxyDeviceStub)) -> Self {
53 Self {
54 device_pair,
55 endpoint_pair: (None, None),
56 last_active: bucky_time_now()
57 }
58 }
59
60 fn recyclable(&self, now: Timestamp, keepalive: Duration) -> bool {
61 if now > self.last_active && Duration::from_micros(now - self.last_active) > keepalive {
62 true
63 } else {
64 false
65 }
66 }
67
68 fn on_device_pair(&mut self, device_pair: (ProxyDeviceStub, ProxyDeviceStub)) -> BuckyResult<()> {
69 self.last_active = bucky_time_now();
70 let (left, right) = device_pair;
71 let (fl, fr) = {
72 if left.id.eq(&self.device_pair.0.id) && right.id.eq(&self.device_pair.1.id) {
73 Ok((&mut self.device_pair.0, &mut self.device_pair.1))
74 } else if right.id.eq(&self.device_pair.0.id) && left.id.eq(&self.device_pair.1.id) {
75 Ok((&mut self.device_pair.1, &mut self.device_pair.0))
76 } else {
77 trace!("{} ignore device pair ({:?}, {:?}) for not match {:?}", self, left, right, self.device_pair);
78 Err(BuckyError::new(BuckyErrorCode::NotMatch, "device pair not match"))
79 }
80 }?;
81 if left.timestamp > fl.timestamp {
82 fl.timestamp = left.timestamp;
83 self.endpoint_pair = (None, None);
84 trace!("proxy tunnel update endpoint pair to (None, None)");
85 }
86 if right.timestamp > right.timestamp {
87 fr.timestamp = right.timestamp;
88 self.endpoint_pair = (None, None);
89 trace!("proxy tunnel update endpoint pair to (None, None)");
90 }
91 Ok(())
92 }
93
94 fn on_proxied_datagram(&mut self, mix_hash: &KeyMixHash, from: &SocketAddr) -> Option<SocketAddr> {
95 self.last_active = bucky_time_now();
96 if self.endpoint_pair.0.is_none() {
97 self.endpoint_pair.0 = Some(ProxyEndpointStub {
98 endpoint: *from,
99 last_active: bucky_time_now()
100 });
101 trace!("{} mix_hash:{} update endpoint pair to {:?}", self, mix_hash, self.endpoint_pair);
102 None
103 } else if self.endpoint_pair.1.is_none() {
104 let left = self.endpoint_pair.0.as_mut().unwrap();
105 if left.endpoint.eq(from) {
106 left.last_active = bucky_time_now();
107 } else {
108 self.endpoint_pair.1 = Some(ProxyEndpointStub {
109 endpoint: *from,
110 last_active: bucky_time_now()
111 });
112 }
113 trace!("{} mix_hash:{} update endpoint pair to {:?}", self, mix_hash, self.endpoint_pair);
114 None
115 } else {
116 let left = self.endpoint_pair.0.as_mut().unwrap();
117 let right = self.endpoint_pair.1.as_mut().unwrap();
118
119 if left.endpoint.eq(from) {
120 left.last_active = bucky_time_now();
121 Some(right.endpoint)
122 } else if right.endpoint.eq(from) {
123 right.last_active = bucky_time_now();
124 Some(left.endpoint)
125 } else {
126 *left = right.clone();
127 right.endpoint = *from;
128 right.last_active = bucky_time_now();
129 trace!("ProxyTunnel mix_hash:{} mix_hash update endpoint pair to ({:?}, {:?})", mix_hash, left, right);
130 Some(left.endpoint)
131 }
132 }
133 }
134}
135
136#[derive(Clone)]
137struct TunnelMixHash {
138 tunnel: ProxyTunnel,
139 mix_key: AesKey,
140 mixhash: Vec<MixHashInfo>,
141}
142
143impl TunnelMixHash {
144 pub fn recyclable(&self, now: Timestamp, keepalive: Duration) -> bool {
145 self.tunnel.recyclable(now, keepalive)
146 }
147
148 pub fn new(mix_key: AesKey, tunnel: ProxyTunnel) -> Self {
149 TunnelMixHash {
150 tunnel,
151 mix_key,
152 mixhash: Vec::new(),
153 }
154 }
155
156 pub fn rehash(&mut self, min: u64, max: u64) -> (Vec<KeyMixHash>, Vec<KeyMixHash>) {
157 let mut timeout_n = 0;
158 let mut next_ts = min;
159 for h in self.mixhash.as_slice() {
160 let t = h.minute_timestamp;
161 if t < min {
162 timeout_n += 1;
163 } else if t > next_ts {
164 next_ts = t + 1;
165 }
166 }
167
168 let removed: Vec<MixHashInfo> = self.mixhash.splice(..timeout_n, vec![].iter().cloned()).collect();
169 let removed = removed.iter().map(|h| h.hash.clone()).collect();
170
171 let mut added = vec![];
172 if next_ts < max {
173 for t in next_ts..(max+1) {
174 let h = MixHashInfo::new(self.mix_key.mix_hash(Some(t)), t);
175 added.push(h.hash.clone());
176 self.mixhash.push(h);
177 }
178 }
179
180 (added, removed)
181 }
182}
183
184#[derive(Clone)]
185struct MixHashInfo {
186 hash: KeyMixHash,
187 minute_timestamp: u64,
188}
189
190impl MixHashInfo {
191 pub fn new(hash: KeyMixHash, minute_timestamp: u64) -> Self {
192 MixHashInfo {
193 hash,
194 minute_timestamp
195 }
196 }
197}
198
199struct TunnelsManager {
200 tunnel_mixhash_map: HashMap<KeyMixHash, TunnelMixHash>,
201 tunnel_mixkey_list: LinkedList<TunnelMixHash>,
202 keepalive: Duration,
203 mixhash_live_minutes: u64,
204}
205
206impl TunnelsManager {
207 pub fn default() -> Self {
208 let def_keepalive = 60;
209 let def_mixhash_live_minute = 31;
210
211 Self {
212 tunnel_mixhash_map: HashMap::new(),
213 tunnel_mixkey_list: LinkedList::new(),
214 keepalive: Duration::from_secs(def_keepalive),
215 mixhash_live_minutes: def_mixhash_live_minute,
216 }
217 }
218}
219
220impl TunnelsManager {
221 fn minute_timestamp_range(&self) -> (u64, u64) {
222 let minute_timestamp = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs() / 60;
223 let min = minute_timestamp - (self.mixhash_live_minutes - 1) / 2;
224 let max = minute_timestamp + (self.mixhash_live_minutes - 1) / 2;
225
226 (min, max)
227 }
228
229 fn mixkey_update(&mut self, mix_key: AesKey, device_pair: (ProxyDeviceStub, ProxyDeviceStub)) -> BuckyResult<()> {
230 let tunnel = self.tunnel_mixhash_map.get(&mix_key.mix_hash(None)).unwrap();
231 let mut tunnel = tunnel.tunnel.clone();
232 let (left, right) = device_pair;
233
234 let (fl, fr) = {
235 if left.id.eq(&tunnel.device_pair.0.id) && right.id.eq(&tunnel.device_pair.1.id) {
236 Ok((&mut tunnel.device_pair.0, &mut tunnel.device_pair.1))
237 } else if right.id.eq(&tunnel.device_pair.0.id) && left.id.eq(&tunnel.device_pair.1.id) {
238 Ok((&mut tunnel.device_pair.1, &mut tunnel.device_pair.0))
239 } else {
240 trace!("{} ignore device pair ({:?}, {:?}) for not match {:?}", tunnel, left, right, tunnel.device_pair);
241 Err(BuckyError::new(BuckyErrorCode::NotMatch, "device pair not match"))
242 }
243 }?;
244 if left.timestamp > fl.timestamp {
245 fl.timestamp = left.timestamp;
246 tunnel.endpoint_pair = (None, None);
247 trace!("proxy tunnel update endpoint pair to (None, None)");
248 }
249 if right.timestamp > right.timestamp {
250 fr.timestamp = right.timestamp;
251 tunnel.endpoint_pair = (None, None);
252 trace!("proxy tunnel update endpoint pair to (None, None)");
253 }
254
255 Ok(())
256 }
257
258 fn mixkey_add(&mut self, mix_key: AesKey, device_pair: (ProxyDeviceStub, ProxyDeviceStub)) -> BuckyResult<()> {
259 let mut tunnel = TunnelMixHash::new(mix_key.clone(), ProxyTunnel::new(device_pair));
260
261 let (min, max) = self.minute_timestamp_range();
262 let (added, _) = tunnel.rehash(min, max);
263
264 self.tunnel_mixkey_list.push_front(tunnel.clone());
265
266 for h in added.as_slice() {
267 self.tunnel_mixhash_map.insert(h.clone(), tunnel.clone());
268 }
269 self.tunnel_mixhash_map.insert(mix_key.mix_hash(None), tunnel.clone());
270
271 Ok(())
272 }
273
274 pub fn create_tunnel(&mut self, mix_key: AesKey, device_pair: (ProxyDeviceStub, ProxyDeviceStub)) -> BuckyResult<()> {
275 let mix_hash = mix_key.mix_hash(None);
276
277 if self.has_tunnel(&mix_hash) {
278 self.mixkey_update(mix_key, device_pair)
279 } else {
280 self.mixkey_add(mix_key, device_pair)
281 }
282 }
283
284 pub fn on_proxied_datagram(&mut self, datagram: &[u8], from: &SocketAddr) -> Option<SocketAddr> {
285 match KeyMixHash::raw_decode(datagram) {
286 Ok((mut mix_hash, _)) => {
287 mix_hash.as_mut()[0] &= 0x7f;
288 if let Some(tunnel) = self.tunnel_mixhash_map.get_mut(&mix_hash) {
289 trace!("{} recv datagram of mix_hash: {}", tunnel.tunnel, mix_hash);
290 tunnel.tunnel.on_proxied_datagram(&mix_hash, from)
291 } else {
292 trace!("ignore datagram of mix_hash: {}", mix_hash);
293 None
294 }
295 },
296 _ => {
297 trace!("ignore datagram for invalid key foramt");
298 None
299 }
300 }
301 }
302
303 pub fn has_tunnel(&self, mix_hash: &KeyMixHash) -> bool {
304 if let Some(_) = self.tunnel_mixhash_map.get(mix_hash) {
305 true
306 } else {
307 false
308 }
309 }
310
311 pub fn rehash(&mut self) {
312 let (min, max) = self.minute_timestamp_range();
313
314 trace!("rehash min={} max={}", min, max);
315
316 for (_, tunnel) in self.tunnel_mixkey_list.iter_mut().enumerate() {
317 let (added, removed) = tunnel.rehash(min, max);
318 for h in added.as_slice() {
319 self.tunnel_mixhash_map.insert(h.clone(), tunnel.clone());
320 }
321 for h in removed.as_slice() {
322 self.tunnel_mixhash_map.remove(h);
323 }
324 }
325 }
326
327 pub fn recycle(&mut self) {
328 let now = bucky_time_now();
329
330 trace!("recycle now={}", now);
331
332 let mut removed = Vec::new();
333 for (i, tunnel) in self.tunnel_mixkey_list.iter_mut().enumerate() {
334 if tunnel.recyclable(now, self.keepalive) {
335 removed.push(i-removed.len());
336 }
337 }
338
339 for i in 0..removed.len() {
340 let mut last_part = self.tunnel_mixkey_list.split_off(*removed.get(i).unwrap());
341 let tunnel = last_part.pop_front().unwrap();
342 self.tunnel_mixkey_list.append(&mut last_part);
343
344 self.tunnel_mixhash_map.remove(&tunnel.mix_key.mix_hash(None));
345 for i in 0..tunnel.mixhash.len() {
346 let mixhash = tunnel.mixhash.get(i).unwrap();
347 self.tunnel_mixhash_map.remove(&mixhash.hash);
348 }
349 }
350 }
351}
352
353struct ProxyInterfaceImpl {
354 config: Config,
355 socket: UdpSocket,
356 outer: SocketAddr,
357 tunnels: Mutex<TunnelsManager>,
358}
359
360#[derive(Clone)]
361struct ProxyInterface(Arc<ProxyInterfaceImpl>);
362
363impl std::fmt::Display for ProxyInterface {
364 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
365 write!(f, "ProxyInterface:{{endpoint:{:?}}}", self.local())
366 }
367}
368
369thread_local! {
370 static UDP_RECV_BUFFER: RefCell<[u8; MTU_LARGE]> = RefCell::new([0u8; MTU_LARGE]);
371}
372
373impl ProxyInterface {
374 fn open(config: Config, local: SocketAddr, outer: Option<SocketAddr>) -> BuckyResult<Self> {
375 let socket = UdpSocket::bind(local)
376 .map_err(|e| {
377 error!("ProxyInterface bind socket on {:?} failed for {}", local, e);
378 e
379 })?;
380 let interface = Self(Arc::new(ProxyInterfaceImpl {
381 config,
382 socket,
383 outer: outer.unwrap_or(local),
384 tunnels: Mutex::new(TunnelsManager::default()),
385 }));
386
387 let num_cpus = 4;
388 let pool_size = num_cpus + 2;
389 for _ in 0..pool_size {
390 let interface = interface.clone();
391 thread::spawn(move || {
392 interface.proxy_loop();
393 });
394 }
395
396 {
397 let interface = interface.clone();
398 task::spawn(async move {
399 interface.timer().await;
400 });
401 }
402
403 Ok(interface)
404 }
405
406 fn local(&self) -> SocketAddr {
407 self.0.socket.local_addr().unwrap()
408 }
409
410 fn outer(&self) -> &SocketAddr {
411 &self.0.outer
412 }
413
414 async fn timer(&self) {
415 let tick_sec = 60;
416 loop {
417 {
418 let mut tunnels = self.0.tunnels.lock().unwrap();
419 tunnels.recycle();
420 tunnels.rehash();
421 }
422
423 let _ = future::timeout(Duration::from_secs(tick_sec), future::pending::<()>()).await;
424 }
425 }
426
427 fn proxy_loop(&self) {
428 info!("{} started", self);
429 loop {
430 UDP_RECV_BUFFER.with(|thread_recv_buf| {
431 let recv_buf = &mut thread_recv_buf.borrow_mut()[..];
432 loop {
433 let rr = self.0.socket.recv_from(recv_buf);
434 if rr.is_ok() {
435 let (len, from) = rr.unwrap();
436 let recv = &recv_buf[..len];
437 trace!("{} recv datagram len {} from {:?}", self, len, from);
438 self.on_proxied_datagram(recv, &from);
439 } else {
440 let err = rr.err().unwrap();
441 if let Some(10054i32) = err.raw_os_error() {
442 trace!("{} socket recv failed for {}, ingore this error", self, err);
448 } else {
449 info!("{} socket recv failed for {}, break recv loop", self, err);
450 break;
451 }
452 }
453 }
454 });
455 }
456 }
457
458 fn has_tunnel(&self, key: &KeyMixHash) -> bool {
459 self.0.tunnels.lock().unwrap().has_tunnel(key)
460 }
461
462 fn on_proxied_datagram(&self, datagram: &[u8], from: &SocketAddr) {
463 let proxy_to = {
464 self.0.tunnels.lock().unwrap().on_proxied_datagram(datagram, from)
465 };
466
467 if let Some(proxy_to) = proxy_to {
468 let _ = self.0.socket.send_to(datagram, &proxy_to);
469 }
470 }
471
472 fn create_tunnel(&self, mix_key: AesKey, device_pair: (ProxyDeviceStub, ProxyDeviceStub)) -> BuckyResult<()> {
473 self.0.tunnels.lock().unwrap().create_tunnel(mix_key, device_pair)
474 }
475}
476
477pub struct ProxyTunnelManager {
478 interface: ProxyInterface
479}
480
481impl std::fmt::Display for ProxyTunnelManager {
482 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
483 write!(f, "ProxyTunnelManager")
484 }
485}
486
487impl ProxyTunnelManager {
488 pub fn open(config: Config, listen: &[(SocketAddr, Option<SocketAddr>)]) -> BuckyResult<Self> {
489 let (local, outer) = listen[0];
491 let interface = ProxyInterface::open(config, local, outer)?;
492 Ok(Self {
493 interface
494 })
495 }
496
497 pub fn create_tunnel(&self, mix_key: &AesKey, device_pair: (ProxyDeviceStub, ProxyDeviceStub)) -> BuckyResult<SocketAddr> {
498 let _ = self.interface.create_tunnel(mix_key.clone(), device_pair)?;
499 Ok(self.interface.outer().clone())
500 }
501
502 pub fn tunnel_of(&self, key: &KeyMixHash) -> Option<SocketAddr> {
503 self.interface.has_tunnel(key);
504 Some(self.interface.outer().clone())
505 }
506}