use std::sync::Arc;
use std::time::Duration;
use std::{collections::HashMap, io, net::SocketAddr};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::sync::mpsc::Receiver;
use tokio::{io::split, net::TcpStream, sync::mpsc::channel};
use tokio::{
io::{AsyncRead, AsyncWrite},
sync::mpsc::Sender,
};
use tokio_rustls::{client::TlsStream, TlsConnector};
use webparse::{BinaryMut, Buf};
use crate::{
Helper, MappingConfig, ProtClose, ProtCreate, ProtFrame, ProxyResult, TransStream,
};
pub struct CenterClient {
tls_client: Option<Arc<rustls::ClientConfig>>,
domain: Option<String>,
server_addr: SocketAddr,
mappings: Vec<MappingConfig>,
stream: Option<TcpStream>,
tls_stream: Option<TlsStream<TcpStream>>,
next_id: u32,
sender_work: Sender<(ProtCreate, Sender<ProtFrame>)>,
receiver_work: Option<Receiver<(ProtCreate, Sender<ProtFrame>)>>,
sender: Sender<ProtFrame>,
receiver: Option<Receiver<ProtFrame>>,
}
impl CenterClient {
pub fn new(
server_addr: SocketAddr,
tls_client: Option<Arc<rustls::ClientConfig>>,
domain: Option<String>,
mappings: Vec<MappingConfig>,
) -> Self {
let (sender, receiver) = channel::<ProtFrame>(100);
let (sender_work, receiver_work) = channel::<(ProtCreate, Sender<ProtFrame>)>(10);
Self {
tls_client,
domain,
server_addr,
mappings,
stream: None,
tls_stream: None,
next_id: 1,
sender_work,
receiver_work: Some(receiver_work),
sender,
receiver: Some(receiver),
}
}
async fn inner_connect(
tls_client: Option<Arc<rustls::ClientConfig>>,
server_addr: SocketAddr,
domain: Option<String>,
) -> ProxyResult<(Option<TcpStream>, Option<TlsStream<TcpStream>>)> {
if tls_client.is_some() {
println!("connect by tls");
let connector = TlsConnector::from(tls_client.unwrap());
let stream = TcpStream::connect(&server_addr).await?;
let domain =
rustls::ServerName::try_from(&*domain.unwrap_or("soft.wm-proxy.com".to_string()))
.map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid dnsname"))?;
let outbound = connector.connect(domain, stream).await?;
Ok((None, Some(outbound)))
} else {
let outbound = TcpStream::connect(server_addr).await?;
Ok((Some(outbound), None))
}
}
pub async fn connect(&mut self) -> ProxyResult<bool> {
let (stream, tls_stream) = Self::inner_connect(
self.tls_client.clone(),
self.server_addr,
self.domain.clone(),
)
.await?;
self.stream = stream;
self.tls_stream = tls_stream;
Ok(self.stream.is_some() || self.tls_stream.is_some())
}
async fn inner_serve<T>(
stream: T,
sender: &mut Sender<ProtFrame>,
receiver_work: &mut Receiver<(ProtCreate, Sender<ProtFrame>)>,
receiver: &mut Receiver<ProtFrame>,
mappings: &mut Vec<MappingConfig>,
) -> ProxyResult<()>
where
T: AsyncRead + AsyncWrite + Unpin,
{
let mut map = HashMap::<u32, Sender<ProtFrame>>::new();
let mut read_buf = BinaryMut::new();
let mut write_buf = BinaryMut::new();
let (mut reader, mut writer) = split(stream);
let mut vec = vec![0u8; 4096];
let is_closed;
println!("mappings = {:?}", mappings);
if mappings.len() > 0 {
println!("encode mapping = {:?}", mappings);
ProtFrame::new_mapping(0, mappings.clone()).encode(&mut write_buf)?;
}
loop {
let _ = tokio::select! {
biased;
r = receiver_work.recv() => {
if let Some((create, sender)) = r {
map.insert(create.sock_map(), sender);
let _ = create.encode(&mut write_buf);
}
}
r = receiver.recv() => {
if let Some(p) = r {
let _ = p.encode(&mut write_buf);
}
}
r = reader.read(&mut vec) => {
match r {
Ok(0)=>{
is_closed=true;
break;
}
Ok(n) => {
read_buf.put_slice(&vec[..n]);
}
Err(_err) => {
is_closed = true;
break;
},
}
}
r = writer.write(write_buf.chunk()), if write_buf.has_remaining() => {
match r {
Ok(n) => {
write_buf.advance(n);
if !write_buf.has_remaining() {
write_buf.clear();
}
}
Err(e) => {
println!("center_client errrrr = {:?}", e);
},
}
}
};
loop {
match Helper::decode_frame(&mut read_buf)? {
Some(p) => {
match p {
ProtFrame::Create(p) => {
let domain = p.domain().clone().unwrap_or(String::new());
let mut local_addr = None;
for m in &*mappings {
if m.domain == domain {
local_addr = m.local_addr.clone();
} else if domain.len() == 0 && m.is_tcp() {
local_addr = m.local_addr.clone();
}
}
if local_addr.is_none() {
log::warn!("local addr is none, can't mapping");
continue;
}
let (virtual_sender, virtual_receiver) = channel::<ProtFrame>(10);
map.insert(p.sock_map(), virtual_sender);
let domain = local_addr.unwrap();
let sock_map = p.sock_map();
let sender = sender.clone();
println!("receiver sock_map {}, domain = {}", sock_map, domain);
tokio::spawn(async move {
let stream = TcpStream::connect(domain).await;
println!("connect server {:?}", stream);
if let Ok(tcp) = stream {
let trans = TransStream::new(
tcp,
sock_map,
sender,
virtual_receiver,
);
let _ = trans.copy_wait().await;
}
});
}
ProtFrame::Close(_) | ProtFrame::Data(_) => {
if let Some(sender) = map.get(&p.sock_map()) {
let _ = sender.try_send(p);
}
}
ProtFrame::Mapping(_) => {}
}
}
None => {
break;
}
}
}
if !read_buf.has_remaining() {
read_buf.clear();
}
}
if is_closed {
for v in map {
let _ = v.1.try_send(ProtFrame::Close(ProtClose::new(v.0)));
}
}
Ok(())
}
pub async fn serve(&mut self) -> ProxyResult<()> {
let tls_client = self.tls_client.clone();
let server = self.server_addr.clone();
let domain = self.domain.clone();
let stream = self.stream.take();
let tls_stream = self.tls_stream.take();
let mut client_sender = self.sender.clone();
let mut client_receiver = self.receiver.take().unwrap();
let mut receiver_work = self.receiver_work.take().unwrap();
let mut mappings = self.mappings.clone();
tokio::spawn(async move {
let mut stream = stream;
let mut tls_stream = tls_stream;
loop {
if stream.is_some() {
let _ = Self::inner_serve(
stream.take().unwrap(),
&mut client_sender,
&mut receiver_work,
&mut client_receiver,
&mut mappings,
)
.await;
} else if tls_stream.is_some() {
let _ = Self::inner_serve(
tls_stream.take().unwrap(),
&mut client_sender,
&mut receiver_work,
&mut client_receiver,
&mut mappings,
)
.await;
};
match Self::inner_connect(tls_client.clone(), server.clone(), domain.clone()).await
{
Ok((s, tls)) => {
stream = s;
tls_stream = tls;
}
Err(_err) => {
tokio::time::sleep(Duration::from_millis(1000)).await;
}
}
}
});
Ok(())
}
fn calc_next_id(&mut self) -> u32 {
let id = self.next_id;
self.next_id += 2;
id
}
pub async fn deal_new_stream<T>(&mut self, inbound: T) -> ProxyResult<()>
where
T: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static,
{
let id = self.calc_next_id();
let sender = self.sender.clone();
let (stream_sender, stream_receiver) = channel::<ProtFrame>(10);
let _ = self
.sender_work
.send((ProtCreate::new(id, None), stream_sender))
.await;
tokio::spawn(async move {
let trans = TransStream::new(inbound, id, sender, stream_receiver);
let _ = trans.copy_wait().await;
});
Ok(())
}
}