1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
use std::sync::Arc;
use std::time::Duration;
use std::{collections::HashMap, io, net::SocketAddr};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::sync::mpsc::Receiver;
use tokio::{io::split, net::TcpStream, sync::mpsc::channel};
use tokio::{
    io::{AsyncRead, AsyncWrite},
    sync::mpsc::Sender,
};
use tokio_rustls::{client::TlsStream, TlsConnector};

use webparse::{BinaryMut, Buf};

use crate::{
    Helper, MappingConfig, ProtClose, ProtCreate, ProtFrame, ProxyResult, TransStream,
};

/// 中心客户端
/// 负责与服务端建立连接,断开后自动再重连
pub struct CenterClient {
    /// tls的客户端连接信息
    tls_client: Option<Arc<rustls::ClientConfig>>,
    /// tls的客户端连接域名
    domain: Option<String>,
    /// 连接中心服务器的地址
    server_addr: SocketAddr,
    /// 内网映射的相关消息
    mappings: Vec<MappingConfig>,

    /// 存在普通连接和加密连接,此处不为None则表示普通连接
    stream: Option<TcpStream>,
    /// 存在普通连接和加密连接,此处不为None则表示加密连接
    tls_stream: Option<TlsStream<TcpStream>>,
    /// 绑定的下一个sock_map映射,为单数
    next_id: u32,

    /// 发送Create,并将绑定的Sender发到做绑定
    sender_work: Sender<(ProtCreate, Sender<ProtFrame>)>,
    /// 接收的Sender绑定,开始服务时这值move到工作协程中,所以不能二次调用服务
    receiver_work: Option<Receiver<(ProtCreate, Sender<ProtFrame>)>>,

    /// 发送协议数据,接收到服务端的流数据,转发给相应的Stream
    sender: Sender<ProtFrame>,
    /// 接收协议数据,并转发到服务端。
    receiver: Option<Receiver<ProtFrame>>,
}

impl CenterClient {
    pub fn new(
        server_addr: SocketAddr,
        tls_client: Option<Arc<rustls::ClientConfig>>,
        domain: Option<String>,
        mappings: Vec<MappingConfig>,
    ) -> Self {
        let (sender, receiver) = channel::<ProtFrame>(100);
        let (sender_work, receiver_work) = channel::<(ProtCreate, Sender<ProtFrame>)>(10);

        Self {
            tls_client,
            domain,
            server_addr,
            mappings,
            stream: None,
            tls_stream: None,
            next_id: 1,

            sender_work,
            receiver_work: Some(receiver_work),
            sender,
            receiver: Some(receiver),
        }
    }

    async fn inner_connect(
        tls_client: Option<Arc<rustls::ClientConfig>>,
        server_addr: SocketAddr,
        domain: Option<String>,
    ) -> ProxyResult<(Option<TcpStream>, Option<TlsStream<TcpStream>>)> {
        if tls_client.is_some() {
            println!("connect by tls");
            let connector = TlsConnector::from(tls_client.unwrap());
            let stream = TcpStream::connect(&server_addr).await?;
            // 这里的域名只为认证设置
            let domain =
                rustls::ServerName::try_from(&*domain.unwrap_or("soft.wm-proxy.com".to_string()))
                    .map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid dnsname"))?;

            let outbound = connector.connect(domain, stream).await?;
            Ok((None, Some(outbound)))
        } else {
            let outbound = TcpStream::connect(server_addr).await?;
            Ok((Some(outbound), None))
        }
    }

    pub async fn connect(&mut self) -> ProxyResult<bool> {
        let (stream, tls_stream) = Self::inner_connect(
            self.tls_client.clone(),
            self.server_addr,
            self.domain.clone(),
        )
        .await?;
        self.stream = stream;
        self.tls_stream = tls_stream;
        Ok(self.stream.is_some() || self.tls_stream.is_some())
    }

