cyfs_bdt/stream/
manager.rs

1use async_std::sync::{Arc, Weak};
2use std::{
3    collections::{BTreeMap}, 
4    sync::{RwLock, Mutex}
5};
6use lru_time_cache::LruCache;
7use cyfs_base::*;
8use crate::{
9    types::*, 
10    protocol::{*, v0::*},
11    interface::*,  
12    tunnel::{TunnelGuard, TunnelContainer, BuildTunnelParams}, 
13    stack::{Stack, WeakStack}
14};
15use super::{
16    container::*, 
17    listener::*
18};
19use log::*;
20
21const QUESTION_MAX_LEN: usize = 1024*25;
22
23#[derive(PartialEq, Eq, PartialOrd, Ord, Clone)]
24pub struct RemoteSequence(DeviceId, TempSeq);
25
26impl std::fmt::Display for RemoteSequence {
27    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28        write!(f, "{{id:{}, seq:{:?}}}", self.0, self.1)
29    }
30}
31
32impl From<(DeviceId, TempSeq)> for RemoteSequence {
33    fn from(p: (DeviceId, TempSeq)) -> Self {
34        Self(p.0, p.1)
35    }
36}
37
38struct StreamEntries {
39    id_entries: BTreeMap<IncreaseId, StreamContainer>,
40    remote_entries: BTreeMap<RemoteSequence, StreamContainer>, 
41}
42
43impl StreamEntries {
44    fn stream_of_id(&self, id: &IncreaseId) -> Option<StreamContainer> {
45        self.id_entries.get(id).cloned()
46    }
47
48    fn stream_of_remote_sequence(&self, rs: &RemoteSequence) -> Option<StreamContainer> {
49        self.remote_entries.get(rs).cloned()
50    }
51
52    fn remove_stream(&mut self, stream: &StreamContainer) -> (
53        Option<IncreaseId>, 
54        Option<RemoteSequence>
55    ) {
56        let remote_seq = RemoteSequence::from((stream.remote().0.clone(), stream.sequence()));
57        (self.id_entries.remove(&stream.local_id()).map(|_| stream.local_id()), 
58            self.remote_entries.remove(&remote_seq).map(|_| remote_seq))
59    }
60}
61
62struct ReservingEntries {
63    id_entries: LruCache<IncreaseId, StreamContainer>,
64    remote_entries: LruCache<RemoteSequence, StreamContainer>, 
65}
66
67impl ReservingEntries {
68    fn stream_of_id(&mut self, id: &IncreaseId) -> Option<StreamContainer> {
69        self.id_entries.get(id).cloned()
70    }
71
72    fn stream_of_remote_sequence(&mut self, rs: &RemoteSequence) -> Option<StreamContainer> {
73        self.remote_entries.get(rs).cloned()
74    }
75}
76
77struct StreamManagerImpl {
78    stack: WeakStack, 
79    stream_entries: RwLock<StreamEntries>, 
80    reserving_entries: Mutex<ReservingEntries>, 
81    acceptor_entries: RwLock<BTreeMap<u16, StreamListener>>
82}
83
84
85#[derive(Clone)]
86pub struct StreamManager(Arc<StreamManagerImpl>);
87
88#[derive(Clone)]
89pub struct WeakStreamManager(Weak<StreamManagerImpl>);
90
91impl std::fmt::Display for StreamManager {
92    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
93        write!(f, "StreamManager {{local:{}}}", Stack::from(&self.0.stack).local_device_id())
94    }
95}
96
97impl StreamManager {
98    pub fn new(stack: WeakStack) -> Self {
99        let strong_stack = Stack::from(&stack);
100        Self(Arc::new(StreamManagerImpl {
101            stack, 
102            stream_entries: RwLock::new(StreamEntries {
103                id_entries: BTreeMap::new(), 
104                remote_entries: BTreeMap::new(), 
105            }), 
106            reserving_entries: Mutex::new(ReservingEntries {
107                id_entries: LruCache::with_expiry_duration(strong_stack.config().stream.stream.package.msl), 
108                remote_entries: LruCache::with_expiry_duration(strong_stack.config().stream.stream.package.msl), 
109            }), 
110            acceptor_entries: RwLock::new(BTreeMap::new())
111        }))
112    }
113
114    fn to_weak(&self) -> WeakStreamManager {
115        WeakStreamManager(Arc::downgrade(&self.0))
116    }
117
118    // connect完成是返回stream
119    pub async fn connect(
120        &self, 
121        port: u16, 
122        question: Vec<u8>, 
123        build_params: BuildTunnelParams
124    ) -> Result<StreamGuard, BuckyError> {
125        if question.len() > QUESTION_MAX_LEN {
126            return Err(BuckyError::new(
127                BuckyErrorCode::Failed,
128                format!("question's length large than {}", QUESTION_MAX_LEN),
129            ));
130        }
131
132        info!("{} connect stream to {}:{} with params {}", self, build_params.remote_const.device_id(), port, build_params);
133        let manager_impl = &self.0;
134        let stack = Stack::from(&manager_impl.stack);
135        let local_id = stack.id_generator().generate();
136        let tunnel = stack.tunnel_manager().create_container(&build_params.remote_const)?;
137
138        let stream = StreamContainer::new(
139            manager_impl.stack.clone(), 
140            tunnel.clone(), 
141            port, 
142            local_id, 
143            tunnel.generate_sequence());
144        manager_impl.stream_entries.write().unwrap().id_entries.insert(local_id, stream.clone());
145        stream.connect(question, build_params).await.map_err(|err| {self.remove_stream(&stream, true); err})?;
146        Ok(StreamGuard::from(stream))
147    }
148
149    pub fn listen(&self, port: u16) -> Result<StreamListenerGuard, BuckyError> {
150        let stack = Stack::from(&self.0.stack);
151        let mut entries = self.0.acceptor_entries.write().unwrap();
152        match entries.get(&port) {
153            Some(_) => {
154                Err(BuckyError::new(BuckyErrorCode::AlreadyExists, "port is listening"))
155            },
156            None => {
157                let acceptor = StreamListener::new(self.to_weak(), port, stack.config().stream.listener.backlog);
158                entries.insert(port, acceptor.clone());
159                Ok(StreamListenerGuard::from(acceptor))
160            }
161        }.map(|v| {info!("{} listen on {}", self, port);v})
162            .map_err(|err| {error!("{} listen on {} failed for {}", self, port, err); err})
163    } 
164
165    fn stream_of_id(&self, id: &IncreaseId) -> Option<StreamContainer> {
166        self.0.stream_entries.read().unwrap().stream_of_id(id)
167            .or_else(|| self.0.reserving_entries.lock().unwrap().stream_of_id(id))
168    }
169
170    pub(crate) fn stream_of_remote_sequence(&self, rs: &RemoteSequence) -> Option<StreamContainer> {
171        self.0.stream_entries.read().unwrap().stream_of_remote_sequence(rs)
172            .or_else(|| self.0.reserving_entries.lock().unwrap().stream_of_remote_sequence(rs))
173    }
174
175    pub(crate) fn remove_stream(&self, stream: &StreamContainer, reserving: bool) {
176        info!("{} remove from stream manager", stream);
177        let (local_id, remote_seq) = self.0.stream_entries.write().unwrap().remove_stream(stream);
178        if reserving {
179            info!("{} reserved closed in stream manager ", stream);
180            let mut entries = self.0.reserving_entries.lock().unwrap();
181            if let Some(local_id) = local_id {
182                entries.id_entries.insert(local_id, stream.clone());
183            }
184            if let Some(remote_seq) = remote_seq {
185                entries.remote_entries.insert(remote_seq, stream.clone());
186            }            
187        }
188    }
189
190    pub(crate) fn remove_acceptor(&self, acceptor: &StreamListener) {
191        self.0.acceptor_entries.write().unwrap().remove(&acceptor.port());
192    }
193
194    fn try_accept(
195        &self, 
196        tunnel: TunnelGuard, 
197        port: u16, 
198        sequence: TempSeq, 
199        remote_id: IncreaseId, 
200        question: Vec<u8>) -> Option<StreamContainer> {
201        match self.0.acceptor_entries.read().unwrap().get(&port).map(|a| a.clone()) {
202            Some(acceptor) => {
203                let manager_impl = &self.0;
204                let local_id = Stack::from(&manager_impl.stack).id_generator().generate();
205                let stream = StreamContainer::new(
206                    manager_impl.stack.clone(), 
207                    tunnel.clone(), 
208                    port, 
209                    local_id, 
210                    sequence);
211                stream.accept(remote_id);
212                // 先加入到stream entries
213                if let Some(exists) = {
214                    let remote_seq = RemoteSequence(tunnel.remote().clone(), sequence);
215                    let mut stream_entries = manager_impl.stream_entries.write().unwrap();
216                    if let Some(exists) = stream_entries.remote_entries.get(&remote_seq) {
217                        Some(exists.clone())
218                    } else {
219                        stream_entries.remote_entries.insert(remote_seq, stream.clone());
220                        stream_entries.id_entries.insert(local_id, stream.clone());
221                        None
222                    }                    
223                } {
224                    let _ = stream.cancel_connecting_with(&BuckyError::new(BuckyErrorCode::AlreadyExists, "duplicate accepting stream"));
225                    Some(exists)
226                } else {
227                    // 因为可能会失败,用guard保证reset掉,从stream entries中移除
228                    let _ = acceptor.push_stream(PreAcceptedStream {
229                        stream: StreamGuard::from(stream.clone()),
230                        question
231                    });
232                    Some(stream)  
233                }
234            }, 
235            None => {
236                debug!("{} is not listening {}", self, port);
237                None
238            }
239        }
240    }
241
242    pub(crate) fn on_statistic(&self) -> String {
243        let stream_count = self.0.stream_entries.read().unwrap().id_entries.len();
244        format!("StreamCount: {}", stream_count)
245    }
246}
247
248impl From<&WeakStreamManager> for StreamManager {
249    fn from(w: &WeakStreamManager) -> Self {
250        Self(w.0.upgrade().unwrap())
251    }
252}
253
254impl From<&WeakStreamManager> for Stack {
255    fn from(w: &WeakStreamManager) -> Stack {
256        Stack::from(&StreamManager::from(w).0.stack)
257    }
258}
259
260impl OnPackage<SessionData, &TunnelContainer> for StreamManager {
261    fn on_package(&self, pkg: &SessionData, tunnel: &TunnelContainer) -> Result<OnPackageResult, BuckyError> {
262        let stack = Stack::from(&self.0.stack);
263        match {
264            if pkg.is_syn() {
265                debug!("{} on {} from {}", self, pkg, tunnel.remote());
266                let syn_info = pkg.syn_info.as_ref().unwrap();
267                let remote_seq = RemoteSequence(tunnel.remote().clone(), syn_info.sequence);
268                if let Some(stream) = self.stream_of_remote_sequence(&remote_seq) {
269                    Some(stream)
270                } else {
271                    let mut question = vec![0; pkg.payload.as_ref().len()];
272                    question.copy_from_slice(pkg.payload.as_ref());
273
274                    self.try_accept(
275                        stack.tunnel_manager().container_of(tunnel.remote()).unwrap(), 
276                        syn_info.to_vport,
277                        syn_info.sequence,  
278                        pkg.session_id, 
279                        question)
280                }
281            } else if pkg.is_syn_ack() {
282                debug!("{} on {} from {}", self, pkg, tunnel.remote());
283                let to_session_id = pkg.to_session_id.as_ref().unwrap();
284                self.stream_of_id(to_session_id)
285            } else {
286                self.stream_of_id(&pkg.session_id)
287            }
288        } {
289            Some(stream) => {
290                stream.on_package(pkg, None)
291            },
292            None => {
293                debug!("{} ingore {} for no valid stream", self, pkg);
294
295                if !pkg.is_flags_contain(SESSIONDATA_FLAG_RESET) {
296                    let mut rst_pkg = SessionData::new();
297                    rst_pkg.flags_add(SESSIONDATA_FLAG_RESET);
298                    rst_pkg.to_session_id = Some(pkg.session_id);
299                    rst_pkg.send_time = bucky_time_now();
300
301                    let _ = tunnel.send_package(DynamicPackage::from(rst_pkg), false);
302                }
303
304                Err(BuckyError::new(BuckyErrorCode::NotFound, "stream of id not found"))
305            }
306        }
307    }
308}
309
310
311impl OnPackage<TcpSynConnection, (&TunnelContainer, tcp::AcceptInterface)> for StreamManager {
312    fn on_package(&self, pkg: &TcpSynConnection, context: (&TunnelContainer, tcp::AcceptInterface)) -> Result<OnPackageResult, BuckyError> {
313        let (tunnel, interface) = context;
314        let remote_seq = RemoteSequence(tunnel.remote().clone(), pkg.sequence);
315        let stack = Stack::from(&self.0.stack);
316        match {
317            if let Some(stream) = self.stream_of_remote_sequence(&remote_seq) {
318                Some(stream)
319            } else {
320                let mut question = vec![0; pkg.payload.as_ref().len()];
321                question.copy_from_slice(pkg.payload.as_ref());
322                self.try_accept(
323                    stack.tunnel_manager().container_of(tunnel.remote()).unwrap(), 
324                    pkg.to_vport,
325                    pkg.sequence,  
326                    pkg.from_session_id, 
327                    question)
328            }
329        } {
330            Some(stream) => stream.on_package(pkg, interface), 
331            None => Err(BuckyError::new(BuckyErrorCode::NotFound, "stream of id not found"))
332        }
333    }
334}
335
336// tcp 反连的请求
337impl OnPackage<TcpSynConnection, &TunnelContainer> for StreamManager {
338    fn on_package(&self, pkg: &TcpSynConnection, tunnel: &TunnelContainer) -> Result<OnPackageResult, BuckyError> {
339        if pkg.reverse_endpoint.is_none() {
340            return Err(BuckyError::new(BuckyErrorCode::InvalidInput, "tcp syn connection should has reverse endpoints"));
341        }
342        let stack = Stack::from(&self.0.stack);
343        let remote_seq = RemoteSequence(tunnel.remote().clone(), pkg.sequence);
344        match {
345            if let Some(stream) = self.stream_of_remote_sequence(&remote_seq) {
346                Some(stream)
347            } else {
348                let mut question = vec![0; pkg.payload.as_ref().len()];
349                question.copy_from_slice(pkg.payload.as_ref());
350                if let Some(guard) = stack.tunnel_manager().container_of(tunnel.remote()) {
351                    self.try_accept(
352                        guard, 
353                        pkg.to_vport,
354                        pkg.sequence,  
355                        pkg.from_session_id, 
356                        question)
357                } else {
358                    error!("{} tunnel released, pkg={:?}, tunnel={}", self, pkg, tunnel);
359                    None
360                }
361            }
362        } {
363            Some(stream) => stream.on_package(pkg, None), 
364            None => Err(BuckyError::new(BuckyErrorCode::NotFound, "stream of id not found"))
365        }
366    }
367}
368
369impl OnPackage<TcpAckConnection, (&TunnelContainer, tcp::AcceptInterface)> for StreamManager {
370    fn on_package(&self, pkg: &TcpAckConnection, context: (&TunnelContainer, tcp::AcceptInterface)) -> Result<OnPackageResult, BuckyError> {
371        let (_tunnel, interface) = context;
372        match self.stream_of_id(&pkg.to_session_id) {
373            Some(stream) => stream.on_package(pkg, interface), 
374            None => Err(BuckyError::new(BuckyErrorCode::NotFound, "stream of id not found"))
375        }
376    }
377}
378
379