Skip to main content

crabka_client_core/
connection.rs

1//! Single-broker `Connection`: TCP socket + reader/writer tasks +
2//! correlation-ID multiplexing.
3
4use std::net::SocketAddr;
5use std::sync::Arc;
6use std::sync::atomic::{AtomicI32, Ordering};
7use std::time::Duration;
8
9use bytes::{BufMut, Bytes, BytesMut};
10use dashmap::DashMap;
11use tokio::io::{AsyncRead, AsyncWrite};
12use tokio::net::TcpStream;
13use tokio::sync::{mpsc, oneshot};
14use tokio::task::JoinHandle;
15use tokio_util::sync::CancellationToken;
16
17use crate::error::ClientError;
18use crate::request::ProtocolRequest;
19use crate::version::ApiVersionTable;
20
21/// Trait alias for the duplex stream types `Connection::from_stream`
22/// accepts (`TcpStream`, `tokio_rustls::client::TlsStream`, etc.). Boxed
23/// so callers can hand in heterogeneous stream types via one path.
24pub trait ClientDuplex: AsyncRead + AsyncWrite + Send + Unpin {}
25impl<T: AsyncRead + AsyncWrite + Send + Unpin + ?Sized> ClientDuplex for T {}
26
27type Pending = Arc<DashMap<i32, oneshot::Sender<Result<Bytes, ClientError>>>>;
28
29/// Kafka API key for `ApiVersionsRequest` / `ApiVersionsResponse`.
30///
31/// Used to apply the response-header quirk: `ApiVersionsResponse` always
32/// uses `ResponseHeader v0` (no tagged-fields byte) even when the request
33/// version is flexible (v3+).
34const API_VERSIONS_KEY: i16 = 18;
35
36/// Connect-time + per-request configuration knobs.
37#[derive(Debug, Clone)]
38pub struct ConnectionOptions {
39    pub client_id: String,
40    pub connect_timeout: Duration,
41    pub request_timeout: Duration,
42    /// Client-side TLS/SASL policy. `None` = plaintext (default).
43    ///
44    /// Boxed so `ConnectionOptions` stays small: it is cloned widely and
45    /// embedded in many connection-building futures, and `ClientSecurity`
46    /// carries several `String`/`PathBuf` fields that would otherwise
47    /// bloat every such future.
48    pub security: Option<Box<crate::security::ClientSecurity>>,
49}
50
51impl Default for ConnectionOptions {
52    fn default() -> Self {
53        Self {
54            client_id: "crabka".into(),
55            connect_timeout: Duration::from_secs(30),
56            request_timeout: Duration::from_secs(30),
57            security: None,
58        }
59    }
60}
61
62/// A connection to a single Kafka broker.
63#[derive(Clone)]
64pub struct Connection {
65    inner: Arc<ConnectionInner>,
66}
67
68struct ConnectionInner {
69    versions: ApiVersionTable,
70    options: ConnectionOptions,
71    next_corr_id: AtomicI32,
72    pending: Pending,
73    writer_tx: mpsc::Sender<DispatchItem>,
74    shutdown: CancellationToken,
75    _reader: JoinHandle<()>,
76    _writer: JoinHandle<()>,
77}
78
79struct DispatchItem {
80    bytes: Bytes,
81}
82
83impl Connection {
84    /// Connect to `addr`, negotiate API versions, return a usable `Connection`.
85    pub async fn connect(
86        addr: SocketAddr,
87        options: ConnectionOptions,
88    ) -> Result<Self, ClientError> {
89        let stream = tokio::time::timeout(options.connect_timeout, TcpStream::connect(addr))
90            .await
91            .map_err(|_| ClientError::Timeout(options.connect_timeout))?
92            .map_err(|source| ClientError::Connect { addr, source })?;
93
94        stream.set_nodelay(true).ok();
95
96        Self::from_stream(Box::new(stream), options).await
97    }
98
99    /// Connect to `addr` honouring `options.security`: a secured (TLS/SASL)
100    /// dial when a policy is set, plaintext otherwise.
101    ///
102    /// This is the single connect entry point for every metadata-client
103    /// site (pool, admin, RLMM fetch loop) so the plaintext-vs-secured
104    /// branch can't drift between them. The plaintext (`None`) path is
105    /// byte-identical to [`Self::connect`].
106    ///
107    /// # Errors
108    /// Propagates [`Self::connect`] / [`Self::connect_secured`] failures.
109    pub async fn connect_with_options(
110        addr: SocketAddr,
111        options: ConnectionOptions,
112    ) -> Result<Self, ClientError> {
113        match options.security.clone() {
114            Some(sec) => Self::connect_secured(addr, options, sec.as_ref()).await,
115            None => Self::connect(addr, options).await,
116        }
117    }
118
119    /// Connect to `addr`, applying `security` (TLS then SASL) before the
120    /// API-versions bootstrap. `Plaintext` is identical to [`Self::connect`].
121    ///
122    /// # Errors
123    ///
124    /// Returns [`ClientError::Connect`] / [`ClientError::Timeout`] on the
125    /// TCP dial, or [`ClientError::Io`] if the TLS or SASL handshake fails
126    /// or the security policy is internally inconsistent (e.g. a TLS
127    /// protocol with no TLS config).
128    pub async fn connect_secured(
129        addr: SocketAddr,
130        options: ConnectionOptions,
131        security: &crate::security::ClientSecurity,
132    ) -> Result<Self, ClientError> {
133        let tcp = tokio::time::timeout(options.connect_timeout, TcpStream::connect(addr))
134            .await
135            .map_err(|_| ClientError::Timeout(options.connect_timeout))?
136            .map_err(|source| ClientError::Connect { addr, source })?;
137        tcp.set_nodelay(true).ok();
138
139        // 1. TLS (if the protocol demands it).
140        let mut stream: Box<dyn ClientDuplex> = if security.protocol.requires_tls() {
141            let tls = security.tls.as_ref().ok_or_else(|| {
142                ClientError::Io(std::io::Error::other("TLS protocol without tls config"))
143            })?;
144            let connector = tls
145                .connector()
146                .map_err(|e| ClientError::Io(std::io::Error::other(e)))?;
147            let sni =
148                tokio_rustls::rustls::pki_types::ServerName::try_from(tls.server_name.clone())
149                    .map_err(|e| {
150                        ClientError::Io(std::io::Error::other(format!("invalid SNI: {e}")))
151                    })?;
152            let s = connector
153                .connect(sni, tcp)
154                .await
155                .map_err(|e| ClientError::Io(std::io::Error::other(e.to_string())))?;
156            Box::new(s)
157        } else {
158            Box::new(tcp)
159        };
160
161        // 2. SASL (if the protocol demands it).
162        if security.protocol.requires_sasl() {
163            let creds = security.sasl.as_ref().ok_or_else(|| {
164                ClientError::Io(std::io::Error::other("SASL protocol without credentials"))
165            })?;
166            // GSSAPI SPN host: explicit `sasl_host`, else TLS SNI, else the
167            // connection's target IP, else "localhost". The target IP is a
168            // last resort — for GSSAPI the caller should set `sasl_host` so
169            // the principal matches the broker's advertised hostname.
170            let target = addr.ip().to_string();
171            let server_name = security.sasl_handshake_host(Some(target.as_str()));
172            crate::sasl::outbound_sasl(&mut *stream, creds, server_name)
173                .await
174                .map_err(|e| ClientError::Io(std::io::Error::other(e.to_string())))?;
175        }
176
177        Self::from_stream(stream, options).await
178    }
179
180    /// Build a `Connection` over a pre-established, optionally
181    /// pre-authenticated stream. Negotiates API versions over the stream
182    /// and returns a usable `Connection`.
183    ///
184    /// Used by the broker's `InterBrokerClient` integration: TLS + SASL
185    /// handshake run before this call, so the stream is already
186    /// authenticated. From here on the connection's normal request /
187    /// response framing applies.
188    pub async fn from_stream(
189        stream: Box<dyn ClientDuplex>,
190        options: ConnectionOptions,
191    ) -> Result<Self, ClientError> {
192        let (writer_tx, writer_rx) = mpsc::channel::<DispatchItem>(64);
193        let shutdown = CancellationToken::new();
194        let pending: Pending = Arc::new(DashMap::new());
195
196        let (reader_handle, writer_handle) =
197            spawn_io_tasks(stream, writer_rx, shutdown.clone(), Arc::clone(&pending));
198
199        let mut conn = Self {
200            inner: Arc::new(ConnectionInner {
201                versions: ApiVersionTable::default(),
202                options: options.clone(),
203                next_corr_id: AtomicI32::new(0),
204                pending,
205                writer_tx,
206                shutdown,
207                _reader: reader_handle,
208                _writer: writer_handle,
209            }),
210        };
211
212        let versions = fetch_api_versions(&conn).await?;
213        let inner = Arc::get_mut(&mut conn.inner).expect("unique handle at connect-time");
214        inner.versions = versions;
215
216        Ok(conn)
217    }
218
219    /// Send a typed request and await the typed response.
220    ///
221    /// The version is negotiated from the broker-advertised table populated
222    /// during `connect`. The request and response headers are encoded and
223    /// decoded automatically.
224    ///
225    /// # Errors
226    ///
227    /// Returns `ClientError::IncompatibleVersion` if there is no mutually
228    /// supported version, `ClientError::Disconnected` if the I/O loop has
229    /// exited, or `ClientError::Timeout` if no response arrives in time.
230    pub async fn send<R: ProtocolRequest>(&self, req: R) -> Result<R::Response, ClientError> {
231        // 1. Negotiate version.
232        let version = self.inner.versions.negotiate::<R>()?;
233
234        // 2. Allocate correlation ID.
235        let corr_id = self.inner.next_corr_id.fetch_add(1, Ordering::Relaxed);
236
237        // 3. Build request header + encoded body into one frame.
238        //
239        // The header has a trailing tagged-fields byte (header v2) iff the
240        // body is flexible. The `client_id` field is always i16 NULLABLE_STRING
241        // per the upstream `RequestHeader.json` schema.
242        let body_flexible = version >= R::FLEXIBLE_MIN;
243        let mut frame = build_request_header(
244            R::API_KEY,
245            version,
246            corr_id,
247            &self.inner.options.client_id,
248            body_flexible,
249        );
250        req.encode(&mut frame, version)?;
251
252        // 4. Register the oneshot before dispatching (avoids a race).
253        let (tx, rx) = oneshot::channel::<Result<Bytes, ClientError>>();
254        self.inner.pending.insert(corr_id, tx);
255
256        // 5. Dispatch to writer.
257        self.inner
258            .writer_tx
259            .send(DispatchItem {
260                bytes: frame.freeze(),
261            })
262            .await
263            .map_err(|_| ClientError::Disconnected)?;
264
265        // 6. Await response with timeout.
266        let body_bytes = match tokio::time::timeout(self.inner.options.request_timeout, rx).await {
267            Ok(Ok(Ok(b))) => b,
268            Ok(Ok(Err(e))) => return Err(e),
269            Ok(Err(_recv_closed)) => return Err(ClientError::Disconnected),
270            Err(_timeout) => {
271                // Evict the pending entry so the reader won't try to fulfil it.
272                self.inner.pending.remove(&corr_id);
273                return Err(ClientError::Timeout(self.inner.options.request_timeout));
274            }
275        };
276
277        // 7. Decode the response.
278        //
279        // The reader has already stripped the 4-byte correlation_id prefix.
280        // What remains is: [ResponseHeader fields after corr_id] + [response body].
281        //
282        // ResponseHeader version rules:
283        //   - ApiVersionsResponse (api_key=18): always ResponseHeader v0, which
284        //     has NO fields after the correlation_id. This is a long-standing
285        //     Kafka asymmetry — even flexible ApiVersions responses use v0 header.
286        //   - All other flexible messages (version >= FLEXIBLE_MIN): ResponseHeader
287        //     v1 adds 1 byte for the tagged-fields count (0x00 when empty).
288        //   - Non-flexible messages: ResponseHeader v0 (no bytes after corr_id).
289        let mut cursor: &[u8] = &body_bytes;
290        let uses_flexible_resp_header = body_flexible && R::API_KEY != API_VERSIONS_KEY;
291        if uses_flexible_resp_header && !cursor.is_empty() {
292            // Consume the tagged-fields byte (always 0x00 in practice).
293            cursor = &cursor[1..];
294        }
295
296        let resp = <R::Response as crabka_protocol::Decode>::decode(&mut cursor, version)?;
297        Ok(resp)
298    }
299
300    /// Send a hand-framed request and await the raw response body.
301    ///
302    /// This bypasses the typed [`ProtocolRequest`] codegen path so callers
303    /// can speak Crabka-private APIs (e.g., the controller's Raft RPCs at
304    /// api keys 1000+) whose wire types live outside `crabka-protocol`.
305    ///
306    /// The header is always written as `RequestHeader v2` (flexible) with
307    /// an empty trailing tagged-fields byte. The response is assumed to
308    /// use `ResponseHeader v1` (flexible): the I/O loop strips the 4-byte
309    /// correlation id, and this method strips the leading tagged-fields
310    /// byte before returning. Callers receive the raw body bytes only.
311    ///
312    /// `body` is the encoded request body (everything after the request
313    /// header), exactly as it should appear on the wire.
314    ///
315    /// # Errors
316    ///
317    /// Returns [`ClientError::Disconnected`] if the I/O loop has exited
318    /// or [`ClientError::Timeout`] if no response arrives within the
319    /// configured request timeout.
320    pub async fn raw_request(
321        &self,
322        api_key: i16,
323        api_version: i16,
324        body: Bytes,
325    ) -> Result<Bytes, ClientError> {
326        let corr_id = self.inner.next_corr_id.fetch_add(1, Ordering::Relaxed);
327
328        // RequestHeader v2 (flexible). Crabka-private api keys are always
329        // declared flexible so the header shape is predictable.
330        let mut frame = build_request_header(
331            api_key,
332            api_version,
333            corr_id,
334            &self.inner.options.client_id,
335            true,
336        );
337        frame.put_slice(&body);
338
339        let (tx, rx) = oneshot::channel::<Result<Bytes, ClientError>>();
340        self.inner.pending.insert(corr_id, tx);
341
342        self.inner
343            .writer_tx
344            .send(DispatchItem {
345                bytes: frame.freeze(),
346            })
347            .await
348            .map_err(|_| ClientError::Disconnected)?;
349
350        let body_bytes = match tokio::time::timeout(self.inner.options.request_timeout, rx).await {
351            Ok(Ok(Ok(b))) => b,
352            Ok(Ok(Err(e))) => return Err(e),
353            Ok(Err(_recv_closed)) => return Err(ClientError::Disconnected),
354            Err(_timeout) => {
355                self.inner.pending.remove(&corr_id);
356                return Err(ClientError::Timeout(self.inner.options.request_timeout));
357            }
358        };
359
360        // ResponseHeader v1: 1-byte empty-tagged-fields marker after the
361        // already-stripped correlation id. Drop it if present.
362        let slice: &[u8] = &body_bytes;
363        let out = if slice.is_empty() {
364            Bytes::new()
365        } else {
366            body_bytes.slice(1..)
367        };
368        Ok(out)
369    }
370
371    /// Negotiated API versions known to this connection.
372    #[must_use]
373    pub fn versions(&self) -> &ApiVersionTable {
374        &self.inner.versions
375    }
376
377    /// Close the connection, cancelling all background tasks.
378    pub fn close(self) {
379        self.inner.shutdown.cancel();
380        // The Arc gets dropped when `self` does; `JoinHandle`s abort naturally.
381    }
382}
383
384/// Spawn the combined I/O task on a single `Framed` socket.
385///
386/// One task owns the entire `Framed` and `select!`s between incoming
387/// frames (from the broker) and outgoing dispatch items (from callers).
388/// A no-op second handle preserves the `(reader, writer)` API shape
389/// expected by `ConnectionInner`.
390fn spawn_io_tasks(
391    stream: Box<dyn ClientDuplex>,
392    mut writer_rx: mpsc::Receiver<DispatchItem>,
393    shutdown: CancellationToken,
394    pending: Pending,
395) -> (JoinHandle<()>, JoinHandle<()>) {
396    use futures_util::{SinkExt, StreamExt};
397
398    let mut framed = crate::transport::frame_generic(stream);
399    let pending_for_drain = Arc::clone(&pending);
400
401    let combined = tokio::spawn(async move {
402        loop {
403            tokio::select! {
404                () = shutdown.cancelled() => break,
405                Some(item) = writer_rx.recv() => {
406                    if framed.send(item.bytes).await.is_err() {
407                        break;
408                    }
409                }
410                maybe_frame = framed.next() => {
411                    let Some(frame) = maybe_frame else { break; };
412                    let Ok(frame) = frame else { break; };
413                    if frame.len() < 4 { continue; }
414                    let corr_id = i32::from_be_bytes([frame[0], frame[1], frame[2], frame[3]]);
415                    if let Some((_, tx)) = pending.remove(&corr_id) {
416                        let body = Bytes::copy_from_slice(&frame[4..]);
417                        let _ = tx.send(Ok(body));
418                    }
419                }
420            }
421        }
422        // Drain pending: every outstanding request fails with Disconnected.
423        let keys: Vec<i32> = pending_for_drain.iter().map(|e| *e.key()).collect();
424        for k in keys {
425            if let Some((_, tx)) = pending_for_drain.remove(&k) {
426                let _ = tx.send(Err(ClientError::Disconnected));
427            }
428        }
429    });
430
431    let noop = tokio::spawn(async {});
432    (combined, noop)
433}
434
435/// Build an encoded `RequestHeader` into a `BytesMut`.
436///
437/// Kafka has only two `RequestHeader` formats:
438///
439/// - **v1** (non-flexible): `api_key` + `version` + `corr_id` + i16
440///   `client_id` length + `client_id` bytes.
441/// - **v2** (flexible): same fields *plus* a trailing `tagged_fields` byte
442///   (`0x00` when empty).
443///
444/// Note that `client_id` is `NULLABLE_STRING` (i16 length) in **both**
445/// versions — the upstream `RequestHeader.json` schema marks the field as
446/// `"flexibleVersions": "none"`, so even a v2 header keeps the i16-length
447/// encoding. Using UVARINT here causes the broker to misread the length and
448/// throw `InvalidRequestException` during header parsing.
449///
450/// Pass `with_tagged_fields = true` iff the request body is flexible
451/// (`version >= R::FLEXIBLE_MIN`).
452fn build_request_header(
453    api_key: i16,
454    version: i16,
455    corr_id: i32,
456    client_id: &str,
457    with_tagged_fields: bool,
458) -> BytesMut {
459    let mut buf = BytesMut::with_capacity(32);
460    buf.put_i16(api_key);
461    buf.put_i16(version);
462    buf.put_i32(corr_id);
463    let n = i16::try_from(client_id.len()).expect("client_id fits in i16");
464    buf.put_i16(n);
465    buf.put_slice(client_id.as_bytes());
466    if with_tagged_fields {
467        buf.put_u8(0); // empty tagged fields
468    }
469    buf
470}
471
472/// Send an `ApiVersionsRequest` at version 0 and return the negotiated table.
473///
474/// This is the bootstrap step inside `connect`: no version table exists yet,
475/// so we cannot use `Connection::send`. Version 0 is guaranteed to be
476/// supported by every broker.
477async fn fetch_api_versions(conn: &Connection) -> Result<ApiVersionTable, ClientError> {
478    use crabka_protocol::Encode;
479    use crabka_protocol::owned::api_versions_request::ApiVersionsRequest;
480    use crabka_protocol::owned::api_versions_response::ApiVersionsResponse;
481
482    let req = ApiVersionsRequest::default();
483    let corr_id = conn.inner.next_corr_id.fetch_add(1, Ordering::Relaxed);
484
485    // v0 is non-flexible: header v1, no tagged-fields byte.
486    let mut frame = build_request_header(
487        ApiVersionsRequest::API_KEY,
488        0,
489        corr_id,
490        &conn.inner.options.client_id,
491        false,
492    );
493    req.encode(&mut frame, 0)?;
494
495    let (tx, rx) = oneshot::channel::<Result<Bytes, ClientError>>();
496    conn.inner.pending.insert(corr_id, tx);
497    conn.inner
498        .writer_tx
499        .send(DispatchItem {
500            bytes: frame.freeze(),
501        })
502        .await
503        .map_err(|_| ClientError::Disconnected)?;
504
505    let body_bytes = tokio::time::timeout(conn.inner.options.connect_timeout, rx)
506        .await
507        .map_err(|_| ClientError::Timeout(conn.inner.options.connect_timeout))?
508        .map_err(|_| ClientError::Disconnected)??;
509
510    // ResponseHeader v0: only correlation_id (already stripped by the reader).
511    // No tagged-fields byte — this holds for all ApiVersionsResponse versions,
512    // including flexible ones (the Kafka asymmetry documented in `send`).
513    let mut cursor: &[u8] = &body_bytes;
514    let resp = <ApiVersionsResponse as crabka_protocol::Decode>::decode(&mut cursor, 0)?;
515    if resp.error_code != 0 {
516        return Err(ClientError::Server {
517            error_code: resp.error_code,
518        });
519    }
520
521    let entries = resp
522        .api_keys
523        .iter()
524        .map(|k| (k.api_key, k.min_version, k.max_version));
525    Ok(ApiVersionTable::from_entries(entries))
526}
527
528#[cfg(test)]
529mod secured_tests {
530    use super::*;
531    use crate::security::{ClientSecurity, SaslCredentials};
532    use crabka_security::ListenerProtocol;
533
534    // A SASL_PLAINTEXT connect drives the handshake then ApiVersions.
535    // The fake broker answers SaslHandshake(0), SaslAuthenticate(0),
536    // then a minimal ApiVersionsResponse v0 so from_stream succeeds.
537    #[tokio::test]
538    async fn connect_secured_runs_sasl_then_api_versions() {
539        use crabka_protocol::Encode;
540        use crabka_protocol::owned::api_versions_response::ApiVersionsResponse;
541        use crabka_protocol::owned::sasl_authenticate_response::SaslAuthenticateResponse;
542        use crabka_protocol::owned::sasl_handshake_response::SaslHandshakeResponse;
543        use tokio::io::{AsyncReadExt, AsyncWriteExt};
544
545        let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
546        let addr = listener.local_addr().unwrap();
547        let server = tokio::spawn(async move {
548            let (mut s, _) = listener.accept().await.unwrap();
549            // (body, flexible_response_header)
550            let replies: [(BytesMut, bool); 3] = [
551                {
552                    let mut b = BytesMut::new();
553                    SaslHandshakeResponse {
554                        error_code: 0,
555                        ..Default::default()
556                    }
557                    .encode(&mut b, 1)
558                    .unwrap();
559                    (b, false)
560                },
561                {
562                    let mut b = BytesMut::new();
563                    SaslAuthenticateResponse {
564                        error_code: 0,
565                        ..Default::default()
566                    }
567                    .encode(&mut b, 2)
568                    .unwrap();
569                    (b, true)
570                },
571                {
572                    let mut b = BytesMut::new();
573                    ApiVersionsResponse::default().encode(&mut b, 0).unwrap();
574                    // ApiVersions always uses a v0 response header.
575                    (b, false)
576                },
577            ];
578            for (body, flex_header) in replies {
579                let req_len = s.read_u32().await.unwrap();
580                let mut req = vec![0u8; req_len as usize];
581                s.read_exact(&mut req).await.unwrap();
582                let corr = i32::from_be_bytes([req[4], req[5], req[6], req[7]]);
583                let mut frame = BytesMut::new();
584                frame.put_i32(corr);
585                if flex_header {
586                    frame.put_u8(0);
587                }
588                frame.put_slice(&body);
589                s.write_u32(u32::try_from(frame.len()).unwrap())
590                    .await
591                    .unwrap();
592                s.write_all(&frame).await.unwrap();
593                s.flush().await.unwrap();
594            }
595        });
596        let security = ClientSecurity {
597            protocol: ListenerProtocol::SaslPlaintext,
598            tls: None,
599            sasl: Some(SaslCredentials::Plain {
600                username: "u".into(),
601                password: "p".into(),
602            }),
603            sasl_host: None,
604        };
605        let conn = Connection::connect_secured(addr, ConnectionOptions::default(), &security)
606            .await
607            .expect("secured connect completes");
608        conn.close();
609        server.await.unwrap();
610    }
611}