use super::tls::TlsConnector;
use super::AsClient;
use crate::error::Error;
use crate::fiber::r#async::Mutex;
use crate::network::client::ClientError;
use crate::network::protocol;
use std::rc::Rc;
use std::sync::Arc;
#[cfg(feature = "internal_test")]
use std::sync::atomic::{AtomicUsize, Ordering};
type ClientOrConnectionClosedError = Result<super::Client, Arc<Error>>;
#[derive(Debug, Clone)]
pub struct Client {
client: Rc<Mutex<Option<ClientOrConnectionClosedError>>>,
url: String,
port: u16,
protocol_config: protocol::Config,
tls_connector: Option<TlsConnector>,
#[cfg(feature = "internal_test")]
inject_error: Rc<std::cell::RefCell<Option<ClientError>>>,
#[cfg(feature = "internal_test")]
reconnect_count: Rc<AtomicUsize>,
}
impl Client {
async fn client(&self) -> Result<super::Client, ClientError> {
let mut client = self.client.lock().await;
match &*client {
Some(Ok(client)) => {
return Ok(client.clone());
}
Some(Err(e)) => {
return Err(ClientError::ConnectionClosed(e.clone()));
}
None => {}
}
#[cfg(feature = "internal_test")]
{
self.reconnect_count.fetch_add(1, Ordering::Relaxed);
}
let res = super::Client::connect_with_config_and_tls(
&self.url,
self.port,
self.protocol_config.clone(),
self.tls_connector.clone(),
)
.await;
match res {
Ok(new_client) => {
*client = Some(Ok(new_client.clone()));
return Ok(new_client);
}
Err(ClientError::ConnectionClosed(e)) => {
*client = Some(Err(e.clone()));
return Err(ClientError::ConnectionClosed(e));
}
Err(_) => unreachable!(
"Client::connect_with_config should only return `ConnectionClosed` errors"
),
}
}
pub fn reconnect(&self) {
if let Some(mut client) = self.client.try_lock() {
*client = None;
} else {
}
}
pub async fn reconnect_now(&self) -> Result<(), Error> {
self.reconnect();
self.client().await?;
Ok(())
}
pub fn new(url: String, port: u16) -> Self {
Self::with_config(url, port, Default::default())
}
pub fn with_config(url: String, port: u16, config: protocol::Config) -> Self {
Self {
client: Default::default(),
url,
port,
protocol_config: config,
tls_connector: None,
#[cfg(feature = "internal_test")]
inject_error: Default::default(),
#[cfg(feature = "internal_test")]
reconnect_count: Default::default(),
}
}
pub fn with_config_and_tls(
url: String,
port: u16,
config: protocol::Config,
tls_connector: Option<TlsConnector>,
) -> Self {
Self {
client: Default::default(),
url,
port,
protocol_config: config,
tls_connector,
#[cfg(feature = "internal_test")]
inject_error: Default::default(),
#[cfg(feature = "internal_test")]
reconnect_count: Default::default(),
}
}
#[cfg(feature = "internal_test")]
pub fn reconnect_count(&self) -> usize {
self.reconnect_count
.load(Ordering::Relaxed)
.saturating_sub(1)
}
}
#[async_trait::async_trait(?Send)]
impl AsClient for Client {
async fn send<R: protocol::api::Request>(
&self,
request: &R,
) -> Result<R::Response, ClientError> {
let client = self.client().await?;
#[cfg(not(feature = "internal_test"))]
{
client.send(request).await
}
#[cfg(feature = "internal_test")]
{
let inject_error = self.inject_error.borrow_mut().take();
if let Some(error) = inject_error {
Err(error)
} else {
client.send(request).await
}
}
}
}
#[cfg(feature = "internal_test")]
mod tests {
use super::*;
use crate::fiber;
use crate::fiber::r#async::timeout::IntoTimeout as _;
use crate::test::util::listen_port;
#[cfg(feature = "picodata")]
use crate::test::util::{get_tls_connector, tls_listen_port};
use std::time::Duration;
const _3_SEC: Duration = Duration::from_secs(3);
fn test_client() -> Client {
Client::with_config(
"localhost".into(),
listen_port(),
protocol::Config {
creds: Some(("test_user".into(), "password".into())),
auth_method: crate::auth::AuthMethod::ChapSha1,
..Default::default()
},
)
}
#[cfg(feature = "picodata")]
fn test_tls_client() -> Client {
Client::with_config_and_tls(
"127.0.0.1".into(),
tls_listen_port(),
protocol::Config {
creds: Some(("test_user".into(), "password".into())),
auth_method: crate::auth::AuthMethod::ChapSha1,
..Default::default()
},
Some(get_tls_connector()),
)
}
#[crate::test(tarantool = "crate")]
async fn connect_failure() {
let client = Client::new("localhost".into(), 0);
let err = client.ping().await.unwrap_err();
let correct_err = [
"failed to connect to address 'localhost:0': Connection refused (os error 111)",
"failed to connect to address 'localhost:0': Cannot assign requested address (os error 99)",
"failed to connect to address 'localhost:0': Can't assign requested address (os error 49)",
].contains(dbg!(&&*err.to_string()));
assert!(correct_err);
}
async fn test_ping_after_reconnect(client: Client) {
for _ in 0..2 {
client.ping().timeout(_3_SEC).await.unwrap();
}
assert_eq!(client.reconnect_count(), 0);
client.reconnect();
for _ in 0..2 {
client.ping().timeout(_3_SEC).await.unwrap();
}
assert_eq!(client.reconnect_count(), 1);
}
#[crate::test(tarantool = "crate")]
async fn ping_after_reconnect() {
let client = test_client();
test_ping_after_reconnect(client).await;
}
#[cfg(feature = "picodata")]
#[crate::test(tarantool = "crate")]
async fn tls_ping_after_reconnect() {
let client = test_tls_client();
test_ping_after_reconnect(client).await;
}
#[crate::test(tarantool = "crate")]
async fn reconnect_now_vs_later() {
let client = test_client();
client.ping().timeout(_3_SEC).await.unwrap();
assert_eq!(client.reconnect_count(), 0);
client.reconnect();
assert_eq!(client.reconnect_count(), 0);
client.ping().timeout(_3_SEC).await.unwrap();
assert_eq!(client.reconnect_count(), 1);
client.reconnect_now().await.unwrap();
assert_eq!(client.reconnect_count(), 2);
}
fn test_reconnect_on_network_error(client: Client) {
use std::io::{Error as IOError, ErrorKind};
use std::sync::Arc;
fiber::block_on(async {
let err = ClientError::ConnectionClosed(Arc::new(
IOError::from(ErrorKind::ConnectionAborted).into(),
));
*client.inject_error.borrow_mut() = Some(err);
client.ping().timeout(_3_SEC).await.unwrap_err();
client.reconnect_now().await.unwrap();
assert_eq!(client.reconnect_count(), 1);
let err = ClientError::ConnectionClosed(Arc::new(
IOError::from(ErrorKind::ConnectionAborted).into(),
));
*client.inject_error.borrow_mut() = Some(err);
client.ping().timeout(_3_SEC).await.unwrap_err();
client.reconnect_now().await.unwrap();
assert_eq!(client.reconnect_count(), 2);
});
}
#[crate::test(tarantool = "crate")]
fn reconnect_on_network_error() {
let client = test_client();
test_reconnect_on_network_error(client);
}
#[cfg(feature = "picodata")]
#[crate::test(tarantool = "crate")]
fn tls_reconnect_on_network_error() {
let client = test_tls_client();
test_reconnect_on_network_error(client);
}
#[crate::test(tarantool = "crate")]
fn old_connection_remains_for_old_request() {
let lua = crate::global_lua();
lua.exec(
"fiber = require('fiber')
_G.reconnect_test_chan = fiber.channel()",
)
.unwrap();
let client = test_client();
fiber::block_on(client.ping()).unwrap();
assert_eq!(client.reconnect_count(), 0);
let client_clone = client.clone();
let jh = fiber::defer_async(async move {
client_clone.reconnect_now().await.unwrap();
assert_eq!(client_clone.reconnect_count(), 1);
lua.exec("_G.reconnect_test_chan:put(42)").unwrap();
});
fiber::block_on(async move {
let result = client
.eval("return _G.reconnect_test_chan:get()", &())
.await
.unwrap()
.decode::<(i32,)>()
.unwrap();
assert_eq!(result, (42,));
assert_eq!(client.reconnect_count(), 1);
});
jh.join();
}
async fn test_concurrent_messages_one_fiber(client: Client) {
let mut ping_futures = vec![];
for _ in 0..10 {
ping_futures.push(client.ping());
}
for res in futures::future::join_all(ping_futures).await {
res.unwrap();
}
}
#[crate::test(tarantool = "crate")]
async fn concurrent_messages_one_fiber() {
let client = test_client();
test_concurrent_messages_one_fiber(client).await;
}
#[cfg(feature = "picodata")]
#[crate::test(tarantool = "crate")]
async fn tls_concurrent_messages_one_fiber() {
let client = test_tls_client();
test_concurrent_messages_one_fiber(client).await;
}
#[crate::test(tarantool = "crate")]
async fn try_reconnect_only_once() {
let client = Client::new("localhost".into(), 0);
client.ping().await.unwrap_err();
assert_eq!(client.reconnect_count(), 0);
client.reconnect();
for _ in 0..10 {
client.ping().await.unwrap_err();
}
assert_eq!(client.reconnect_count(), 1);
}
}