use std::collections::VecDeque;
use std::io::{ErrorKind, Read, Write};
use std::net::TcpStream;
use std::result::Result;
use std::str;
use std::time::Duration;
use super::message::{
ChannelMessage, ChannelMessageModeControl, ChannelMessageModeIngest, ChannelMessageModeSearch,
ChannelMessageResult,
};
use super::mode::ChannelMode;
use super::statistics::CLIENTS_CONNECTED;
use crate::APP_CONF;
use crate::LINE_FEED;
pub struct ChannelHandle;
enum ChannelHandleError {
Closed,
InvalidMode,
AuthenticationRequired,
AuthenticationFailed,
NotRecognized,
TimedOut,
ConnectionAborted,
Interrupted,
Unknown,
}
const LINE_END_GAP: usize = 1;
const BUFFER_SIZE: usize = 20000;
const MAX_LINE_SIZE: usize = BUFFER_SIZE + LINE_END_GAP + 1;
const TCP_TIMEOUT_NON_ESTABLISHED: u64 = 10;
const PROTOCOL_REVISION: u8 = 1;
const BUFFER_LINE_SEPARATOR: u8 = b'\n';
lazy_static! {
static ref CONNECTED_BANNER: String = format!(
"CONNECTED <{} v{}>",
env!("CARGO_PKG_NAME"),
env!("CARGO_PKG_VERSION")
);
}
impl ChannelHandleError {
pub fn to_str(&self) -> &'static str {
match *self {
ChannelHandleError::Closed => "closed",
ChannelHandleError::InvalidMode => "invalid_mode",
ChannelHandleError::AuthenticationRequired => "authentication_required",
ChannelHandleError::AuthenticationFailed => "authentication_failed",
ChannelHandleError::NotRecognized => "not_recognized",
ChannelHandleError::TimedOut => "timed_out",
ChannelHandleError::ConnectionAborted => "connection_aborted",
ChannelHandleError::Interrupted => "interrupted",
ChannelHandleError::Unknown => "unknown",
}
}
}
impl ChannelHandle {
pub fn client(mut stream: TcpStream) {
ChannelHandle::configure_stream(&stream, false);
write!(stream, "{}{}", *CONNECTED_BANNER, LINE_FEED).expect("write failed");
*CLIENTS_CONNECTED.write().unwrap() += 1;
match Self::ensure_start(&stream) {
Ok(mode) => {
ChannelHandle::configure_stream(&stream, true);
write!(
stream,
"STARTED {} protocol({}) buffer({}){}",
mode.to_str(),
PROTOCOL_REVISION,
BUFFER_SIZE,
LINE_FEED
)
.expect("write failed");
Self::handle_stream(mode, stream);
}
Err(err) => {
write!(stream, "ENDED {}{}", err.to_str(), LINE_FEED).expect("write failed");
}
}
*CLIENTS_CONNECTED.write().unwrap() -= 1;
}
fn configure_stream(stream: &TcpStream, is_established: bool) {
let tcp_timeout = if is_established {
APP_CONF.channel.tcp_timeout
} else {
TCP_TIMEOUT_NON_ESTABLISHED
};
assert!(stream.set_nodelay(true).is_ok());
assert!(stream
.set_read_timeout(Some(Duration::new(tcp_timeout, 0)))
.is_ok());
assert!(stream
.set_write_timeout(Some(Duration::new(tcp_timeout, 0)))
.is_ok());
}
fn handle_stream(mode: ChannelMode, mut stream: TcpStream) {
let mut buffer: VecDeque<u8> = VecDeque::with_capacity(MAX_LINE_SIZE);
'handler: loop {
let mut read = [0; MAX_LINE_SIZE];
match stream.read(&mut read) {
Ok(n) => {
if n == 0 {
break;
}
{
let buffer_len = n + buffer.len();
if buffer_len > MAX_LINE_SIZE {
error!("closing channel thread because of buffer overflow");
panic!("buffer overflow ({}/{} bytes)", buffer_len, MAX_LINE_SIZE);
}
}
buffer.extend(&read[0..n]);
{
let mut processed_line = Vec::with_capacity(MAX_LINE_SIZE);
while let Some(byte) = buffer.pop_front() {
if byte == BUFFER_LINE_SEPARATOR {
if Self::on_message(&mode, &stream, &processed_line)
== ChannelMessageResult::Close
{
break 'handler;
}
processed_line.clear();
} else {
processed_line.push(byte);
}
}
if !processed_line.is_empty() {
buffer.extend(processed_line);
}
}
}
Err(err) => {
error!("closing channel thread with traceback: {}", err);
panic!("closing channel");
}
}
}
}
fn ensure_start(mut stream: &TcpStream) -> Result<ChannelMode, ChannelHandleError> {
#[allow(clippy::never_loop)]
loop {
let mut read = [0; MAX_LINE_SIZE];
match stream.read(&mut read) {
Ok(n) => {
if n == 0 {
return Err(ChannelHandleError::Closed);
}
let mut parts = str::from_utf8(&read[0..n]).unwrap_or("").split_whitespace();
if parts.next().unwrap_or("").to_uppercase().as_str() == "START" {
if let Some(res_mode) = parts.next() {
debug!("got mode response: {}", res_mode);
if let Ok(mode) = ChannelMode::from_str(res_mode) {
if let Some(ref auth_password) = APP_CONF.channel.auth_password {
if let Some(provided_auth) = parts.next() {
if provided_auth != auth_password {
info!("password provided, but does not match");
return Err(ChannelHandleError::AuthenticationFailed);
}
} else {
info!("no password provided, but one required");
return Err(ChannelHandleError::AuthenticationRequired);
}
}
return Ok(mode);
}
}
return Err(ChannelHandleError::InvalidMode);
}
return Err(ChannelHandleError::NotRecognized);
}
Err(err) => {
let err_reason = match err.kind() {
ErrorKind::TimedOut => ChannelHandleError::TimedOut,
ErrorKind::ConnectionAborted => ChannelHandleError::ConnectionAborted,
ErrorKind::Interrupted => ChannelHandleError::Interrupted,
_ => ChannelHandleError::Unknown,
};
return Err(err_reason);
}
}
}
}
fn on_message(
mode: &ChannelMode,
stream: &TcpStream,
message_slice: &[u8],
) -> ChannelMessageResult {
match mode {
ChannelMode::Search => {
ChannelMessage::on::<ChannelMessageModeSearch>(stream, message_slice)
}
ChannelMode::Ingest => {
ChannelMessage::on::<ChannelMessageModeIngest>(stream, message_slice)
}
ChannelMode::Control => {
ChannelMessage::on::<ChannelMessageModeControl>(stream, message_slice)
}
}
}
}