use std::io;
use crate::{
ws::{WsHandshake, WsOption, WsTrait},
ProtError, ProtResult, RecvRequest, RecvResponse, Server,
};
use algorithm::buf::{BinaryMut, Bt};
use async_trait::async_trait;
use tokio::{
io::{split, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt},
net::{TcpStream, ToSocketAddrs},
sync::mpsc::{channel, Receiver, Sender},
};
use webparse::{ws::OwnedMessage, Response};
pub struct WsToStream<T: AsyncRead + AsyncWrite + Unpin + Send + 'static, A: ToSocketAddrs> {
addr: A,
domain: Option<String>,
io: T,
}
struct Operate {
domain: Option<String>,
stream_sender: Sender<Vec<u8>>,
receiver: Option<Receiver<OwnedMessage>>,
}
#[async_trait]
impl WsTrait for Operate {
#[inline]
async fn on_request(&mut self, req: &RecvRequest) -> ProtResult<RecvResponse> {
if self.domain.is_some() {
if req.get_host() != self.domain {
Ok(Response::builder()
.status(400)
.body("host not match")?
.into_type())
} else {
WsHandshake::build_request(req)
}
} else {
WsHandshake::build_request(req)
}
}
async fn on_open(&mut self, _shake: WsHandshake) -> ProtResult<Option<WsOption>> {
let mut option = WsOption::new();
if self.receiver.is_some() {
option.set_receiver(self.receiver.take().unwrap());
}
Ok(Some(option))
}
async fn on_message(&mut self, msg: OwnedMessage) -> ProtResult<()> {
match msg {
OwnedMessage::Text(v) => self
.stream_sender
.send(v.into_bytes())
.await
.map_err(|_| ProtError::Extension("close"))?,
OwnedMessage::Binary(v) => self
.stream_sender
.send(v)
.await
.map_err(|_| ProtError::Extension("close"))?,
_ => (),
}
Ok(())
}
async fn on_interval(&mut self, _option: &mut Option<WsOption>) -> ProtResult<()> {
Ok(())
}
}
impl<T: AsyncRead + AsyncWrite + Unpin + Send + 'static, A: ToSocketAddrs> WsToStream<T, A> {
pub fn new(io: T, addr: A) -> ProtResult<Self> {
Ok(Self {
addr,
io,
domain: None,
})
}
pub fn set_domain(&mut self, domain: String) {
self.domain = Some(domain);
}
pub async fn copy_bidirectional(self) -> ProtResult<()> {
let (ws_sender, ws_receiver) = channel(10);
let (stream_sender, stream_receiver) = channel::<Vec<u8>>(10);
let stream = TcpStream::connect(self.addr).await?;
let io = self.io;
tokio::spawn(async move {
let mut server = Server::new(io, None);
server.set_callback_ws(Box::new(Operate {
stream_sender,
receiver: Some(ws_receiver),
domain: self.domain,
}));
let e = server.incoming().await;
println!("close server ==== addr = {:?} e = {:?}", 0, e);
});
Self::bind(stream, ws_sender, stream_receiver).await?;
Ok(())
}
pub async fn bind<S>(
io: S,
ws_sender: Sender<OwnedMessage>,
mut stream_receiver: Receiver<Vec<u8>>,
) -> ProtResult<()>
where
S: AsyncRead + AsyncWrite + Unpin,
{
let mut buf = vec![0; 20480];
let (mut reader, mut writer) = split(io);
let (mut read, mut write) = (BinaryMut::new(), BinaryMut::new());
loop {
tokio::select! {
n = reader.read(&mut buf) => {
let n = n?;
if n == 0 {
return Ok(())
} else {
read.put_slice(&buf[..n]);
}
},
r = writer.write(write.chunk()), if write.has_remaining() => {
match r {
Ok(n) => {
write.advance(n);
if !write.has_remaining() {
write.clear();
}
}
Err(_) => todo!(),
}
}
r = stream_receiver.recv() => {
if let Some(v) = r {
write.put_slice(&v);
} else {
return Err(io::Error::new(io::ErrorKind::InvalidInput, "invalid frame").into())
}
}
p = ws_sender.reserve(), if read.has_remaining() => {
match p {
Err(_)=>{
return Err(io::Error::new(io::ErrorKind::InvalidInput, "invalid frame").into())
}
Ok(p) => {
let msg = OwnedMessage::Binary(read.chunk().to_vec());
read.clear();
p.send(msg);
},
}
}
}
}
}
}