use core::pin::Pin;
use std::collections::HashMap;
use std::io;
use std::string::ToString;
use std::sync::{Arc, Mutex, RwLock};
use futures::select;
use futures::stream::StreamExt;
use futures::SinkExt;
use futures_channel::mpsc;
use futures_channel::oneshot;
use futures_util::FutureExt;
use tokio::net::TcpStream;
use tokio::time::{sleep_until, Duration, Instant};
use tokio_util::time::{delay_queue, DelayQueue};
use tokio_util::codec::{Decoder, Framed};
use crate::tarantool::codec::{RequestId, TarantoolCodec, TarantoolFramedRequest};
use crate::tarantool::packets::{AuthPacket, CommandPacket, TarantoolRequest, TarantoolResponse};
pub type TarantoolFramed = Framed<TcpStream, TarantoolCodec>;
pub type CallbackSender = oneshot::Sender<io::Result<TarantoolResponse>>;
pub type ReconnectNotifySender = mpsc::UnboundedSender<ClientStatus>;
pub static ERROR_DISPATCH_THREAD_IS_DEAD: &str = "DISPATCH THREAD IS DEAD!";
pub static ERROR_CLIENT_DISCONNECTED: &str = "CLIENT DISCONNECTED!";
#[derive(Clone, PartialEq, Eq, Hash, Debug)]
pub struct ClientConfig {
addr: String,
login: String,
password: String,
reconnect_time_ms: u64,
timeout_time_ms: Option<u64>,
}
impl ClientConfig {
pub fn new<S0, S, S1>(addr: S0, login: S, password: S1) -> ClientConfig
where
S0: Into<String>,
S: Into<String>,
S1: Into<String>,
{
ClientConfig {
addr: addr.into(),
login: login.into(),
password: password.into(),
reconnect_time_ms: 10000,
timeout_time_ms: None,
}
}
pub fn set_timeout_time_ms(mut self, timeout_time_ms: u64) -> ClientConfig {
self.timeout_time_ms = Some(timeout_time_ms);
self
}
pub fn set_reconnect_time_ms(mut self, reconnect_time_ms: u64) -> ClientConfig {
self.reconnect_time_ms = reconnect_time_ms;
self
}
}
#[derive(Clone, Debug)]
pub enum ClientStatus {
Init,
Connecting,
Handshaking,
Connected,
Disconnecting(String),
Disconnected(String),
Closed,
}
pub struct Dispatch {
config: ClientConfig,
command_receiver: mpsc::UnboundedReceiver<(CommandPacket, CallbackSender)>,
is_command_receiver_closed: bool,
awaiting_callbacks: HashMap<RequestId, CallbackSender>,
notify_callbacks: Arc<Mutex<Vec<ReconnectNotifySender>>>,
buffered_command: Option<TarantoolFramedRequest>,
command_counter: RequestId,
timeout_time_ms: Option<u64>,
timeout_queue: DelayQueue<RequestId>,
timeout_id_to_key: HashMap<RequestId, delay_queue::Key>,
status: Arc<RwLock<ClientStatus>>,
}
impl Dispatch {
pub fn new(
config: ClientConfig,
command_receiver: mpsc::UnboundedReceiver<(CommandPacket, CallbackSender)>,
status: Arc<RwLock<ClientStatus>>,
notify_callbacks: Arc<Mutex<Vec<ReconnectNotifySender>>>,
) -> Dispatch {
let timeout_time_ms = config.timeout_time_ms;
Dispatch {
config,
command_receiver,
is_command_receiver_closed:false,
buffered_command: None,
awaiting_callbacks: HashMap::new(),
notify_callbacks,
command_counter: 3,
timeout_time_ms,
timeout_queue: DelayQueue::new(),
timeout_id_to_key: HashMap::new(),
status,
}
}
async fn send_notify(&mut self, status: &ClientStatus) {
let callbacks: Vec<ReconnectNotifySender> =
self.notify_callbacks.lock().unwrap().split_off(0);
let mut filtered_callbacks: Vec<ReconnectNotifySender> = Vec::new();
for mut callback in callbacks {
if callback.send(status.clone()).await.is_ok() {
filtered_callbacks.push(callback);
}
}
self.notify_callbacks
.lock()
.unwrap()
.extend(filtered_callbacks.iter().cloned());
}
async fn try_send_buffered_command(&mut self, sink: &mut TarantoolFramed) -> io::Result<()> {
if let Some(command) = self.buffered_command.take() {
if let Err(e) = Pin::new(sink).send(command.clone()).await {
self.buffered_command = Some(command);
return Err(io::Error::new(
io::ErrorKind::ConnectionAborted,
e.to_string(),
));
}
}
Ok(())
}
fn send_error_to_all(&mut self, error_description: String) {
for (_, callback_sender) in self.awaiting_callbacks.drain() {
let _res = callback_sender.send(Err(io::Error::new(
io::ErrorKind::Other,
error_description.clone(),
)));
}
self.buffered_command = None;
if self.timeout_time_ms.is_some() {
self.timeout_id_to_key.clear();
self.timeout_queue.clear();
}
if !self.is_command_receiver_closed {
while let Ok(Some((_, callback_sender))) = self.command_receiver.try_next() {
let _res = callback_sender.send(Err(io::Error::new(
io::ErrorKind::Other,
error_description.clone(),
)));
}
}
}
async fn process_command(
&mut self,
command: Option<(CommandPacket, CallbackSender)>,
sink: &mut TarantoolFramed,
) -> io::Result<()> {
self.try_send_buffered_command(sink).await?;
match command {
Some((command_packet, callback_sender)) => {
let request_id = self.increment_command_counter();
self.awaiting_callbacks.insert(request_id, callback_sender);
self.buffered_command =
Some((request_id, TarantoolRequest::Command(command_packet)));
if let Some(timeout_time_ms) = self.timeout_time_ms {
let delay_key = self.timeout_queue.insert_at(
request_id,
Instant::now() + Duration::from_millis(timeout_time_ms),
);
self.timeout_id_to_key.insert(request_id, delay_key);
}
self.try_send_buffered_command(sink).await
}
None => {
self.is_command_receiver_closed = true;
Err(io::Error::new(
io::ErrorKind::InvalidInput,
"inbound commands queue is over",
))
}
}
}
async fn process_tarantool_response(
&mut self,
response: Option<io::Result<(RequestId, io::Result<TarantoolResponse>)>>,
) -> io::Result<()> {
debug!("receive command! {:?} ", response);
match response {
Some(Ok((request_id, Ok(command_packet)))) => {
debug!("receive command! {} {:?} ", request_id, command_packet);
if self.timeout_time_ms.is_some() {
if let Some(delay_key) = self.timeout_id_to_key.remove(&request_id) {
self.timeout_queue.remove(&delay_key);
}
}
if let Some(callback) = self.awaiting_callbacks.remove(&request_id) {
let _send_res = callback.send(Ok(command_packet));
}
Ok(())
},
Some(Ok((request_id, Err(e)))) => {
debug!("receive command! {} {:?} ", request_id, e);
if self.timeout_time_ms.is_some() {
if let Some(delay_key) = self.timeout_id_to_key.remove(&request_id) {
self.timeout_queue.remove(&delay_key);
}
}
if let Some(callback) = self.awaiting_callbacks.remove(&request_id) {
let _send_res = callback.send(Err(e));
}
Ok(())
},
None => Err(io::Error::new(
io::ErrorKind::ConnectionAborted,
"return none from stream!",
)),
_ => Ok(()),
}
}
fn increment_command_counter(&mut self) -> RequestId {
self.command_counter += 1;
self.command_counter
}
fn clean_command_counter(&mut self) {
self.command_counter = 3;
}
async fn set_status(&mut self, client_status: ClientStatus) {
self.send_notify(&client_status).await;
*(self.status.write().unwrap()) = client_status;
}
pub async fn run(&mut self) {
self.set_status(ClientStatus::Connecting).await;
loop {
match self.connect_and_process_commands().await {
Ok(()) => {
return;
}
Err(e) => {
self.set_status(ClientStatus::Disconnected(e.to_string()))
.await;
self.send_error_to_all(e.to_string());
sleep_until(Instant::now() + Duration::from_millis(self.config.reconnect_time_ms)).await;
}
}
if self.is_command_receiver_closed {
self.set_status(ClientStatus::Closed).await;
return;
}
}
}
async fn connect_and_process_commands(&mut self) -> io::Result<()> {
let tcp_stream = TcpStream::connect(self.config.addr.clone()).await?;
let mut framed_io = self.auth(tcp_stream).await?;
self.set_status(ClientStatus::Connected).await;
loop {
select! {
tarantool_response = framed_io.next().fuse() => {
self.process_tarantool_response(tarantool_response).await?
}
command = self.command_receiver.next() => {
self.process_command(command, &mut framed_io).await?
}
}
}
}
async fn auth(&mut self, tcp_stream: TcpStream) -> io::Result<TarantoolFramed> {
let codec: TarantoolCodec = Default::default();
let mut framed_io = codec.framed(tcp_stream);
let _first_response = framed_io.next().await;
framed_io
.send((
2,
TarantoolRequest::Auth(AuthPacket {
login: self.config.login.clone(),
password: self.config.password.clone(),
}),
))
.await?;
let auth_response = framed_io.next().await;
match auth_response {
Some(Ok((_, Err(e)))) => Err(io::Error::new(
io::ErrorKind::PermissionDenied,
e.to_string(),
)),
_ => {
self.clean_command_counter();
Ok(framed_io)
}
}
}
}