1use std::net::SocketAddr;
5use std::sync::Arc;
6use std::sync::atomic::{AtomicI32, Ordering};
7use std::time::Duration;
8
9use bytes::{BufMut, Bytes, BytesMut};
10use dashmap::DashMap;
11use tokio::io::{AsyncRead, AsyncWrite};
12use tokio::net::TcpStream;
13use tokio::sync::{mpsc, oneshot};
14use tokio::task::JoinHandle;
15use tokio_util::sync::CancellationToken;
16
17use crate::error::ClientError;
18use crate::request::ProtocolRequest;
19use crate::version::ApiVersionTable;
20
21pub trait ClientDuplex: AsyncRead + AsyncWrite + Send + Unpin {}
25impl<T: AsyncRead + AsyncWrite + Send + Unpin + ?Sized> ClientDuplex for T {}
26
27type Pending = Arc<DashMap<i32, oneshot::Sender<Result<Bytes, ClientError>>>>;
28
29const API_VERSIONS_KEY: i16 = 18;
35
36#[derive(Debug, Clone)]
38pub struct ConnectionOptions {
39 pub client_id: String,
40 pub connect_timeout: Duration,
41 pub request_timeout: Duration,
42 pub security: Option<Box<crate::security::ClientSecurity>>,
49}
50
51impl Default for ConnectionOptions {
52 fn default() -> Self {
53 Self {
54 client_id: "crabka".into(),
55 connect_timeout: Duration::from_secs(30),
56 request_timeout: Duration::from_secs(30),
57 security: None,
58 }
59 }
60}
61
62#[derive(Clone)]
64pub struct Connection {
65 inner: Arc<ConnectionInner>,
66}
67
68struct ConnectionInner {
69 versions: ApiVersionTable,
70 options: ConnectionOptions,
71 next_corr_id: AtomicI32,
72 pending: Pending,
73 writer_tx: mpsc::Sender<DispatchItem>,
74 shutdown: CancellationToken,
75 _reader: JoinHandle<()>,
76 _writer: JoinHandle<()>,
77}
78
79struct DispatchItem {
80 bytes: Bytes,
81}
82
83impl Connection {
84 pub async fn connect(
86 addr: SocketAddr,
87 options: ConnectionOptions,
88 ) -> Result<Self, ClientError> {
89 let stream = tokio::time::timeout(options.connect_timeout, TcpStream::connect(addr))
90 .await
91 .map_err(|_| ClientError::Timeout(options.connect_timeout))?
92 .map_err(|source| ClientError::Connect { addr, source })?;
93
94 stream.set_nodelay(true).ok();
95
96 Self::from_stream(Box::new(stream), options).await
97 }
98
99 pub async fn connect_with_options(
110 addr: SocketAddr,
111 options: ConnectionOptions,
112 ) -> Result<Self, ClientError> {
113 match options.security.clone() {
114 Some(sec) => Self::connect_secured(addr, options, sec.as_ref()).await,
115 None => Self::connect(addr, options).await,
116 }
117 }
118
119 pub async fn connect_secured(
129 addr: SocketAddr,
130 options: ConnectionOptions,
131 security: &crate::security::ClientSecurity,
132 ) -> Result<Self, ClientError> {
133 let tcp = tokio::time::timeout(options.connect_timeout, TcpStream::connect(addr))
134 .await
135 .map_err(|_| ClientError::Timeout(options.connect_timeout))?
136 .map_err(|source| ClientError::Connect { addr, source })?;
137 tcp.set_nodelay(true).ok();
138
139 let mut stream: Box<dyn ClientDuplex> = if security.protocol.requires_tls() {
141 let tls = security.tls.as_ref().ok_or_else(|| {
142 ClientError::Io(std::io::Error::other("TLS protocol without tls config"))
143 })?;
144 let connector = tls
145 .connector()
146 .map_err(|e| ClientError::Io(std::io::Error::other(e)))?;
147 let sni =
148 tokio_rustls::rustls::pki_types::ServerName::try_from(tls.server_name.clone())
149 .map_err(|e| {
150 ClientError::Io(std::io::Error::other(format!("invalid SNI: {e}")))
151 })?;
152 let s = connector
153 .connect(sni, tcp)
154 .await
155 .map_err(|e| ClientError::Io(std::io::Error::other(e.to_string())))?;
156 Box::new(s)
157 } else {
158 Box::new(tcp)
159 };
160
161 if security.protocol.requires_sasl() {
163 let creds = security.sasl.as_ref().ok_or_else(|| {
164 ClientError::Io(std::io::Error::other("SASL protocol without credentials"))
165 })?;
166 let target = addr.ip().to_string();
171 let server_name = security.sasl_handshake_host(Some(target.as_str()));
172 crate::sasl::outbound_sasl(&mut *stream, creds, server_name)
173 .await
174 .map_err(|e| ClientError::Io(std::io::Error::other(e.to_string())))?;
175 }
176
177 Self::from_stream(stream, options).await
178 }
179
180 pub async fn from_stream(
189 stream: Box<dyn ClientDuplex>,
190 options: ConnectionOptions,
191 ) -> Result<Self, ClientError> {
192 let (writer_tx, writer_rx) = mpsc::channel::<DispatchItem>(64);
193 let shutdown = CancellationToken::new();
194 let pending: Pending = Arc::new(DashMap::new());
195
196 let (reader_handle, writer_handle) =
197 spawn_io_tasks(stream, writer_rx, shutdown.clone(), Arc::clone(&pending));
198
199 let mut conn = Self {
200 inner: Arc::new(ConnectionInner {
201 versions: ApiVersionTable::default(),
202 options: options.clone(),
203 next_corr_id: AtomicI32::new(0),
204 pending,
205 writer_tx,
206 shutdown,
207 _reader: reader_handle,
208 _writer: writer_handle,
209 }),
210 };
211
212 let versions = fetch_api_versions(&conn).await?;
213 let inner = Arc::get_mut(&mut conn.inner).expect("unique handle at connect-time");
214 inner.versions = versions;
215
216 Ok(conn)
217 }
218
219 pub async fn send<R: ProtocolRequest>(&self, req: R) -> Result<R::Response, ClientError> {
231 let version = self.inner.versions.negotiate::<R>()?;
233
234 let corr_id = self.inner.next_corr_id.fetch_add(1, Ordering::Relaxed);
236
237 let body_flexible = version >= R::FLEXIBLE_MIN;
243 let mut frame = build_request_header(
244 R::API_KEY,
245 version,
246 corr_id,
247 &self.inner.options.client_id,
248 body_flexible,
249 );
250 req.encode(&mut frame, version)?;
251
252 let (tx, rx) = oneshot::channel::<Result<Bytes, ClientError>>();
254 self.inner.pending.insert(corr_id, tx);
255
256 self.inner
258 .writer_tx
259 .send(DispatchItem {
260 bytes: frame.freeze(),
261 })
262 .await
263 .map_err(|_| ClientError::Disconnected)?;
264
265 let body_bytes = match tokio::time::timeout(self.inner.options.request_timeout, rx).await {
267 Ok(Ok(Ok(b))) => b,
268 Ok(Ok(Err(e))) => return Err(e),
269 Ok(Err(_recv_closed)) => return Err(ClientError::Disconnected),
270 Err(_timeout) => {
271 self.inner.pending.remove(&corr_id);
273 return Err(ClientError::Timeout(self.inner.options.request_timeout));
274 }
275 };
276
277 let mut cursor: &[u8] = &body_bytes;
290 let uses_flexible_resp_header = body_flexible && R::API_KEY != API_VERSIONS_KEY;
291 if uses_flexible_resp_header && !cursor.is_empty() {
292 cursor = &cursor[1..];
294 }
295
296 let resp = <R::Response as crabka_protocol::Decode>::decode(&mut cursor, version)?;
297 Ok(resp)
298 }
299
300 pub async fn raw_request(
321 &self,
322 api_key: i16,
323 api_version: i16,
324 body: Bytes,
325 ) -> Result<Bytes, ClientError> {
326 let corr_id = self.inner.next_corr_id.fetch_add(1, Ordering::Relaxed);
327
328 let mut frame = build_request_header(
331 api_key,
332 api_version,
333 corr_id,
334 &self.inner.options.client_id,
335 true,
336 );
337 frame.put_slice(&body);
338
339 let (tx, rx) = oneshot::channel::<Result<Bytes, ClientError>>();
340 self.inner.pending.insert(corr_id, tx);
341
342 self.inner
343 .writer_tx
344 .send(DispatchItem {
345 bytes: frame.freeze(),
346 })
347 .await
348 .map_err(|_| ClientError::Disconnected)?;
349
350 let body_bytes = match tokio::time::timeout(self.inner.options.request_timeout, rx).await {
351 Ok(Ok(Ok(b))) => b,
352 Ok(Ok(Err(e))) => return Err(e),
353 Ok(Err(_recv_closed)) => return Err(ClientError::Disconnected),
354 Err(_timeout) => {
355 self.inner.pending.remove(&corr_id);
356 return Err(ClientError::Timeout(self.inner.options.request_timeout));
357 }
358 };
359
360 let slice: &[u8] = &body_bytes;
363 let out = if slice.is_empty() {
364 Bytes::new()
365 } else {
366 body_bytes.slice(1..)
367 };
368 Ok(out)
369 }
370
371 #[must_use]
373 pub fn versions(&self) -> &ApiVersionTable {
374 &self.inner.versions
375 }
376
377 pub fn close(self) {
379 self.inner.shutdown.cancel();
380 }
382}
383
384fn spawn_io_tasks(
391 stream: Box<dyn ClientDuplex>,
392 mut writer_rx: mpsc::Receiver<DispatchItem>,
393 shutdown: CancellationToken,
394 pending: Pending,
395) -> (JoinHandle<()>, JoinHandle<()>) {
396 use futures_util::{SinkExt, StreamExt};
397
398 let mut framed = crate::transport::frame_generic(stream);
399 let pending_for_drain = Arc::clone(&pending);
400
401 let combined = tokio::spawn(async move {
402 loop {
403 tokio::select! {
404 () = shutdown.cancelled() => break,
405 Some(item) = writer_rx.recv() => {
406 if framed.send(item.bytes).await.is_err() {
407 break;
408 }
409 }
410 maybe_frame = framed.next() => {
411 let Some(frame) = maybe_frame else { break; };
412 let Ok(frame) = frame else { break; };
413 if frame.len() < 4 { continue; }
414 let corr_id = i32::from_be_bytes([frame[0], frame[1], frame[2], frame[3]]);
415 if let Some((_, tx)) = pending.remove(&corr_id) {
416 let body = Bytes::copy_from_slice(&frame[4..]);
417 let _ = tx.send(Ok(body));
418 }
419 }
420 }
421 }
422 let keys: Vec<i32> = pending_for_drain.iter().map(|e| *e.key()).collect();
424 for k in keys {
425 if let Some((_, tx)) = pending_for_drain.remove(&k) {
426 let _ = tx.send(Err(ClientError::Disconnected));
427 }
428 }
429 });
430
431 let noop = tokio::spawn(async {});
432 (combined, noop)
433}
434
435fn build_request_header(
453 api_key: i16,
454 version: i16,
455 corr_id: i32,
456 client_id: &str,
457 with_tagged_fields: bool,
458) -> BytesMut {
459 let mut buf = BytesMut::with_capacity(32);
460 buf.put_i16(api_key);
461 buf.put_i16(version);
462 buf.put_i32(corr_id);
463 let n = i16::try_from(client_id.len()).expect("client_id fits in i16");
464 buf.put_i16(n);
465 buf.put_slice(client_id.as_bytes());
466 if with_tagged_fields {
467 buf.put_u8(0); }
469 buf
470}
471
472async fn fetch_api_versions(conn: &Connection) -> Result<ApiVersionTable, ClientError> {
478 use crabka_protocol::Encode;
479 use crabka_protocol::owned::api_versions_request::ApiVersionsRequest;
480 use crabka_protocol::owned::api_versions_response::ApiVersionsResponse;
481
482 let req = ApiVersionsRequest::default();
483 let corr_id = conn.inner.next_corr_id.fetch_add(1, Ordering::Relaxed);
484
485 let mut frame = build_request_header(
487 ApiVersionsRequest::API_KEY,
488 0,
489 corr_id,
490 &conn.inner.options.client_id,
491 false,
492 );
493 req.encode(&mut frame, 0)?;
494
495 let (tx, rx) = oneshot::channel::<Result<Bytes, ClientError>>();
496 conn.inner.pending.insert(corr_id, tx);
497 conn.inner
498 .writer_tx
499 .send(DispatchItem {
500 bytes: frame.freeze(),
501 })
502 .await
503 .map_err(|_| ClientError::Disconnected)?;
504
505 let body_bytes = tokio::time::timeout(conn.inner.options.connect_timeout, rx)
506 .await
507 .map_err(|_| ClientError::Timeout(conn.inner.options.connect_timeout))?
508 .map_err(|_| ClientError::Disconnected)??;
509
510 let mut cursor: &[u8] = &body_bytes;
514 let resp = <ApiVersionsResponse as crabka_protocol::Decode>::decode(&mut cursor, 0)?;
515 if resp.error_code != 0 {
516 return Err(ClientError::Server {
517 error_code: resp.error_code,
518 });
519 }
520
521 let entries = resp
522 .api_keys
523 .iter()
524 .map(|k| (k.api_key, k.min_version, k.max_version));
525 Ok(ApiVersionTable::from_entries(entries))
526}
527
528#[cfg(test)]
529mod secured_tests {
530 use super::*;
531 use crate::security::{ClientSecurity, SaslCredentials};
532 use crabka_security::ListenerProtocol;
533
534 #[tokio::test]
538 async fn connect_secured_runs_sasl_then_api_versions() {
539 use crabka_protocol::Encode;
540 use crabka_protocol::owned::api_versions_response::ApiVersionsResponse;
541 use crabka_protocol::owned::sasl_authenticate_response::SaslAuthenticateResponse;
542 use crabka_protocol::owned::sasl_handshake_response::SaslHandshakeResponse;
543 use tokio::io::{AsyncReadExt, AsyncWriteExt};
544
545 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
546 let addr = listener.local_addr().unwrap();
547 let server = tokio::spawn(async move {
548 let (mut s, _) = listener.accept().await.unwrap();
549 let replies: [(BytesMut, bool); 3] = [
551 {
552 let mut b = BytesMut::new();
553 SaslHandshakeResponse {
554 error_code: 0,
555 ..Default::default()
556 }
557 .encode(&mut b, 1)
558 .unwrap();
559 (b, false)
560 },
561 {
562 let mut b = BytesMut::new();
563 SaslAuthenticateResponse {
564 error_code: 0,
565 ..Default::default()
566 }
567 .encode(&mut b, 2)
568 .unwrap();
569 (b, true)
570 },
571 {
572 let mut b = BytesMut::new();
573 ApiVersionsResponse::default().encode(&mut b, 0).unwrap();
574 (b, false)
576 },
577 ];
578 for (body, flex_header) in replies {
579 let req_len = s.read_u32().await.unwrap();
580 let mut req = vec![0u8; req_len as usize];
581 s.read_exact(&mut req).await.unwrap();
582 let corr = i32::from_be_bytes([req[4], req[5], req[6], req[7]]);
583 let mut frame = BytesMut::new();
584 frame.put_i32(corr);
585 if flex_header {
586 frame.put_u8(0);
587 }
588 frame.put_slice(&body);
589 s.write_u32(u32::try_from(frame.len()).unwrap())
590 .await
591 .unwrap();
592 s.write_all(&frame).await.unwrap();
593 s.flush().await.unwrap();
594 }
595 });
596 let security = ClientSecurity {
597 protocol: ListenerProtocol::SaslPlaintext,
598 tls: None,
599 sasl: Some(SaslCredentials::Plain {
600 username: "u".into(),
601 password: "p".into(),
602 }),
603 sasl_host: None,
604 };
605 let conn = Connection::connect_secured(addr, ConnectionOptions::default(), &security)
606 .await
607 .expect("secured connect completes");
608 conn.close();
609 server.await.unwrap();
610 }
611}