Skip to main content

rivven_client/
client.rs

1use crate::{Error, MessageData, Request, Response, Result};
2use bytes::Bytes;
3use rivven_core::PasswordHash;
4use rivven_protocol::SyncGroupAssignments;
5use sha2::{Digest, Sha256};
6use std::time::Duration;
7use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
8use tokio::net::TcpStream;
9use tracing::{debug, info};
10
11#[cfg(feature = "tls")]
12use rivven_core::tls::{TlsClientStream, TlsConfig, TlsConnector};
13
14// Default maximum response size (100 MB) - prevents malicious server from exhausting client memory
15const DEFAULT_MAX_RESPONSE_SIZE: usize = 100 * 1024 * 1024;
16
17// Maximum request size — aligned with server `max_request_size` and
18// `rivven_protocol::MAX_MESSAGE_SIZE` (10 MiB). Requests exceeding this
19// are rejected client-side before touching the wire, preventing TCP-level
20// deadlocks when the server must drain an oversized body.
21const DEFAULT_MAX_REQUEST_SIZE: usize = rivven_protocol::MAX_MESSAGE_SIZE;
22
23/// Default connection timeout (10 seconds) — prevents hanging on
24/// unreachable hosts instead of relying on OS TCP timeout (~75-120s).
25const DEFAULT_CONNECTION_TIMEOUT: Duration = Duration::from_secs(10);
26
27/// Default request timeout (30 seconds) — prevents callers from
28/// blocking forever when a server stalls or stops responding.
29const DEFAULT_REQUEST_TIMEOUT: Duration = Duration::from_secs(30);
30
31// ============================================================================
32// Stream Wrapper
33// ============================================================================
34
35/// Wrapper for either plaintext or TLS streams
36/// Note: TLS variant is significantly larger due to TLS state, but boxing
37/// would add indirection overhead for every I/O operation
38#[allow(clippy::large_enum_variant)]
39pub(crate) enum ClientStream {
40    Plaintext(TcpStream),
41    #[cfg(feature = "tls")]
42    Tls(TlsClientStream<TcpStream>),
43}
44
45impl AsyncRead for ClientStream {
46    fn poll_read(
47        self: std::pin::Pin<&mut Self>,
48        cx: &mut std::task::Context<'_>,
49        buf: &mut tokio::io::ReadBuf<'_>,
50    ) -> std::task::Poll<std::io::Result<()>> {
51        match self.get_mut() {
52            ClientStream::Plaintext(s) => std::pin::Pin::new(s).poll_read(cx, buf),
53            #[cfg(feature = "tls")]
54            ClientStream::Tls(s) => std::pin::Pin::new(s).poll_read(cx, buf),
55        }
56    }
57}
58
59impl AsyncWrite for ClientStream {
60    fn poll_write(
61        self: std::pin::Pin<&mut Self>,
62        cx: &mut std::task::Context<'_>,
63        buf: &[u8],
64    ) -> std::task::Poll<std::io::Result<usize>> {
65        match self.get_mut() {
66            ClientStream::Plaintext(s) => std::pin::Pin::new(s).poll_write(cx, buf),
67            #[cfg(feature = "tls")]
68            ClientStream::Tls(s) => std::pin::Pin::new(s).poll_write(cx, buf),
69        }
70    }
71
72    fn poll_flush(
73        self: std::pin::Pin<&mut Self>,
74        cx: &mut std::task::Context<'_>,
75    ) -> std::task::Poll<std::io::Result<()>> {
76        match self.get_mut() {
77            ClientStream::Plaintext(s) => std::pin::Pin::new(s).poll_flush(cx),
78            #[cfg(feature = "tls")]
79            ClientStream::Tls(s) => std::pin::Pin::new(s).poll_flush(cx),
80        }
81    }
82
83    fn poll_shutdown(
84        self: std::pin::Pin<&mut Self>,
85        cx: &mut std::task::Context<'_>,
86    ) -> std::task::Poll<std::io::Result<()>> {
87        match self.get_mut() {
88            ClientStream::Plaintext(s) => std::pin::Pin::new(s).poll_shutdown(cx),
89            #[cfg(feature = "tls")]
90            ClientStream::Tls(s) => std::pin::Pin::new(s).poll_shutdown(cx),
91        }
92    }
93}
94
95// ============================================================================
96// Client
97// ============================================================================
98
99/// Rivven client for connecting to a Rivven server
100pub struct Client {
101    stream: ClientStream,
102    next_correlation_id: u32,
103    /// Per-request timeout for send_request() I/O.
104    request_timeout: Duration,
105    /// Set to true when the stream is desynchronized (e.g. correlation ID
106    /// mismatch). All subsequent requests immediately return an error,
107    /// forcing the caller to reconnect.
108    poisoned: bool,
109}
110
111impl Client {
112    /// Connect to a Rivven server (plaintext)
113    ///
114    /// Automatically sends a protocol handshake after connecting.
115    /// The handshake validates protocol version compatibility before any
116    /// application requests are sent.
117    ///
118    /// Uses a default connection timeout of 10 seconds. For a custom
119    /// timeout, use [`connect_with_timeout`](Self::connect_with_timeout).
120    pub async fn connect(addr: &str) -> Result<Self> {
121        Self::connect_with_timeout(addr, DEFAULT_CONNECTION_TIMEOUT).await
122    }
123
124    /// Connect to a Rivven server with a custom connection timeout.
125    ///
126    /// Wraps the TCP connect + handshake in `tokio::time::timeout` so
127    /// callers never block longer than `timeout` on an unreachable host.
128    pub async fn connect_with_timeout(addr: &str, timeout: Duration) -> Result<Self> {
129        info!("Connecting to Rivven server at {}", addr);
130        let stream = tokio::time::timeout(timeout, TcpStream::connect(addr))
131            .await
132            .map_err(|_| Error::TimeoutWithMessage(format!("Connection to {} timed out", addr)))?
133            .map_err(|e| Error::ConnectionError(e.to_string()))?;
134
135        // Disable Nagle's algorithm for lower latency on small request/response RPC.
136        let _ = stream.set_nodelay(true);
137
138        let mut client = Self {
139            stream: ClientStream::Plaintext(stream),
140            next_correlation_id: 0,
141            request_timeout: DEFAULT_REQUEST_TIMEOUT,
142            poisoned: false,
143        };
144
145        // auto-handshake on connect
146        client.handshake("rivven-client").await?;
147
148        Ok(client)
149    }
150
151    /// Connect to a Rivven server with TLS
152    #[cfg(feature = "tls")]
153    pub async fn connect_tls(
154        addr: &str,
155        tls_config: &TlsConfig,
156        server_name: &str,
157    ) -> Result<Self> {
158        Self::connect_tls_with_timeout(addr, tls_config, server_name, DEFAULT_CONNECTION_TIMEOUT)
159            .await
160    }
161
162    /// Connect to a Rivven server with TLS and a custom connection timeout.
163    #[cfg(feature = "tls")]
164    pub async fn connect_tls_with_timeout(
165        addr: &str,
166        tls_config: &TlsConfig,
167        server_name: &str,
168        timeout: Duration,
169    ) -> Result<Self> {
170        info!("Connecting to Rivven server at {} with TLS", addr);
171
172        // Resolve address — supports both IP:port and DNS hostname:port,
173        // unlike the old SocketAddr::parse() which rejected DNS names.
174        let tcp_stream = tokio::time::timeout(timeout, TcpStream::connect(addr))
175            .await
176            .map_err(|_| {
177                Error::TimeoutWithMessage(format!("TLS connection to {} timed out", addr))
178            })?
179            .map_err(|e| Error::ConnectionError(format!("TCP connection error: {}", e)))?;
180
181        // Disable Nagle for low-latency RPC (matches plaintext path)
182        tcp_stream
183            .set_nodelay(true)
184            .map_err(|e| Error::ConnectionError(format!("Failed to set TCP_NODELAY: {}", e)))?;
185
186        // Create TLS connector
187        let connector = TlsConnector::new(tls_config)
188            .map_err(|e| Error::ConnectionError(format!("TLS config error: {}", e)))?;
189
190        // Wrap the TCP stream in TLS
191        let tls_stream = connector
192            .connect(tcp_stream, server_name)
193            .await
194            .map_err(|e| Error::ConnectionError(format!("TLS handshake error: {}", e)))?;
195
196        info!("TLS connection established to {} ({})", addr, server_name);
197
198        let mut client = Self {
199            stream: ClientStream::Tls(tls_stream),
200            next_correlation_id: 0,
201            request_timeout: DEFAULT_REQUEST_TIMEOUT,
202            poisoned: false,
203        };
204
205        // auto-handshake on TLS connect
206        client.handshake("rivven-client").await?;
207
208        Ok(client)
209    }
210
211    /// Connect with mTLS (mutual TLS) using client certificate
212    #[cfg(feature = "tls")]
213    pub async fn connect_mtls(
214        addr: &str,
215        cert_path: impl Into<std::path::PathBuf>,
216        key_path: impl Into<std::path::PathBuf>,
217        ca_path: impl Into<std::path::PathBuf> + Clone,
218        server_name: &str,
219    ) -> Result<Self> {
220        let tls_config = TlsConfig::mtls_from_pem_files(cert_path, key_path, ca_path);
221        Self::connect_tls(addr, &tls_config, server_name).await
222    }
223
224    // ========================================================================
225    // Handshake
226    // ========================================================================
227
228    /// Send a protocol version handshake to the server.
229    ///
230    /// Validates that the server speaks a compatible protocol version.
231    /// Called automatically by `connect()` and `connect_tls()`.
232    pub async fn handshake(&mut self, client_id: &str) -> Result<()> {
233        let request = Request::Handshake {
234            protocol_version: rivven_protocol::PROTOCOL_VERSION,
235            client_id: client_id.to_string(),
236        };
237
238        let response = self.send_request(request).await?;
239
240        match response {
241            Response::HandshakeResult {
242                compatible,
243                message: _,
244                server_version,
245            } => {
246                if compatible {
247                    info!(
248                        "Handshake OK (client v{}, server v{})",
249                        rivven_protocol::PROTOCOL_VERSION,
250                        server_version
251                    );
252                    Ok(())
253                } else {
254                    Err(Error::ProtocolError(
255                        rivven_protocol::ProtocolError::VersionMismatch {
256                            expected: rivven_protocol::PROTOCOL_VERSION,
257                            actual: server_version,
258                        },
259                    ))
260                }
261            }
262            Response::Error { message } => {
263                // Server doesn't support handshake — log warning but proceed
264                // for backward compatibility with older servers
265                tracing::warn!(
266                    "Server returned error on handshake: {}, proceeding anyway",
267                    message
268                );
269                Ok(())
270            }
271            _ => {
272                // Unknown response — old server, proceed anyway
273                tracing::warn!(
274                    "Server did not return HandshakeResult, proceeding without version check"
275                );
276                Ok(())
277            }
278        }
279    }
280
281    /// Consume the client and return the underlying TCP stream.
282    ///
283    /// This is used by the producer to hand off the authenticated +
284    /// handshaked connection to its background sender task.
285    /// Only works for plaintext connections; for TLS use `into_client_stream()`.
286    pub fn into_stream(self) -> Result<TcpStream> {
287        match self.stream {
288            ClientStream::Plaintext(s) => Ok(s),
289            #[cfg(feature = "tls")]
290            ClientStream::Tls(_) => Err(Error::ConnectionError(
291                "Cannot extract TcpStream from TLS connection. Use into_client_stream() instead."
292                    .to_string(),
293            )),
294        }
295    }
296
297    /// Consume the client and return the underlying `ClientStream`.
298    ///
299    /// Works for both plaintext and TLS connections. Used by the
300    /// producer to hand off the authenticated + handshaked connection
301    /// to its background sender task.
302    pub(crate) fn into_client_stream(self) -> ClientStream {
303        self.stream
304    }
305
306    /// Set the per-request timeout for `send_request()` I/O.
307    pub fn set_request_timeout(&mut self, timeout: Duration) {
308        self.request_timeout = timeout;
309    }
310
311    /// Returns `true` when the underlying transport is TLS-encrypted.
312    pub fn is_tls(&self) -> bool {
313        match &self.stream {
314            ClientStream::Plaintext(_) => false,
315            #[cfg(feature = "tls")]
316            ClientStream::Tls(_) => true,
317        }
318    }
319
320    // ========================================================================
321    // Authentication Methods
322    // ========================================================================
323
324    /// Authenticate with simple username/password
325    ///
326    /// # Security — plaintext credentials
327    ///
328    /// This uses SASL/PLAIN which sends the password in **cleartext** on the
329    /// wire.  The client automatically sets `require_tls = true` when the
330    /// underlying connection is TLS-encrypted so the server will reject the
331    /// request if it somehow arrives over a non-TLS channel.  For untrusted
332    /// networks, prefer `authenticate_scram()` which never sends the password.
333    #[allow(deprecated)]
334    pub async fn authenticate(&mut self, username: &str, password: &str) -> Result<AuthSession> {
335        // Client-side guard: SASL/PLAIN sends the password in cleartext.
336        // Without TLS, a network observer captures credentials immediately.
337        if !self.is_tls() {
338            return Err(Error::AuthenticationFailed(
339                "SASL/PLAIN requires a TLS connection — use authenticate_scram() for plaintext channels".to_string(),
340            ));
341        }
342
343        // Tell the server whether we are on a TLS connection so it can
344        // reject the request when TLS is not active.
345        let require_tls = true; // always true — we checked above
346        let request = Request::Authenticate {
347            username: username.to_string(),
348            password: password.to_string(),
349            require_tls,
350        };
351
352        let response = self.send_request(request).await?;
353
354        match response {
355            Response::Authenticated {
356                session_id,
357                expires_in,
358            } => {
359                info!("Authenticated as '{}'", username);
360                Ok(AuthSession {
361                    session_id,
362                    expires_in,
363                })
364            }
365            Response::Error { message } => Err(Error::AuthenticationFailed(message)),
366            _ => Err(Error::InvalidResponse),
367        }
368    }
369
370    /// Authenticate using SCRAM-SHA-256 (secure challenge-response)
371    ///
372    /// SCRAM-SHA-256 (RFC 5802/7677) provides:
373    /// - Password never sent over the wire
374    /// - Mutual authentication (server proves it knows password too)
375    /// - Protection against replay attacks
376    ///
377    /// # Example
378    /// ```no_run
379    /// # use rivven_client::Client;
380    /// # async fn example() -> rivven_client::Result<()> {
381    /// let mut client = Client::connect("127.0.0.1:9092").await?;
382    /// let session = client.authenticate_scram("alice", "password123").await?;
383    /// println!("Session: {} (expires in {}s)", session.session_id, session.expires_in);
384    /// # Ok(())
385    /// # }
386    /// ```
387    pub async fn authenticate_scram(
388        &mut self,
389        username: &str,
390        password: &str,
391    ) -> Result<AuthSession> {
392        // Step 1: Generate client nonce and send client-first message
393        let client_nonce = generate_nonce();
394        let client_first_bare = format!("n={},r={}", escape_username(username), client_nonce);
395        let client_first = format!("n,,{}", client_first_bare);
396
397        debug!("SCRAM: Sending client-first");
398        let request = Request::ScramClientFirst {
399            message: Bytes::from(client_first.clone()),
400        };
401
402        let response = self.send_request(request).await?;
403
404        // Step 2: Parse server-first message
405        let server_first = match response {
406            Response::ScramServerFirst { message } => String::from_utf8(message.to_vec())
407                .map_err(|_| Error::AuthenticationFailed("Invalid server-first encoding".into()))?,
408            Response::Error { message } => return Err(Error::AuthenticationFailed(message)),
409            _ => return Err(Error::InvalidResponse),
410        };
411
412        debug!("SCRAM: Received server-first");
413
414        // Parse server-first: r=<nonce>,s=<salt>,i=<iterations>
415        let (combined_nonce, salt_b64, iterations) = parse_server_first(&server_first)?;
416
417        // Verify server nonce starts with our client nonce
418        if !combined_nonce.starts_with(&client_nonce) {
419            return Err(Error::AuthenticationFailed("Server nonce mismatch".into()));
420        }
421
422        // Decode salt
423        let salt = base64_decode(&salt_b64)
424            .map_err(|_| Error::AuthenticationFailed("Invalid salt encoding".into()))?;
425
426        // Step 3: Compute client proof
427        let salted_password = pbkdf2_sha256(password.as_bytes(), &salt, iterations);
428        let client_key = PasswordHash::hmac_sha256(&salted_password, b"Client Key");
429        let stored_key = sha256(&client_key);
430
431        let client_final_without_proof = format!("c=biws,r={}", combined_nonce);
432        let auth_message = format!(
433            "{},{},{}",
434            client_first_bare, server_first, client_final_without_proof
435        );
436
437        let client_signature = PasswordHash::hmac_sha256(&stored_key, auth_message.as_bytes());
438        let client_proof = xor_bytes(&client_key, &client_signature);
439        let client_proof_b64 = base64_encode(&client_proof);
440
441        // Step 4: Send client-final message
442        let client_final = format!("{},p={}", client_final_without_proof, client_proof_b64);
443
444        debug!("SCRAM: Sending client-final");
445        let request = Request::ScramClientFinal {
446            message: Bytes::from(client_final),
447        };
448
449        let response = self.send_request(request).await?;
450
451        // Step 5: Verify server-final and get session
452        match response {
453            Response::ScramServerFinal {
454                message,
455                session_id,
456                expires_in,
457            } => {
458                let server_final = String::from_utf8(message.to_vec()).map_err(|_| {
459                    Error::AuthenticationFailed("Invalid server-final encoding".into())
460                })?;
461
462                // Check for error response
463                if let Some(error_msg) = server_final.strip_prefix("e=") {
464                    return Err(Error::AuthenticationFailed(error_msg.to_string()));
465                }
466
467                // Verify server signature (mutual authentication)
468                if let Some(verifier_b64) = server_final.strip_prefix("v=") {
469                    let server_key = PasswordHash::hmac_sha256(&salted_password, b"Server Key");
470                    let expected_server_sig =
471                        PasswordHash::hmac_sha256(&server_key, auth_message.as_bytes());
472                    let expected_verifier = base64_encode(&expected_server_sig);
473
474                    if verifier_b64 != expected_verifier {
475                        return Err(Error::AuthenticationFailed(
476                            "Server verification failed".into(),
477                        ));
478                    }
479                }
480
481                let session_id = session_id.ok_or_else(|| {
482                    Error::AuthenticationFailed("No session ID in response".into())
483                })?;
484                let expires_in = expires_in
485                    .ok_or_else(|| Error::AuthenticationFailed("No expiry in response".into()))?;
486
487                info!("SCRAM authentication successful for '{}'", username);
488                Ok(AuthSession {
489                    session_id,
490                    expires_in,
491                })
492            }
493            Response::Error { message } => Err(Error::AuthenticationFailed(message)),
494            _ => Err(Error::InvalidResponse),
495        }
496    }
497
498    // ========================================================================
499    // Request/Response Handling
500    // ========================================================================
501
502    /// Send a request and receive a response.
503    ///
504    /// The entire I/O round-trip (write + read) is wrapped in
505    /// `tokio::time::timeout` using the client's `request_timeout`
506    /// (default 30 s). A stalled or unresponsive server will therefore
507    /// return `Error::Timeout` instead of blocking the caller forever.
508    pub(crate) async fn send_request(&mut self, request: Request) -> Result<Response> {
509        let timeout_dur = self.request_timeout;
510        match tokio::time::timeout(timeout_dur, self.send_request_inner(request)).await {
511            Ok(result) => result,
512            Err(_elapsed) => {
513                // Timeout may have cancelled mid-I/O — stream is potentially
514                // desynchronized. Poison so the next call reconnects.
515                self.poisoned = true;
516                Err(Error::Timeout)
517            }
518        }
519    }
520
521    /// Inner implementation of send_request without timeout wrapper.
522    async fn send_request_inner(&mut self, request: Request) -> Result<Response> {
523        // Fail immediately if the stream is desynchronized. The caller must
524        // reconnect to get a new, clean Client instance.
525        if self.poisoned {
526            return Err(Error::ConnectionError(
527                "Client stream is desynchronized — reconnect required".into(),
528            ));
529        }
530
531        // Generate sequential correlation ID
532        let correlation_id = self.next_correlation_id;
533        self.next_correlation_id = self.next_correlation_id.wrapping_add(1);
534
535        // Serialize request with wire format prefix and correlation ID
536        let request_bytes =
537            request.to_wire(rivven_protocol::WireFormat::Postcard, correlation_id)?;
538
539        // Reject oversized requests client-side before touching the wire.
540        // This prevents a TCP-level deadlock where write_all() blocks waiting
541        // for the server to read, while the server rejects and tries to respond.
542        if request_bytes.len() > DEFAULT_MAX_REQUEST_SIZE {
543            return Err(Error::RequestTooLarge(
544                request_bytes.len(),
545                DEFAULT_MAX_REQUEST_SIZE,
546            ));
547        }
548
549        // Write length prefix + request.
550        // After the first write_all succeeds, bytes may be on the wire.
551        // Any subsequent I/O failure desynchronizes the TCP stream, so we
552        // must poison the client to prevent silent corruption.
553        let len: u32 = request_bytes
554            .len()
555            .try_into()
556            .map_err(|_| Error::RequestTooLarge(request_bytes.len(), u32::MAX as usize))?;
557        self.stream
558            .write_all(&len.to_be_bytes())
559            .await
560            .map_err(|e| {
561                self.poisoned = true;
562                Error::from(e)
563            })?;
564        self.stream.write_all(&request_bytes).await.map_err(|e| {
565            self.poisoned = true;
566            Error::from(e)
567        })?;
568        self.stream.flush().await.map_err(|e| {
569            self.poisoned = true;
570            Error::from(e)
571        })?;
572
573        // Read length prefix — request was sent, so read failure desynchronizes
574        let mut len_buf = [0u8; 4];
575        self.stream.read_exact(&mut len_buf).await.map_err(|e| {
576            self.poisoned = true;
577            Error::from(e)
578        })?;
579        let msg_len = u32::from_be_bytes(len_buf) as usize;
580
581        // Validate response size to prevent memory exhaustion from malicious server
582        if msg_len > DEFAULT_MAX_RESPONSE_SIZE {
583            self.poisoned = true;
584            return Err(Error::ResponseTooLarge(msg_len, DEFAULT_MAX_RESPONSE_SIZE));
585        }
586
587        // Read response — partial read desynchronizes
588        let mut response_buf = vec![0u8; msg_len];
589        self.stream
590            .read_exact(&mut response_buf)
591            .await
592            .map_err(|e| {
593                self.poisoned = true;
594                Error::from(e)
595            })?;
596
597        // Deserialize response (auto-detects wire format).
598        // The full response was consumed from the wire, so framing is intact
599        // regardless of deserialization outcome — no need to poison here.
600        let (response, _format, response_correlation_id) = Response::from_wire(&response_buf)?;
601
602        // Validate that the response correlation ID matches the
603        // request we sent. A mismatch indicates stream desynchronization
604        // (e.g. partial reads) or a buggy server.
605        if response_correlation_id != correlation_id {
606            // Mark the client as poisoned — the stream is no longer usable
607            // because subsequent reads would parse at wrong byte boundaries.
608            self.poisoned = true;
609            return Err(Error::ProtocolError(
610                rivven_protocol::ProtocolError::InvalidFormat(format!(
611                    "Correlation ID mismatch: expected {}, got {}",
612                    correlation_id, response_correlation_id
613                )),
614            ));
615        }
616
617        Ok(response)
618    }
619
620    /// Consume from multiple partitions using request pipelining.
621    ///
622    /// Sends all `Consume` requests back-to-back on the wire *before*
623    /// reading any response. This eliminates per-partition round-trip
624    /// latency and avoids head-of-line blocking when fetching from many
625    /// partitions over a single connection.
626    ///
627    /// The returned `Vec` has one entry per input partition in the same
628    /// order. Each entry is `Ok(messages)` or `Err(Error)`.
629    pub async fn consume_pipelined(
630        &mut self,
631        fetches: &[(&str, u32, u64, u32, Option<u8>)],
632    ) -> Result<Vec<Result<Vec<MessageData>>>> {
633        if fetches.is_empty() {
634            return Ok(Vec::new());
635        }
636        if self.poisoned {
637            return Err(Error::ConnectionError(
638                "Client stream is desynchronized — reconnect required".into(),
639            ));
640        }
641
642        let timeout_dur = self.request_timeout;
643        match tokio::time::timeout(timeout_dur, self.consume_pipelined_inner(fetches)).await {
644            Ok(result) => result,
645            Err(_elapsed) => {
646                self.poisoned = true;
647                Err(Error::Timeout)
648            }
649        }
650    }
651
652    async fn consume_pipelined_inner(
653        &mut self,
654        fetches: &[(&str, u32, u64, u32, Option<u8>)],
655    ) -> Result<Vec<Result<Vec<MessageData>>>> {
656        let mut correlation_ids = Vec::with_capacity(fetches.len());
657        let mut bytes_sent = false;
658
659        // Phase 1: Send all requests without waiting for responses.
660        for &(topic, partition, offset, max_messages, isolation_level) in fetches {
661            let correlation_id = self.next_correlation_id;
662            self.next_correlation_id = self.next_correlation_id.wrapping_add(1);
663            correlation_ids.push(correlation_id);
664
665            let request = Request::Consume {
666                topic: topic.to_string(),
667                partition,
668                offset,
669                max_messages,
670                isolation_level,
671                max_wait_ms: None,
672            };
673
674            let request_bytes = request
675                .to_wire(rivven_protocol::WireFormat::Postcard, correlation_id)
676                .inspect_err(|_| {
677                    if bytes_sent {
678                        self.poisoned = true;
679                    }
680                })?;
681
682            if request_bytes.len() > DEFAULT_MAX_REQUEST_SIZE {
683                if bytes_sent {
684                    self.poisoned = true;
685                }
686                return Err(Error::RequestTooLarge(
687                    request_bytes.len(),
688                    DEFAULT_MAX_REQUEST_SIZE,
689                ));
690            }
691
692            let len: u32 = request_bytes.len().try_into().map_err(|_| {
693                if bytes_sent {
694                    self.poisoned = true;
695                }
696                Error::RequestTooLarge(request_bytes.len(), u32::MAX as usize)
697            })?;
698            self.stream
699                .write_all(&len.to_be_bytes())
700                .await
701                .map_err(|e| {
702                    if bytes_sent {
703                        self.poisoned = true;
704                    }
705                    Error::from(e)
706                })?;
707            self.stream.write_all(&request_bytes).await.map_err(|e| {
708                self.poisoned = true; // At least length prefix was sent
709                Error::from(e)
710            })?;
711            bytes_sent = true;
712        }
713        self.stream.flush().await.map_err(|e| {
714            self.poisoned = true;
715            Error::from(e)
716        })?;
717
718        // Phase 2: Read all responses in-order.
719        let mut results = Vec::with_capacity(fetches.len());
720        let mut response_buf: Vec<u8> = Vec::with_capacity(4096);
721        for &expected_cid in &correlation_ids {
722            let mut len_buf = [0u8; 4];
723            self.stream.read_exact(&mut len_buf).await.map_err(|e| {
724                self.poisoned = true;
725                Error::from(e)
726            })?;
727            let msg_len = u32::from_be_bytes(len_buf) as usize;
728
729            if msg_len > DEFAULT_MAX_RESPONSE_SIZE {
730                self.poisoned = true;
731                return Err(Error::ResponseTooLarge(msg_len, DEFAULT_MAX_RESPONSE_SIZE));
732            }
733
734            response_buf.resize(msg_len, 0);
735            self.stream
736                .read_exact(&mut response_buf)
737                .await
738                .map_err(|e| {
739                    self.poisoned = true;
740                    Error::from(e)
741                })?;
742            let (response, _format, response_cid) = Response::from_wire(&response_buf)
743                .inspect_err(|_| {
744                    self.poisoned = true;
745                })?;
746
747            if response_cid != expected_cid {
748                self.poisoned = true;
749                return Err(Error::ProtocolError(
750                    rivven_protocol::ProtocolError::InvalidFormat(format!(
751                        "Correlation ID mismatch: expected {}, got {}",
752                        expected_cid, response_cid
753                    )),
754                ));
755            }
756
757            let result = match response {
758                Response::Messages { messages } => Ok(messages),
759                Response::Error { message } => Err(Error::ServerError(message)),
760                _ => Err(Error::InvalidResponse),
761            };
762            results.push(result);
763        }
764
765        Ok(results)
766    }
767
768    /// Publish a message to a topic
769    pub async fn publish(
770        &mut self,
771        topic: impl Into<String>,
772        value: impl Into<Bytes>,
773    ) -> Result<u64> {
774        self.publish_with_key(topic, None::<Bytes>, value).await
775    }
776
777    /// Publish a message with a key to a topic
778    pub async fn publish_with_key(
779        &mut self,
780        topic: impl Into<String>,
781        key: Option<impl Into<Bytes>>,
782        value: impl Into<Bytes>,
783    ) -> Result<u64> {
784        let request = Request::Publish {
785            topic: topic.into(),
786            partition: None,
787            key: key.map(|k| k.into()),
788            value: value.into(),
789            leader_epoch: None,
790        };
791
792        let response = self.send_request(request).await?;
793
794        match response {
795            Response::Published { offset, .. } => Ok(offset),
796            Response::Error { message } => Err(Error::ServerError(message)),
797            _ => Err(Error::InvalidResponse),
798        }
799    }
800
801    /// Publish a message to a specific partition
802    pub async fn publish_to_partition(
803        &mut self,
804        topic: impl Into<String>,
805        partition: u32,
806        key: Option<impl Into<Bytes>>,
807        value: impl Into<Bytes>,
808    ) -> Result<u64> {
809        let request = Request::Publish {
810            topic: topic.into(),
811            partition: Some(partition),
812            key: key.map(|k| k.into()),
813            value: value.into(),
814            leader_epoch: None,
815        };
816
817        let response = self.send_request(request).await?;
818
819        match response {
820            Response::Published { offset, .. } => Ok(offset),
821            Response::Error { message } => Err(Error::ServerError(message)),
822            _ => Err(Error::InvalidResponse),
823        }
824    }
825
826    /// Consume messages from a topic partition
827    ///
828    /// Uses read_uncommitted isolation level (default).
829    /// For transactional consumers that should not see aborted transaction messages,
830    /// use [`Self::consume_with_isolation`] with `isolation_level = 1` (read_committed).
831    pub async fn consume(
832        &mut self,
833        topic: impl Into<String>,
834        partition: u32,
835        offset: u64,
836        max_messages: u32,
837    ) -> Result<Vec<MessageData>> {
838        self.consume_with_isolation(topic, partition, offset, max_messages, None)
839            .await
840    }
841
842    /// Consume messages from a topic partition with specified isolation level
843    ///
844    /// # Arguments
845    /// * `topic` - Topic name
846    /// * `partition` - Partition number
847    /// * `offset` - Starting offset
848    /// * `max_messages` - Maximum messages to return
849    /// * `isolation_level` - Transaction isolation level:
850    ///   - `None` or `Some(0)` = read_uncommitted (default): Returns all messages
851    ///   - `Some(1)` = read_committed: Filters out messages from aborted transactions
852    ///
853    /// # Read Committed Isolation
854    ///
855    /// When using `isolation_level = Some(1)` (read_committed), the consumer will:
856    /// - Not see messages from transactions that were aborted
857    /// - Not see control records (transaction markers)
858    /// - Only see committed transactional messages
859    ///
860    /// This is essential for exactly-once semantics (EOS) consumers.
861    pub async fn consume_with_isolation(
862        &mut self,
863        topic: impl Into<String>,
864        partition: u32,
865        offset: u64,
866        max_messages: u32,
867        isolation_level: Option<u8>,
868    ) -> Result<Vec<MessageData>> {
869        let request = Request::Consume {
870            topic: topic.into(),
871            partition,
872            offset,
873            max_messages,
874            isolation_level,
875            max_wait_ms: None,
876        };
877
878        let response = self.send_request(request).await?;
879
880        match response {
881            Response::Messages { messages } => Ok(messages),
882            Response::Error { message } => Err(Error::ServerError(message)),
883            _ => Err(Error::InvalidResponse),
884        }
885    }
886
887    /// Consume messages with long-polling support
888    ///
889    /// If no data is available immediately, the server will hold the request
890    /// for up to `max_wait_ms` milliseconds before returning an empty response.
891    /// This avoids tight polling loops and reduces network overhead.
892    ///
893    /// Capped server-side at 30 000 ms. `0` or `None` = immediate (no waiting).
894    pub async fn consume_long_poll(
895        &mut self,
896        topic: impl Into<String>,
897        partition: u32,
898        offset: u64,
899        max_messages: u32,
900        isolation_level: Option<u8>,
901        max_wait_ms: u64,
902    ) -> Result<Vec<MessageData>> {
903        let request = Request::Consume {
904            topic: topic.into(),
905            partition,
906            offset,
907            max_messages,
908            isolation_level,
909            max_wait_ms: Some(max_wait_ms),
910        };
911
912        let response = self.send_request(request).await?;
913
914        match response {
915            Response::Messages { messages } => Ok(messages),
916            Response::Error { message } => Err(Error::ServerError(message)),
917            _ => Err(Error::InvalidResponse),
918        }
919    }
920
921    /// Consume messages with read_committed isolation level
922    ///
923    /// This is a convenience method for transactional consumers that should
924    /// only see committed messages. Messages from aborted transactions are filtered out.
925    ///
926    /// Equivalent to calling [`Self::consume_with_isolation`] with `isolation_level = Some(1)`.
927    pub async fn consume_read_committed(
928        &mut self,
929        topic: impl Into<String>,
930        partition: u32,
931        offset: u64,
932        max_messages: u32,
933    ) -> Result<Vec<MessageData>> {
934        self.consume_with_isolation(topic, partition, offset, max_messages, Some(1))
935            .await
936    }
937
938    /// Create a new topic
939    pub async fn create_topic(
940        &mut self,
941        name: impl Into<String>,
942        partitions: Option<u32>,
943    ) -> Result<u32> {
944        let name = name.into();
945        let request = Request::CreateTopic {
946            name: name.clone(),
947            partitions,
948        };
949
950        let response = self.send_request(request).await?;
951
952        match response {
953            Response::TopicCreated { partitions, .. } => Ok(partitions),
954            Response::Error { message } => Err(Error::ServerError(message)),
955            _ => Err(Error::InvalidResponse),
956        }
957    }
958
959    /// List all topics
960    pub async fn list_topics(&mut self) -> Result<Vec<String>> {
961        let request = Request::ListTopics;
962        let response = self.send_request(request).await?;
963
964        match response {
965            Response::Topics { topics } => Ok(topics),
966            Response::Error { message } => Err(Error::ServerError(message)),
967            _ => Err(Error::InvalidResponse),
968        }
969    }
970
971    /// Delete a topic
972    pub async fn delete_topic(&mut self, name: impl Into<String>) -> Result<()> {
973        let request = Request::DeleteTopic { name: name.into() };
974        let response = self.send_request(request).await?;
975
976        match response {
977            Response::TopicDeleted => Ok(()),
978            Response::Error { message } => Err(Error::ServerError(message)),
979            _ => Err(Error::InvalidResponse),
980        }
981    }
982
983    /// Commit consumer offset
984    pub async fn commit_offset(
985        &mut self,
986        consumer_group: impl Into<String>,
987        topic: impl Into<String>,
988        partition: u32,
989        offset: u64,
990    ) -> Result<()> {
991        let request = Request::CommitOffset {
992            consumer_group: consumer_group.into(),
993            topic: topic.into(),
994            partition,
995            offset,
996        };
997
998        let response = self.send_request(request).await?;
999
1000        match response {
1001            Response::OffsetCommitted => Ok(()),
1002            Response::Error { message } => Err(Error::ServerError(message)),
1003            _ => Err(Error::InvalidResponse),
1004        }
1005    }
1006
1007    /// Commit offsets for multiple partitions using request pipelining.
1008    ///
1009    /// Sends all `CommitOffset` requests at once, then reads all responses.
1010    /// Returns per-partition results in the same order as `offsets`.
1011    pub async fn commit_offsets_pipelined(
1012        &mut self,
1013        consumer_group: &str,
1014        offsets: &[(String, u32, u64)],
1015    ) -> Result<Vec<Result<()>>> {
1016        if offsets.is_empty() {
1017            return Ok(Vec::new());
1018        }
1019        if self.poisoned {
1020            return Err(Error::ConnectionError(
1021                "Client stream is desynchronized — reconnect required".into(),
1022            ));
1023        }
1024
1025        let timeout_dur = self.request_timeout;
1026        match tokio::time::timeout(
1027            timeout_dur,
1028            self.commit_offsets_pipelined_inner(consumer_group, offsets),
1029        )
1030        .await
1031        {
1032            Ok(result) => result,
1033            Err(_elapsed) => {
1034                self.poisoned = true;
1035                Err(Error::Timeout)
1036            }
1037        }
1038    }
1039
1040    async fn commit_offsets_pipelined_inner(
1041        &mut self,
1042        consumer_group: &str,
1043        offsets: &[(String, u32, u64)],
1044    ) -> Result<Vec<Result<()>>> {
1045        let mut correlation_ids = Vec::with_capacity(offsets.len());
1046        let mut bytes_sent = false;
1047
1048        // Phase 1: Send all commit requests.
1049        for (topic, partition, offset) in offsets {
1050            let correlation_id = self.next_correlation_id;
1051            self.next_correlation_id = self.next_correlation_id.wrapping_add(1);
1052            correlation_ids.push(correlation_id);
1053
1054            let request = Request::CommitOffset {
1055                consumer_group: consumer_group.to_string(),
1056                topic: topic.clone(),
1057                partition: *partition,
1058                offset: *offset,
1059            };
1060
1061            let request_bytes = request
1062                .to_wire(rivven_protocol::WireFormat::Postcard, correlation_id)
1063                .inspect_err(|_| {
1064                    if bytes_sent {
1065                        self.poisoned = true;
1066                    }
1067                })?;
1068
1069            if request_bytes.len() > DEFAULT_MAX_REQUEST_SIZE {
1070                if bytes_sent {
1071                    self.poisoned = true;
1072                }
1073                return Err(Error::RequestTooLarge(
1074                    request_bytes.len(),
1075                    DEFAULT_MAX_REQUEST_SIZE,
1076                ));
1077            }
1078
1079            let len: u32 = request_bytes.len().try_into().map_err(|_| {
1080                if bytes_sent {
1081                    self.poisoned = true;
1082                }
1083                Error::RequestTooLarge(request_bytes.len(), u32::MAX as usize)
1084            })?;
1085            self.stream
1086                .write_all(&len.to_be_bytes())
1087                .await
1088                .map_err(|e| {
1089                    if bytes_sent {
1090                        self.poisoned = true;
1091                    }
1092                    Error::from(e)
1093                })?;
1094            self.stream.write_all(&request_bytes).await.map_err(|e| {
1095                self.poisoned = true;
1096                Error::from(e)
1097            })?;
1098            bytes_sent = true;
1099        }
1100        self.stream.flush().await.map_err(|e| {
1101            self.poisoned = true;
1102            Error::from(e)
1103        })?;
1104
1105        // Phase 2: Read all responses in-order.
1106        let mut results = Vec::with_capacity(offsets.len());
1107        for &expected_cid in &correlation_ids {
1108            let mut len_buf = [0u8; 4];
1109            self.stream.read_exact(&mut len_buf).await.map_err(|e| {
1110                self.poisoned = true;
1111                Error::from(e)
1112            })?;
1113            let msg_len = u32::from_be_bytes(len_buf) as usize;
1114
1115            if msg_len > DEFAULT_MAX_RESPONSE_SIZE {
1116                self.poisoned = true;
1117                return Err(Error::ResponseTooLarge(msg_len, DEFAULT_MAX_RESPONSE_SIZE));
1118            }
1119
1120            let mut response_buf = vec![0u8; msg_len];
1121            self.stream
1122                .read_exact(&mut response_buf)
1123                .await
1124                .map_err(|e| {
1125                    self.poisoned = true;
1126                    Error::from(e)
1127                })?;
1128            let (response, _format, response_cid) = Response::from_wire(&response_buf)
1129                .inspect_err(|_| {
1130                    self.poisoned = true;
1131                })?;
1132
1133            if response_cid != expected_cid {
1134                self.poisoned = true;
1135                return Err(Error::ProtocolError(
1136                    rivven_protocol::ProtocolError::InvalidFormat(format!(
1137                        "Correlation ID mismatch: expected {}, got {}",
1138                        expected_cid, response_cid
1139                    )),
1140                ));
1141            }
1142
1143            let result = match response {
1144                Response::OffsetCommitted => Ok(()),
1145                Response::Error { message } => Err(Error::ServerError(message)),
1146                _ => Err(Error::InvalidResponse),
1147            };
1148            results.push(result);
1149        }
1150
1151        Ok(results)
1152    }
1153
1154    /// Returns `true` if the client stream is desynchronized and unusable.
1155    pub fn is_poisoned(&self) -> bool {
1156        self.poisoned
1157    }
1158
1159    /// Get consumer offset
1160    pub async fn get_offset(
1161        &mut self,
1162        consumer_group: impl Into<String>,
1163        topic: impl Into<String>,
1164        partition: u32,
1165    ) -> Result<Option<u64>> {
1166        let request = Request::GetOffset {
1167            consumer_group: consumer_group.into(),
1168            topic: topic.into(),
1169            partition,
1170        };
1171
1172        let response = self.send_request(request).await?;
1173
1174        match response {
1175            Response::Offset { offset } => Ok(offset),
1176            Response::Error { message } => Err(Error::ServerError(message)),
1177            _ => Err(Error::InvalidResponse),
1178        }
1179    }
1180
1181    /// Get earliest and latest offsets for a topic partition
1182    ///
1183    /// Returns (earliest, latest) where:
1184    /// - earliest: First available offset (messages before this are deleted/compacted)
1185    /// - latest: Next offset to be assigned (one past the last message)
1186    pub async fn get_offset_bounds(
1187        &mut self,
1188        topic: impl Into<String>,
1189        partition: u32,
1190    ) -> Result<(u64, u64)> {
1191        let request = Request::GetOffsetBounds {
1192            topic: topic.into(),
1193            partition,
1194        };
1195
1196        let response = self.send_request(request).await?;
1197
1198        match response {
1199            Response::OffsetBounds { earliest, latest } => Ok((earliest, latest)),
1200            Response::Error { message } => Err(Error::ServerError(message)),
1201            _ => Err(Error::InvalidResponse),
1202        }
1203    }
1204
1205    /// Get topic metadata
1206    pub async fn get_metadata(&mut self, topic: impl Into<String>) -> Result<(String, u32)> {
1207        let request = Request::GetMetadata {
1208            topic: topic.into(),
1209        };
1210
1211        let response = self.send_request(request).await?;
1212
1213        match response {
1214            Response::Metadata { name, partitions } => Ok((name, partitions)),
1215            Response::Error { message } => Err(Error::ServerError(message)),
1216            _ => Err(Error::InvalidResponse),
1217        }
1218    }
1219
1220    /// Ping the server
1221    pub async fn ping(&mut self) -> Result<()> {
1222        let request = Request::Ping;
1223        let response = self.send_request(request).await?;
1224
1225        match response {
1226            Response::Pong => Ok(()),
1227            Response::Error { message } => Err(Error::ServerError(message)),
1228            _ => Err(Error::InvalidResponse),
1229        }
1230    }
1231
1232    /// Register a schema with the schema registry (via HTTP REST API)
1233    ///
1234    /// The schema registry runs as a separate service (`rivven-schema`) with a
1235    /// Confluent-compatible REST API. This method performs a POST to
1236    /// `{registry_url}/subjects/{subject}/versions`.
1237    ///
1238    /// Supports both `http://` and `https://` registry URLs. HTTPS requires the
1239    /// `schema-registry` feature which brings in `reqwest` with `rustls-tls`.
1240    /// Without the feature flag, a minimal inline HTTP/1.1 client is used (HTTP only).
1241    ///
1242    /// # Arguments
1243    /// * `registry_url` - Schema registry base URL (e.g., `http://localhost:8081` or `https://registry.example.com`)
1244    /// * `subject` - Subject name (typically `{topic}-key` or `{topic}-value`)
1245    /// * `schema_type` - Schema format: `"AVRO"`, `"PROTOBUF"`, or `"JSON"`
1246    /// * `schema` - The schema definition string
1247    ///
1248    /// # Returns
1249    /// The global schema ID on success.
1250    pub async fn register_schema(
1251        &self,
1252        registry_url: &str,
1253        subject: &str,
1254        schema_type: &str,
1255        schema: &str,
1256    ) -> Result<u32> {
1257        let url = registry_url.trim_end_matches('/');
1258        let endpoint = format!("{}/subjects/{}/versions", url, subject);
1259
1260        let body = serde_json::json!({
1261            "schema": schema,
1262            "schemaType": schema_type,
1263        });
1264
1265        #[cfg(feature = "schema-registry")]
1266        {
1267            self.register_schema_reqwest(&endpoint, &body).await
1268        }
1269
1270        #[cfg(not(feature = "schema-registry"))]
1271        {
1272            self.register_schema_inline(url, &endpoint, &body).await
1273        }
1274    }
1275
1276    /// HTTPS-capable schema registration using reqwest (requires `schema-registry` feature)
1277    #[cfg(feature = "schema-registry")]
1278    async fn register_schema_reqwest(
1279        &self,
1280        endpoint: &str,
1281        body: &serde_json::Value,
1282    ) -> Result<u32> {
1283        let client = reqwest::Client::new();
1284        let response = client
1285            .post(endpoint)
1286            .header("Content-Type", "application/vnd.schemaregistry.v1+json")
1287            .json(body)
1288            .send()
1289            .await
1290            .map_err(|e| Error::ConnectionError(format!("schema registry request failed: {e}")))?;
1291
1292        let status = response.status();
1293        if !status.is_success() {
1294            let body_text = response.text().await.unwrap_or_default();
1295            return Err(Error::ServerError(format!(
1296                "schema registry returned HTTP {status}: {body_text}"
1297            )));
1298        }
1299
1300        #[derive(serde::Deserialize)]
1301        struct RegisterResponse {
1302            id: u32,
1303        }
1304
1305        let result: RegisterResponse = response
1306            .json()
1307            .await
1308            .map_err(|e| Error::ConnectionError(format!("failed to parse response: {e}")))?;
1309
1310        Ok(result.id)
1311    }
1312
1313    /// Minimal inline HTTP/1.1 client for schema registration (HTTP only, no external deps).
1314    ///
1315    /// Handles both `Content-Length` and `Transfer-Encoding: chunked` responses.
1316    /// For production deployments behind proxies or with HTTPS requirements,
1317    /// enable the `schema-registry` feature to use `reqwest` instead.
1318    #[cfg(not(feature = "schema-registry"))]
1319    async fn register_schema_inline(
1320        &self,
1321        base_url: &str,
1322        _endpoint: &str,
1323        body: &serde_json::Value,
1324    ) -> Result<u32> {
1325        use tokio::io::AsyncBufReadExt;
1326        use tokio::io::BufReader;
1327        use tokio::net::TcpStream as TokioTcpStream;
1328
1329        let stripped = base_url.strip_prefix("http://").ok_or_else(|| {
1330            Error::ConnectionError(
1331                "HTTPS requires the `schema-registry` feature; URL must start with http:// without it".into(),
1332            )
1333        })?;
1334        let (host_port, _) = stripped.split_once('/').unwrap_or((stripped, ""));
1335
1336        // Extract path from endpoint (skip the base URL part)
1337        let path = _endpoint.strip_prefix(base_url).unwrap_or(_endpoint);
1338
1339        let body_bytes = serde_json::to_vec(body)
1340            .map_err(|e| Error::ConnectionError(format!("failed to serialize schema: {e}")))?;
1341
1342        let request = format!(
1343            "POST {} HTTP/1.1\r\nHost: {}\r\nContent-Type: application/vnd.schemaregistry.v1+json\r\nContent-Length: {}\r\nConnection: close\r\n\r\n",
1344            path, host_port, body_bytes.len()
1345        );
1346
1347        let timeout = tokio::time::Duration::from_secs(30);
1348
1349        let mut stream = tokio::time::timeout(timeout, TokioTcpStream::connect(host_port))
1350            .await
1351            .map_err(|_| Error::ConnectionError("schema registry connect timed out".into()))?
1352            .map_err(|e| {
1353                Error::ConnectionError(format!("failed to connect to schema registry: {e}"))
1354            })?;
1355
1356        stream
1357            .write_all(request.as_bytes())
1358            .await
1359            .map_err(|e| Error::ConnectionError(format!("failed to send request: {e}")))?;
1360        stream
1361            .write_all(&body_bytes)
1362            .await
1363            .map_err(|e| Error::ConnectionError(format!("failed to send body: {e}")))?;
1364
1365        // Read response with timeout
1366        let response_body =
1367            tokio::time::timeout(timeout, async {
1368                let mut reader = BufReader::new(stream);
1369
1370                // --- Parse status line ---
1371                let mut status_line = String::new();
1372                reader.read_line(&mut status_line).await.map_err(|e| {
1373                    Error::ConnectionError(format!("failed to read status line: {e}"))
1374                })?;
1375
1376                let status_code: u16 = status_line
1377                    .split_whitespace()
1378                    .nth(1)
1379                    .and_then(|s| s.parse().ok())
1380                    .unwrap_or(0);
1381
1382                if !(200..300).contains(&status_code) {
1383                    return Err(Error::ServerError(format!(
1384                        "schema registry returned HTTP {status_code}"
1385                    )));
1386                }
1387
1388                // --- Parse headers ---
1389                let mut content_length: Option<usize> = None;
1390                let mut is_chunked = false;
1391                loop {
1392                    let mut header_line = String::new();
1393                    reader.read_line(&mut header_line).await.map_err(|e| {
1394                        Error::ConnectionError(format!("failed to read header: {e}"))
1395                    })?;
1396
1397                    let trimmed = header_line.trim();
1398                    if trimmed.is_empty() {
1399                        break; // End of headers
1400                    }
1401
1402                    let lower = trimmed.to_ascii_lowercase();
1403                    if let Some(val) = lower.strip_prefix("content-length:") {
1404                        content_length = val.trim().parse().ok();
1405                    } else if lower.starts_with("transfer-encoding:") && lower.contains("chunked") {
1406                        is_chunked = true;
1407                    }
1408                }
1409
1410                // --- Read body ---
1411                let body_bytes = if is_chunked {
1412                    // Chunked transfer-encoding: read chunk-size\r\n, chunk-data\r\n, repeat until 0\r\n
1413                    const MAX_CHUNK_SIZE: usize = 16 * 1024 * 1024;
1414                    const MAX_TOTAL_BODY: usize = 16 * 1024 * 1024;
1415                    let mut body = Vec::new();
1416                    loop {
1417                        let mut size_line = String::new();
1418                        reader.read_line(&mut size_line).await.map_err(|e| {
1419                            Error::ConnectionError(format!("failed to read chunk size: {e}"))
1420                        })?;
1421
1422                        let chunk_size =
1423                            usize::from_str_radix(size_line.trim(), 16).map_err(|_| {
1424                                Error::ConnectionError(format!(
1425                                    "invalid chunk size: {:?}",
1426                                    size_line.trim()
1427                                ))
1428                            })?;
1429                        if chunk_size == 0 {
1430                            // Terminal chunk — consume trailing CRLF per RFC 7230 §4.1
1431                            let mut trailer_buf = [0u8; 2];
1432                            let _ = reader.read_exact(&mut trailer_buf).await;
1433                            break;
1434                        }
1435
1436                        // Guard against malicious chunk sizes
1437                        if chunk_size > MAX_CHUNK_SIZE {
1438                            return Err(Error::ConnectionError(format!(
1439                                "chunk size {} exceeds maximum {}",
1440                                chunk_size, MAX_CHUNK_SIZE
1441                            )));
1442                        }
1443
1444                        let mut chunk = vec![0u8; chunk_size];
1445                        reader.read_exact(&mut chunk).await.map_err(|e| {
1446                            Error::ConnectionError(format!("failed to read chunk data: {e}"))
1447                        })?;
1448                        body.extend_from_slice(&chunk);
1449
1450                        // Guard against unbounded total body accumulation
1451                        if body.len() > MAX_TOTAL_BODY {
1452                            return Err(Error::ConnectionError(format!(
1453                                "chunked body {} bytes exceeds maximum {}",
1454                                body.len(),
1455                                MAX_TOTAL_BODY
1456                            )));
1457                        }
1458
1459                        // Consume trailing \r\n (exactly 2 bytes) after chunk data.
1460                        // Using read_exact instead of read_line prevents unbounded
1461                        // memory reads on malformed input.
1462                        let mut crlf_buf = [0u8; 2];
1463                        reader.read_exact(&mut crlf_buf).await.map_err(|e| {
1464                            Error::ConnectionError(format!("failed to read chunk CRLF: {e}"))
1465                        })?;
1466                        if crlf_buf != [b'\r', b'\n'] {
1467                            return Err(Error::ConnectionError(format!(
1468                                "expected CRLF after chunk data, got {:02x?}",
1469                                crlf_buf
1470                            )));
1471                        }
1472                    }
1473                    body
1474                } else if let Some(len) = content_length {
1475                    const MAX_CONTENT_LENGTH: usize = 16 * 1024 * 1024;
1476                    if len > MAX_CONTENT_LENGTH {
1477                        return Err(Error::ConnectionError(format!(
1478                            "response Content-Length {} bytes exceeds maximum {}",
1479                            len, MAX_CONTENT_LENGTH
1480                        )));
1481                    }
1482                    let mut body = vec![0u8; len];
1483                    reader.read_exact(&mut body).await.map_err(|e| {
1484                        Error::ConnectionError(format!("failed to read response body: {e}"))
1485                    })?;
1486                    body
1487                } else {
1488                    // Fallback: Connection: close — read until EOF (capped at 16 MB)
1489                    const MAX_RESPONSE_SIZE: usize = 16 * 1024 * 1024;
1490                    let mut body = Vec::with_capacity(4096);
1491                    reader.read_to_end(&mut body).await.map_err(|e| {
1492                        Error::ConnectionError(format!("failed to read response: {e}"))
1493                    })?;
1494                    if body.len() > MAX_RESPONSE_SIZE {
1495                        return Err(Error::ConnectionError(format!(
1496                            "response body {} bytes exceeds maximum {}",
1497                            body.len(),
1498                            MAX_RESPONSE_SIZE
1499                        )));
1500                    }
1501                    body
1502                };
1503
1504                Ok(body_bytes)
1505            })
1506            .await
1507            .map_err(|_| Error::ConnectionError("schema registry response timed out".into()))??;
1508
1509        #[derive(serde::Deserialize)]
1510        struct RegisterResponse {
1511            id: u32,
1512        }
1513
1514        let result: RegisterResponse = serde_json::from_slice(&response_body).map_err(|e| {
1515            Error::ConnectionError(format!("failed to parse schema registry response: {e}"))
1516        })?;
1517
1518        Ok(result.id)
1519    }
1520
1521    /// List all consumer groups
1522    pub async fn list_groups(&mut self) -> Result<Vec<String>> {
1523        let request = Request::ListGroups;
1524
1525        let response = self.send_request(request).await?;
1526
1527        match response {
1528            Response::Groups { groups } => Ok(groups),
1529            Response::Error { message } => Err(Error::ServerError(message)),
1530            _ => Err(Error::InvalidResponse),
1531        }
1532    }
1533
1534    /// Describe a consumer group (get all committed offsets)
1535    pub async fn describe_group(
1536        &mut self,
1537        consumer_group: impl Into<String>,
1538    ) -> Result<std::collections::HashMap<String, std::collections::HashMap<u32, u64>>> {
1539        let request = Request::DescribeGroup {
1540            consumer_group: consumer_group.into(),
1541        };
1542
1543        let response = self.send_request(request).await?;
1544
1545        match response {
1546            Response::GroupDescription { offsets, .. } => Ok(offsets),
1547            Response::Error { message } => Err(Error::ServerError(message)),
1548            _ => Err(Error::InvalidResponse),
1549        }
1550    }
1551
1552    /// Delete a consumer group
1553    pub async fn delete_group(&mut self, consumer_group: impl Into<String>) -> Result<()> {
1554        let request = Request::DeleteGroup {
1555            consumer_group: consumer_group.into(),
1556        };
1557
1558        let response = self.send_request(request).await?;
1559
1560        match response {
1561            Response::GroupDeleted => Ok(()),
1562            Response::Error { message } => Err(Error::ServerError(message)),
1563            _ => Err(Error::InvalidResponse),
1564        }
1565    }
1566
1567    // =========================================================================
1568    // Consumer Group Coordination
1569    // =========================================================================
1570
1571    /// Join a consumer group.
1572    ///
1573    /// The coordinator assigns (or generates) a member ID and returns the
1574    /// current generation, leader, and the full member list. The leader
1575    /// uses the member list to compute partition assignments and sends
1576    /// them via [`sync_group`](Self::sync_group).
1577    ///
1578    /// # Returns
1579    /// `(generation_id, protocol_type, member_id, leader_id, members)`
1580    /// where `members` is `Vec<(member_id, subscriptions)>`.
1581    pub async fn join_group(
1582        &mut self,
1583        group_id: impl Into<String>,
1584        member_id: impl Into<String>,
1585        session_timeout_ms: u32,
1586        rebalance_timeout_ms: u32,
1587        protocol_type: impl Into<String>,
1588        subscriptions: Vec<String>,
1589    ) -> Result<(u32, String, String, String, Vec<(String, Vec<String>)>)> {
1590        let request = Request::JoinGroup {
1591            group_id: group_id.into(),
1592            member_id: member_id.into(),
1593            session_timeout_ms,
1594            rebalance_timeout_ms,
1595            protocol_type: protocol_type.into(),
1596            subscriptions,
1597        };
1598
1599        let response = self.send_request(request).await?;
1600
1601        match response {
1602            Response::JoinGroupResult {
1603                generation_id,
1604                protocol_type,
1605                member_id,
1606                leader_id,
1607                members,
1608            } => Ok((generation_id, protocol_type, member_id, leader_id, members)),
1609            Response::Error { message } => Err(Error::ServerError(message)),
1610            _ => Err(Error::InvalidResponse),
1611        }
1612    }
1613
1614    /// Synchronize consumer group assignments.
1615    ///
1616    /// The group leader sends partition assignments for all members;
1617    /// followers send an empty assignment list. Every member receives
1618    /// their own assignment in the response.
1619    ///
1620    /// # Arguments
1621    /// * `assignments` — `Vec<(member_id, Vec<(topic, Vec<partition>)>)>`.
1622    ///   Only the leader should provide non-empty assignments.
1623    ///
1624    /// # Returns
1625    /// This member's partition assignment: `Vec<(topic, Vec<partition>)>`.
1626    pub async fn sync_group(
1627        &mut self,
1628        group_id: impl Into<String>,
1629        generation_id: u32,
1630        member_id: impl Into<String>,
1631        assignments: SyncGroupAssignments,
1632    ) -> Result<Vec<(String, Vec<u32>)>> {
1633        let request = Request::SyncGroup {
1634            group_id: group_id.into(),
1635            generation_id,
1636            member_id: member_id.into(),
1637            assignments,
1638        };
1639
1640        let response = self.send_request(request).await?;
1641
1642        match response {
1643            Response::SyncGroupResult { assignments } => Ok(assignments),
1644            Response::Error { message } => Err(Error::ServerError(message)),
1645            _ => Err(Error::InvalidResponse),
1646        }
1647    }
1648
1649    /// Send a heartbeat to the consumer group coordinator.
1650    ///
1651    /// # Returns
1652    /// * `0` — OK, member is in sync
1653    /// * `27` — REBALANCE_IN_PROGRESS, member should rejoin
1654    pub async fn heartbeat(
1655        &mut self,
1656        group_id: impl Into<String>,
1657        generation_id: u32,
1658        member_id: impl Into<String>,
1659    ) -> Result<i32> {
1660        let request = Request::Heartbeat {
1661            group_id: group_id.into(),
1662            generation_id,
1663            member_id: member_id.into(),
1664        };
1665
1666        let response = self.send_request(request).await?;
1667
1668        match response {
1669            Response::HeartbeatResult { error_code } => Ok(error_code),
1670            Response::Error { message } => Err(Error::ServerError(message)),
1671            _ => Err(Error::InvalidResponse),
1672        }
1673    }
1674
1675    /// Leave a consumer group.
1676    ///
1677    /// Gracefully removes this member, triggering a rebalance for
1678    /// remaining group members. Call this during consumer shutdown.
1679    pub async fn leave_group(
1680        &mut self,
1681        group_id: impl Into<String>,
1682        member_id: impl Into<String>,
1683    ) -> Result<()> {
1684        let request = Request::LeaveGroup {
1685            group_id: group_id.into(),
1686            member_id: member_id.into(),
1687        };
1688
1689        let response = self.send_request(request).await?;
1690
1691        match response {
1692            Response::LeaveGroupResult => Ok(()),
1693            Response::Error { message } => Err(Error::ServerError(message)),
1694            _ => Err(Error::InvalidResponse),
1695        }
1696    }
1697
1698    /// Get the first offset with timestamp >= the given timestamp
1699    ///
1700    /// # Arguments
1701    /// * `topic` - The topic name
1702    /// * `partition` - The partition number
1703    /// * `timestamp_ms` - Timestamp in milliseconds since Unix epoch
1704    ///
1705    /// # Returns
1706    /// * `Some(offset)` - The first offset with message timestamp >= timestamp_ms
1707    /// * `None` - No messages found with timestamp >= timestamp_ms
1708    pub async fn get_offset_for_timestamp(
1709        &mut self,
1710        topic: impl Into<String>,
1711        partition: u32,
1712        timestamp_ms: i64,
1713    ) -> Result<Option<u64>> {
1714        let request = Request::GetOffsetForTimestamp {
1715            topic: topic.into(),
1716            partition,
1717            timestamp_ms,
1718        };
1719
1720        let response = self.send_request(request).await?;
1721
1722        match response {
1723            Response::OffsetForTimestamp { offset } => Ok(offset),
1724            Response::Error { message } => Err(Error::ServerError(message)),
1725            _ => Err(Error::InvalidResponse),
1726        }
1727    }
1728
1729    // ========================================================================
1730    // Admin API
1731    // ========================================================================
1732
1733    /// Describe topic configurations
1734    ///
1735    /// Returns the current configuration for the specified topics.
1736    ///
1737    /// # Arguments
1738    /// * `topics` - Topics to describe (empty slice = all topics)
1739    ///
1740    /// # Example
1741    /// ```no_run
1742    /// # use rivven_client::Client;
1743    /// # async fn example() -> rivven_client::Result<()> {
1744    /// let mut client = Client::connect("127.0.0.1:9092").await?;
1745    /// let configs = client.describe_topic_configs(&["orders", "events"]).await?;
1746    /// for (topic, config) in configs {
1747    ///     println!("{}: {:?}", topic, config);
1748    /// }
1749    /// # Ok(())
1750    /// # }
1751    /// ```
1752    pub async fn describe_topic_configs(
1753        &mut self,
1754        topics: &[&str],
1755    ) -> Result<std::collections::HashMap<String, std::collections::HashMap<String, String>>> {
1756        let request = Request::DescribeTopicConfigs {
1757            topics: topics.iter().map(|s| s.to_string()).collect(),
1758        };
1759
1760        let response = self.send_request(request).await?;
1761
1762        match response {
1763            Response::TopicConfigsDescribed { configs } => {
1764                let mut result = std::collections::HashMap::new();
1765                for desc in configs {
1766                    let mut topic_configs = std::collections::HashMap::new();
1767                    for (key, value) in desc.configs {
1768                        topic_configs.insert(key, value.value);
1769                    }
1770                    result.insert(desc.topic, topic_configs);
1771                }
1772                Ok(result)
1773            }
1774            Response::Error { message } => Err(Error::ServerError(message)),
1775            _ => Err(Error::InvalidResponse),
1776        }
1777    }
1778
1779    /// Alter topic configuration
1780    ///
1781    /// Modifies configuration for an existing topic. Pass `None` as value to reset
1782    /// a configuration key to its default.
1783    ///
1784    /// # Arguments
1785    /// * `topic` - Topic name
1786    /// * `configs` - Configuration changes: (key, value) pairs. Use `None` to reset to default.
1787    ///
1788    /// # Example
1789    /// ```no_run
1790    /// # use rivven_client::Client;
1791    /// # async fn example() -> rivven_client::Result<()> {
1792    /// let mut client = Client::connect("127.0.0.1:9092").await?;
1793    /// let result = client.alter_topic_config("orders", &[
1794    ///     ("retention.ms", Some("86400000")),  // 1 day retention
1795    ///     ("cleanup.policy", Some("compact")), // Enable compaction
1796    /// ]).await?;
1797    /// println!("Changed {} configs", result.changed_count);
1798    /// # Ok(())
1799    /// # }
1800    /// ```
1801    pub async fn alter_topic_config(
1802        &mut self,
1803        topic: impl Into<String>,
1804        configs: &[(&str, Option<&str>)],
1805    ) -> Result<AlterTopicConfigResult> {
1806        use rivven_protocol::TopicConfigEntry;
1807
1808        let request = Request::AlterTopicConfig {
1809            topic: topic.into(),
1810            configs: configs
1811                .iter()
1812                .map(|(k, v)| TopicConfigEntry {
1813                    key: k.to_string(),
1814                    value: v.map(|s| s.to_string()),
1815                })
1816                .collect(),
1817        };
1818
1819        let response = self.send_request(request).await?;
1820
1821        match response {
1822            Response::TopicConfigAltered {
1823                topic,
1824                changed_count,
1825            } => Ok(AlterTopicConfigResult {
1826                topic,
1827                changed_count,
1828            }),
1829            Response::Error { message } => Err(Error::ServerError(message)),
1830            _ => Err(Error::InvalidResponse),
1831        }
1832    }
1833
1834    /// Create additional partitions for an existing topic
1835    ///
1836    /// Increases the partition count for a topic. The new partition count
1837    /// must be greater than the current count (you cannot reduce partitions).
1838    ///
1839    /// # Arguments
1840    /// * `topic` - Topic name
1841    /// * `new_partition_count` - New total partition count
1842    ///
1843    /// # Example
1844    /// ```no_run
1845    /// # use rivven_client::Client;
1846    /// # async fn example() -> rivven_client::Result<()> {
1847    /// let mut client = Client::connect("127.0.0.1:9092").await?;
1848    /// // Increase from 3 to 6 partitions
1849    /// let new_count = client.create_partitions("orders", 6).await?;
1850    /// println!("Topic now has {} partitions", new_count);
1851    /// # Ok(())
1852    /// # }
1853    /// ```
1854    pub async fn create_partitions(
1855        &mut self,
1856        topic: impl Into<String>,
1857        new_partition_count: u32,
1858    ) -> Result<u32> {
1859        let request = Request::CreatePartitions {
1860            topic: topic.into(),
1861            new_partition_count,
1862            assignments: vec![], // Let broker auto-assign
1863        };
1864
1865        let response = self.send_request(request).await?;
1866
1867        match response {
1868            Response::PartitionsCreated {
1869                new_partition_count,
1870                ..
1871            } => Ok(new_partition_count),
1872            Response::Error { message } => Err(Error::ServerError(message)),
1873            _ => Err(Error::InvalidResponse),
1874        }
1875    }
1876
1877    /// Delete records before a given offset (log truncation)
1878    ///
1879    /// Removes all records with offsets less than the specified offset for each
1880    /// partition. This is useful for freeing up disk space or removing old data.
1881    ///
1882    /// # Arguments
1883    /// * `topic` - Topic name
1884    /// * `partition_offsets` - List of (partition, before_offset) pairs
1885    ///
1886    /// # Returns
1887    /// A list of results indicating the new low watermark for each partition.
1888    ///
1889    /// # Example
1890    /// ```no_run
1891    /// # use rivven_client::Client;
1892    /// # async fn example() -> rivven_client::Result<()> {
1893    /// let mut client = Client::connect("127.0.0.1:9092").await?;
1894    /// // Delete records before offset 1000 on partitions 0, 1, 2
1895    /// let results = client.delete_records("orders", &[
1896    ///     (0, 1000),
1897    ///     (1, 1000),
1898    ///     (2, 1000),
1899    /// ]).await?;
1900    /// for r in results {
1901    ///     println!("Partition {}: low watermark now {}", r.partition, r.low_watermark);
1902    /// }
1903    /// # Ok(())
1904    /// # }
1905    /// ```
1906    pub async fn delete_records(
1907        &mut self,
1908        topic: impl Into<String>,
1909        partition_offsets: &[(u32, u64)],
1910    ) -> Result<Vec<DeleteRecordsResult>> {
1911        let request = Request::DeleteRecords {
1912            topic: topic.into(),
1913            partition_offsets: partition_offsets.to_vec(),
1914        };
1915
1916        let response = self.send_request(request).await?;
1917
1918        match response {
1919            Response::RecordsDeleted { results, .. } => Ok(results),
1920            Response::Error { message } => Err(Error::ServerError(message)),
1921            _ => Err(Error::InvalidResponse),
1922        }
1923    }
1924
1925    // =========================================================================
1926    // Idempotent Producer API
1927    // =========================================================================
1928
1929    /// Initialize an idempotent producer
1930    ///
1931    /// Returns a producer ID and epoch that should be used for all subsequent
1932    /// idempotent publish operations. The broker uses these to detect and
1933    /// deduplicate messages in case of retries.
1934    ///
1935    /// # Arguments
1936    /// * `previous_producer_id` - If reconnecting, pass the previous producer_id
1937    ///   to bump the epoch (prevents zombie producers)
1938    ///
1939    /// # Returns
1940    /// `ProducerState` containing the producer_id and producer_epoch
1941    ///
1942    /// # Example
1943    /// ```no_run
1944    /// # use rivven_client::Client;
1945    /// # async fn example() -> rivven_client::Result<()> {
1946    /// let mut client = Client::connect("127.0.0.1:9092").await?;
1947    /// let producer = client.init_producer_id(None).await?;
1948    /// println!("Producer ID: {}, Epoch: {}", producer.producer_id, producer.producer_epoch);
1949    /// # Ok(())
1950    /// # }
1951    /// ```
1952    pub async fn init_producer_id(
1953        &mut self,
1954        previous_producer_id: Option<u64>,
1955    ) -> Result<ProducerState> {
1956        let request = Request::InitProducerId {
1957            producer_id: previous_producer_id,
1958        };
1959
1960        let response = self.send_request(request).await?;
1961
1962        match response {
1963            Response::ProducerIdInitialized {
1964                producer_id,
1965                producer_epoch,
1966            } => Ok(ProducerState {
1967                producer_id,
1968                producer_epoch,
1969                partition_sequences: std::collections::HashMap::new(),
1970                next_sequence: 0,
1971            }),
1972            Response::Error { message } => Err(Error::ServerError(message)),
1973            _ => Err(Error::InvalidResponse),
1974        }
1975    }
1976
1977    /// Publish a message with idempotent semantics
1978    ///
1979    /// Uses producer_id/epoch/sequence for exactly-once delivery. The broker
1980    /// deduplicates messages based on these values, making retries safe.
1981    ///
1982    /// # Arguments
1983    /// * `topic` - Topic to publish to
1984    /// * `key` - Optional message key (used for partitioning)
1985    /// * `value` - Message payload
1986    /// * `producer` - Producer state from `init_producer_id`
1987    ///
1988    /// # Returns
1989    /// Tuple of (offset, partition, was_duplicate)
1990    ///
1991    /// # Example
1992    /// ```no_run
1993    /// # use rivven_client::Client;
1994    /// # async fn example() -> rivven_client::Result<()> {
1995    /// let mut client = Client::connect("127.0.0.1:9092").await?;
1996    /// let mut producer = client.init_producer_id(None).await?;
1997    ///
1998    /// let (offset, partition, duplicate) = client
1999    ///     .publish_idempotent("orders", None::<Vec<u8>>, b"order-1".to_vec(), &mut producer)
2000    ///     .await?;
2001    ///
2002    /// println!("Published to partition {} at offset {}", partition, offset);
2003    /// if duplicate {
2004    ///     println!("(This was a retry - message already existed)");
2005    /// }
2006    /// # Ok(())
2007    /// # }
2008    /// ```
2009    pub async fn publish_idempotent(
2010        &mut self,
2011        topic: impl Into<String>,
2012        key: Option<impl Into<Bytes>>,
2013        value: impl Into<Bytes>,
2014        producer: &mut ProducerState,
2015    ) -> Result<(u64, u32, bool)> {
2016        let topic_str = topic.into();
2017        // Use per-partition sequence when a partition is known; fall back to
2018        // the global counter for server-assigned partitions. The broker
2019        // deduplicates per (producer_id, partition, sequence).
2020        let sequence = producer.next_sequence;
2021        producer.next_sequence = producer.next_sequence.wrapping_add(1);
2022        // After wrapping past i32::MAX, reset to 1 (not 0)
2023        // because sequence 0 was used for the first message. Reusing 0
2024        // could collide with the broker's dedup window.
2025        if producer.next_sequence <= 0 {
2026            producer.next_sequence = 1;
2027        }
2028
2029        let request = Request::IdempotentPublish {
2030            topic: topic_str,
2031            partition: None,
2032            key: key.map(|k| k.into()),
2033            value: value.into(),
2034            producer_id: producer.producer_id,
2035            producer_epoch: producer.producer_epoch,
2036            sequence,
2037            leader_epoch: None,
2038        };
2039
2040        let response = self.send_request(request).await?;
2041
2042        match response {
2043            Response::IdempotentPublished {
2044                offset,
2045                partition,
2046                duplicate,
2047            } => Ok((offset, partition, duplicate)),
2048            Response::Error { message } => Err(Error::ServerError(message)),
2049            _ => Err(Error::InvalidResponse),
2050        }
2051    }
2052
2053    /// Publish a message idempotently to a **specific** partition.
2054    ///
2055    /// Uses the per-(topic, partition) sequence counter from `ProducerState`
2056    /// instead of the global counter, giving the broker correct dedup tracking
2057    /// per partition.
2058    pub async fn publish_idempotent_to_partition(
2059        &mut self,
2060        topic: impl Into<String>,
2061        partition: u32,
2062        key: Option<impl Into<Bytes>>,
2063        value: impl Into<Bytes>,
2064        producer: &mut ProducerState,
2065    ) -> Result<(u64, u32, bool)> {
2066        let topic_str = topic.into();
2067        let sequence = producer.next_sequence_for(&topic_str, partition);
2068
2069        let request = Request::IdempotentPublish {
2070            topic: topic_str,
2071            partition: Some(partition),
2072            key: key.map(|k| k.into()),
2073            value: value.into(),
2074            producer_id: producer.producer_id,
2075            producer_epoch: producer.producer_epoch,
2076            sequence,
2077            leader_epoch: None,
2078        };
2079
2080        let response = self.send_request(request).await?;
2081
2082        match response {
2083            Response::IdempotentPublished {
2084                offset,
2085                partition: resp_partition,
2086                duplicate,
2087            } => Ok((offset, resp_partition, duplicate)),
2088            Response::Error { message } => Err(Error::ServerError(message)),
2089            _ => Err(Error::InvalidResponse),
2090        }
2091    }
2092
2093    // =========================================================================
2094    // Transaction API
2095    // =========================================================================
2096
2097    /// Begin a new transaction
2098    ///
2099    /// Starts a transaction that can span multiple topics and partitions.
2100    /// All writes within the transaction are atomic - they either all succeed
2101    /// or all fail together.
2102    ///
2103    /// # Arguments
2104    /// * `txn_id` - Unique transaction identifier (should be stable per producer)
2105    /// * `producer` - Producer state from `init_producer_id`
2106    /// * `timeout_ms` - Optional transaction timeout (defaults to 60s)
2107    ///
2108    /// # Example
2109    /// ```no_run
2110    /// # use rivven_client::Client;
2111    /// # async fn example() -> rivven_client::Result<()> {
2112    /// let mut client = Client::connect("127.0.0.1:9092").await?;
2113    /// let producer = client.init_producer_id(None).await?;
2114    ///
2115    /// // Start a transaction
2116    /// client.begin_transaction("txn-1", &producer, None).await?;
2117    /// // ... publish messages ...
2118    /// client.commit_transaction("txn-1", &producer).await?;
2119    /// # Ok(())
2120    /// # }
2121    /// ```
2122    pub async fn begin_transaction(
2123        &mut self,
2124        txn_id: impl Into<String>,
2125        producer: &ProducerState,
2126        timeout_ms: Option<u64>,
2127    ) -> Result<()> {
2128        let request = Request::BeginTransaction {
2129            txn_id: txn_id.into(),
2130            producer_id: producer.producer_id,
2131            producer_epoch: producer.producer_epoch,
2132            timeout_ms,
2133        };
2134
2135        let response = self.send_request(request).await?;
2136
2137        match response {
2138            Response::TransactionStarted { .. } => Ok(()),
2139            Response::Error { message } => Err(Error::ServerError(message)),
2140            _ => Err(Error::InvalidResponse),
2141        }
2142    }
2143
2144    /// Add partitions to an active transaction
2145    ///
2146    /// Registers partitions that will be written to within the transaction.
2147    /// This must be called before publishing to a new partition.
2148    ///
2149    /// # Arguments
2150    /// * `txn_id` - Transaction identifier
2151    /// * `producer` - Producer state
2152    /// * `partitions` - List of (topic, partition) pairs to add
2153    pub async fn add_partitions_to_txn(
2154        &mut self,
2155        txn_id: impl Into<String>,
2156        producer: &ProducerState,
2157        partitions: &[(&str, u32)],
2158    ) -> Result<usize> {
2159        let request = Request::AddPartitionsToTxn {
2160            txn_id: txn_id.into(),
2161            producer_id: producer.producer_id,
2162            producer_epoch: producer.producer_epoch,
2163            partitions: partitions
2164                .iter()
2165                .map(|(t, p)| (t.to_string(), *p))
2166                .collect(),
2167        };
2168
2169        let response = self.send_request(request).await?;
2170
2171        match response {
2172            Response::PartitionsAddedToTxn {
2173                partition_count, ..
2174            } => Ok(partition_count),
2175            Response::Error { message } => Err(Error::ServerError(message)),
2176            _ => Err(Error::InvalidResponse),
2177        }
2178    }
2179
2180    /// Publish a message within a transaction
2181    ///
2182    /// Like `publish_idempotent`, but the message is only visible to consumers
2183    /// after the transaction is committed.
2184    ///
2185    /// # Arguments
2186    /// * `txn_id` - Transaction identifier
2187    /// * `topic` - Topic to publish to
2188    /// * `key` - Optional message key
2189    /// * `value` - Message payload
2190    /// * `producer` - Producer state with sequence tracking
2191    ///
2192    /// # Returns
2193    /// Tuple of (offset, partition, sequence) - offset is pending until commit
2194    pub async fn publish_transactional(
2195        &mut self,
2196        txn_id: impl Into<String>,
2197        topic: impl Into<String>,
2198        key: Option<impl Into<Bytes>>,
2199        value: impl Into<Bytes>,
2200        producer: &mut ProducerState,
2201    ) -> Result<(u64, u32, i32)> {
2202        let sequence = producer.next_sequence;
2203        producer.next_sequence = producer.next_sequence.wrapping_add(1);
2204        // Avoid reusing sequence 0 after wrap
2205        if producer.next_sequence <= 0 {
2206            producer.next_sequence = 1;
2207        }
2208
2209        let request = Request::TransactionalPublish {
2210            txn_id: txn_id.into(),
2211            topic: topic.into(),
2212            partition: None,
2213            key: key.map(|k| k.into()),
2214            value: value.into(),
2215            producer_id: producer.producer_id,
2216            producer_epoch: producer.producer_epoch,
2217            sequence,
2218            leader_epoch: None,
2219        };
2220
2221        let response = self.send_request(request).await?;
2222
2223        match response {
2224            Response::TransactionalPublished {
2225                offset,
2226                partition,
2227                sequence,
2228            } => Ok((offset, partition, sequence)),
2229            Response::Error { message } => Err(Error::ServerError(message)),
2230            _ => Err(Error::InvalidResponse),
2231        }
2232    }
2233
2234    /// Add consumer offsets to a transaction
2235    ///
2236    /// For exactly-once consume-transform-produce patterns: commits consumer
2237    /// offsets atomically with the produced messages.
2238    ///
2239    /// # Arguments
2240    /// * `txn_id` - Transaction identifier
2241    /// * `producer` - Producer state
2242    /// * `group_id` - Consumer group ID
2243    /// * `offsets` - List of (topic, partition, offset) to commit
2244    pub async fn add_offsets_to_txn(
2245        &mut self,
2246        txn_id: impl Into<String>,
2247        producer: &ProducerState,
2248        group_id: impl Into<String>,
2249        offsets: &[(&str, u32, i64)],
2250    ) -> Result<()> {
2251        let request = Request::AddOffsetsToTxn {
2252            txn_id: txn_id.into(),
2253            producer_id: producer.producer_id,
2254            producer_epoch: producer.producer_epoch,
2255            group_id: group_id.into(),
2256            offsets: offsets
2257                .iter()
2258                .map(|(t, p, o)| (t.to_string(), *p, *o))
2259                .collect(),
2260        };
2261
2262        let response = self.send_request(request).await?;
2263
2264        match response {
2265            Response::OffsetsAddedToTxn { .. } => Ok(()),
2266            Response::Error { message } => Err(Error::ServerError(message)),
2267            _ => Err(Error::InvalidResponse),
2268        }
2269    }
2270
2271    /// Commit a transaction
2272    ///
2273    /// Makes all writes in the transaction visible to consumers atomically.
2274    /// If this fails, the transaction should be aborted.
2275    ///
2276    /// # Arguments
2277    /// * `txn_id` - Transaction identifier
2278    /// * `producer` - Producer state
2279    pub async fn commit_transaction(
2280        &mut self,
2281        txn_id: impl Into<String>,
2282        producer: &ProducerState,
2283    ) -> Result<()> {
2284        let request = Request::CommitTransaction {
2285            txn_id: txn_id.into(),
2286            producer_id: producer.producer_id,
2287            producer_epoch: producer.producer_epoch,
2288        };
2289
2290        let response = self.send_request(request).await?;
2291
2292        match response {
2293            Response::TransactionCommitted { .. } => Ok(()),
2294            Response::Error { message } => Err(Error::ServerError(message)),
2295            _ => Err(Error::InvalidResponse),
2296        }
2297    }
2298
2299    /// Abort a transaction
2300    ///
2301    /// Discards all writes in the transaction. Call this if any write fails
2302    /// or if you need to cancel the transaction.
2303    ///
2304    /// # Arguments
2305    /// * `txn_id` - Transaction identifier
2306    /// * `producer` - Producer state
2307    pub async fn abort_transaction(
2308        &mut self,
2309        txn_id: impl Into<String>,
2310        producer: &ProducerState,
2311    ) -> Result<()> {
2312        let request = Request::AbortTransaction {
2313            txn_id: txn_id.into(),
2314            producer_id: producer.producer_id,
2315            producer_epoch: producer.producer_epoch,
2316        };
2317
2318        let response = self.send_request(request).await?;
2319
2320        match response {
2321            Response::TransactionAborted { .. } => Ok(()),
2322            Response::Error { message } => Err(Error::ServerError(message)),
2323            _ => Err(Error::InvalidResponse),
2324        }
2325    }
2326}
2327
2328/// State for an idempotent/transactional producer
2329#[derive(Debug, Clone)]
2330pub struct ProducerState {
2331    /// Producer ID assigned by the broker
2332    pub producer_id: u64,
2333    /// Current epoch (increments on reconnect)
2334    pub producer_epoch: u16,
2335    /// Per-partition sequence numbers for idempotent produce.
2336    /// When the client specifies a partition, the per-partition sequence
2337    /// is used for correct broker-side deduplication.
2338    /// Key: (topic, partition), Value: next sequence number.
2339    pub partition_sequences: std::collections::HashMap<(String, u32), i32>,
2340    /// Global sequence counter used when the partition is server-assigned
2341    /// (partition=None). The broker tracks this per producer_id.
2342    pub next_sequence: i32,
2343}
2344
2345impl ProducerState {
2346    /// Get the next sequence number for a specific topic-partition,
2347    /// initializing to 1 if this is the first message to that partition.
2348    pub fn next_sequence_for(&mut self, topic: &str, partition: u32) -> i32 {
2349        let seq = self
2350            .partition_sequences
2351            .entry((topic.to_string(), partition))
2352            .or_insert(1);
2353        let current = *seq;
2354        *seq = seq.wrapping_add(1);
2355        if *seq <= 0 {
2356            *seq = 1;
2357        }
2358        current
2359    }
2360}
2361
2362/// Result of altering topic configuration
2363#[derive(Debug, Clone)]
2364pub struct AlterTopicConfigResult {
2365    /// Topic name
2366    pub topic: String,
2367    /// Number of configurations changed
2368    pub changed_count: usize,
2369}
2370
2371/// Result of deleting records from a partition
2372pub use rivven_protocol::DeleteRecordsResult;
2373
2374// ============================================================================
2375// Authentication Session
2376// ============================================================================
2377
2378/// Authentication session information
2379#[derive(Debug, Clone)]
2380pub struct AuthSession {
2381    /// Session ID for subsequent requests
2382    pub session_id: String,
2383    /// Session timeout in seconds
2384    pub expires_in: u64,
2385}
2386
2387// ============================================================================
2388// SCRAM Helper Functions
2389// ============================================================================
2390
2391/// Generate a random nonce for SCRAM authentication
2392pub(crate) fn generate_nonce() -> String {
2393    use rand::Rng;
2394    let mut rng = rand::thread_rng();
2395    let nonce_bytes: Vec<u8> = (0..24).map(|_| rng.gen()).collect();
2396    base64_encode(&nonce_bytes)
2397}
2398
2399/// Escape username for SCRAM (RFC 5802)
2400pub(crate) fn escape_username(username: &str) -> String {
2401    username.replace('=', "=3D").replace(',', "=2C")
2402}
2403
2404/// Parse server-first message
2405pub(crate) fn parse_server_first(server_first: &str) -> Result<(String, String, u32)> {
2406    let mut nonce = None;
2407    let mut salt = None;
2408    let mut iterations = None;
2409
2410    for attr in server_first.split(',') {
2411        if let Some(value) = attr.strip_prefix("r=") {
2412            nonce = Some(value.to_string());
2413        } else if let Some(value) = attr.strip_prefix("s=") {
2414            salt = Some(value.to_string());
2415        } else if let Some(value) = attr.strip_prefix("i=") {
2416            iterations = Some(
2417                value
2418                    .parse::<u32>()
2419                    .map_err(|_| Error::AuthenticationFailed("Invalid iteration count".into()))?,
2420            );
2421        }
2422    }
2423
2424    let nonce = nonce.ok_or_else(|| Error::AuthenticationFailed("Missing nonce".into()))?;
2425    let salt = salt.ok_or_else(|| Error::AuthenticationFailed("Missing salt".into()))?;
2426    let iterations =
2427        iterations.ok_or_else(|| Error::AuthenticationFailed("Missing iterations".into()))?;
2428
2429    // RFC 7677 §4 requires a minimum of 4096 iterations for SCRAM-SHA-256.
2430    // A malicious server could send i=1 to weaken key derivation.
2431    if iterations < 4096 {
2432        return Err(Error::AuthenticationFailed(format!(
2433            "SCRAM iteration count {} is below minimum 4096 (possible downgrade attack)",
2434            iterations
2435        )));
2436    }
2437
2438    Ok((nonce, salt, iterations))
2439}
2440
2441/// PBKDF2-HMAC-SHA256 key derivation
2442pub(crate) fn pbkdf2_sha256(password: &[u8], salt: &[u8], iterations: u32) -> Vec<u8> {
2443    let mut result = vec![0u8; 32];
2444
2445    // U1 = PRF(Password, Salt || INT(1))
2446    let mut u = PasswordHash::hmac_sha256(password, &[salt, &1u32.to_be_bytes()].concat());
2447    result.copy_from_slice(&u);
2448
2449    // Ui = PRF(Password, Ui-1)
2450    for _ in 1..iterations {
2451        u = PasswordHash::hmac_sha256(password, &u);
2452        for (r, ui) in result.iter_mut().zip(u.iter()) {
2453            *r ^= ui;
2454        }
2455    }
2456
2457    result
2458}
2459
2460/// SHA-256 hash
2461pub(crate) fn sha256(data: &[u8]) -> Vec<u8> {
2462    let mut hasher = Sha256::new();
2463    hasher.update(data);
2464    hasher.finalize().to_vec()
2465}
2466
2467/// XOR two byte arrays
2468pub(crate) fn xor_bytes(a: &[u8], b: &[u8]) -> Vec<u8> {
2469    a.iter().zip(b.iter()).map(|(x, y)| x ^ y).collect()
2470}
2471
2472/// Base64 encode
2473pub(crate) fn base64_encode(data: &[u8]) -> String {
2474    use base64::{engine::general_purpose::STANDARD, Engine};
2475    STANDARD.encode(data)
2476}
2477
2478/// Base64 decode
2479pub(crate) fn base64_decode(data: &str) -> std::result::Result<Vec<u8>, base64::DecodeError> {
2480    use base64::{engine::general_purpose::STANDARD, Engine};
2481    STANDARD.decode(data)
2482}
2483
2484// ============================================================================
2485// Tests
2486// ============================================================================
2487
2488#[cfg(test)]
2489mod tests {
2490    use super::*;
2491
2492    #[test]
2493    fn test_escape_username() {
2494        assert_eq!(escape_username("alice"), "alice");
2495        assert_eq!(escape_username("user=name"), "user=3Dname");
2496        assert_eq!(escape_username("user,name"), "user=2Cname");
2497        assert_eq!(escape_username("user=,name"), "user=3D=2Cname");
2498    }
2499
2500    #[test]
2501    fn test_parse_server_first() {
2502        let server_first = "r=clientnonce+servernonce,s=c2FsdA==,i=4096";
2503        let (nonce, salt, iterations) = parse_server_first(server_first).unwrap();
2504
2505        assert_eq!(nonce, "clientnonce+servernonce");
2506        assert_eq!(salt, "c2FsdA==");
2507        assert_eq!(iterations, 4096);
2508    }
2509
2510    #[test]
2511    fn test_parse_server_first_missing_nonce() {
2512        let server_first = "s=c2FsdA==,i=4096";
2513        assert!(parse_server_first(server_first).is_err());
2514    }
2515
2516    #[test]
2517    fn test_parse_server_first_missing_salt() {
2518        let server_first = "r=nonce,i=4096";
2519        assert!(parse_server_first(server_first).is_err());
2520    }
2521
2522    #[test]
2523    fn test_parse_server_first_missing_iterations() {
2524        let server_first = "r=nonce,s=c2FsdA==";
2525        assert!(parse_server_first(server_first).is_err());
2526    }
2527
2528    #[test]
2529    fn test_xor_bytes() {
2530        assert_eq!(xor_bytes(&[0xFF, 0x00], &[0xFF, 0xFF]), vec![0x00, 0xFF]);
2531        assert_eq!(xor_bytes(&[0x12, 0x34], &[0x12, 0x34]), vec![0x00, 0x00]);
2532    }
2533
2534    #[test]
2535    fn test_base64_roundtrip() {
2536        let data = b"hello world";
2537        let encoded = base64_encode(data);
2538        let decoded = base64_decode(&encoded).unwrap();
2539        assert_eq!(decoded, data);
2540    }
2541
2542    #[test]
2543    fn test_sha256() {
2544        // SHA-256 of empty string
2545        let hash = sha256(b"");
2546        assert_eq!(hash.len(), 32);
2547        // Known hash value
2548        assert_eq!(
2549            hex::encode(&hash),
2550            "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"
2551        );
2552    }
2553
2554    #[test]
2555    fn test_pbkdf2_sha256() {
2556        // Test vector from RFC 7914 (derived from RFC 6070)
2557        let password = b"password";
2558        let salt = b"salt";
2559        let iterations = 1;
2560
2561        let result = pbkdf2_sha256(password, salt, iterations);
2562        assert_eq!(result.len(), 32);
2563        // The result should be deterministic
2564        let result2 = pbkdf2_sha256(password, salt, iterations);
2565        assert_eq!(result, result2);
2566    }
2567
2568    #[test]
2569    fn test_generate_nonce() {
2570        let nonce1 = generate_nonce();
2571        let nonce2 = generate_nonce();
2572
2573        // Nonces should be non-empty
2574        assert!(!nonce1.is_empty());
2575        assert!(!nonce2.is_empty());
2576
2577        // Nonces should be different (with overwhelming probability)
2578        assert_ne!(nonce1, nonce2);
2579
2580        // Should be valid base64
2581        assert!(base64_decode(&nonce1).is_ok());
2582    }
2583}