    async fn inner_serve<T>(
        stream: T,
        sender: &mut Sender<ProtFrame>,
        receiver_work: &mut Receiver<(ProtCreate, Sender<ProtFrame>)>,
        receiver: &mut Receiver<ProtFrame>,
        mappings: &mut Vec<MappingConfig>,
    ) -> ProxyResult<()>
    where
        T: AsyncRead + AsyncWrite + Unpin,
    {
        let mut map = HashMap::<u32, Sender<ProtFrame>>::new();
        let mut read_buf = BinaryMut::new();
        let mut write_buf = BinaryMut::new();
        let (mut reader, mut writer) = split(stream);
        let mut vec = vec![0u8; 4096];
        let is_closed;
        println!("mappings = {:?}", mappings);
        if mappings.len() > 0 {
            println!("encode mapping = {:?}", mappings);
            ProtFrame::new_mapping(0, mappings.clone()).encode(&mut write_buf)?;
        }
        loop {
            let _ = tokio::select! {
                // 严格的顺序流
                biased;
                // 新的流建立,这里接收Create并进行绑定
                r = receiver_work.recv() => {
                    if let Some((create, sender)) = r {
                        map.insert(create.sock_map(), sender);
                        let _ = create.encode(&mut write_buf);
                    }
                }
                // 数据的接收,并将数据写入给远程端
                r = receiver.recv() => {
                    if let Some(p) = r {
                        let _ = p.encode(&mut write_buf);
                    }
                }
                // 数据的等待读取,一旦流可读则触发,读到0则关闭主动关闭所有连接
                r = reader.read(&mut vec) => {
                    match r {
                        Ok(0)=>{
                            is_closed=true;
                            break;
                        }
                        Ok(n) => {
                            read_buf.put_slice(&vec[..n]);
                        }
                        Err(_err) => {
                            is_closed = true;
                            break;
                        },
                    }
                }
                // 一旦有写数据,则尝试写入数据,写入成功后扣除相应的数据
                r = writer.write(write_buf.chunk()), if write_buf.has_remaining() => {
                    match r {
                        Ok(n) => {
                            write_buf.advance(n);
                            if !write_buf.has_remaining() {
                                write_buf.clear();
                            }
                        }
                        Err(e) => {
                            println!("center_client errrrr = {:?}", e);
                        },
                    }
                }
            };

            loop {
                // 将读出来的数据全部解析成ProtFrame并进行相应的处理,如果是0则是自身消息,其它进行转发
                match Helper::decode_frame(&mut read_buf)? {
                    Some(p) => {
                        match p {
                            ProtFrame::Create(p) => {
                                let domain = p.domain().clone().unwrap_or(String::new());
                                let mut local_addr = None;
                                for m in &*mappings {
                                    if m.domain == domain {
                                        local_addr = m.local_addr.clone();
                                    } else if domain.len() == 0 && m.is_tcp() {
                                        local_addr = m.local_addr.clone();
                                    }
                                }
                                if local_addr.is_none() {
                                    log::warn!("local addr is none, can't mapping");
                                    continue;
                                }
                                let (virtual_sender, virtual_receiver) = channel::<ProtFrame>(10);
                                map.insert(p.sock_map(), virtual_sender);
                                
                                let domain = local_addr.unwrap();
                                let sock_map = p.sock_map();
                                let sender = sender.clone();
                                println!("receiver sock_map {}, domain = {}", sock_map, domain);
                                // let (flag, username, password, udp_bind) = (option.flag, option.username.clone(), option.password.clone(), option.udp_bind.clone());
                                tokio::spawn(async move {
                                    let stream = TcpStream::connect(domain).await;
                                    println!("connect server {:?}", stream);
                                    if let Ok(tcp) = stream {
                                        let trans = TransStream::new(
                                            tcp,
                                            sock_map,
                                            sender,
                                            virtual_receiver,
                                        );
                                        let _ = trans.copy_wait().await;
                                        // let _ = copy_bidirectional(&mut tcp, &mut stream).await;
                                    }
                                });
                            }
                            ProtFrame::Close(_) | ProtFrame::Data(_) => {
                                if let Some(sender) = map.get(&p.sock_map()) {
                                    let _ = sender.try_send(p);
                                }
                            }
                            ProtFrame::Mapping(_) => {}
                        }
                    }
                    None => {
                        break;
                    }
                }
            }
            if !read_buf.has_remaining() {
                read_buf.clear();
            }
        }
        if is_closed {
            for v in map {
                let _ = v.1.try_send(ProtFrame::Close(ProtClose::new(v.0)));
            }
        }
        Ok(())
    }

    pub async fn serve(&mut self) -> ProxyResult<()> {
        let tls_client = self.tls_client.clone();
        let server = self.server_addr.clone();
        let domain = self.domain.clone();

        let stream = self.stream.take();
        let tls_stream = self.tls_stream.take();
        let mut client_sender = self.sender.clone();
        let mut client_receiver = self.receiver.take().unwrap();
        let mut receiver_work = self.receiver_work.take().unwrap();
        let mut mappings = self.mappings.clone();
        tokio::spawn(async move {
            let mut stream = stream;
            let mut tls_stream = tls_stream;
            loop {
                if stream.is_some() {
                    let _ = Self::inner_serve(
                        stream.take().unwrap(),
                        &mut client_sender,
                        &mut receiver_work,
                        &mut client_receiver,
                        &mut mappings,
                    )
                    .await;
                } else if tls_stream.is_some() {
                    let _ = Self::inner_serve(
                        tls_stream.take().unwrap(),
                        &mut client_sender,
                        &mut receiver_work,
                        &mut client_receiver,
                        &mut mappings,
                    )
                    .await;
                };
                match Self::inner_connect(tls_client.clone(), server.clone(), domain.clone()).await
                {
                    Ok((s, tls)) => {
                        stream = s;
                        tls_stream = tls;
                    }
                    Err(_err) => {
                        tokio::time::sleep(Duration::from_millis(1000)).await;
                    }
                }
            }
        });

        Ok(())
    }

    fn calc_next_id(&mut self) -> u32 {
        let id = self.next_id;
        self.next_id += 2;
        id
    }

    pub async fn deal_new_stream<T>(&mut self, inbound: T) -> ProxyResult<()>
    where
        T: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static,
    {
        let id = self.calc_next_id();
        let sender = self.sender.clone();
        let (stream_sender, stream_receiver) = channel::<ProtFrame>(10);
        let _ = self
            .sender_work
            .send((ProtCreate::new(id, None), stream_sender))
            .await;
        tokio::spawn(async move {
            let trans = TransStream::new(inbound, id, sender, stream_receiver);
            let _ = trans.copy_wait().await;
        });
        Ok(())
    }
}