cyfs_bdt/utils/
stream_pool.rs

1use log::*;
2use std::{
3    collections::{LinkedList, BTreeMap}, 
4    time::Duration, 
5    task::{Context, Poll, Waker}, 
6    pin::Pin, 
7    sync::{atomic::{AtomicBool, Ordering}}, 
8    ops::Deref, 
9    net::Shutdown, 
10};
11use cyfs_debug::Mutex;
12use async_std::{
13    future, 
14    sync::{Arc}, 
15    task, 
16    channel::{bounded, Sender, Receiver}, 
17    io::prelude::{Read, Write}
18};
19use futures::StreamExt;
20
21use cyfs_base::*;
22use crate::{
23    types::*, 
24    tunnel::{BuildTunnelParams, TunnelState}, 
25    stream::{StreamGuard, StreamContainer, StreamState, StreamListenerGuard},  
26    stack::{Stack}
27};
28
29
30#[derive(Clone)]
31pub struct PooledStream(Arc<PooledStreamImpl>);
32
33impl std::fmt::Display for PooledStream {
34    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35        write!(f, "PooledStream {{stream:{}}}", self.0.stream)
36    }
37}
38
39
40impl Deref for PooledStream {
41    type Target = StreamContainer;
42    fn deref(&self) -> &StreamContainer {
43        &self.0.stream
44    }
45}
46
47enum PooledStreamType {
48    Active(StreamPoolConnector), 
49    Passive(StreamPoolListener), 
50}
51
52struct PooledStreamImpl {
53    shutdown: AtomicBool, 
54    stream_type: PooledStreamType, 
55    stream: StreamGuard, 
56}
57
58impl PooledStream {
59    pub fn shutdown(&self, which: Shutdown) -> std::io::Result<()> {
60        self.0.shutdown.store(true, Ordering::SeqCst);
61        self.0.stream.shutdown(which)
62    }
63}
64
65impl Drop for PooledStreamImpl {
66    fn drop(&mut self) {
67        let shutdown = self.shutdown.load(Ordering::SeqCst);
68        match &self.stream_type {
69            PooledStreamType::Passive(owner) => owner.recycle(&self.stream, shutdown), 
70            PooledStreamType::Active(owner) => owner.recycle(&self.stream, shutdown)
71        };
72    }
73}
74
75
76impl Read for PooledStream {
77    fn poll_read(
78        self: Pin<&mut Self>,
79        cx: &mut Context<'_>,
80        buf: &mut [u8],
81    ) -> Poll<std::io::Result<usize>> {
82        let mut stream = self.0.stream.clone();
83        Pin::new(&mut stream).poll_read(cx, buf)
84    }
85
86    fn poll_read_vectored(
87        self: Pin<&mut Self>,
88        cx: &mut Context<'_>,
89        bufs: &mut [std::io::IoSliceMut<'_>],
90    ) -> Poll<std::io::Result<usize>> {
91        let mut stream = self.0.stream.clone();
92        Pin::new(&mut stream).poll_read_vectored(cx, bufs)
93    }
94}
95
96
97impl Write for PooledStream {
98    fn poll_write(
99        self: Pin<&mut Self>,
100        cx: &mut Context<'_>,
101        buf: &[u8],
102    ) -> Poll<std::io::Result<usize>> {
103        let mut stream = self.0.stream.clone();
104        Pin::new(&mut stream).poll_write(cx, buf)
105    }
106
107    fn poll_write_vectored(
108        self: Pin<&mut Self>,
109        cx: &mut Context<'_>,
110        bufs: &[std::io::IoSlice<'_>],
111    ) -> Poll<std::io::Result<usize>> {
112        let mut stream = self.0.stream.clone();
113        Pin::new(&mut stream).poll_write_vectored(cx, bufs)
114    }
115
116    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
117        let mut stream = self.0.stream.clone();
118        Pin::new(&mut stream).poll_flush(cx)
119    }
120
121    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
122        let mut stream = self.0.stream.clone();
123        Pin::new(&mut stream).poll_close(cx)
124    }
125}
126
127
128#[derive(Clone)]
129struct StreamPoolConnector(Arc<StreamPoolConnectorImpl>);
130
131impl std::fmt::Display for StreamPoolConnector {
132    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
133        write!(f, "StreamPoolConnector {{local:{} remote:{} port:{}}}", self.0.stack.local_device_id(), self.0.remote, self.0.port)
134    }
135}
136
137struct StreamPoolConnectorImpl {
138    stack: Stack, 
139    remote: DeviceId, 
140    port: u16, 
141    capacity: usize, 
142    timeout: Duration, 
143    stream_list: Mutex<LinkedList<(StreamGuard, Timestamp)>>,
144}
145
146impl StreamPoolConnector {
147    pub fn new(
148        stack: Stack, 
149        remote: &DeviceId, 
150        port: u16, 
151        capacity: usize, 
152        timeout: Duration
153    ) -> Self {
154        Self(Arc::new(StreamPoolConnectorImpl {
155            stack, 
156            remote: remote.clone(), 
157            port, 
158            capacity, 
159            timeout, 
160            stream_list: Mutex::new(LinkedList::new()), 
161        }))
162    }
163
164    pub fn stream_count(&self) -> usize {
165        self.0.stream_list.lock().unwrap().len()
166    }
167
168    pub fn remote(&self) -> (&DeviceId, u16) {
169        (&self.0.remote, self.0.port)
170    }
171
172    fn wrap_stream(&self, stream: StreamGuard) -> PooledStream {
173        PooledStream(Arc::new(PooledStreamImpl {
174            shutdown: AtomicBool::new(false), 
175            stream_type: PooledStreamType::Active(self.clone()), 
176            stream, 
177        }))
178    }
179
180    pub async fn connect(&self) -> BuckyResult<PooledStream> {
181        let exists = {
182            let mut stream_list = self.0.stream_list.lock().unwrap();
183            stream_list.pop_front()
184        };
185        if let Some((stream, _)) = exists {
186            debug!("{} connect return reused stream {}", self, stream);
187            Ok(self.wrap_stream(stream))
188        } else {
189            debug!("{} will connect new stream", self);
190            let stack = &self.0.stack;
191            if let Some(remote_device) = stack.device_cache().get(&self.0.remote).await {
192                let build_params = BuildTunnelParams {
193                    remote_const: remote_device.desc().clone(),
194                    remote_sn: None,
195                    remote_desc: Some(remote_device)
196                };
197                let stream = stack.stream_manager().connect(
198                    self.0.port, 
199                    vec![], 
200                    build_params).await
201                    .map_err(|e| {
202                        warn!("{} connect new stream failed for {}", self, e);
203                        e
204                    })?;
205                info!("{} return newly connected stream {}", self, stream);
206                Ok(self.wrap_stream(stream))
207            } else {
208                let e = BuckyError::new(BuckyErrorCode::NotFound, "device desc not cached");
209                warn!("{} connect failed for {}", self, e);
210                Err(e)
211            }
212        }
213    }
214
215
216    fn recycle(&self, stream: &StreamGuard, shutdown: bool) {
217        debug!("{} will recycle stream {}", self, stream);
218        if let StreamState::Establish(_) = stream.state() {
219            if !shutdown {
220                let mut stream_list = self.0.stream_list.lock().unwrap();
221                if stream_list.len() < self.0.capacity {
222                    stream_list.push_back((stream.clone(), bucky_time_now()));
223                } else {
224                    warn!("{} drop stream {} for full", self, stream);
225                }   
226            } else {
227                self.check_tunnel();
228                info!("{} drop stream {} for shutdown", self, stream);
229            }
230        } else {
231            self.check_tunnel();
232            warn!("{} drop stream {} for not establish", self, stream);
233        }
234    }
235
236    fn drop_stream(&self, remote_timestamp: Option<Timestamp>) {
237        let remove = if let Some(remote_timestamp) = remote_timestamp {
238            let mut remain = LinkedList::new();
239            let mut remove = LinkedList::new();
240            let mut streams = self.0.stream_list.lock().unwrap();
241            while let Some((stream, last_used)) = streams.pop_back() {
242                match stream.state() {
243                    StreamState::Establish(remote) => {
244                        if remote <= remote_timestamp {
245                            remove.push_back((stream, last_used));
246                        } else {
247                            remain.push_back((stream, last_used));
248                        }
249                    },
250                    _ => {
251                        remove.push_back((stream, last_used));
252                    }
253                }
254            }
255            *streams = remain;
256            remove
257        } else {
258            let mut remove = LinkedList::new();
259            let mut streams = self.0.stream_list.lock().unwrap();
260            remove.append(&mut *streams);
261            remove
262        };
263        
264        for (stream, _) in remove {
265            let _ = stream.shutdown(Shutdown::Both);
266        }
267    }
268
269    fn check_tunnel(&self) {
270        if let Some(tunnel) = self.0.stack.tunnel_manager().container_of(self.remote().0) {
271            let state = tunnel.state();
272            match state {
273                TunnelState::Active(remote) => {
274                    self.drop_stream(Some(remote));
275                }, 
276                TunnelState::Dead => {
277                    self.drop_stream(None);
278                }, 
279                _ => {}
280            }
281        } 
282    }
283
284    fn on_time_escape(&self, now: Timestamp) {
285        let remove = {
286            let mut remain = LinkedList::new();
287            let mut remove = LinkedList::new();
288            let mut streams = self.0.stream_list.lock().unwrap();
289            while let Some((stream, last_used)) = streams.pop_front() {
290                match stream.state() {
291                    StreamState::Establish(_) => {
292                        if now > last_used && Duration::from_micros(now - last_used) > self.0.timeout {
293                            remove.push_back(stream);
294                        } else {
295                            remain.push_back((stream, last_used));
296                        }
297                    },
298                    _ => {
299                        remove.push_back(stream);
300                    }
301                }
302            }
303            *streams = remain;
304            remove
305        };
306        
307        for stream in remove {
308            info!("{} shutdown stream {} for pool timeout", self, stream);
309            let _ = stream.shutdown(Shutdown::Both);
310        }
311    }
312}
313
314#[derive(Clone)]
315struct StreamPoolListener(Arc<StreamPoolListenerImpl>);
316
317struct StreamPoolListenerImpl {
318    origin_listener: StreamListenerGuard, 
319    sender: Sender<StreamGuard>, 
320    recver: Receiver<StreamGuard>
321}
322
323impl std::fmt::Display for StreamPoolListener {
324    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
325        write!(f, "StreamPoolListener {{listener:{}}}", self.0.origin_listener)
326    }
327}
328
329impl StreamPoolListener {
330    pub fn new(origin_listener: StreamListenerGuard, backlog: usize) -> Self {
331        let (sender, recver) = bounded::<StreamGuard>(backlog);
332        
333        let listener = Self(Arc::new(StreamPoolListenerImpl {
334            origin_listener, 
335            sender, 
336            recver
337        }));
338
339        {
340            let listener = listener.clone();
341            task::spawn(async move {
342                listener.listen().await;
343            });
344        }
345
346        listener
347    }
348
349    async fn listen(&self) {
350        info!("{} listen", self);
351        let mut incoming = self.0.origin_listener.incoming();
352        let listener = self.clone();
353        loop {
354            match incoming.next().await {
355                Some(ret) => {
356                    match ret {
357                        Ok(pre_stream) => {
358                            info!("{} accept stream {}", listener, pre_stream.stream);
359                            let listener = self.clone();
360                            task::spawn(async move {
361                                match pre_stream.stream.confirm(b"".as_ref()).await {
362                                    Ok(_) => {
363                                        debug!("{} confirm stream {}", listener, pre_stream.stream);
364                                        let _ = listener.0.sender.try_send(pre_stream.stream);
365                                    },
366                                    Err(e) => {
367                                        error!("{} confirm stream {} failed for {}", listener, pre_stream.stream, e);
368                                    }
369                                }
370                            });
371                        }, 
372                        Err(e) => {
373                            error!("{} stop listen for {}", listener, e);
374                            break;
375                        }
376                    }
377                }, 
378                None => {
379                    // do nothing
380                }
381            }
382        }
383    } 
384
385    fn wrap_stream(&self, stream: StreamGuard) -> PooledStream {
386        PooledStream(Arc::new(PooledStreamImpl {
387            shutdown: AtomicBool::new(false), 
388            stream_type: PooledStreamType::Passive(self.clone()), 
389            stream, 
390        }))
391    }
392
393    pub async fn accept(&self) -> BuckyResult<PooledStream> {
394        match self.0.recver.recv().await {
395            Ok(stream) => {
396                debug!("{} accepted stream {}", self, stream);
397                Ok(self.wrap_stream(stream))
398            },
399            Err(_) => unreachable!()
400        }
401    }
402
403    pub fn incoming(&self) -> PooledStreamIncoming {
404        PooledStreamIncoming {
405            owner: self.clone(), 
406            state: Arc::new(Mutex::new(IncommingState {
407                exists: None, 
408                waker: None
409            }))
410        }
411    }
412
413    pub fn recycle(&self, stream: &StreamGuard, shutdown: bool) {
414        if !shutdown {
415            debug!("{} recyle stream {}", self, stream);
416            let pool = self.clone();
417            let stream = stream.clone();
418            task::spawn(async move {
419                match stream.readable().await {
420                    Ok(len) => {
421                        if len != 0 {
422                            debug!("{} return resued stream {}", pool, stream);
423                            let _ = pool.0.sender.try_send(stream);
424                        } else {
425                            // do nothing
426                            debug!("{} drop stream {} for remote closed", pool, stream);
427                        }
428                    }, 
429                    Err(e) => {
430                        // do nothing
431                        warn!("{} drop stream {} for {}", pool, stream, e);
432                    }
433                }
434            });
435        }
436    }
437}
438
439
440struct IncommingState {
441    exists: Option<std::io::Result<PooledStream>>,
442    waker: Option<Waker>,
443}
444
445pub struct PooledStreamIncoming {
446    owner: StreamPoolListener, 
447    state: Arc<Mutex<IncommingState>>
448}
449
450impl async_std::stream::Stream for PooledStreamIncoming {
451    type Item = std::io::Result<PooledStream>;
452
453    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<<Self as async_std::stream::Stream>::Item>> {
454        let exists = {
455            let mut state = self.state.lock().unwrap();
456            match &state.exists {
457                Some(exists) => {
458                    match exists {
459                        Ok(stream) => {
460                            let exists = Ok(stream.clone());
461                            state.exists = None;
462                            Some(exists)
463                        }, 
464                        Err(_) => {
465                            Some(Err(std::io::Error::new(std::io::ErrorKind::Other, BuckyError::new(BuckyErrorCode::ErrorState, "listener stopped"))))
466                        }
467                    }
468                }, 
469                None => {
470                    assert!(state.waker.is_none());
471                    state.waker = Some(cx.waker().clone());
472                    None
473                }
474            }
475        };
476
477        if exists.is_none() {
478            let owner = self.owner.clone();
479            let state = self.state.clone();
480            task::spawn(async move {
481                let next = owner.accept().await
482                    .map_err(|_| std::io::Error::new(std::io::ErrorKind::Other, BuckyError::new(BuckyErrorCode::ErrorState, "listener stopped")));
483                let waker = {
484                    let mut state = state.lock().unwrap();
485                    assert!(state.waker.is_some());
486                    assert!(state.exists.is_none());
487                    state.exists = Some(next);
488                    let mut waker = None;
489                    std::mem::swap(&mut state.waker, &mut waker);
490                    waker.unwrap()
491                };
492                waker.wake();
493            });
494            Poll::Pending
495        } else {
496            Poll::Ready(exists)
497        }
498    }
499}
500
501
502struct StreamPoolImpl {
503    stack: Stack, 
504    port: u16, 
505    config: StreamPoolConfig, 
506    connectors: Mutex<BTreeMap<DeviceId, StreamPoolConnector>>, 
507    listener: StreamPoolListener
508}
509
510#[derive(Clone)]
511pub struct StreamPool(Arc<StreamPoolImpl>);
512
513impl std::fmt::Display for StreamPool {
514    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
515        write!(f, "StreamPool {{local:{} port:{}}}", self.0.stack.local_device_id(), self.port())
516    }
517}
518
519
520#[derive(Debug)]
521pub struct StreamPoolConfig {
522    pub capacity: usize,
523    pub backlog: usize,
524    pub atomic_interval: Duration,
525    pub timeout: Duration,
526} 
527
528impl Default for StreamPoolConfig {
529    fn default() -> Self {
530        Self {
531            capacity: 10, 
532            backlog: 100,
533            atomic_interval: Duration::from_secs(5),  
534            timeout: Duration::from_secs(30), 
535        }
536    }
537}
538
539impl StreamPool {
540    pub fn new(
541        stack: Stack, 
542        port: u16, 
543        config: StreamPoolConfig
544    ) -> BuckyResult<Self> {
545        info!("create stream pool on port {} config {:?}", port, config);
546        let origin_listener = stack.stream_manager().listen(port)?;
547        let listener = StreamPoolListener::new(origin_listener, config.backlog);
548
549        let pool = Self(Arc::new(StreamPoolImpl {
550            stack: stack.clone(), 
551            port, 
552            config, 
553            connectors: Mutex::new(BTreeMap::new()), 
554            listener
555        }));
556
557        {
558            let pool = pool.clone();
559            task::spawn(async move {
560                let _ = future::timeout(pool.config().atomic_interval, future::pending::<()>()).await;
561                pool.on_time_escape(bucky_time_now());
562            });
563        }
564
565        Ok(pool)
566    }
567
568    fn on_time_escape(&self, now: Timestamp) {
569        let connectors: Vec<StreamPoolConnector> = self.0.connectors.lock().unwrap().values().cloned().collect();
570
571        for connector in connectors {
572            connector.on_time_escape(now);
573        }
574    }
575
576    pub async fn connect(&self, remote: &DeviceId) -> BuckyResult<PooledStream> {
577        debug!("{} will connect to {}", self, remote);
578        let connector = {
579            let mut connectors = self.0.connectors.lock().unwrap();
580            if let Some(connector) = connectors.get(remote) {
581                connector.clone()
582            } else {
583                let connector = StreamPoolConnector::new(self.0.stack.clone(), remote, self.port(), self.config().capacity, self.config().timeout);
584                connectors.insert(remote.clone(), connector.clone());
585                connector
586            }
587        };
588        connector.connect().await
589    }
590
591    pub fn incoming(&self) -> PooledStreamIncoming {
592        self.0.listener.incoming()
593    }
594
595    pub fn port(&self) -> u16 {
596        self.0.port
597    }
598
599    pub fn config(&self) -> &StreamPoolConfig {
600        &self.0.config
601    }
602}