cyfs_bdt/pn/service/
proxy.rs

1use std::{
2    collections::{LinkedList, HashMap}, 
3    net::{UdpSocket, SocketAddr}, 
4    cell::RefCell,  
5    thread, 
6    time::Duration, 
7};
8use cyfs_debug::Mutex;
9use async_std::{
10    sync::{Arc}, 
11    task, 
12    future
13};
14use cyfs_base::*;
15use crate::{
16    types::*, 
17    interface::udp::MTU_LARGE
18};
19use std::time::{UNIX_EPOCH, SystemTime};
20
21#[derive(Clone)]
22pub struct Config {
23    pub keepalive: Duration
24}
25
26#[derive(Clone, Debug)]
27pub struct ProxyDeviceStub {
28    pub id: DeviceId, 
29    pub timestamp: Timestamp, 
30}
31
32#[derive(Clone, Debug)]
33pub struct ProxyEndpointStub {
34    endpoint: SocketAddr, 
35    last_active: Timestamp
36}
37
38#[derive(Clone)]
39struct ProxyTunnel {
40    device_pair: (ProxyDeviceStub, ProxyDeviceStub), 
41    endpoint_pair: (Option<ProxyEndpointStub>, Option<ProxyEndpointStub>), 
42    last_active: Timestamp
43}
44
45impl std::fmt::Display for ProxyTunnel {
46    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
47        write!(f, "ProxyTunnel")
48    }
49}
50
51impl ProxyTunnel {
52    fn new(device_pair: (ProxyDeviceStub, ProxyDeviceStub)) -> Self {
53        Self {
54            device_pair, 
55            endpoint_pair: (None, None), 
56            last_active: bucky_time_now()
57        }
58    }
59
60    fn recyclable(&self, now: Timestamp, keepalive: Duration) -> bool {
61        if now > self.last_active && Duration::from_micros(now - self.last_active) > keepalive {
62            true
63        } else {
64            false
65        }
66    }
67
68    fn on_device_pair(&mut self, device_pair: (ProxyDeviceStub, ProxyDeviceStub)) -> BuckyResult<()> {
69        self.last_active = bucky_time_now();
70        let (left, right) = device_pair;
71        let (fl, fr) = {
72            if left.id.eq(&self.device_pair.0.id) && right.id.eq(&self.device_pair.1.id) {
73                Ok((&mut self.device_pair.0, &mut self.device_pair.1))
74            } else if right.id.eq(&self.device_pair.0.id) && left.id.eq(&self.device_pair.1.id) {
75                Ok((&mut self.device_pair.1, &mut self.device_pair.0))
76            } else {
77                trace!("{} ignore device pair ({:?}, {:?}) for not match {:?}", self, left, right, self.device_pair);
78                Err(BuckyError::new(BuckyErrorCode::NotMatch, "device pair not match"))
79            }
80        }?;
81        if left.timestamp > fl.timestamp {
82            fl.timestamp = left.timestamp;
83            self.endpoint_pair = (None, None);
84            trace!("proxy tunnel update endpoint pair to (None, None)");
85        }
86        if right.timestamp > right.timestamp {
87            fr.timestamp = right.timestamp;
88            self.endpoint_pair = (None, None);
89            trace!("proxy tunnel update endpoint pair to (None, None)");
90        }
91        Ok(())
92    }
93
94    fn on_proxied_datagram(&mut self, mix_hash: &KeyMixHash, from: &SocketAddr) -> Option<SocketAddr> {
95        self.last_active = bucky_time_now();
96        if self.endpoint_pair.0.is_none() {
97            self.endpoint_pair.0 = Some(ProxyEndpointStub {
98                endpoint: *from, 
99                last_active: bucky_time_now()
100            });
101            trace!("{} mix_hash:{} update endpoint pair to {:?}", self, mix_hash, self.endpoint_pair);
102            None
103        } else if self.endpoint_pair.1.is_none() {
104            let left = self.endpoint_pair.0.as_mut().unwrap(); 
105            if left.endpoint.eq(from) {
106                left.last_active = bucky_time_now();
107            } else {
108                self.endpoint_pair.1 = Some(ProxyEndpointStub {
109                    endpoint: *from, 
110                    last_active: bucky_time_now()
111                });
112            }
113            trace!("{} mix_hash:{} update endpoint pair to {:?}", self, mix_hash, self.endpoint_pair);
114            None
115        } else {
116            let left = self.endpoint_pair.0.as_mut().unwrap(); 
117            let right = self.endpoint_pair.1.as_mut().unwrap(); 
118
119            if left.endpoint.eq(from) {
120                left.last_active = bucky_time_now();
121                Some(right.endpoint)
122            } else if right.endpoint.eq(from) {
123                right.last_active = bucky_time_now();
124                Some(left.endpoint)
125            } else {
126                *left = right.clone();
127                right.endpoint = *from;
128                right.last_active = bucky_time_now();
129                trace!("ProxyTunnel mix_hash:{} mix_hash update endpoint pair to ({:?}, {:?})", mix_hash, left, right);
130                Some(left.endpoint)
131            }
132        }
133    }
134}
135
136#[derive(Clone)]
137struct TunnelMixHash {
138    tunnel: ProxyTunnel,
139    mix_key: AesKey,
140    mixhash: Vec<MixHashInfo>,
141}
142
143impl TunnelMixHash {
144    pub fn recyclable(&self, now: Timestamp, keepalive: Duration) -> bool {
145        self.tunnel.recyclable(now, keepalive)
146    }
147
148    pub fn new(mix_key: AesKey, tunnel: ProxyTunnel) -> Self {
149        TunnelMixHash {
150            tunnel,
151            mix_key,
152            mixhash: Vec::new(),
153        }
154    }
155
156    pub fn rehash(&mut self, min: u64, max: u64) -> (Vec<KeyMixHash>, Vec<KeyMixHash>) {
157        let mut timeout_n = 0;
158        let mut next_ts = min;
159        for h in self.mixhash.as_slice() {
160            let t = h.minute_timestamp;
161            if t < min {
162                timeout_n += 1;
163            } else if t > next_ts {
164                next_ts = t + 1;
165            }
166        }
167
168        let removed: Vec<MixHashInfo> = self.mixhash.splice(..timeout_n, vec![].iter().cloned()).collect();
169        let removed = removed.iter().map(|h| h.hash.clone()).collect();
170
171        let mut added = vec![];
172        if next_ts < max {
173            for t in next_ts..(max+1) {
174                let h = MixHashInfo::new(self.mix_key.mix_hash(Some(t)), t);
175                added.push(h.hash.clone());
176                self.mixhash.push(h);
177            }
178        }
179
180        (added, removed)
181    }
182}
183
184#[derive(Clone)]
185struct MixHashInfo {
186    hash: KeyMixHash,
187    minute_timestamp: u64,
188}
189
190impl MixHashInfo {
191    pub fn new(hash: KeyMixHash, minute_timestamp: u64) -> Self {
192        MixHashInfo {
193            hash,
194            minute_timestamp
195        }
196    }
197}
198
199struct TunnelsManager {
200	tunnel_mixhash_map: HashMap<KeyMixHash, TunnelMixHash>,
201    tunnel_mixkey_list: LinkedList<TunnelMixHash>,
202    keepalive: Duration,
203    mixhash_live_minutes: u64,
204}
205
206impl TunnelsManager {
207    pub fn default() -> Self {
208        let def_keepalive = 60;
209        let def_mixhash_live_minute = 31;
210
211        Self {
212            tunnel_mixhash_map: HashMap::new(),
213            tunnel_mixkey_list: LinkedList::new(),
214            keepalive: Duration::from_secs(def_keepalive),
215            mixhash_live_minutes: def_mixhash_live_minute,
216        }
217    }
218}
219
220impl TunnelsManager {
221    fn minute_timestamp_range(&self) -> (u64, u64) {
222        let minute_timestamp = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs() / 60;
223        let min = minute_timestamp - (self.mixhash_live_minutes - 1) / 2;
224        let max = minute_timestamp + (self.mixhash_live_minutes - 1) / 2;
225
226        (min, max)
227    }
228
229    fn mixkey_update(&mut self,  mix_key: AesKey, device_pair: (ProxyDeviceStub, ProxyDeviceStub)) -> BuckyResult<()> {
230        let tunnel = self.tunnel_mixhash_map.get(&mix_key.mix_hash(None)).unwrap();
231        let mut tunnel = tunnel.tunnel.clone();
232        let (left, right) = device_pair;
233
234        let (fl, fr) = {
235            if left.id.eq(&tunnel.device_pair.0.id) && right.id.eq(&tunnel.device_pair.1.id) {
236                Ok((&mut tunnel.device_pair.0, &mut tunnel.device_pair.1))
237            } else if right.id.eq(&tunnel.device_pair.0.id) && left.id.eq(&tunnel.device_pair.1.id) {
238                Ok((&mut tunnel.device_pair.1, &mut tunnel.device_pair.0))
239            } else {
240                trace!("{} ignore device pair ({:?}, {:?}) for not match {:?}", tunnel, left, right, tunnel.device_pair);
241                Err(BuckyError::new(BuckyErrorCode::NotMatch, "device pair not match"))
242            }
243        }?;
244        if left.timestamp > fl.timestamp {
245            fl.timestamp = left.timestamp;
246            tunnel.endpoint_pair = (None, None);
247            trace!("proxy tunnel update endpoint pair to (None, None)");
248        }
249        if right.timestamp > right.timestamp {
250            fr.timestamp = right.timestamp;
251            tunnel.endpoint_pair = (None, None);
252            trace!("proxy tunnel update endpoint pair to (None, None)");
253        }
254
255        Ok(())
256    }
257
258    fn mixkey_add(&mut self,  mix_key: AesKey, device_pair: (ProxyDeviceStub, ProxyDeviceStub)) -> BuckyResult<()> {
259        let mut tunnel = TunnelMixHash::new(mix_key.clone(), ProxyTunnel::new(device_pair));
260
261        let (min, max) = self.minute_timestamp_range();
262        let (added, _) = tunnel.rehash(min, max);
263
264        self.tunnel_mixkey_list.push_front(tunnel.clone());
265
266        for h in added.as_slice() {
267            self.tunnel_mixhash_map.insert(h.clone(), tunnel.clone());
268        }
269        self.tunnel_mixhash_map.insert(mix_key.mix_hash(None), tunnel.clone());
270
271        Ok(())
272    }
273
274    pub fn create_tunnel(&mut self, mix_key: AesKey, device_pair: (ProxyDeviceStub, ProxyDeviceStub)) -> BuckyResult<()> {
275        let mix_hash = mix_key.mix_hash(None);
276
277        if self.has_tunnel(&mix_hash) {
278            self.mixkey_update(mix_key, device_pair)
279        } else {
280            self.mixkey_add(mix_key, device_pair)
281        }
282    }
283
284    pub fn on_proxied_datagram(&mut self, datagram: &[u8], from: &SocketAddr) -> Option<SocketAddr> {
285        match KeyMixHash::raw_decode(datagram) {
286            Ok((mut mix_hash, _)) => {
287                mix_hash.as_mut()[0] &= 0x7f;
288                if let Some(tunnel) = self.tunnel_mixhash_map.get_mut(&mix_hash) {
289                    trace!("{} recv datagram of mix_hash: {}", tunnel.tunnel, mix_hash);
290                    tunnel.tunnel.on_proxied_datagram(&mix_hash, from)
291                } else {
292                    trace!("ignore datagram of mix_hash: {}", mix_hash);
293                    None
294                }
295            }, 
296            _ => {
297                trace!("ignore datagram for invalid key foramt");
298                None
299            }
300        }
301    }
302
303    pub fn has_tunnel(&self, mix_hash: &KeyMixHash) -> bool {
304        if let Some(_) = self.tunnel_mixhash_map.get(mix_hash) {
305            true
306        } else {
307            false
308        }
309    }
310
311    pub fn rehash(&mut self) {
312        let (min, max) = self.minute_timestamp_range();
313
314        trace!("rehash min={} max={}", min, max);
315
316        for (_, tunnel) in self.tunnel_mixkey_list.iter_mut().enumerate() {
317            let (added, removed) = tunnel.rehash(min, max);
318            for h in added.as_slice() {
319                self.tunnel_mixhash_map.insert(h.clone(), tunnel.clone());
320            }
321            for h in removed.as_slice() {
322                self.tunnel_mixhash_map.remove(h);
323            }
324        }
325    }
326
327    pub fn recycle(&mut self) {
328        let now = bucky_time_now();
329
330        trace!("recycle now={}", now);
331
332        let mut removed = Vec::new();
333        for (i, tunnel) in self.tunnel_mixkey_list.iter_mut().enumerate() {
334            if tunnel.recyclable(now, self.keepalive) {
335                removed.push(i-removed.len());
336            }
337        }
338
339        for i in 0..removed.len() {
340            let mut last_part = self.tunnel_mixkey_list.split_off(*removed.get(i).unwrap());
341            let tunnel = last_part.pop_front().unwrap();
342            self.tunnel_mixkey_list.append(&mut last_part);
343
344            self.tunnel_mixhash_map.remove(&tunnel.mix_key.mix_hash(None));
345            for i in 0..tunnel.mixhash.len() {
346                let mixhash = tunnel.mixhash.get(i).unwrap();
347                self.tunnel_mixhash_map.remove(&mixhash.hash);
348            }
349        }
350    }
351}
352
353struct ProxyInterfaceImpl {
354    config: Config, 
355    socket: UdpSocket, 
356    outer: SocketAddr, 
357    tunnels: Mutex<TunnelsManager>,
358}
359
360#[derive(Clone)]
361struct ProxyInterface(Arc<ProxyInterfaceImpl>);
362
363impl std::fmt::Display for ProxyInterface {
364    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
365        write!(f, "ProxyInterface:{{endpoint:{:?}}}", self.local())
366    }
367}
368
369thread_local! {
370    static UDP_RECV_BUFFER: RefCell<[u8; MTU_LARGE]> = RefCell::new([0u8; MTU_LARGE]);
371}
372
373impl ProxyInterface {
374    fn open(config: Config, local: SocketAddr, outer: Option<SocketAddr>) -> BuckyResult<Self> {
375        let socket = UdpSocket::bind(local)
376            .map_err(|e| {
377                error!("ProxyInterface bind socket on {:?} failed for {}", local, e);
378                e
379            })?;
380        let interface = Self(Arc::new(ProxyInterfaceImpl {
381            config, 
382            socket, 
383            outer: outer.unwrap_or(local), 
384            tunnels: Mutex::new(TunnelsManager::default()),
385        }));
386
387        let num_cpus = 4;
388        let pool_size = num_cpus + 2;
389        for _ in 0..pool_size {
390            let interface = interface.clone();
391            thread::spawn(move || {
392                interface.proxy_loop();
393            });
394        }
395
396        {
397            let interface = interface.clone();
398            task::spawn(async move {
399                interface.timer().await;
400            });
401        }
402        
403        Ok(interface)
404    }
405
406    fn local(&self) -> SocketAddr {
407        self.0.socket.local_addr().unwrap()
408    }
409
410    fn outer(&self) -> &SocketAddr {
411        &self.0.outer
412    }
413
414    async fn timer(&self) {
415        let tick_sec = 60;
416        loop {
417            {
418                let mut tunnels = self.0.tunnels.lock().unwrap();
419                tunnels.recycle();
420                tunnels.rehash();
421            }
422
423            let _ = future::timeout(Duration::from_secs(tick_sec), future::pending::<()>()).await;
424        }
425    }
426
427    fn proxy_loop(&self) {
428        info!("{} started", self);
429        loop {
430            UDP_RECV_BUFFER.with(|thread_recv_buf| {
431                let recv_buf = &mut thread_recv_buf.borrow_mut()[..];
432                loop {
433                    let rr = self.0.socket.recv_from(recv_buf);
434                    if rr.is_ok() {
435                        let (len, from) = rr.unwrap();
436                        let recv = &recv_buf[..len];
437                        trace!("{} recv datagram len {} from {:?}", self, len, from);
438                        self.on_proxied_datagram(recv, &from);
439                    } else {
440                        let err = rr.err().unwrap();
441                        if let Some(10054i32) = err.raw_os_error() {
442                            // In Windows, if host A use UDP socket and call sendto() to send something to host B,
443                            // but B doesn't bind any port so that B doesn't receive the message,
444                            // and then host A call recvfrom() to receive some message,
445                            // recvfrom() will failed, and WSAGetLastError() will return 10054.
446                            // It's a bug of Windows.
447                            trace!("{} socket recv failed for {}, ingore this error", self, err);
448                        } else {
449                            info!("{} socket recv failed for {}, break recv loop", self, err);
450                            break;
451                        }
452                    }
453                }
454            });
455        }
456    }
457
458    fn has_tunnel(&self, key: &KeyMixHash) -> bool {
459        self.0.tunnels.lock().unwrap().has_tunnel(key)
460    }
461
462    fn on_proxied_datagram(&self, datagram: &[u8], from: &SocketAddr) {
463        let proxy_to = {
464            self.0.tunnels.lock().unwrap().on_proxied_datagram(datagram, from)
465        };
466
467        if let Some(proxy_to) = proxy_to {
468            let _ = self.0.socket.send_to(datagram, &proxy_to);
469        }
470    }
471
472    fn create_tunnel(&self, mix_key: AesKey, device_pair: (ProxyDeviceStub, ProxyDeviceStub)) -> BuckyResult<()> {
473        self.0.tunnels.lock().unwrap().create_tunnel(mix_key, device_pair)
474    }
475}
476
477pub struct ProxyTunnelManager {
478    interface: ProxyInterface
479}
480
481impl std::fmt::Display for ProxyTunnelManager {
482    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
483        write!(f, "ProxyTunnelManager")
484    }
485}
486
487impl ProxyTunnelManager {
488    pub fn open(config: Config, listen: &[(SocketAddr, Option<SocketAddr>)]) -> BuckyResult<Self> {
489        //TODO: 支持多interface扩展
490        let (local, outer) = listen[0];
491        let interface = ProxyInterface::open(config, local, outer)?;
492        Ok(Self {
493            interface
494        })
495    }
496
497    pub fn create_tunnel(&self, mix_key: &AesKey, device_pair: (ProxyDeviceStub, ProxyDeviceStub)) -> BuckyResult<SocketAddr> {
498        let _ = self.interface.create_tunnel(mix_key.clone(), device_pair)?;
499        Ok(self.interface.outer().clone())
500    }
501
502    pub fn tunnel_of(&self, key: &KeyMixHash) -> Option<SocketAddr> {
503        self.interface.has_tunnel(key);
504        Some(self.interface.outer().clone())
505    }
506}