use std::fs::File;
use std::io::{self, Read, Write};
use std::net::{Shutdown, TcpStream};
use std::sync::{mpsc, Arc};
use std::{
cmp,
thread::{self, JoinHandle},
time,
};
use crate::core::ser;
use crate::core::ser::FixedLength;
use crate::msg::{
read_body, read_discard, read_header, read_item, write_to_buf, MsgHeader, MsgHeaderWrapper,
Type,
};
use crate::types::Error;
use crate::util::read_write::{read_exact, write_all};
use crate::util::{RateCounter, RwLock};
pub trait MessageHandler: Send + 'static {
fn consume<'a>(
&self,
msg: Message<'a>,
writer: &'a mut dyn Write,
tracker: Arc<Tracker>,
) -> Result<Option<Response<'a>>, Error>;
}
macro_rules! try_break {
($inner:expr) => {
match $inner {
Ok(v) => Some(v),
Err(Error::Connection(ref e)) if e.kind() == io::ErrorKind::WouldBlock => None,
Err(Error::Store(_))
| Err(Error::Chain(_))
| Err(Error::Internal)
| Err(Error::NoDandelionRelay) => None,
Err(ref e) => {
debug!("try_break: exit the loop: {:?}", e);
break;
}
}
};
}
pub struct Message<'a> {
pub header: MsgHeader,
stream: &'a mut dyn Read,
}
impl<'a> Message<'a> {
fn from_header(header: MsgHeader, stream: &'a mut dyn Read) -> Message<'a> {
Message { header, stream }
}
pub fn body<T: ser::Readable>(&mut self) -> Result<T, Error> {
read_body(&self.header, self.stream)
}
pub fn streaming_read<T: ser::Readable>(&mut self) -> Result<(T, u64), Error> {
read_item(self.stream)
}
pub fn copy_attachment(&mut self, len: usize, writer: &mut dyn Write) -> Result<usize, Error> {
let mut written = 0;
while written < len {
let read_len = cmp::min(8000, len - written);
let mut buf = vec![0u8; read_len];
read_exact(
&mut self.stream,
&mut buf[..],
time::Duration::from_secs(10),
true,
)?;
writer.write_all(&mut buf)?;
written += read_len;
}
Ok(written)
}
}
pub struct Response<'a> {
resp_type: Type,
body: Vec<u8>,
stream: &'a mut dyn Write,
attachment: Option<File>,
}
impl<'a> Response<'a> {
pub fn new<T: ser::Writeable>(
resp_type: Type,
body: T,
stream: &'a mut dyn Write,
) -> Result<Response<'a>, Error> {
let body = ser::ser_vec(&body)?;
Ok(Response {
resp_type,
body,
stream,
attachment: None,
})
}
fn write(mut self, tracker: Arc<Tracker>) -> Result<(), Error> {
let mut msg = ser::ser_vec(&MsgHeader::new(self.resp_type, self.body.len() as u64))?;
msg.append(&mut self.body);
write_all(&mut self.stream, &msg[..], time::Duration::from_secs(10))?;
tracker.inc_sent(msg.len() as u64);
if let Some(mut file) = self.attachment {
let mut buf = [0u8; 8000];
loop {
match file.read(&mut buf[..]) {
Ok(0) => break,
Ok(n) => {
write_all(&mut self.stream, &buf[..n], time::Duration::from_secs(10))?;
tracker.inc_quiet_sent(n as u64);
}
Err(e) => return Err(From::from(e)),
}
}
}
Ok(())
}
pub fn add_attachment(&mut self, file: File) {
self.attachment = Some(file);
}
}
pub const SEND_CHANNEL_CAP: usize = 100;
pub struct StopHandle {
pub close_channel: mpsc::Sender<()>,
peer_thread: Option<JoinHandle<()>>,
}
impl StopHandle {
pub fn stop(&self) {
if self.close_channel.send(()).is_err() {
debug!("peer's close_channel is disconnected, must be stopped already");
return;
}
}
pub fn wait(&mut self) {
if let Some(peer_thread) = self.peer_thread.take() {
if thread::current().id() != peer_thread.thread().id() {
debug!("waiting for thread {:?} exit", peer_thread.thread().id());
if let Err(e) = peer_thread.join() {
error!("failed to wait for peer thread to stop: {:?}", e);
}
} else {
debug!(
"attempt to wait for thread {:?} from itself",
peer_thread.thread().id()
);
}
}
}
}
pub struct ConnHandle {
pub send_channel: mpsc::SyncSender<Vec<u8>>,
}
impl ConnHandle {
pub fn send<T>(&self, body: T, msg_type: Type) -> Result<u64, Error>
where
T: ser::Writeable,
{
let buf = write_to_buf(body, msg_type)?;
let buf_len = buf.len();
self.send_channel.try_send(buf)?;
Ok(buf_len as u64)
}
}
pub struct Tracker {
pub sent_bytes: Arc<RwLock<RateCounter>>,
pub received_bytes: Arc<RwLock<RateCounter>>,
}
impl Tracker {
pub fn new() -> Tracker {
let received_bytes = Arc::new(RwLock::new(RateCounter::new()));
let sent_bytes = Arc::new(RwLock::new(RateCounter::new()));
Tracker {
received_bytes,
sent_bytes,
}
}
pub fn inc_received(&self, size: u64) {
self.received_bytes.write().inc(size);
}
pub fn inc_sent(&self, size: u64) {
self.sent_bytes.write().inc(size);
}
pub fn inc_quiet_received(&self, size: u64) {
self.received_bytes.write().inc_quiet(size);
}
pub fn inc_quiet_sent(&self, size: u64) {
self.sent_bytes.write().inc_quiet(size);
}
}
pub fn listen<H>(
stream: TcpStream,
tracker: Arc<Tracker>,
handler: H,
) -> io::Result<(ConnHandle, StopHandle)>
where
H: MessageHandler,
{
let (send_tx, send_rx) = mpsc::sync_channel(SEND_CHANNEL_CAP);
let (close_tx, close_rx) = mpsc::channel();
stream
.set_nonblocking(true)
.expect("Non-blocking IO not available.");
let peer_thread = poll(stream, handler, send_rx, close_rx, tracker)?;
Ok((
ConnHandle {
send_channel: send_tx,
},
StopHandle {
close_channel: close_tx,
peer_thread: Some(peer_thread),
},
))
}
fn poll<H>(
conn: TcpStream,
handler: H,
send_rx: mpsc::Receiver<Vec<u8>>,
close_rx: mpsc::Receiver<()>,
tracker: Arc<Tracker>,
) -> io::Result<JoinHandle<()>>
where
H: MessageHandler,
{
let mut reader = conn.try_clone().expect("clone conn for reader failed");
let mut writer = conn.try_clone().expect("clone conn for writer failed");
thread::Builder::new()
.name("peer".to_string())
.spawn(move || {
let sleep_time = time::Duration::from_millis(5);
let mut retry_send = Err(());
loop {
match try_break!(read_header(&mut reader, None)) {
Some(MsgHeaderWrapper::Known(header)) => {
let msg = Message::from_header(header, &mut reader);
trace!(
"Received message header, type {:?}, len {}.",
msg.header.msg_type,
msg.header.msg_len
);
tracker.inc_received(MsgHeader::LEN as u64 + msg.header.msg_len);
if let Some(Some(resp)) =
try_break!(handler.consume(msg, &mut writer, tracker.clone()))
{
try_break!(resp.write(tracker.clone()));
}
}
Some(MsgHeaderWrapper::Unknown(msg_len)) => {
tracker.inc_received(MsgHeader::LEN as u64 + msg_len);
try_break!(read_discard(msg_len, &mut reader));
}
None => {}
}
let maybe_data = retry_send.or_else(|_| send_rx.try_recv());
retry_send = Err(());
if let Ok(data) = maybe_data {
let written = try_break!(write_all(
&mut writer,
&data[..],
std::time::Duration::from_secs(10)
)
.map_err(&From::from));
if written.is_none() {
retry_send = Ok(data);
}
}
if let Ok(_) = close_rx.try_recv() {
break;
}
thread::sleep(sleep_time);
}
debug!(
"Shutting down connection with {}",
conn.peer_addr()
.map(|a| a.to_string())
.unwrap_or("?".to_owned())
);
let _ = conn.shutdown(Shutdown::Both);
})
}