use futures::Future;
use nanoservices_utils::errors::NanoServiceError;
use tokio::net::TcpStream;
pub async fn attempt_tcp_connection(address: String) -> Result<TcpStream, NanoServiceError> {
match TcpStream::connect(address).await.map_err(|e| {
NanoServiceError::new(
format!("Error connecting to server: {:?}", e),
nanoservices_utils::errors::NanoServiceErrorStatus::Unknown,
)
}) {
Ok(stream) => Ok(stream),
Err(e) => Err(e),
}
}
#[derive(Debug, PartialEq)]
pub struct ConnectionCreator {
pub address: String,
}
impl ConnectionCreator {
pub async fn attempt_connection<F, Fut, Y>(
&self,
connection_closure: F,
) -> Result<Y, NanoServiceError>
where
F: FnOnce(String) -> Fut + Copy + Send,
Fut: Future<Output = Result<Y, NanoServiceError>> + Send,
{
connection_closure(self.address.clone()).await
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::AtomicUsize;
use tokio::net::TcpListener;
static CONNECTION_ATTEMPT: AtomicUsize = AtomicUsize::new(0);
async fn check_connection(_address: String) -> Result<bool, NanoServiceError> {
let count = CONNECTION_ATTEMPT.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
if count == 3 {
return Ok(true);
}
Err(NanoServiceError::new(
"Error connecting to server".to_string(),
nanoservices_utils::errors::NanoServiceErrorStatus::Unknown,
))
}
#[tokio::test]
async fn test_connection_creator() {
let connector = ConnectionCreator {
address: "fake".to_string(),
};
let outcome = connector.attempt_connection(check_connection).await;
assert!(outcome.is_err());
let outcome = connector.attempt_connection(check_connection).await;
assert!(outcome.is_err());
let outcome = connector.attempt_connection(check_connection).await;
assert!(outcome.is_err());
let outcome = connector.attempt_connection(check_connection).await;
assert!(outcome.is_ok());
}
#[tokio::test]
async fn test_live_reconnection() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap().to_string();
let addr_ref = addr.clone();
std::mem::drop(listener);
let connector = ConnectionCreator {
address: addr.clone(),
};
assert!(connector.attempt_connection(attempt_tcp_connection).await.is_err());
let _listener = TcpListener::bind(addr_ref).await.unwrap();
assert!(connector.attempt_connection(attempt_tcp_connection).await.is_ok());
}
}