use std::net::SocketAddr;
use std::sync::Arc;
use std::sync::atomic::{AtomicI32, Ordering};
use std::time::Duration;
use bytes::{BufMut, Bytes, BytesMut};
use dashmap::DashMap;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::TcpStream;
use tokio::sync::{mpsc, oneshot};
use tokio::task::JoinHandle;
use tokio_util::sync::CancellationToken;
use crate::error::ClientError;
use crate::request::ProtocolRequest;
use crate::version::ApiVersionTable;
pub trait ClientDuplex: AsyncRead + AsyncWrite + Send + Unpin {}
impl<T: AsyncRead + AsyncWrite + Send + Unpin + ?Sized> ClientDuplex for T {}
type Pending = Arc<DashMap<i32, oneshot::Sender<Result<Bytes, ClientError>>>>;
const API_VERSIONS_KEY: i16 = 18;
#[derive(Debug, Clone)]
pub struct ConnectionOptions {
pub client_id: String,
pub connect_timeout: Duration,
pub request_timeout: Duration,
pub security: Option<Box<crate::security::ClientSecurity>>,
}
impl Default for ConnectionOptions {
fn default() -> Self {
Self {
client_id: "crabka".into(),
connect_timeout: Duration::from_secs(30),
request_timeout: Duration::from_secs(30),
security: None,
}
}
}
#[derive(Clone)]
pub struct Connection {
inner: Arc<ConnectionInner>,
}
struct ConnectionInner {
versions: ApiVersionTable,
options: ConnectionOptions,
next_corr_id: AtomicI32,
pending: Pending,
writer_tx: mpsc::Sender<DispatchItem>,
shutdown: CancellationToken,
_reader: JoinHandle<()>,
_writer: JoinHandle<()>,
}
struct DispatchItem {
bytes: Bytes,
}
impl Connection {
pub async fn connect(
addr: SocketAddr,
options: ConnectionOptions,
) -> Result<Self, ClientError> {
let stream = tokio::time::timeout(options.connect_timeout, TcpStream::connect(addr))
.await
.map_err(|_| ClientError::Timeout(options.connect_timeout))?
.map_err(|source| ClientError::Connect { addr, source })?;
stream.set_nodelay(true).ok();
Self::from_stream(Box::new(stream), options).await
}
pub async fn connect_with_options(
addr: SocketAddr,
options: ConnectionOptions,
) -> Result<Self, ClientError> {
match options.security.clone() {
Some(sec) => Self::connect_secured(addr, options, sec.as_ref()).await,
None => Self::connect(addr, options).await,
}
}
pub async fn connect_secured(
addr: SocketAddr,
options: ConnectionOptions,
security: &crate::security::ClientSecurity,
) -> Result<Self, ClientError> {
let tcp = tokio::time::timeout(options.connect_timeout, TcpStream::connect(addr))
.await
.map_err(|_| ClientError::Timeout(options.connect_timeout))?
.map_err(|source| ClientError::Connect { addr, source })?;
tcp.set_nodelay(true).ok();
let mut stream: Box<dyn ClientDuplex> = if security.protocol.requires_tls() {
let tls = security.tls.as_ref().ok_or_else(|| {
ClientError::Io(std::io::Error::other("TLS protocol without tls config"))
})?;
let connector = tls
.connector()
.map_err(|e| ClientError::Io(std::io::Error::other(e)))?;
let sni =
tokio_rustls::rustls::pki_types::ServerName::try_from(tls.server_name.clone())
.map_err(|e| {
ClientError::Io(std::io::Error::other(format!("invalid SNI: {e}")))
})?;
let s = connector
.connect(sni, tcp)
.await
.map_err(|e| ClientError::Io(std::io::Error::other(e.to_string())))?;
Box::new(s)
} else {
Box::new(tcp)
};
if security.protocol.requires_sasl() {
let creds = security.sasl.as_ref().ok_or_else(|| {
ClientError::Io(std::io::Error::other("SASL protocol without credentials"))
})?;
let target = addr.ip().to_string();
let server_name = security.sasl_handshake_host(Some(target.as_str()));
crate::sasl::outbound_sasl(&mut *stream, creds, server_name)
.await
.map_err(|e| ClientError::Io(std::io::Error::other(e.to_string())))?;
}
Self::from_stream(stream, options).await
}
pub async fn from_stream(
stream: Box<dyn ClientDuplex>,
options: ConnectionOptions,
) -> Result<Self, ClientError> {
let (writer_tx, writer_rx) = mpsc::channel::<DispatchItem>(64);
let shutdown = CancellationToken::new();
let pending: Pending = Arc::new(DashMap::new());
let (reader_handle, writer_handle) =
spawn_io_tasks(stream, writer_rx, shutdown.clone(), Arc::clone(&pending));
let mut conn = Self {
inner: Arc::new(ConnectionInner {
versions: ApiVersionTable::default(),
options: options.clone(),
next_corr_id: AtomicI32::new(0),
pending,
writer_tx,
shutdown,
_reader: reader_handle,
_writer: writer_handle,
}),
};
let versions = fetch_api_versions(&conn).await?;
let inner = Arc::get_mut(&mut conn.inner).expect("unique handle at connect-time");
inner.versions = versions;
Ok(conn)
}
pub async fn send<R: ProtocolRequest>(&self, req: R) -> Result<R::Response, ClientError> {
let version = self.inner.versions.negotiate::<R>()?;
let corr_id = self.inner.next_corr_id.fetch_add(1, Ordering::Relaxed);
let body_flexible = version >= R::FLEXIBLE_MIN;
let mut frame = build_request_header(
R::API_KEY,
version,
corr_id,
&self.inner.options.client_id,
body_flexible,
);
req.encode(&mut frame, version)?;
let (tx, rx) = oneshot::channel::<Result<Bytes, ClientError>>();
self.inner.pending.insert(corr_id, tx);
self.inner
.writer_tx
.send(DispatchItem {
bytes: frame.freeze(),
})
.await
.map_err(|_| ClientError::Disconnected)?;
let body_bytes = match tokio::time::timeout(self.inner.options.request_timeout, rx).await {
Ok(Ok(Ok(b))) => b,
Ok(Ok(Err(e))) => return Err(e),
Ok(Err(_recv_closed)) => return Err(ClientError::Disconnected),
Err(_timeout) => {
self.inner.pending.remove(&corr_id);
return Err(ClientError::Timeout(self.inner.options.request_timeout));
}
};
let mut cursor: &[u8] = &body_bytes;
let uses_flexible_resp_header = body_flexible && R::API_KEY != API_VERSIONS_KEY;
if uses_flexible_resp_header && !cursor.is_empty() {
cursor = &cursor[1..];
}
let resp = <R::Response as crabka_protocol::Decode>::decode(&mut cursor, version)?;
Ok(resp)
}
pub async fn raw_request(
&self,
api_key: i16,
api_version: i16,
body: Bytes,
) -> Result<Bytes, ClientError> {
let corr_id = self.inner.next_corr_id.fetch_add(1, Ordering::Relaxed);
let mut frame = build_request_header(
api_key,
api_version,
corr_id,
&self.inner.options.client_id,
true,
);
frame.put_slice(&body);
let (tx, rx) = oneshot::channel::<Result<Bytes, ClientError>>();
self.inner.pending.insert(corr_id, tx);
self.inner
.writer_tx
.send(DispatchItem {
bytes: frame.freeze(),
})
.await
.map_err(|_| ClientError::Disconnected)?;
let body_bytes = match tokio::time::timeout(self.inner.options.request_timeout, rx).await {
Ok(Ok(Ok(b))) => b,
Ok(Ok(Err(e))) => return Err(e),
Ok(Err(_recv_closed)) => return Err(ClientError::Disconnected),
Err(_timeout) => {
self.inner.pending.remove(&corr_id);
return Err(ClientError::Timeout(self.inner.options.request_timeout));
}
};
let slice: &[u8] = &body_bytes;
let out = if slice.is_empty() {
Bytes::new()
} else {
body_bytes.slice(1..)
};
Ok(out)
}
#[must_use]
pub fn versions(&self) -> &ApiVersionTable {
&self.inner.versions
}
pub fn close(self) {
self.inner.shutdown.cancel();
}
}
fn spawn_io_tasks(
stream: Box<dyn ClientDuplex>,
mut writer_rx: mpsc::Receiver<DispatchItem>,
shutdown: CancellationToken,
pending: Pending,
) -> (JoinHandle<()>, JoinHandle<()>) {
use futures_util::{SinkExt, StreamExt};
let mut framed = crate::transport::frame_generic(stream);
let pending_for_drain = Arc::clone(&pending);
let combined = tokio::spawn(async move {
loop {
tokio::select! {
() = shutdown.cancelled() => break,
Some(item) = writer_rx.recv() => {
if framed.send(item.bytes).await.is_err() {
break;
}
}
maybe_frame = framed.next() => {
let Some(frame) = maybe_frame else { break; };
let Ok(frame) = frame else { break; };
if frame.len() < 4 { continue; }
let corr_id = i32::from_be_bytes([frame[0], frame[1], frame[2], frame[3]]);
if let Some((_, tx)) = pending.remove(&corr_id) {
let body = Bytes::copy_from_slice(&frame[4..]);
let _ = tx.send(Ok(body));
}
}
}
}
let keys: Vec<i32> = pending_for_drain.iter().map(|e| *e.key()).collect();
for k in keys {
if let Some((_, tx)) = pending_for_drain.remove(&k) {
let _ = tx.send(Err(ClientError::Disconnected));
}
}
});
let noop = tokio::spawn(async {});
(combined, noop)
}
fn build_request_header(
api_key: i16,
version: i16,
corr_id: i32,
client_id: &str,
with_tagged_fields: bool,
) -> BytesMut {
let mut buf = BytesMut::with_capacity(32);
buf.put_i16(api_key);
buf.put_i16(version);
buf.put_i32(corr_id);
let n = i16::try_from(client_id.len()).expect("client_id fits in i16");
buf.put_i16(n);
buf.put_slice(client_id.as_bytes());
if with_tagged_fields {
buf.put_u8(0); }
buf
}
async fn fetch_api_versions(conn: &Connection) -> Result<ApiVersionTable, ClientError> {
use crabka_protocol::Encode;
use crabka_protocol::owned::api_versions_request::ApiVersionsRequest;
use crabka_protocol::owned::api_versions_response::ApiVersionsResponse;
let req = ApiVersionsRequest::default();
let corr_id = conn.inner.next_corr_id.fetch_add(1, Ordering::Relaxed);
let mut frame = build_request_header(
ApiVersionsRequest::API_KEY,
0,
corr_id,
&conn.inner.options.client_id,
false,
);
req.encode(&mut frame, 0)?;
let (tx, rx) = oneshot::channel::<Result<Bytes, ClientError>>();
conn.inner.pending.insert(corr_id, tx);
conn.inner
.writer_tx
.send(DispatchItem {
bytes: frame.freeze(),
})
.await
.map_err(|_| ClientError::Disconnected)?;
let body_bytes = tokio::time::timeout(conn.inner.options.connect_timeout, rx)
.await
.map_err(|_| ClientError::Timeout(conn.inner.options.connect_timeout))?
.map_err(|_| ClientError::Disconnected)??;
let mut cursor: &[u8] = &body_bytes;
let resp = <ApiVersionsResponse as crabka_protocol::Decode>::decode(&mut cursor, 0)?;
if resp.error_code != 0 {
return Err(ClientError::Server {
error_code: resp.error_code,
});
}
let entries = resp
.api_keys
.iter()
.map(|k| (k.api_key, k.min_version, k.max_version));
Ok(ApiVersionTable::from_entries(entries))
}
#[cfg(test)]
mod secured_tests {
use super::*;
use crate::security::{ClientSecurity, SaslCredentials};
use crabka_security::ListenerProtocol;
#[tokio::test]
async fn connect_secured_runs_sasl_then_api_versions() {
use crabka_protocol::Encode;
use crabka_protocol::owned::api_versions_response::ApiVersionsResponse;
use crabka_protocol::owned::sasl_authenticate_response::SaslAuthenticateResponse;
use crabka_protocol::owned::sasl_handshake_response::SaslHandshakeResponse;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let server = tokio::spawn(async move {
let (mut s, _) = listener.accept().await.unwrap();
let replies: [(BytesMut, bool); 3] = [
{
let mut b = BytesMut::new();
SaslHandshakeResponse {
error_code: 0,
..Default::default()
}
.encode(&mut b, 1)
.unwrap();
(b, false)
},
{
let mut b = BytesMut::new();
SaslAuthenticateResponse {
error_code: 0,
..Default::default()
}
.encode(&mut b, 2)
.unwrap();
(b, true)
},
{
let mut b = BytesMut::new();
ApiVersionsResponse::default().encode(&mut b, 0).unwrap();
(b, false)
},
];
for (body, flex_header) in replies {
let req_len = s.read_u32().await.unwrap();
let mut req = vec![0u8; req_len as usize];
s.read_exact(&mut req).await.unwrap();
let corr = i32::from_be_bytes([req[4], req[5], req[6], req[7]]);
let mut frame = BytesMut::new();
frame.put_i32(corr);
if flex_header {
frame.put_u8(0);
}
frame.put_slice(&body);
s.write_u32(u32::try_from(frame.len()).unwrap())
.await
.unwrap();
s.write_all(&frame).await.unwrap();
s.flush().await.unwrap();
}
});
let security = ClientSecurity {
protocol: ListenerProtocol::SaslPlaintext,
tls: None,
sasl: Some(SaslCredentials::Plain {
username: "u".into(),
password: "p".into(),
}),
sasl_host: None,
};
let conn = Connection::connect_secured(addr, ConnectionOptions::default(), &security)
.await
.expect("secured connect completes");
conn.close();
server.await.unwrap();
}
}