use core::cell::RefCell;
use std::cell::Cell;
use std::io::{self, Cursor, Read, Write};
use std::net::SocketAddr;
use std::rc::{Rc, Weak};
use std::time::Duration;
use crate::coio::CoIOStream;
use crate::error::Error;
use crate::fiber::{is_cancelled, set_cancellable, sleep, time, Cond, Fiber};
use crate::net_box::stream::ConnStream;
use crate::tuple::Decode;
use crate::unwrap_or;
use super::options::{ConnOptions, ConnTriggers, Options};
use super::promise::Promise;
use super::protocol::{self, Header, Request};
use super::recv_queue::RecvQueue;
use super::schema::ConnSchema;
use super::send_queue::{self, SendQueue};
use super::Conn;
#[derive(Debug, Copy, Clone)]
enum ConnState {
Init,
Connecting,
Auth,
Active,
Error,
ErrorReconnect,
Closed,
}
pub struct ConnInner {
addrs: Vec<SocketAddr>,
options: ConnOptions,
state: Cell<ConnState>,
state_change_cond: Cond,
schema: Rc<ConnSchema>,
schema_version: Cell<Option<u32>>,
stream: RefCell<Option<ConnStream>>,
send_queue: SendQueue,
recv_queue: RecvQueue,
send_fiber: RefCell<Fiber<'static, Weak<ConnInner>>>,
recv_fiber: RefCell<Fiber<'static, Weak<ConnInner>>>,
triggers: RefCell<Option<Rc<dyn ConnTriggers>>>,
error: RefCell<Option<io::Error>>,
}
impl ConnInner {
pub fn new(
addrs: Vec<SocketAddr>,
options: ConnOptions,
triggers: Option<Rc<dyn ConnTriggers>>,
) -> Rc<Self> {
let mut recv_fiber = Fiber::new("_recv_worker", &mut recv_worker);
recv_fiber.set_joinable(true);
let mut send_fiber = Fiber::new("_send_worker", &mut send_worker);
send_fiber.set_joinable(true);
let conn_inner = Rc::new(ConnInner {
state: Cell::new(ConnState::Init),
state_change_cond: Cond::new(),
schema: ConnSchema::acquire(&addrs),
schema_version: Cell::new(None),
stream: RefCell::new(None),
send_queue: SendQueue::new(
options.send_buffer_size,
options.send_buffer_limit,
options.send_buffer_flush_interval,
),
recv_queue: RecvQueue::new(options.recv_buffer_size),
send_fiber: RefCell::new(send_fiber),
recv_fiber: RefCell::new(recv_fiber),
triggers: RefCell::new(triggers),
error: RefCell::new(None),
addrs,
options,
});
conn_inner
.send_fiber
.borrow_mut()
.start(Rc::downgrade(&conn_inner));
conn_inner
.recv_fiber
.borrow_mut()
.start(Rc::downgrade(&conn_inner));
conn_inner
}
pub fn is_connected(&self) -> bool {
matches!(self.state.get(), ConnState::Active)
}
pub fn wait_connected(self: &Rc<Self>, timeout: Option<Duration>) -> Result<bool, Error> {
let begin_ts = time();
loop {
let state = self.state.get();
match state {
ConnState::Init => {
self.init()?;
}
ConnState::Active => return Ok(true),
ConnState::Closed => return Ok(false),
_ => {
let timeout = match timeout {
None => None,
Some(timeout) => {
timeout.checked_sub(Duration::from_secs_f64(time() - begin_ts))
}
};
if !self.wait_state_changed(timeout) {
return Err(io::Error::from(io::ErrorKind::TimedOut).into());
}
}
};
}
}
pub fn request<Fp, Fc, R>(
self: &Rc<Self>,
request_producer: Fp,
response_consumer: Fc,
options: &Options,
) -> Result<R, Error>
where
Fp: FnOnce(&mut Cursor<Vec<u8>>, u64) -> Result<(), Error>,
Fc: FnOnce(&mut Cursor<Vec<u8>>, &Header) -> Result<R, Error>,
{
loop {
let state = self.state.get();
match state {
ConnState::Init => {
self.init()?;
}
ConnState::Active => {
return match self.send_queue.send(request_producer) {
Ok(sync) => {
self.recv_queue
.recv(sync, response_consumer, options)
.map(|response| {
self.schema_version
.set(Some(response.header.schema_version));
response.payload
})
}
Err(err) => Err(self.handle_error(err).err().unwrap()),
};
}
ConnState::Error => self.disconnect(),
ConnState::ErrorReconnect => self.reconnect_or_fail()?,
ConnState::Closed => {
return Err(io::Error::from(io::ErrorKind::NotConnected).into())
}
_ => {
self.wait_state_changed(None);
}
};
}
}
pub(crate) fn request_async<I, O>(self: &Rc<Self>, request: I) -> crate::Result<Promise<O>>
where
I: Request,
O: for<'de> Decode<'de> + 'static,
{
loop {
match self.state.get() {
ConnState::Init => {
self.init()?;
}
ConnState::Active => {
let sync = self
.send_queue
.send(protocol::request_producer(request))
.map_err(|err| self.handle_error(err).err().unwrap())?;
let promise = Promise::new(Rc::downgrade(self));
self.recv_queue.add_consumer(sync, promise.downgrade());
return Ok(promise);
}
ConnState::Error => self.disconnect(),
ConnState::ErrorReconnect => self.reconnect_or_fail()?,
ConnState::Closed => {
return Err(io::Error::from(io::ErrorKind::NotConnected).into())
}
_ => {
self.wait_state_changed(None);
}
}
}
}
pub fn lookup_space(self: &Rc<Self>, name: &str) -> Result<Option<u32>, Error> {
self.refresh_schema()?;
Ok(self.schema.lookup_space(name))
}
pub fn lookup_index(self: &Rc<Self>, name: &str, space_id: u32) -> Result<Option<u32>, Error> {
self.refresh_schema()?;
Ok(self.schema.lookup_index(name, space_id))
}
pub fn close(self: &Rc<Self>) {
let state = self.state.get();
if matches!(state, ConnState::Connecting) || matches!(state, ConnState::Auth) {
let _ = self.wait_connected(None);
}
if !matches!(self.state.get(), ConnState::Closed) {
self.disconnect();
let mut send_fiber = self.send_fiber.borrow_mut();
send_fiber.cancel();
send_fiber.join();
let mut recv_fiber = self.recv_fiber.borrow_mut();
recv_fiber.cancel();
recv_fiber.join();
}
}
fn init(self: &Rc<Self>) -> Result<(), Error> {
match self.connect() {
Ok(_) => (),
Err(err) => {
return self.handle_error(err);
}
};
Ok(())
}
fn connect(self: &Rc<Self>) -> Result<(), Error> {
self.update_state(ConnState::Connecting);
let connect_timeout = self.options.connect_timeout;
let mut stream = if connect_timeout.subsec_nanos() == 0 && connect_timeout.as_secs() == 0 {
CoIOStream::connect(&*self.addrs)?
} else {
CoIOStream::connect_timeout(self.addrs.first().unwrap(), connect_timeout)?
};
let salt = protocol::decode_greeting(&mut stream)?;
if !self.options.user.is_empty() {
self.update_state(ConnState::Auth);
self.auth(&mut stream, &salt)?;
}
self.stream.replace(Some(ConnStream::new(stream)?));
self.update_state(ConnState::Active);
if let Some(triggers) = self.triggers.borrow().as_ref() {
triggers.on_connect(&Conn::downgrade(self.clone()))?;
}
Ok(())
}
fn auth(&self, stream: &mut CoIOStream, salt: &[u8]) -> Result<(), Error> {
let buf = Vec::new();
let mut cur = Cursor::new(buf);
let sync = self.send_queue.next_sync();
send_queue::write_to_buffer(&mut cur, sync, |buf, sync| {
protocol::encode_auth(
buf,
self.options.user.as_str(),
self.options.password.as_str(),
salt,
sync,
)
})?;
stream.write_all(cur.get_ref())?;
let response_len = rmp::decode::read_u32(stream)?;
{
let buffer = cur.get_mut();
buffer.clear();
buffer.reserve(response_len as usize);
stream.take(response_len as u64).read_to_end(buffer)?;
cur.set_position(0);
}
let header = protocol::decode_header(&mut cur)?;
if header.status_code != 0 {
return Err(protocol::decode_error(stream)?.into());
}
Ok(())
}
fn refresh_schema(self: &Rc<Self>) -> Result<(), Error> {
self.wait_connected(Some(self.options.connect_timeout))?;
if self.schema.refresh(self, self.schema_version.get())? {
if let Some(triggers) = self.triggers.borrow().as_ref() {
triggers.on_schema_reload(&Conn::downgrade(self.clone()));
}
}
Ok(())
}
fn update_state(&self, state: ConnState) {
self.state.set(state);
self.state_change_cond.broadcast();
}
fn wait_state_changed(&self, timeout: Option<Duration>) -> bool {
match timeout {
Some(timeout) => self.state_change_cond.wait_timeout(timeout),
None => self.state_change_cond.wait(),
}
}
fn handle_error(&self, err: Error) -> Result<(), Error> {
if matches!(self.state.get(), ConnState::Closed) {
return Ok(());
}
match err {
Error::IO(err) => {
self.error.replace(Some(err));
self.update_state(ConnState::ErrorReconnect);
Ok(())
}
err => {
self.update_state(ConnState::Error);
Err(err)
}
}
}
fn reconnect_or_fail(self: &Rc<Self>) -> Result<(), Error> {
if matches!(self.state.get(), ConnState::Closed) {
return Ok(());
}
let error = self.error.replace(None).unwrap();
let reconnect_after = self.options.reconnect_after;
if reconnect_after.as_secs() == 0 && reconnect_after.subsec_nanos() == 0 {
self.update_state(ConnState::Error);
return Err(error.into());
} else {
sleep(reconnect_after);
match self.connect() {
Ok(_) => {}
Err(err) => {
self.handle_error(err)?;
}
}
}
Ok(())
}
fn disconnect(&self) {
if matches!(self.state.get(), ConnState::Closed) {
return;
}
self.update_state(ConnState::Closed);
if let Some(stream) = self.stream.borrow().as_ref() {
if stream.is_reader_acquired() {
self.recv_fiber.borrow().wakeup();
}
}
self.recv_queue.close();
self.send_queue.close();
self.stream.replace(None);
if let Some(triggers) = self.triggers.replace(None) {
triggers.on_disconnect();
}
}
}
#[allow(clippy::redundant_allocation, clippy::boxed_local)]
fn send_worker(conn: Box<Weak<ConnInner>>) -> i32 {
set_cancellable(true);
let weak_conn = *conn;
loop {
if is_cancelled() {
return 0;
}
let conn = unwrap_or!(weak_conn.upgrade(), return 0);
match conn.state.get() {
ConnState::Active => {
let mut writer = conn.stream.borrow().as_ref().unwrap().acquire_writer();
if let Err(e) = conn.send_queue.flush_to_stream(&mut writer) {
if is_cancelled() {
return 0;
}
conn.handle_error(e.into()).unwrap();
}
}
ConnState::Closed => return 0,
_ => {
conn.wait_state_changed(None);
}
}
}
}
#[allow(clippy::redundant_allocation, clippy::boxed_local)]
fn recv_worker(conn: Box<Weak<ConnInner>>) -> i32 {
set_cancellable(true);
let weak_conn = *conn;
loop {
if is_cancelled() {
return 0;
}
let conn = unwrap_or!(weak_conn.upgrade(), return 0);
match conn.state.get() {
ConnState::Active => {
let result = {
let mut reader = conn.stream.borrow().as_ref().unwrap().acquire_reader();
conn.recv_queue.pull(&mut reader)
};
match result {
Err(e) => {
if is_cancelled() {
return 0;
}
conn.handle_error(e).unwrap();
}
Ok(is_data_pulled) => {
if !is_data_pulled && conn.is_connected() {
conn.disconnect();
}
}
}
}
ConnState::Closed => return 0,
_ => {
conn.wait_state_changed(None);
}
}
}
}