use super::*;
use crate::{
messaging::SerialisedFrame,
net::{
buffers::{BufferChunk, BufferPool, DecodeBuffer},
frames::{Frame, FramingError, Hello, Start, FRAME_HEAD_LEN},
},
};
use bytes::{Buf, BytesMut};
use mio::{net::TcpStream, Token};
use network_thread::*;
use std::{
cell::RefCell,
collections::VecDeque,
fmt::Formatter,
io,
io::{Error, ErrorKind, Read, Write},
net::{Shutdown::Both, SocketAddr},
};
#[derive(Debug)]
#[allow(unused)]
pub(crate) enum ChannelState {
Requested(SocketAddr, SessionId),
Initialising,
Initialised(SocketAddr, SessionId),
Connected(SocketAddr, SessionId),
CloseRequested(SocketAddr, SessionId),
CloseReceived(SocketAddr, SessionId),
Closed(SocketAddr, SessionId),
Error(Error),
}
impl ChannelState {
fn session_id(&self) -> Option<SessionId> {
match self {
ChannelState::Requested(_, session) => Some(*session),
ChannelState::Initialising => None,
ChannelState::Initialised(_, session) => Some(*session),
ChannelState::Connected(_, session) => Some(*session),
ChannelState::CloseRequested(_, session) => Some(*session),
ChannelState::CloseReceived(_, session) => Some(*session),
ChannelState::Closed(_, session) => Some(*session),
ChannelState::Error(_) => None,
}
}
}
pub(crate) struct TcpChannel {
stream: TcpStream,
outbound_queue: VecDeque<SerialisedFrame>,
pub token: Token,
address: SocketAddr,
input_buffer: DecodeBuffer,
pub state: ChannelState,
pub messages: u32,
own_addr: SocketAddr,
nodelay: bool,
}
impl TcpChannel {
pub fn new(
stream: TcpStream,
token: Token,
address: SocketAddr,
buffer_chunk: BufferChunk,
state: ChannelState,
own_addr: SocketAddr,
network_config: &NetworkConfig,
) -> Self {
let input_buffer = DecodeBuffer::new(buffer_chunk, network_config.get_buffer_config());
TcpChannel {
stream,
outbound_queue: VecDeque::new(),
token,
address,
input_buffer,
state,
messages: 0,
own_addr,
nodelay: network_config.get_tcp_nodelay(),
}
}
pub fn stream_mut(&mut self) -> &mut TcpStream {
&mut self.stream
}
#[allow(dead_code)]
pub fn stream(&self) -> &TcpStream {
&self.stream
}
pub fn connected(&self) -> bool {
matches!(self.state, ChannelState::Connected(_, _))
}
pub fn session_id(&self) -> Option<SessionId> {
self.state.session_id()
}
fn send_frame(&mut self, mut frame: Frame) -> () {
let len = frame.encoded_len() + FRAME_HEAD_LEN as usize;
let mut bytes = BytesMut::with_capacity(len);
bytes.truncate(len);
if let Ok(()) = frame.encode_into(&mut bytes) {
self.outbound_queue
.push_back(SerialisedFrame::Bytes(bytes.freeze()));
let _ = self.try_drain();
} else {
panic!("Failed to encode bytes for Frame {:?}", frame.frame_type());
}
}
pub fn read_state(&mut self) -> &ChannelState {
match self.stream.take_error() {
Ok(Some(error)) => {
self.state = ChannelState::Error(error);
}
Err(error) => {
self.state = ChannelState::Error(error);
}
_ => (), }
&self.state
}
pub fn initialise(&mut self, addr: &SocketAddr) -> () {
if let ChannelState::Initialising = self.state {
let hello = Frame::Hello(Hello::new(*addr));
self.send_frame(hello);
}
}
pub fn handle_hello(&mut self, hello: &Hello) -> () {
if let ChannelState::Requested(_, id) = self.state {
let start = Frame::Start(Start::new(self.own_addr, id));
self.send_frame(start);
self.state = ChannelState::Initialised(hello.addr, id);
self.address = hello.addr;
}
}
pub fn handle_start(&mut self, start: &Start) -> () {
if let ChannelState::Initialising = self.state {
let ack = Frame::Ack();
self.stream
.set_nodelay(self.nodelay)
.expect("set nodelay failed");
self.send_frame(ack);
self.state = ChannelState::Connected(start.addr, start.id);
self.address = start.addr;
}
}
pub fn address(&self) -> SocketAddr {
self.address
}
pub fn handle_ack(&mut self) -> () {
if let ChannelState::Initialised(addr, id) = self.state {
self.stream
.set_nodelay(self.nodelay)
.expect("set nodelay failed");
self.state = ChannelState::Connected(addr, id);
}
}
pub fn swap_buffer(&mut self, new_buffer: &mut BufferChunk) -> () {
self.input_buffer.swap_buffer(new_buffer);
}
pub fn take_outbound(&mut self) -> Vec<SerialisedFrame> {
let mut ret = Vec::new();
while let Some(frame) = self.outbound_queue.pop_front() {
ret.push(frame);
}
ret
}
pub fn read_frame(&mut self, buffer_pool: &RefCell<BufferPool>) -> io::Result<Option<Frame>> {
if !self.input_buffer.has_frame()? {
match self.receive() {
Ok(_) => {}
Err(err) if no_buffer_space(&err) => {
if !&self.input_buffer.has_frame()? {
let mut pool = buffer_pool.borrow_mut();
let mut buffer_chunk = pool.get_buffer().ok_or(err)?;
self.swap_buffer(&mut buffer_chunk);
pool.return_buffer(buffer_chunk);
drop(pool);
return self.read_frame(buffer_pool);
}
}
Err(err) => {
return Err(err);
}
};
}
match self.decode() {
Ok(Frame::Hello(hello)) => Ok(Some(Frame::Hello(hello))),
Ok(Frame::Ack()) => {
self.handle_ack();
Ok(Some(Frame::Ack()))
}
Ok(Frame::Bye()) => {
self.handle_bye();
Ok(Some(Frame::Bye()))
}
Ok(frame) => Ok(Some(frame)),
Err(FramingError::NoData) => Ok(None),
Err(_) => Err(Error::new(ErrorKind::InvalidData, "Framing Error")),
}
}
fn receive(&mut self) -> io::Result<()> {
let mut interrupts = 0;
loop {
let buf = self
.input_buffer
.get_writeable()
.ok_or_else(|| io::Error::new(ErrorKind::InvalidInput, "No Buffer Space"))?;
match self.stream.read(buf) {
Ok(0) => {
return Ok(());
}
Ok(n) => {
self.input_buffer.advance_writeable(n);
}
Err(err) if would_block(&err) => {
return Ok(());
}
Err(err) if interrupted(&err) => {
interrupts += 1;
if interrupts >= network_thread::MAX_INTERRUPTS {
return Err(err);
}
}
Err(err) => {
return Err(err);
}
}
}
}
pub fn send_bye(&mut self) -> io::Result<()> {
let mut bye = Frame::Bye();
let mut bye_bytes = BytesMut::with_capacity(128);
let len = bye.encoded_len() + FRAME_HEAD_LEN as usize;
bye_bytes.truncate(len);
if let Ok(()) = bye.encode_into(&mut bye_bytes) {
self.outbound_queue
.push_back(SerialisedFrame::Bytes(bye_bytes.freeze()));
let _ = self.try_drain(); } else {
panic!("Unable to send bye bytes, failed to encode!");
}
if self.outbound_queue.is_empty() {
io::Result::Ok(())
} else {
io::Result::Err(Error::new(ErrorKind::Interrupted, "bye not sent"))
}
}
pub fn handle_bye(&mut self) -> () {
match self.state {
ChannelState::Connected(addr, id) => {
self.state = ChannelState::CloseReceived(addr, id);
if self.send_bye().is_ok() {
self.state = ChannelState::Closed(addr, id);
}
}
ChannelState::CloseRequested(addr, id) => self.state = ChannelState::Closed(addr, id),
_ => {}
}
}
pub fn initiate_graceful_shutdown(&mut self) -> () {
if let ChannelState::Connected(addr, id) = self.state {
self.state = ChannelState::CloseRequested(addr, id);
let _ = self.send_bye();
}
}
pub fn shutdown(&mut self) -> () {
let _ = self.stream.shutdown(Both); if let ChannelState::Connected(addr, id) = self.state {
self.state = ChannelState::Closed(addr, id);
}
}
pub fn decode(&mut self) -> Result<Frame, FramingError> {
match self.input_buffer.get_frame() {
Ok(frame) => {
if matches!(frame, Frame::Data(_)) {
self.messages += 1;
}
Ok(frame)
}
Err(e) => Err(e),
}
}
pub fn enqueue_serialised(&mut self, serialized: SerialisedFrame) -> () {
self.outbound_queue.push_back(serialized);
}
pub fn try_drain(&mut self) -> io::Result<usize> {
let mut sent_bytes: usize = 0;
let mut interrupts = 0;
while let Some(mut serialized_frame) = self.outbound_queue.pop_front() {
match self.write_serialized(&serialized_frame) {
Ok(n) => {
sent_bytes += n;
match &mut serialized_frame {
SerialisedFrame::Bytes(bytes) => {
if n < bytes.len() {
let _ = bytes.split_to(n); self.outbound_queue.push_front(serialized_frame);
}
}
SerialisedFrame::ChunkLease(chunk) => {
if n < chunk.remaining() {
chunk.advance(n);
self.outbound_queue.push_front(serialized_frame);
}
}
SerialisedFrame::ChunkRef(chunk) => {
if n < chunk.remaining() {
chunk.advance(n);
self.outbound_queue.push_front(serialized_frame);
}
}
}
}
Err(ref err) if would_block(err) => {
self.outbound_queue.push_front(serialized_frame);
return Ok(sent_bytes);
}
Err(err) if interrupted(&err) => {
self.outbound_queue.push_front(serialized_frame);
interrupts += 1;
if interrupts >= MAX_INTERRUPTS {
return Err(err);
}
}
Err(err) => {
self.outbound_queue.push_front(serialized_frame);
return Err(err);
}
}
}
Ok(sent_bytes)
}
fn write_serialized(&mut self, serialized: &SerialisedFrame) -> io::Result<usize> {
match serialized {
SerialisedFrame::ChunkLease(chunk) => self.stream.write(chunk.chunk()),
SerialisedFrame::Bytes(bytes) => self.stream.write(bytes.chunk()),
SerialisedFrame::ChunkRef(chunkref) => self.stream.write(chunkref.chunk()),
}
}
pub(crate) fn kill(&mut self) -> () {
let _ = self.stream.shutdown(Both);
}
}
impl std::fmt::Debug for TcpChannel {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TcpChannel")
.field("State", &self.state)
.field("Messages", &self.messages)
.field("Decode Buffer", &self.input_buffer)
.field("Outbound Queue", &self.outbound_queue.len())
.finish()
}
}