use std::collections::HashMap;
use std::fmt::Debug;
use std::hash::Hash;
use std::sync::RwLock;
use chashmap::CHashMap;
use chashmap::WriteGuard;
use log::trace;
use crate::KfSocket;
use crate::KfSocketError;
#[derive(Debug)]
pub struct SocketPool<T>
where
T: Eq + Hash,
{
clients: CHashMap<T, KfSocket>,
ids: RwLock<HashMap<T, bool>>,
}
impl<T> SocketPool<T>
where
T: Eq + PartialEq + Hash + Debug + Clone,
KfSocket: Sync,
{
#[allow(dead_code)]
pub fn new() -> Self {
Self {
clients: CHashMap::new(),
ids: RwLock::new(HashMap::new()),
}
}
pub fn insert_socket(&self, id: T, socket: KfSocket) {
trace!("inserting connection: {:#?}, returning", id);
let mut ids = self.ids.write().expect("id lock must always lock");
ids.insert(id.clone(), true);
self.clients.insert(id.clone(), socket);
}
pub fn get_socket(&self, id: &T) -> Option<WriteGuard<'_, T, KfSocket>> {
if let Some(client) = self.clients.get_mut(id) {
trace!("got existing connection: {:#?}, returning", id);
if client.is_stale() {
trace!("connection is stale, do not return");
None
} else {
Some(client)
}
} else {
trace!("no existing connection: {:#?}, returning", id);
None
}
}
}
impl<T> SocketPool<T>
where
T: Eq + PartialEq + Hash + Debug + Clone + ToString,
KfSocket: Sync,
{
pub async fn make_connection(&self, id: T) -> Result<(), KfSocketError> {
let addr = id.to_string();
self.make_connection_with_addr(id, &addr).await
}
}
impl<T> SocketPool<T>
where
T: Eq + PartialEq + Hash + Debug + Clone,
KfSocket: Sync,
{
pub async fn make_connection_with_addr(
&self,
id: T,
addr: &str,
) -> Result<(), KfSocketError>
{
trace!("creating new connection: {:#?}", addr);
let client = KfSocket::connect(addr).await?;
trace!("got connection to server: {:#?}", &id);
self.insert_socket(id.clone(), client);
trace!("finish connection to server: {:#?}", &id);
Ok(())
}
pub async fn get_or_make<'a>(
&'a self,
id: T,
addr: &'a str,
) -> Result<Option<WriteGuard<'a,T,KfSocket>>, KfSocketError>
{
if let Some(socket) = self.get_socket(&id) {
return Ok(Some(socket));
}
self.make_connection_with_addr(id.clone(), addr).await?;
Ok(self.get_socket(&id))
}
}
#[cfg(test)]
pub(crate) mod test {
use std::net::SocketAddr;
use std::time::Duration;
use futures::future::join;
use futures::stream::StreamExt;
use log::debug;
use log::error;
use flv_future_aio::net::TcpListener;
use flv_future_aio::timer::sleep;
use flv_future_aio::test_async;
use super::KfSocket;
use super::KfSocketError;
use super::SocketPool;
use crate::test_request::EchoRequest;
use kf_protocol::api::RequestMessage;
type TestPooling = SocketPool<String>;
pub(crate) async fn server_loop(
socket_addr: &SocketAddr,
id: u16,
) -> Result<(), KfSocketError> {
debug!("server: {}-{} ready to bind", socket_addr, id);
let listener = TcpListener::bind(&socket_addr).await?;
debug!(
"server: {}-{} successfully binding. waiting for incoming",
socket_addr, id
);
let mut incoming = listener.incoming();
if let Some(stream) = incoming.next().await {
debug!(
"server: {}-{} got connection from client, sending rely",
socket_addr, id
);
let stream = stream?;
let mut socket: KfSocket = stream.into();
let msg: RequestMessage<EchoRequest> = RequestMessage::new_request(EchoRequest {
msg: "Hello".to_owned(),
});
socket.get_mut_sink().send_request(&msg).await?;
debug!("server: {}-{} finish send echo", socket_addr, id);
} else {
error!("no content from client");
}
drop(incoming);
debug!(
"server: {}-{} sleeping for 100ms to give client chances",
socket_addr, id
);
debug!("server: {}-{} server loop ended", socket_addr, id);
Ok(())
}
async fn create_server(addr: String, _client_count: u16) -> Result<(), KfSocketError> {
let socket_addr = addr.parse::<SocketAddr>().expect("parse");
{
server_loop(&socket_addr, 0).await?;
}
{
server_loop(&socket_addr, 1).await?;
}
Ok(())
}
async fn client_check(
client_pool: &TestPooling,
addr: String,
id: u16,
) -> Result<(), KfSocketError> {
debug!(
"client: {}-{} client start: sleeping for 100 second to give server chances",
&addr, id
);
sleep(Duration::from_millis(10)).await;
debug!("client: {}-{} trying to connect to server", &addr, id);
client_pool.make_connection(addr.clone()).await?;
if let Some(mut client_socket) = client_pool.get_socket(&addr) {
debug!("client: {}-{} got socket from server", &addr, id);
{
let mut req_stream = client_socket.get_mut_stream().request_stream();
debug!(
"client: {}-{} waiting for echo request from server",
&addr, id
);
let next = req_stream.next().await;
if let Some(result) = next {
let req_msg: RequestMessage<EchoRequest> = result?;
debug!(
"client: {}-{} message {} from server",
&addr, id, req_msg.request.msg
);
assert_eq!(req_msg.request.msg, "Hello");
debug!(
"client: {}-{} wait for 2nd, server should terminate this point",
&addr, id
);
let next2 = req_stream.next().await;
assert!(next2.is_none(), "next2 should be none");
debug!("client: {}-{} 2nd wait finished", &addr, id);
}
}
debug!("client: {}-{} mark as stale", &addr, id);
client_socket.set_stale();
Ok(())
} else {
panic!("not able to connect: {}", addr);
}
}
async fn test_client(client_pool: &TestPooling, addr: String) -> Result<(), KfSocketError> {
client_check(client_pool, addr.clone(), 0)
.await
.expect("should finished");
debug!("client wait for 1 second for 2nd server to come up");
sleep(Duration::from_millis(1000)).await;
client_check(client_pool, addr.clone(), 1)
.await
.expect("should be finished");
Ok(())
}
#[test_async]
async fn test_pool() -> Result<(), KfSocketError> {
let count = 1;
let addr1 = "127.0.0.1:20001".to_owned();
let addr2 = "127.0.0.1:20002".to_owned();
let server_ft1 = create_server(addr1.clone(), count);
let server_ft2 = create_server(addr2.clone(), count);
let client_pool = TestPooling::new();
let client_ft1 = test_client(&client_pool, addr1);
let client_ft2 = test_client(&client_pool, addr2);
let _fr = join(join(client_ft1, client_ft2), join(server_ft1, server_ft2)).await;
Ok(())
}
}