#![warn(rust_2018_idioms)]
#![feature(async_await)]
use futures::{SinkExt, Stream};
use std::{env, error::Error, net::SocketAddr};
use tokio::{
codec::{FramedRead, FramedWrite},
io,
sync::{mpsc, oneshot},
};
#[tokio::main]
async fn main() -> Result<(), Box<dyn Error>> {
let (tx, rx) = oneshot::channel();
tokio::spawn(async move {
run().await.unwrap();
tx.send(()).unwrap();
});
rx.await.map_err(Into::into)
}
async fn run() -> Result<(), Box<dyn Error>> {
let mut args = env::args().skip(1).collect::<Vec<_>>();
let tcp = match args.iter().position(|a| a == "--udp") {
Some(i) => {
args.remove(i);
false
}
None => true,
};
let addr = match args.first() {
Some(addr) => addr,
None => Err("this program requires at least one argument")?,
};
let addr = addr.parse::<SocketAddr>()?;
let stdin = stdin();
let stdout = FramedWrite::new(io::stdout(), codec::Bytes);
if tcp {
tcp::connect(&addr, stdin, stdout).await?;
} else {
udp::connect(&addr, stdin, stdout).await?;
}
Ok(())
}
fn stdin() -> impl Stream<Item = Result<Vec<u8>, io::Error>> + Unpin {
let mut stdin = FramedRead::new(io::stdin(), codec::Bytes);
let (mut tx, rx) = mpsc::unbounded_channel();
tokio::spawn(async move {
tx.send_all(&mut stdin).await.unwrap();
});
rx
}
mod tcp {
use super::codec;
use futures::{future, Sink, SinkExt, Stream, StreamExt};
use std::{error::Error, io, net::SocketAddr};
use tokio::{
codec::{FramedRead, FramedWrite},
net::TcpStream,
};
pub async fn connect(
addr: &SocketAddr,
stdin: impl Stream<Item = Result<Vec<u8>, io::Error>> + Unpin,
mut stdout: impl Sink<Vec<u8>, Error = io::Error> + Unpin,
) -> Result<(), Box<dyn Error>> {
let (r, w) = TcpStream::connect(addr).await?.split();
let sink = FramedWrite::new(w, codec::Bytes);
let mut stream = FramedRead::new(r, codec::Bytes).filter_map(|i| match i {
Ok(i) => future::ready(Some(i)),
Err(e) => {
println!("failed to read from socket; error={}", e);
future::ready(None)
}
});
match future::join(stdin.forward(sink), stdout.send_all(&mut stream)).await {
(Err(e), _) | (_, Err(e)) => Err(e.into()),
_ => Ok(()),
}
}
}
mod udp {
use futures::{future, Sink, SinkExt, Stream, StreamExt};
use std::{error::Error, io, net::SocketAddr};
use tokio::net::udp::{
split::{UdpSocketRecvHalf, UdpSocketSendHalf},
UdpSocket,
};
pub async fn connect(
addr: &SocketAddr,
stdin: impl Stream<Item = Result<Vec<u8>, io::Error>> + Unpin,
stdout: impl Sink<Vec<u8>, Error = io::Error> + Unpin,
) -> Result<(), Box<dyn Error>> {
let bind_addr = if addr.ip().is_ipv4() {
"0.0.0.0:0".parse()?
} else {
"[::]:0".parse()?
};
let socket = UdpSocket::bind(&bind_addr)?;
socket.connect(addr)?;
let (mut r, mut w) = socket.split();
future::try_join(send(stdin, &mut w), recv(stdout, &mut r)).await?;
Ok(())
}
async fn send(
mut stdin: impl Stream<Item = Result<Vec<u8>, io::Error>> + Unpin,
writer: &mut UdpSocketSendHalf,
) -> Result<(), io::Error> {
while let Some(item) = stdin.next().await {
let buf = item?;
writer.send(&buf[..]).await?;
}
Ok(())
}
async fn recv(
mut stdout: impl Sink<Vec<u8>, Error = io::Error> + Unpin,
reader: &mut UdpSocketRecvHalf,
) -> Result<(), io::Error> {
loop {
let mut buf = vec![0; 1024];
let n = reader.recv(&mut buf[..]).await?;
if n > 0 {
stdout.send(buf).await?;
}
}
}
}
mod codec {
use bytes::{BufMut, BytesMut};
use std::io;
use tokio::codec::{Decoder, Encoder};
pub struct Bytes;
impl Decoder for Bytes {
type Item = Vec<u8>;
type Error = io::Error;
fn decode(&mut self, buf: &mut BytesMut) -> io::Result<Option<Vec<u8>>> {
if buf.len() > 0 {
let len = buf.len();
Ok(Some(buf.split_to(len).into_iter().collect()))
} else {
Ok(None)
}
}
}
impl Encoder for Bytes {
type Item = Vec<u8>;
type Error = io::Error;
fn encode(&mut self, data: Vec<u8>, buf: &mut BytesMut) -> io::Result<()> {
buf.put(&data[..]);
Ok(())
}
}
}