use std::collections::HashMap;
use std::io;
use rustfs_kafka::error::{Error, ProtocolError, Result};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
use tracing::debug;
pub struct AsyncConnection {
stream: TcpStream,
host: String,
}
impl AsyncConnection {
pub async fn connect(host: &str) -> Result<Self> {
debug!("Connecting to {}", host);
let stream = TcpStream::connect(host).await.map_err(|e| {
Error::Connection(rustfs_kafka::error::ConnectionError::Io(io::Error::new(
e.kind(),
format!("connect to {host}: {e}"),
)))
})?;
debug!("Connected to {}", host);
Ok(Self {
stream,
host: host.to_owned(),
})
}
#[must_use]
pub fn host(&self) -> &str {
&self.host
}
pub async fn send(&mut self, data: &[u8]) -> Result<()> {
self.stream
.write_all(data)
.await
.map_err(|e| Error::Connection(rustfs_kafka::error::ConnectionError::Io(e)))?;
self.stream
.flush()
.await
.map_err(|e| Error::Connection(rustfs_kafka::error::ConnectionError::Io(e)))?;
Ok(())
}
pub async fn read_exact(&mut self, n: u64) -> Result<bytes::Bytes> {
let n = usize::try_from(n).map_err(|_| Error::Protocol(ProtocolError::Codec))?;
let mut buf = bytes::BytesMut::with_capacity(n);
buf.resize(n, 0);
self.stream
.read_exact(&mut buf)
.await
.map_err(|e| Error::Connection(rustfs_kafka::error::ConnectionError::Io(e)))?;
Ok(buf.freeze())
}
pub async fn request_response(&mut self, request: &[u8]) -> Result<bytes::Bytes> {
self.send(request).await?;
let mut size_buf = [0u8; 4];
self.stream
.read_exact(&mut size_buf)
.await
.map_err(|e| Error::Connection(rustfs_kafka::error::ConnectionError::Io(e)))?;
let size = i32::from_be_bytes(size_buf);
if size < 0 {
return Err(Error::Protocol(ProtocolError::Codec));
}
self.read_exact(size as u64).await
}
}
pub struct AsyncConnectionPool {
connections: HashMap<String, AsyncConnection>,
}
impl AsyncConnectionPool {
pub fn new() -> Self {
Self {
connections: HashMap::new(),
}
}
pub async fn get(&mut self, host: &str) -> Result<&mut AsyncConnection> {
if !self.connections.contains_key(host) {
let conn = AsyncConnection::connect(host).await?;
self.connections.insert(host.to_owned(), conn);
}
Ok(self.connections.get_mut(host).unwrap())
}
#[must_use]
pub fn hosts(&self) -> Vec<&str> {
self.connections.keys().map(String::as_str).collect()
}
}
impl Default for AsyncConnectionPool {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use rustfs_kafka::error::ConnectionError;
use super::*;
#[test]
fn pool_new_creates_empty_pool() {
let pool = AsyncConnectionPool::new();
assert!(pool.hosts().is_empty());
}
#[test]
fn pool_default_matches_new() {
let pool = AsyncConnectionPool::default();
assert!(pool.hosts().is_empty());
}
#[tokio::test]
async fn connect_unreachable_host_returns_io_error() {
let result = AsyncConnection::connect("127.0.0.1:1").await;
assert!(matches!(
result,
Err(Error::Connection(ConnectionError::Io(_)))
));
}
#[tokio::test]
async fn pool_get_unreachable_host_propagates_error() {
let mut pool = AsyncConnectionPool::new();
let result = pool.get("127.0.0.1:1").await;
assert!(result.is_err());
assert!(pool.hosts().is_empty());
}
}