1use crate::{Error, MessageData, Request, Response, Result};
2use bytes::Bytes;
3use rivven_core::PasswordHash;
4use rivven_protocol::SyncGroupAssignments;
5use sha2::{Digest, Sha256};
6use std::time::Duration;
7use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
8use tokio::net::TcpStream;
9use tracing::{debug, info};
10
11#[cfg(feature = "tls")]
12use rivven_core::tls::{TlsClientStream, TlsConfig, TlsConnector};
13
14const DEFAULT_MAX_RESPONSE_SIZE: usize = 100 * 1024 * 1024;
16
17const DEFAULT_MAX_REQUEST_SIZE: usize = rivven_protocol::MAX_MESSAGE_SIZE;
22
23const DEFAULT_CONNECTION_TIMEOUT: Duration = Duration::from_secs(10);
26
27const DEFAULT_REQUEST_TIMEOUT: Duration = Duration::from_secs(30);
30
31#[allow(clippy::large_enum_variant)]
39pub(crate) enum ClientStream {
40 Plaintext(TcpStream),
41 #[cfg(feature = "tls")]
42 Tls(TlsClientStream<TcpStream>),
43}
44
45impl AsyncRead for ClientStream {
46 fn poll_read(
47 self: std::pin::Pin<&mut Self>,
48 cx: &mut std::task::Context<'_>,
49 buf: &mut tokio::io::ReadBuf<'_>,
50 ) -> std::task::Poll<std::io::Result<()>> {
51 match self.get_mut() {
52 ClientStream::Plaintext(s) => std::pin::Pin::new(s).poll_read(cx, buf),
53 #[cfg(feature = "tls")]
54 ClientStream::Tls(s) => std::pin::Pin::new(s).poll_read(cx, buf),
55 }
56 }
57}
58
59impl AsyncWrite for ClientStream {
60 fn poll_write(
61 self: std::pin::Pin<&mut Self>,
62 cx: &mut std::task::Context<'_>,
63 buf: &[u8],
64 ) -> std::task::Poll<std::io::Result<usize>> {
65 match self.get_mut() {
66 ClientStream::Plaintext(s) => std::pin::Pin::new(s).poll_write(cx, buf),
67 #[cfg(feature = "tls")]
68 ClientStream::Tls(s) => std::pin::Pin::new(s).poll_write(cx, buf),
69 }
70 }
71
72 fn poll_flush(
73 self: std::pin::Pin<&mut Self>,
74 cx: &mut std::task::Context<'_>,
75 ) -> std::task::Poll<std::io::Result<()>> {
76 match self.get_mut() {
77 ClientStream::Plaintext(s) => std::pin::Pin::new(s).poll_flush(cx),
78 #[cfg(feature = "tls")]
79 ClientStream::Tls(s) => std::pin::Pin::new(s).poll_flush(cx),
80 }
81 }
82
83 fn poll_shutdown(
84 self: std::pin::Pin<&mut Self>,
85 cx: &mut std::task::Context<'_>,
86 ) -> std::task::Poll<std::io::Result<()>> {
87 match self.get_mut() {
88 ClientStream::Plaintext(s) => std::pin::Pin::new(s).poll_shutdown(cx),
89 #[cfg(feature = "tls")]
90 ClientStream::Tls(s) => std::pin::Pin::new(s).poll_shutdown(cx),
91 }
92 }
93}
94
95pub struct Client {
101 stream: ClientStream,
102 next_correlation_id: u32,
103 request_timeout: Duration,
105 poisoned: bool,
109}
110
111impl Client {
112 pub async fn connect(addr: &str) -> Result<Self> {
121 Self::connect_with_timeout(addr, DEFAULT_CONNECTION_TIMEOUT).await
122 }
123
124 pub async fn connect_with_timeout(addr: &str, timeout: Duration) -> Result<Self> {
129 info!("Connecting to Rivven server at {}", addr);
130 let stream = tokio::time::timeout(timeout, TcpStream::connect(addr))
131 .await
132 .map_err(|_| Error::TimeoutWithMessage(format!("Connection to {} timed out", addr)))?
133 .map_err(|e| Error::ConnectionError(e.to_string()))?;
134
135 let _ = stream.set_nodelay(true);
137
138 let mut client = Self {
139 stream: ClientStream::Plaintext(stream),
140 next_correlation_id: 0,
141 request_timeout: DEFAULT_REQUEST_TIMEOUT,
142 poisoned: false,
143 };
144
145 client.handshake("rivven-client").await?;
147
148 Ok(client)
149 }
150
151 #[cfg(feature = "tls")]
153 pub async fn connect_tls(
154 addr: &str,
155 tls_config: &TlsConfig,
156 server_name: &str,
157 ) -> Result<Self> {
158 Self::connect_tls_with_timeout(addr, tls_config, server_name, DEFAULT_CONNECTION_TIMEOUT)
159 .await
160 }
161
162 #[cfg(feature = "tls")]
164 pub async fn connect_tls_with_timeout(
165 addr: &str,
166 tls_config: &TlsConfig,
167 server_name: &str,
168 timeout: Duration,
169 ) -> Result<Self> {
170 info!("Connecting to Rivven server at {} with TLS", addr);
171
172 let tcp_stream = tokio::time::timeout(timeout, TcpStream::connect(addr))
175 .await
176 .map_err(|_| {
177 Error::TimeoutWithMessage(format!("TLS connection to {} timed out", addr))
178 })?
179 .map_err(|e| Error::ConnectionError(format!("TCP connection error: {}", e)))?;
180
181 tcp_stream
183 .set_nodelay(true)
184 .map_err(|e| Error::ConnectionError(format!("Failed to set TCP_NODELAY: {}", e)))?;
185
186 let connector = TlsConnector::new(tls_config)
188 .map_err(|e| Error::ConnectionError(format!("TLS config error: {}", e)))?;
189
190 let tls_stream = connector
192 .connect(tcp_stream, server_name)
193 .await
194 .map_err(|e| Error::ConnectionError(format!("TLS handshake error: {}", e)))?;
195
196 info!("TLS connection established to {} ({})", addr, server_name);
197
198 let mut client = Self {
199 stream: ClientStream::Tls(tls_stream),
200 next_correlation_id: 0,
201 request_timeout: DEFAULT_REQUEST_TIMEOUT,
202 poisoned: false,
203 };
204
205 client.handshake("rivven-client").await?;
207
208 Ok(client)
209 }
210
211 #[cfg(feature = "tls")]
213 pub async fn connect_mtls(
214 addr: &str,
215 cert_path: impl Into<std::path::PathBuf>,
216 key_path: impl Into<std::path::PathBuf>,
217 ca_path: impl Into<std::path::PathBuf> + Clone,
218 server_name: &str,
219 ) -> Result<Self> {
220 let tls_config = TlsConfig::mtls_from_pem_files(cert_path, key_path, ca_path);
221 Self::connect_tls(addr, &tls_config, server_name).await
222 }
223
224 pub async fn handshake(&mut self, client_id: &str) -> Result<()> {
233 let request = Request::Handshake {
234 protocol_version: rivven_protocol::PROTOCOL_VERSION,
235 client_id: client_id.to_string(),
236 };
237
238 let response = self.send_request(request).await?;
239
240 match response {
241 Response::HandshakeResult {
242 compatible,
243 message: _,
244 server_version,
245 } => {
246 if compatible {
247 info!(
248 "Handshake OK (client v{}, server v{})",
249 rivven_protocol::PROTOCOL_VERSION,
250 server_version
251 );
252 Ok(())
253 } else {
254 Err(Error::ProtocolError(
255 rivven_protocol::ProtocolError::VersionMismatch {
256 expected: rivven_protocol::PROTOCOL_VERSION,
257 actual: server_version,
258 },
259 ))
260 }
261 }
262 Response::Error { message } => {
263 tracing::warn!(
266 "Server returned error on handshake: {}, proceeding anyway",
267 message
268 );
269 Ok(())
270 }
271 _ => {
272 tracing::warn!(
274 "Server did not return HandshakeResult, proceeding without version check"
275 );
276 Ok(())
277 }
278 }
279 }
280
281 pub fn into_stream(self) -> Result<TcpStream> {
287 match self.stream {
288 ClientStream::Plaintext(s) => Ok(s),
289 #[cfg(feature = "tls")]
290 ClientStream::Tls(_) => Err(Error::ConnectionError(
291 "Cannot extract TcpStream from TLS connection. Use into_client_stream() instead."
292 .to_string(),
293 )),
294 }
295 }
296
297 pub(crate) fn into_client_stream(self) -> ClientStream {
303 self.stream
304 }
305
306 pub fn set_request_timeout(&mut self, timeout: Duration) {
308 self.request_timeout = timeout;
309 }
310
311 pub fn is_tls(&self) -> bool {
313 match &self.stream {
314 ClientStream::Plaintext(_) => false,
315 #[cfg(feature = "tls")]
316 ClientStream::Tls(_) => true,
317 }
318 }
319
320 #[allow(deprecated)]
334 pub async fn authenticate(&mut self, username: &str, password: &str) -> Result<AuthSession> {
335 if !self.is_tls() {
338 return Err(Error::AuthenticationFailed(
339 "SASL/PLAIN requires a TLS connection — use authenticate_scram() for plaintext channels".to_string(),
340 ));
341 }
342
343 let require_tls = true; let request = Request::Authenticate {
347 username: username.to_string(),
348 password: password.to_string(),
349 require_tls,
350 };
351
352 let response = self.send_request(request).await?;
353
354 match response {
355 Response::Authenticated {
356 session_id,
357 expires_in,
358 } => {
359 info!("Authenticated as '{}'", username);
360 Ok(AuthSession {
361 session_id,
362 expires_in,
363 })
364 }
365 Response::Error { message } => Err(Error::AuthenticationFailed(message)),
366 _ => Err(Error::InvalidResponse),
367 }
368 }
369
370 pub async fn authenticate_scram(
388 &mut self,
389 username: &str,
390 password: &str,
391 ) -> Result<AuthSession> {
392 let client_nonce = generate_nonce();
394 let client_first_bare = format!("n={},r={}", escape_username(username), client_nonce);
395 let client_first = format!("n,,{}", client_first_bare);
396
397 debug!("SCRAM: Sending client-first");
398 let request = Request::ScramClientFirst {
399 message: Bytes::from(client_first.clone()),
400 };
401
402 let response = self.send_request(request).await?;
403
404 let server_first = match response {
406 Response::ScramServerFirst { message } => String::from_utf8(message.to_vec())
407 .map_err(|_| Error::AuthenticationFailed("Invalid server-first encoding".into()))?,
408 Response::Error { message } => return Err(Error::AuthenticationFailed(message)),
409 _ => return Err(Error::InvalidResponse),
410 };
411
412 debug!("SCRAM: Received server-first");
413
414 let (combined_nonce, salt_b64, iterations) = parse_server_first(&server_first)?;
416
417 if !combined_nonce.starts_with(&client_nonce) {
419 return Err(Error::AuthenticationFailed("Server nonce mismatch".into()));
420 }
421
422 let salt = base64_decode(&salt_b64)
424 .map_err(|_| Error::AuthenticationFailed("Invalid salt encoding".into()))?;
425
426 let salted_password = pbkdf2_sha256(password.as_bytes(), &salt, iterations);
428 let client_key = PasswordHash::hmac_sha256(&salted_password, b"Client Key");
429 let stored_key = sha256(&client_key);
430
431 let client_final_without_proof = format!("c=biws,r={}", combined_nonce);
432 let auth_message = format!(
433 "{},{},{}",
434 client_first_bare, server_first, client_final_without_proof
435 );
436
437 let client_signature = PasswordHash::hmac_sha256(&stored_key, auth_message.as_bytes());
438 let client_proof = xor_bytes(&client_key, &client_signature);
439 let client_proof_b64 = base64_encode(&client_proof);
440
441 let client_final = format!("{},p={}", client_final_without_proof, client_proof_b64);
443
444 debug!("SCRAM: Sending client-final");
445 let request = Request::ScramClientFinal {
446 message: Bytes::from(client_final),
447 };
448
449 let response = self.send_request(request).await?;
450
451 match response {
453 Response::ScramServerFinal {
454 message,
455 session_id,
456 expires_in,
457 } => {
458 let server_final = String::from_utf8(message.to_vec()).map_err(|_| {
459 Error::AuthenticationFailed("Invalid server-final encoding".into())
460 })?;
461
462 if let Some(error_msg) = server_final.strip_prefix("e=") {
464 return Err(Error::AuthenticationFailed(error_msg.to_string()));
465 }
466
467 if let Some(verifier_b64) = server_final.strip_prefix("v=") {
469 let server_key = PasswordHash::hmac_sha256(&salted_password, b"Server Key");
470 let expected_server_sig =
471 PasswordHash::hmac_sha256(&server_key, auth_message.as_bytes());
472 let expected_verifier = base64_encode(&expected_server_sig);
473
474 if verifier_b64 != expected_verifier {
475 return Err(Error::AuthenticationFailed(
476 "Server verification failed".into(),
477 ));
478 }
479 }
480
481 let session_id = session_id.ok_or_else(|| {
482 Error::AuthenticationFailed("No session ID in response".into())
483 })?;
484 let expires_in = expires_in
485 .ok_or_else(|| Error::AuthenticationFailed("No expiry in response".into()))?;
486
487 info!("SCRAM authentication successful for '{}'", username);
488 Ok(AuthSession {
489 session_id,
490 expires_in,
491 })
492 }
493 Response::Error { message } => Err(Error::AuthenticationFailed(message)),
494 _ => Err(Error::InvalidResponse),
495 }
496 }
497
498 pub(crate) async fn send_request(&mut self, request: Request) -> Result<Response> {
509 let timeout_dur = self.request_timeout;
510 match tokio::time::timeout(timeout_dur, self.send_request_inner(request)).await {
511 Ok(result) => result,
512 Err(_elapsed) => {
513 self.poisoned = true;
516 Err(Error::Timeout)
517 }
518 }
519 }
520
521 async fn send_request_inner(&mut self, request: Request) -> Result<Response> {
523 if self.poisoned {
526 return Err(Error::ConnectionError(
527 "Client stream is desynchronized — reconnect required".into(),
528 ));
529 }
530
531 let correlation_id = self.next_correlation_id;
533 self.next_correlation_id = self.next_correlation_id.wrapping_add(1);
534
535 let request_bytes =
537 request.to_wire(rivven_protocol::WireFormat::Postcard, correlation_id)?;
538
539 if request_bytes.len() > DEFAULT_MAX_REQUEST_SIZE {
543 return Err(Error::RequestTooLarge(
544 request_bytes.len(),
545 DEFAULT_MAX_REQUEST_SIZE,
546 ));
547 }
548
549 let len: u32 = request_bytes
554 .len()
555 .try_into()
556 .map_err(|_| Error::RequestTooLarge(request_bytes.len(), u32::MAX as usize))?;
557 self.stream
558 .write_all(&len.to_be_bytes())
559 .await
560 .map_err(|e| {
561 self.poisoned = true;
562 Error::from(e)
563 })?;
564 self.stream.write_all(&request_bytes).await.map_err(|e| {
565 self.poisoned = true;
566 Error::from(e)
567 })?;
568 self.stream.flush().await.map_err(|e| {
569 self.poisoned = true;
570 Error::from(e)
571 })?;
572
573 let mut len_buf = [0u8; 4];
575 self.stream.read_exact(&mut len_buf).await.map_err(|e| {
576 self.poisoned = true;
577 Error::from(e)
578 })?;
579 let msg_len = u32::from_be_bytes(len_buf) as usize;
580
581 if msg_len > DEFAULT_MAX_RESPONSE_SIZE {
583 self.poisoned = true;
584 return Err(Error::ResponseTooLarge(msg_len, DEFAULT_MAX_RESPONSE_SIZE));
585 }
586
587 let mut response_buf = vec![0u8; msg_len];
589 self.stream
590 .read_exact(&mut response_buf)
591 .await
592 .map_err(|e| {
593 self.poisoned = true;
594 Error::from(e)
595 })?;
596
597 let (response, _format, response_correlation_id) = Response::from_wire(&response_buf)?;
601
602 if response_correlation_id != correlation_id {
606 self.poisoned = true;
609 return Err(Error::ProtocolError(
610 rivven_protocol::ProtocolError::InvalidFormat(format!(
611 "Correlation ID mismatch: expected {}, got {}",
612 correlation_id, response_correlation_id
613 )),
614 ));
615 }
616
617 Ok(response)
618 }
619
620 pub async fn consume_pipelined(
630 &mut self,
631 fetches: &[(&str, u32, u64, u32, Option<u8>)],
632 ) -> Result<Vec<Result<Vec<MessageData>>>> {
633 if fetches.is_empty() {
634 return Ok(Vec::new());
635 }
636 if self.poisoned {
637 return Err(Error::ConnectionError(
638 "Client stream is desynchronized — reconnect required".into(),
639 ));
640 }
641
642 let timeout_dur = self.request_timeout;
643 match tokio::time::timeout(timeout_dur, self.consume_pipelined_inner(fetches)).await {
644 Ok(result) => result,
645 Err(_elapsed) => {
646 self.poisoned = true;
647 Err(Error::Timeout)
648 }
649 }
650 }
651
652 async fn consume_pipelined_inner(
653 &mut self,
654 fetches: &[(&str, u32, u64, u32, Option<u8>)],
655 ) -> Result<Vec<Result<Vec<MessageData>>>> {
656 let mut correlation_ids = Vec::with_capacity(fetches.len());
657 let mut bytes_sent = false;
658
659 for &(topic, partition, offset, max_messages, isolation_level) in fetches {
661 let correlation_id = self.next_correlation_id;
662 self.next_correlation_id = self.next_correlation_id.wrapping_add(1);
663 correlation_ids.push(correlation_id);
664
665 let request = Request::Consume {
666 topic: topic.to_string(),
667 partition,
668 offset,
669 max_messages,
670 isolation_level,
671 max_wait_ms: None,
672 };
673
674 let request_bytes = request
675 .to_wire(rivven_protocol::WireFormat::Postcard, correlation_id)
676 .inspect_err(|_| {
677 if bytes_sent {
678 self.poisoned = true;
679 }
680 })?;
681
682 if request_bytes.len() > DEFAULT_MAX_REQUEST_SIZE {
683 if bytes_sent {
684 self.poisoned = true;
685 }
686 return Err(Error::RequestTooLarge(
687 request_bytes.len(),
688 DEFAULT_MAX_REQUEST_SIZE,
689 ));
690 }
691
692 let len: u32 = request_bytes.len().try_into().map_err(|_| {
693 if bytes_sent {
694 self.poisoned = true;
695 }
696 Error::RequestTooLarge(request_bytes.len(), u32::MAX as usize)
697 })?;
698 self.stream
699 .write_all(&len.to_be_bytes())
700 .await
701 .map_err(|e| {
702 if bytes_sent {
703 self.poisoned = true;
704 }
705 Error::from(e)
706 })?;
707 self.stream.write_all(&request_bytes).await.map_err(|e| {
708 self.poisoned = true; Error::from(e)
710 })?;
711 bytes_sent = true;
712 }
713 self.stream.flush().await.map_err(|e| {
714 self.poisoned = true;
715 Error::from(e)
716 })?;
717
718 let mut results = Vec::with_capacity(fetches.len());
720 let mut response_buf: Vec<u8> = Vec::with_capacity(4096);
721 for &expected_cid in &correlation_ids {
722 let mut len_buf = [0u8; 4];
723 self.stream.read_exact(&mut len_buf).await.map_err(|e| {
724 self.poisoned = true;
725 Error::from(e)
726 })?;
727 let msg_len = u32::from_be_bytes(len_buf) as usize;
728
729 if msg_len > DEFAULT_MAX_RESPONSE_SIZE {
730 self.poisoned = true;
731 return Err(Error::ResponseTooLarge(msg_len, DEFAULT_MAX_RESPONSE_SIZE));
732 }
733
734 response_buf.resize(msg_len, 0);
735 self.stream
736 .read_exact(&mut response_buf)
737 .await
738 .map_err(|e| {
739 self.poisoned = true;
740 Error::from(e)
741 })?;
742 let (response, _format, response_cid) = Response::from_wire(&response_buf)
743 .inspect_err(|_| {
744 self.poisoned = true;
745 })?;
746
747 if response_cid != expected_cid {
748 self.poisoned = true;
749 return Err(Error::ProtocolError(
750 rivven_protocol::ProtocolError::InvalidFormat(format!(
751 "Correlation ID mismatch: expected {}, got {}",
752 expected_cid, response_cid
753 )),
754 ));
755 }
756
757 let result = match response {
758 Response::Messages { messages } => Ok(messages),
759 Response::Error { message } => Err(Error::ServerError(message)),
760 _ => Err(Error::InvalidResponse),
761 };
762 results.push(result);
763 }
764
765 Ok(results)
766 }
767
768 pub async fn publish(
770 &mut self,
771 topic: impl Into<String>,
772 value: impl Into<Bytes>,
773 ) -> Result<u64> {
774 self.publish_with_key(topic, None::<Bytes>, value).await
775 }
776
777 pub async fn publish_with_key(
779 &mut self,
780 topic: impl Into<String>,
781 key: Option<impl Into<Bytes>>,
782 value: impl Into<Bytes>,
783 ) -> Result<u64> {
784 let request = Request::Publish {
785 topic: topic.into(),
786 partition: None,
787 key: key.map(|k| k.into()),
788 value: value.into(),
789 leader_epoch: None,
790 };
791
792 let response = self.send_request(request).await?;
793
794 match response {
795 Response::Published { offset, .. } => Ok(offset),
796 Response::Error { message } => Err(Error::ServerError(message)),
797 _ => Err(Error::InvalidResponse),
798 }
799 }
800
801 pub async fn publish_to_partition(
803 &mut self,
804 topic: impl Into<String>,
805 partition: u32,
806 key: Option<impl Into<Bytes>>,
807 value: impl Into<Bytes>,
808 ) -> Result<u64> {
809 let request = Request::Publish {
810 topic: topic.into(),
811 partition: Some(partition),
812 key: key.map(|k| k.into()),
813 value: value.into(),
814 leader_epoch: None,
815 };
816
817 let response = self.send_request(request).await?;
818
819 match response {
820 Response::Published { offset, .. } => Ok(offset),
821 Response::Error { message } => Err(Error::ServerError(message)),
822 _ => Err(Error::InvalidResponse),
823 }
824 }
825
826 pub async fn consume(
832 &mut self,
833 topic: impl Into<String>,
834 partition: u32,
835 offset: u64,
836 max_messages: u32,
837 ) -> Result<Vec<MessageData>> {
838 self.consume_with_isolation(topic, partition, offset, max_messages, None)
839 .await
840 }
841
842 pub async fn consume_with_isolation(
862 &mut self,
863 topic: impl Into<String>,
864 partition: u32,
865 offset: u64,
866 max_messages: u32,
867 isolation_level: Option<u8>,
868 ) -> Result<Vec<MessageData>> {
869 let request = Request::Consume {
870 topic: topic.into(),
871 partition,
872 offset,
873 max_messages,
874 isolation_level,
875 max_wait_ms: None,
876 };
877
878 let response = self.send_request(request).await?;
879
880 match response {
881 Response::Messages { messages } => Ok(messages),
882 Response::Error { message } => Err(Error::ServerError(message)),
883 _ => Err(Error::InvalidResponse),
884 }
885 }
886
887 pub async fn consume_long_poll(
895 &mut self,
896 topic: impl Into<String>,
897 partition: u32,
898 offset: u64,
899 max_messages: u32,
900 isolation_level: Option<u8>,
901 max_wait_ms: u64,
902 ) -> Result<Vec<MessageData>> {
903 let request = Request::Consume {
904 topic: topic.into(),
905 partition,
906 offset,
907 max_messages,
908 isolation_level,
909 max_wait_ms: Some(max_wait_ms),
910 };
911
912 let response = self.send_request(request).await?;
913
914 match response {
915 Response::Messages { messages } => Ok(messages),
916 Response::Error { message } => Err(Error::ServerError(message)),
917 _ => Err(Error::InvalidResponse),
918 }
919 }
920
921 pub async fn consume_read_committed(
928 &mut self,
929 topic: impl Into<String>,
930 partition: u32,
931 offset: u64,
932 max_messages: u32,
933 ) -> Result<Vec<MessageData>> {
934 self.consume_with_isolation(topic, partition, offset, max_messages, Some(1))
935 .await
936 }
937
938 pub async fn create_topic(
940 &mut self,
941 name: impl Into<String>,
942 partitions: Option<u32>,
943 ) -> Result<u32> {
944 let name = name.into();
945 let request = Request::CreateTopic {
946 name: name.clone(),
947 partitions,
948 };
949
950 let response = self.send_request(request).await?;
951
952 match response {
953 Response::TopicCreated { partitions, .. } => Ok(partitions),
954 Response::Error { message } => Err(Error::ServerError(message)),
955 _ => Err(Error::InvalidResponse),
956 }
957 }
958
959 pub async fn list_topics(&mut self) -> Result<Vec<String>> {
961 let request = Request::ListTopics;
962 let response = self.send_request(request).await?;
963
964 match response {
965 Response::Topics { topics } => Ok(topics),
966 Response::Error { message } => Err(Error::ServerError(message)),
967 _ => Err(Error::InvalidResponse),
968 }
969 }
970
971 pub async fn delete_topic(&mut self, name: impl Into<String>) -> Result<()> {
973 let request = Request::DeleteTopic { name: name.into() };
974 let response = self.send_request(request).await?;
975
976 match response {
977 Response::TopicDeleted => Ok(()),
978 Response::Error { message } => Err(Error::ServerError(message)),
979 _ => Err(Error::InvalidResponse),
980 }
981 }
982
983 pub async fn commit_offset(
985 &mut self,
986 consumer_group: impl Into<String>,
987 topic: impl Into<String>,
988 partition: u32,
989 offset: u64,
990 ) -> Result<()> {
991 let request = Request::CommitOffset {
992 consumer_group: consumer_group.into(),
993 topic: topic.into(),
994 partition,
995 offset,
996 };
997
998 let response = self.send_request(request).await?;
999
1000 match response {
1001 Response::OffsetCommitted => Ok(()),
1002 Response::Error { message } => Err(Error::ServerError(message)),
1003 _ => Err(Error::InvalidResponse),
1004 }
1005 }
1006
1007 pub async fn commit_offsets_pipelined(
1012 &mut self,
1013 consumer_group: &str,
1014 offsets: &[(String, u32, u64)],
1015 ) -> Result<Vec<Result<()>>> {
1016 if offsets.is_empty() {
1017 return Ok(Vec::new());
1018 }
1019 if self.poisoned {
1020 return Err(Error::ConnectionError(
1021 "Client stream is desynchronized — reconnect required".into(),
1022 ));
1023 }
1024
1025 let timeout_dur = self.request_timeout;
1026 match tokio::time::timeout(
1027 timeout_dur,
1028 self.commit_offsets_pipelined_inner(consumer_group, offsets),
1029 )
1030 .await
1031 {
1032 Ok(result) => result,
1033 Err(_elapsed) => {
1034 self.poisoned = true;
1035 Err(Error::Timeout)
1036 }
1037 }
1038 }
1039
1040 async fn commit_offsets_pipelined_inner(
1041 &mut self,
1042 consumer_group: &str,
1043 offsets: &[(String, u32, u64)],
1044 ) -> Result<Vec<Result<()>>> {
1045 let mut correlation_ids = Vec::with_capacity(offsets.len());
1046 let mut bytes_sent = false;
1047
1048 for (topic, partition, offset) in offsets {
1050 let correlation_id = self.next_correlation_id;
1051 self.next_correlation_id = self.next_correlation_id.wrapping_add(1);
1052 correlation_ids.push(correlation_id);
1053
1054 let request = Request::CommitOffset {
1055 consumer_group: consumer_group.to_string(),
1056 topic: topic.clone(),
1057 partition: *partition,
1058 offset: *offset,
1059 };
1060
1061 let request_bytes = request
1062 .to_wire(rivven_protocol::WireFormat::Postcard, correlation_id)
1063 .inspect_err(|_| {
1064 if bytes_sent {
1065 self.poisoned = true;
1066 }
1067 })?;
1068
1069 if request_bytes.len() > DEFAULT_MAX_REQUEST_SIZE {
1070 if bytes_sent {
1071 self.poisoned = true;
1072 }
1073 return Err(Error::RequestTooLarge(
1074 request_bytes.len(),
1075 DEFAULT_MAX_REQUEST_SIZE,
1076 ));
1077 }
1078
1079 let len: u32 = request_bytes.len().try_into().map_err(|_| {
1080 if bytes_sent {
1081 self.poisoned = true;
1082 }
1083 Error::RequestTooLarge(request_bytes.len(), u32::MAX as usize)
1084 })?;
1085 self.stream
1086 .write_all(&len.to_be_bytes())
1087 .await
1088 .map_err(|e| {
1089 if bytes_sent {
1090 self.poisoned = true;
1091 }
1092 Error::from(e)
1093 })?;
1094 self.stream.write_all(&request_bytes).await.map_err(|e| {
1095 self.poisoned = true;
1096 Error::from(e)
1097 })?;
1098 bytes_sent = true;
1099 }
1100 self.stream.flush().await.map_err(|e| {
1101 self.poisoned = true;
1102 Error::from(e)
1103 })?;
1104
1105 let mut results = Vec::with_capacity(offsets.len());
1107 for &expected_cid in &correlation_ids {
1108 let mut len_buf = [0u8; 4];
1109 self.stream.read_exact(&mut len_buf).await.map_err(|e| {
1110 self.poisoned = true;
1111 Error::from(e)
1112 })?;
1113 let msg_len = u32::from_be_bytes(len_buf) as usize;
1114
1115 if msg_len > DEFAULT_MAX_RESPONSE_SIZE {
1116 self.poisoned = true;
1117 return Err(Error::ResponseTooLarge(msg_len, DEFAULT_MAX_RESPONSE_SIZE));
1118 }
1119
1120 let mut response_buf = vec![0u8; msg_len];
1121 self.stream
1122 .read_exact(&mut response_buf)
1123 .await
1124 .map_err(|e| {
1125 self.poisoned = true;
1126 Error::from(e)
1127 })?;
1128 let (response, _format, response_cid) = Response::from_wire(&response_buf)
1129 .inspect_err(|_| {
1130 self.poisoned = true;
1131 })?;
1132
1133 if response_cid != expected_cid {
1134 self.poisoned = true;
1135 return Err(Error::ProtocolError(
1136 rivven_protocol::ProtocolError::InvalidFormat(format!(
1137 "Correlation ID mismatch: expected {}, got {}",
1138 expected_cid, response_cid
1139 )),
1140 ));
1141 }
1142
1143 let result = match response {
1144 Response::OffsetCommitted => Ok(()),
1145 Response::Error { message } => Err(Error::ServerError(message)),
1146 _ => Err(Error::InvalidResponse),
1147 };
1148 results.push(result);
1149 }
1150
1151 Ok(results)
1152 }
1153
1154 pub fn is_poisoned(&self) -> bool {
1156 self.poisoned
1157 }
1158
1159 pub async fn get_offset(
1161 &mut self,
1162 consumer_group: impl Into<String>,
1163 topic: impl Into<String>,
1164 partition: u32,
1165 ) -> Result<Option<u64>> {
1166 let request = Request::GetOffset {
1167 consumer_group: consumer_group.into(),
1168 topic: topic.into(),
1169 partition,
1170 };
1171
1172 let response = self.send_request(request).await?;
1173
1174 match response {
1175 Response::Offset { offset } => Ok(offset),
1176 Response::Error { message } => Err(Error::ServerError(message)),
1177 _ => Err(Error::InvalidResponse),
1178 }
1179 }
1180
1181 pub async fn get_offset_bounds(
1187 &mut self,
1188 topic: impl Into<String>,
1189 partition: u32,
1190 ) -> Result<(u64, u64)> {
1191 let request = Request::GetOffsetBounds {
1192 topic: topic.into(),
1193 partition,
1194 };
1195
1196 let response = self.send_request(request).await?;
1197
1198 match response {
1199 Response::OffsetBounds { earliest, latest } => Ok((earliest, latest)),
1200 Response::Error { message } => Err(Error::ServerError(message)),
1201 _ => Err(Error::InvalidResponse),
1202 }
1203 }
1204
1205 pub async fn get_metadata(&mut self, topic: impl Into<String>) -> Result<(String, u32)> {
1207 let request = Request::GetMetadata {
1208 topic: topic.into(),
1209 };
1210
1211 let response = self.send_request(request).await?;
1212
1213 match response {
1214 Response::Metadata { name, partitions } => Ok((name, partitions)),
1215 Response::Error { message } => Err(Error::ServerError(message)),
1216 _ => Err(Error::InvalidResponse),
1217 }
1218 }
1219
1220 pub async fn ping(&mut self) -> Result<()> {
1222 let request = Request::Ping;
1223 let response = self.send_request(request).await?;
1224
1225 match response {
1226 Response::Pong => Ok(()),
1227 Response::Error { message } => Err(Error::ServerError(message)),
1228 _ => Err(Error::InvalidResponse),
1229 }
1230 }
1231
1232 pub async fn register_schema(
1251 &self,
1252 registry_url: &str,
1253 subject: &str,
1254 schema_type: &str,
1255 schema: &str,
1256 ) -> Result<u32> {
1257 let url = registry_url.trim_end_matches('/');
1258 let endpoint = format!("{}/subjects/{}/versions", url, subject);
1259
1260 let body = serde_json::json!({
1261 "schema": schema,
1262 "schemaType": schema_type,
1263 });
1264
1265 #[cfg(feature = "schema-registry")]
1266 {
1267 self.register_schema_reqwest(&endpoint, &body).await
1268 }
1269
1270 #[cfg(not(feature = "schema-registry"))]
1271 {
1272 self.register_schema_inline(url, &endpoint, &body).await
1273 }
1274 }
1275
1276 #[cfg(feature = "schema-registry")]
1278 async fn register_schema_reqwest(
1279 &self,
1280 endpoint: &str,
1281 body: &serde_json::Value,
1282 ) -> Result<u32> {
1283 let client = reqwest::Client::new();
1284 let response = client
1285 .post(endpoint)
1286 .header("Content-Type", "application/vnd.schemaregistry.v1+json")
1287 .json(body)
1288 .send()
1289 .await
1290 .map_err(|e| Error::ConnectionError(format!("schema registry request failed: {e}")))?;
1291
1292 let status = response.status();
1293 if !status.is_success() {
1294 let body_text = response.text().await.unwrap_or_default();
1295 return Err(Error::ServerError(format!(
1296 "schema registry returned HTTP {status}: {body_text}"
1297 )));
1298 }
1299
1300 #[derive(serde::Deserialize)]
1301 struct RegisterResponse {
1302 id: u32,
1303 }
1304
1305 let result: RegisterResponse = response
1306 .json()
1307 .await
1308 .map_err(|e| Error::ConnectionError(format!("failed to parse response: {e}")))?;
1309
1310 Ok(result.id)
1311 }
1312
1313 #[cfg(not(feature = "schema-registry"))]
1319 async fn register_schema_inline(
1320 &self,
1321 base_url: &str,
1322 _endpoint: &str,
1323 body: &serde_json::Value,
1324 ) -> Result<u32> {
1325 use tokio::io::AsyncBufReadExt;
1326 use tokio::io::BufReader;
1327 use tokio::net::TcpStream as TokioTcpStream;
1328
1329 let stripped = base_url.strip_prefix("http://").ok_or_else(|| {
1330 Error::ConnectionError(
1331 "HTTPS requires the `schema-registry` feature; URL must start with http:// without it".into(),
1332 )
1333 })?;
1334 let (host_port, _) = stripped.split_once('/').unwrap_or((stripped, ""));
1335
1336 let path = _endpoint.strip_prefix(base_url).unwrap_or(_endpoint);
1338
1339 let body_bytes = serde_json::to_vec(body)
1340 .map_err(|e| Error::ConnectionError(format!("failed to serialize schema: {e}")))?;
1341
1342 let request = format!(
1343 "POST {} HTTP/1.1\r\nHost: {}\r\nContent-Type: application/vnd.schemaregistry.v1+json\r\nContent-Length: {}\r\nConnection: close\r\n\r\n",
1344 path, host_port, body_bytes.len()
1345 );
1346
1347 let timeout = tokio::time::Duration::from_secs(30);
1348
1349 let mut stream = tokio::time::timeout(timeout, TokioTcpStream::connect(host_port))
1350 .await
1351 .map_err(|_| Error::ConnectionError("schema registry connect timed out".into()))?
1352 .map_err(|e| {
1353 Error::ConnectionError(format!("failed to connect to schema registry: {e}"))
1354 })?;
1355
1356 stream
1357 .write_all(request.as_bytes())
1358 .await
1359 .map_err(|e| Error::ConnectionError(format!("failed to send request: {e}")))?;
1360 stream
1361 .write_all(&body_bytes)
1362 .await
1363 .map_err(|e| Error::ConnectionError(format!("failed to send body: {e}")))?;
1364
1365 let response_body =
1367 tokio::time::timeout(timeout, async {
1368 let mut reader = BufReader::new(stream);
1369
1370 let mut status_line = String::new();
1372 reader.read_line(&mut status_line).await.map_err(|e| {
1373 Error::ConnectionError(format!("failed to read status line: {e}"))
1374 })?;
1375
1376 let status_code: u16 = status_line
1377 .split_whitespace()
1378 .nth(1)
1379 .and_then(|s| s.parse().ok())
1380 .unwrap_or(0);
1381
1382 if !(200..300).contains(&status_code) {
1383 return Err(Error::ServerError(format!(
1384 "schema registry returned HTTP {status_code}"
1385 )));
1386 }
1387
1388 let mut content_length: Option<usize> = None;
1390 let mut is_chunked = false;
1391 loop {
1392 let mut header_line = String::new();
1393 reader.read_line(&mut header_line).await.map_err(|e| {
1394 Error::ConnectionError(format!("failed to read header: {e}"))
1395 })?;
1396
1397 let trimmed = header_line.trim();
1398 if trimmed.is_empty() {
1399 break; }
1401
1402 let lower = trimmed.to_ascii_lowercase();
1403 if let Some(val) = lower.strip_prefix("content-length:") {
1404 content_length = val.trim().parse().ok();
1405 } else if lower.starts_with("transfer-encoding:") && lower.contains("chunked") {
1406 is_chunked = true;
1407 }
1408 }
1409
1410 let body_bytes = if is_chunked {
1412 const MAX_CHUNK_SIZE: usize = 16 * 1024 * 1024;
1414 const MAX_TOTAL_BODY: usize = 16 * 1024 * 1024;
1415 let mut body = Vec::new();
1416 loop {
1417 let mut size_line = String::new();
1418 reader.read_line(&mut size_line).await.map_err(|e| {
1419 Error::ConnectionError(format!("failed to read chunk size: {e}"))
1420 })?;
1421
1422 let chunk_size =
1423 usize::from_str_radix(size_line.trim(), 16).map_err(|_| {
1424 Error::ConnectionError(format!(
1425 "invalid chunk size: {:?}",
1426 size_line.trim()
1427 ))
1428 })?;
1429 if chunk_size == 0 {
1430 let mut trailer_buf = [0u8; 2];
1432 let _ = reader.read_exact(&mut trailer_buf).await;
1433 break;
1434 }
1435
1436 if chunk_size > MAX_CHUNK_SIZE {
1438 return Err(Error::ConnectionError(format!(
1439 "chunk size {} exceeds maximum {}",
1440 chunk_size, MAX_CHUNK_SIZE
1441 )));
1442 }
1443
1444 let mut chunk = vec![0u8; chunk_size];
1445 reader.read_exact(&mut chunk).await.map_err(|e| {
1446 Error::ConnectionError(format!("failed to read chunk data: {e}"))
1447 })?;
1448 body.extend_from_slice(&chunk);
1449
1450 if body.len() > MAX_TOTAL_BODY {
1452 return Err(Error::ConnectionError(format!(
1453 "chunked body {} bytes exceeds maximum {}",
1454 body.len(),
1455 MAX_TOTAL_BODY
1456 )));
1457 }
1458
1459 let mut crlf_buf = [0u8; 2];
1463 reader.read_exact(&mut crlf_buf).await.map_err(|e| {
1464 Error::ConnectionError(format!("failed to read chunk CRLF: {e}"))
1465 })?;
1466 if crlf_buf != [b'\r', b'\n'] {
1467 return Err(Error::ConnectionError(format!(
1468 "expected CRLF after chunk data, got {:02x?}",
1469 crlf_buf
1470 )));
1471 }
1472 }
1473 body
1474 } else if let Some(len) = content_length {
1475 const MAX_CONTENT_LENGTH: usize = 16 * 1024 * 1024;
1476 if len > MAX_CONTENT_LENGTH {
1477 return Err(Error::ConnectionError(format!(
1478 "response Content-Length {} bytes exceeds maximum {}",
1479 len, MAX_CONTENT_LENGTH
1480 )));
1481 }
1482 let mut body = vec![0u8; len];
1483 reader.read_exact(&mut body).await.map_err(|e| {
1484 Error::ConnectionError(format!("failed to read response body: {e}"))
1485 })?;
1486 body
1487 } else {
1488 const MAX_RESPONSE_SIZE: usize = 16 * 1024 * 1024;
1490 let mut body = Vec::with_capacity(4096);
1491 reader.read_to_end(&mut body).await.map_err(|e| {
1492 Error::ConnectionError(format!("failed to read response: {e}"))
1493 })?;
1494 if body.len() > MAX_RESPONSE_SIZE {
1495 return Err(Error::ConnectionError(format!(
1496 "response body {} bytes exceeds maximum {}",
1497 body.len(),
1498 MAX_RESPONSE_SIZE
1499 )));
1500 }
1501 body
1502 };
1503
1504 Ok(body_bytes)
1505 })
1506 .await
1507 .map_err(|_| Error::ConnectionError("schema registry response timed out".into()))??;
1508
1509 #[derive(serde::Deserialize)]
1510 struct RegisterResponse {
1511 id: u32,
1512 }
1513
1514 let result: RegisterResponse = serde_json::from_slice(&response_body).map_err(|e| {
1515 Error::ConnectionError(format!("failed to parse schema registry response: {e}"))
1516 })?;
1517
1518 Ok(result.id)
1519 }
1520
1521 pub async fn list_groups(&mut self) -> Result<Vec<String>> {
1523 let request = Request::ListGroups;
1524
1525 let response = self.send_request(request).await?;
1526
1527 match response {
1528 Response::Groups { groups } => Ok(groups),
1529 Response::Error { message } => Err(Error::ServerError(message)),
1530 _ => Err(Error::InvalidResponse),
1531 }
1532 }
1533
1534 pub async fn describe_group(
1536 &mut self,
1537 consumer_group: impl Into<String>,
1538 ) -> Result<std::collections::HashMap<String, std::collections::HashMap<u32, u64>>> {
1539 let request = Request::DescribeGroup {
1540 consumer_group: consumer_group.into(),
1541 };
1542
1543 let response = self.send_request(request).await?;
1544
1545 match response {
1546 Response::GroupDescription { offsets, .. } => Ok(offsets),
1547 Response::Error { message } => Err(Error::ServerError(message)),
1548 _ => Err(Error::InvalidResponse),
1549 }
1550 }
1551
1552 pub async fn delete_group(&mut self, consumer_group: impl Into<String>) -> Result<()> {
1554 let request = Request::DeleteGroup {
1555 consumer_group: consumer_group.into(),
1556 };
1557
1558 let response = self.send_request(request).await?;
1559
1560 match response {
1561 Response::GroupDeleted => Ok(()),
1562 Response::Error { message } => Err(Error::ServerError(message)),
1563 _ => Err(Error::InvalidResponse),
1564 }
1565 }
1566
1567 pub async fn join_group(
1582 &mut self,
1583 group_id: impl Into<String>,
1584 member_id: impl Into<String>,
1585 session_timeout_ms: u32,
1586 rebalance_timeout_ms: u32,
1587 protocol_type: impl Into<String>,
1588 subscriptions: Vec<String>,
1589 ) -> Result<(u32, String, String, String, Vec<(String, Vec<String>)>)> {
1590 let request = Request::JoinGroup {
1591 group_id: group_id.into(),
1592 member_id: member_id.into(),
1593 session_timeout_ms,
1594 rebalance_timeout_ms,
1595 protocol_type: protocol_type.into(),
1596 subscriptions,
1597 };
1598
1599 let response = self.send_request(request).await?;
1600
1601 match response {
1602 Response::JoinGroupResult {
1603 generation_id,
1604 protocol_type,
1605 member_id,
1606 leader_id,
1607 members,
1608 } => Ok((generation_id, protocol_type, member_id, leader_id, members)),
1609 Response::Error { message } => Err(Error::ServerError(message)),
1610 _ => Err(Error::InvalidResponse),
1611 }
1612 }
1613
1614 pub async fn sync_group(
1627 &mut self,
1628 group_id: impl Into<String>,
1629 generation_id: u32,
1630 member_id: impl Into<String>,
1631 assignments: SyncGroupAssignments,
1632 ) -> Result<Vec<(String, Vec<u32>)>> {
1633 let request = Request::SyncGroup {
1634 group_id: group_id.into(),
1635 generation_id,
1636 member_id: member_id.into(),
1637 assignments,
1638 };
1639
1640 let response = self.send_request(request).await?;
1641
1642 match response {
1643 Response::SyncGroupResult { assignments } => Ok(assignments),
1644 Response::Error { message } => Err(Error::ServerError(message)),
1645 _ => Err(Error::InvalidResponse),
1646 }
1647 }
1648
1649 pub async fn heartbeat(
1655 &mut self,
1656 group_id: impl Into<String>,
1657 generation_id: u32,
1658 member_id: impl Into<String>,
1659 ) -> Result<i32> {
1660 let request = Request::Heartbeat {
1661 group_id: group_id.into(),
1662 generation_id,
1663 member_id: member_id.into(),
1664 };
1665
1666 let response = self.send_request(request).await?;
1667
1668 match response {
1669 Response::HeartbeatResult { error_code } => Ok(error_code),
1670 Response::Error { message } => Err(Error::ServerError(message)),
1671 _ => Err(Error::InvalidResponse),
1672 }
1673 }
1674
1675 pub async fn leave_group(
1680 &mut self,
1681 group_id: impl Into<String>,
1682 member_id: impl Into<String>,
1683 ) -> Result<()> {
1684 let request = Request::LeaveGroup {
1685 group_id: group_id.into(),
1686 member_id: member_id.into(),
1687 };
1688
1689 let response = self.send_request(request).await?;
1690
1691 match response {
1692 Response::LeaveGroupResult => Ok(()),
1693 Response::Error { message } => Err(Error::ServerError(message)),
1694 _ => Err(Error::InvalidResponse),
1695 }
1696 }
1697
1698 pub async fn get_offset_for_timestamp(
1709 &mut self,
1710 topic: impl Into<String>,
1711 partition: u32,
1712 timestamp_ms: i64,
1713 ) -> Result<Option<u64>> {
1714 let request = Request::GetOffsetForTimestamp {
1715 topic: topic.into(),
1716 partition,
1717 timestamp_ms,
1718 };
1719
1720 let response = self.send_request(request).await?;
1721
1722 match response {
1723 Response::OffsetForTimestamp { offset } => Ok(offset),
1724 Response::Error { message } => Err(Error::ServerError(message)),
1725 _ => Err(Error::InvalidResponse),
1726 }
1727 }
1728
1729 pub async fn describe_topic_configs(
1753 &mut self,
1754 topics: &[&str],
1755 ) -> Result<std::collections::HashMap<String, std::collections::HashMap<String, String>>> {
1756 let request = Request::DescribeTopicConfigs {
1757 topics: topics.iter().map(|s| s.to_string()).collect(),
1758 };
1759
1760 let response = self.send_request(request).await?;
1761
1762 match response {
1763 Response::TopicConfigsDescribed { configs } => {
1764 let mut result = std::collections::HashMap::new();
1765 for desc in configs {
1766 let mut topic_configs = std::collections::HashMap::new();
1767 for (key, value) in desc.configs {
1768 topic_configs.insert(key, value.value);
1769 }
1770 result.insert(desc.topic, topic_configs);
1771 }
1772 Ok(result)
1773 }
1774 Response::Error { message } => Err(Error::ServerError(message)),
1775 _ => Err(Error::InvalidResponse),
1776 }
1777 }
1778
1779 pub async fn alter_topic_config(
1802 &mut self,
1803 topic: impl Into<String>,
1804 configs: &[(&str, Option<&str>)],
1805 ) -> Result<AlterTopicConfigResult> {
1806 use rivven_protocol::TopicConfigEntry;
1807
1808 let request = Request::AlterTopicConfig {
1809 topic: topic.into(),
1810 configs: configs
1811 .iter()
1812 .map(|(k, v)| TopicConfigEntry {
1813 key: k.to_string(),
1814 value: v.map(|s| s.to_string()),
1815 })
1816 .collect(),
1817 };
1818
1819 let response = self.send_request(request).await?;
1820
1821 match response {
1822 Response::TopicConfigAltered {
1823 topic,
1824 changed_count,
1825 } => Ok(AlterTopicConfigResult {
1826 topic,
1827 changed_count,
1828 }),
1829 Response::Error { message } => Err(Error::ServerError(message)),
1830 _ => Err(Error::InvalidResponse),
1831 }
1832 }
1833
1834 pub async fn create_partitions(
1855 &mut self,
1856 topic: impl Into<String>,
1857 new_partition_count: u32,
1858 ) -> Result<u32> {
1859 let request = Request::CreatePartitions {
1860 topic: topic.into(),
1861 new_partition_count,
1862 assignments: vec![], };
1864
1865 let response = self.send_request(request).await?;
1866
1867 match response {
1868 Response::PartitionsCreated {
1869 new_partition_count,
1870 ..
1871 } => Ok(new_partition_count),
1872 Response::Error { message } => Err(Error::ServerError(message)),
1873 _ => Err(Error::InvalidResponse),
1874 }
1875 }
1876
1877 pub async fn delete_records(
1907 &mut self,
1908 topic: impl Into<String>,
1909 partition_offsets: &[(u32, u64)],
1910 ) -> Result<Vec<DeleteRecordsResult>> {
1911 let request = Request::DeleteRecords {
1912 topic: topic.into(),
1913 partition_offsets: partition_offsets.to_vec(),
1914 };
1915
1916 let response = self.send_request(request).await?;
1917
1918 match response {
1919 Response::RecordsDeleted { results, .. } => Ok(results),
1920 Response::Error { message } => Err(Error::ServerError(message)),
1921 _ => Err(Error::InvalidResponse),
1922 }
1923 }
1924
1925 pub async fn init_producer_id(
1953 &mut self,
1954 previous_producer_id: Option<u64>,
1955 ) -> Result<ProducerState> {
1956 let request = Request::InitProducerId {
1957 producer_id: previous_producer_id,
1958 };
1959
1960 let response = self.send_request(request).await?;
1961
1962 match response {
1963 Response::ProducerIdInitialized {
1964 producer_id,
1965 producer_epoch,
1966 } => Ok(ProducerState {
1967 producer_id,
1968 producer_epoch,
1969 partition_sequences: std::collections::HashMap::new(),
1970 next_sequence: 0,
1971 }),
1972 Response::Error { message } => Err(Error::ServerError(message)),
1973 _ => Err(Error::InvalidResponse),
1974 }
1975 }
1976
1977 pub async fn publish_idempotent(
2010 &mut self,
2011 topic: impl Into<String>,
2012 key: Option<impl Into<Bytes>>,
2013 value: impl Into<Bytes>,
2014 producer: &mut ProducerState,
2015 ) -> Result<(u64, u32, bool)> {
2016 let topic_str = topic.into();
2017 let sequence = producer.next_sequence;
2021 producer.next_sequence = producer.next_sequence.wrapping_add(1);
2022 if producer.next_sequence <= 0 {
2026 producer.next_sequence = 1;
2027 }
2028
2029 let request = Request::IdempotentPublish {
2030 topic: topic_str,
2031 partition: None,
2032 key: key.map(|k| k.into()),
2033 value: value.into(),
2034 producer_id: producer.producer_id,
2035 producer_epoch: producer.producer_epoch,
2036 sequence,
2037 leader_epoch: None,
2038 };
2039
2040 let response = self.send_request(request).await?;
2041
2042 match response {
2043 Response::IdempotentPublished {
2044 offset,
2045 partition,
2046 duplicate,
2047 } => Ok((offset, partition, duplicate)),
2048 Response::Error { message } => Err(Error::ServerError(message)),
2049 _ => Err(Error::InvalidResponse),
2050 }
2051 }
2052
2053 pub async fn publish_idempotent_to_partition(
2059 &mut self,
2060 topic: impl Into<String>,
2061 partition: u32,
2062 key: Option<impl Into<Bytes>>,
2063 value: impl Into<Bytes>,
2064 producer: &mut ProducerState,
2065 ) -> Result<(u64, u32, bool)> {
2066 let topic_str = topic.into();
2067 let sequence = producer.next_sequence_for(&topic_str, partition);
2068
2069 let request = Request::IdempotentPublish {
2070 topic: topic_str,
2071 partition: Some(partition),
2072 key: key.map(|k| k.into()),
2073 value: value.into(),
2074 producer_id: producer.producer_id,
2075 producer_epoch: producer.producer_epoch,
2076 sequence,
2077 leader_epoch: None,
2078 };
2079
2080 let response = self.send_request(request).await?;
2081
2082 match response {
2083 Response::IdempotentPublished {
2084 offset,
2085 partition: resp_partition,
2086 duplicate,
2087 } => Ok((offset, resp_partition, duplicate)),
2088 Response::Error { message } => Err(Error::ServerError(message)),
2089 _ => Err(Error::InvalidResponse),
2090 }
2091 }
2092
2093 pub async fn begin_transaction(
2123 &mut self,
2124 txn_id: impl Into<String>,
2125 producer: &ProducerState,
2126 timeout_ms: Option<u64>,
2127 ) -> Result<()> {
2128 let request = Request::BeginTransaction {
2129 txn_id: txn_id.into(),
2130 producer_id: producer.producer_id,
2131 producer_epoch: producer.producer_epoch,
2132 timeout_ms,
2133 };
2134
2135 let response = self.send_request(request).await?;
2136
2137 match response {
2138 Response::TransactionStarted { .. } => Ok(()),
2139 Response::Error { message } => Err(Error::ServerError(message)),
2140 _ => Err(Error::InvalidResponse),
2141 }
2142 }
2143
2144 pub async fn add_partitions_to_txn(
2154 &mut self,
2155 txn_id: impl Into<String>,
2156 producer: &ProducerState,
2157 partitions: &[(&str, u32)],
2158 ) -> Result<usize> {
2159 let request = Request::AddPartitionsToTxn {
2160 txn_id: txn_id.into(),
2161 producer_id: producer.producer_id,
2162 producer_epoch: producer.producer_epoch,
2163 partitions: partitions
2164 .iter()
2165 .map(|(t, p)| (t.to_string(), *p))
2166 .collect(),
2167 };
2168
2169 let response = self.send_request(request).await?;
2170
2171 match response {
2172 Response::PartitionsAddedToTxn {
2173 partition_count, ..
2174 } => Ok(partition_count),
2175 Response::Error { message } => Err(Error::ServerError(message)),
2176 _ => Err(Error::InvalidResponse),
2177 }
2178 }
2179
2180 pub async fn publish_transactional(
2195 &mut self,
2196 txn_id: impl Into<String>,
2197 topic: impl Into<String>,
2198 key: Option<impl Into<Bytes>>,
2199 value: impl Into<Bytes>,
2200 producer: &mut ProducerState,
2201 ) -> Result<(u64, u32, i32)> {
2202 let sequence = producer.next_sequence;
2203 producer.next_sequence = producer.next_sequence.wrapping_add(1);
2204 if producer.next_sequence <= 0 {
2206 producer.next_sequence = 1;
2207 }
2208
2209 let request = Request::TransactionalPublish {
2210 txn_id: txn_id.into(),
2211 topic: topic.into(),
2212 partition: None,
2213 key: key.map(|k| k.into()),
2214 value: value.into(),
2215 producer_id: producer.producer_id,
2216 producer_epoch: producer.producer_epoch,
2217 sequence,
2218 leader_epoch: None,
2219 };
2220
2221 let response = self.send_request(request).await?;
2222
2223 match response {
2224 Response::TransactionalPublished {
2225 offset,
2226 partition,
2227 sequence,
2228 } => Ok((offset, partition, sequence)),
2229 Response::Error { message } => Err(Error::ServerError(message)),
2230 _ => Err(Error::InvalidResponse),
2231 }
2232 }
2233
2234 pub async fn add_offsets_to_txn(
2245 &mut self,
2246 txn_id: impl Into<String>,
2247 producer: &ProducerState,
2248 group_id: impl Into<String>,
2249 offsets: &[(&str, u32, i64)],
2250 ) -> Result<()> {
2251 let request = Request::AddOffsetsToTxn {
2252 txn_id: txn_id.into(),
2253 producer_id: producer.producer_id,
2254 producer_epoch: producer.producer_epoch,
2255 group_id: group_id.into(),
2256 offsets: offsets
2257 .iter()
2258 .map(|(t, p, o)| (t.to_string(), *p, *o))
2259 .collect(),
2260 };
2261
2262 let response = self.send_request(request).await?;
2263
2264 match response {
2265 Response::OffsetsAddedToTxn { .. } => Ok(()),
2266 Response::Error { message } => Err(Error::ServerError(message)),
2267 _ => Err(Error::InvalidResponse),
2268 }
2269 }
2270
2271 pub async fn commit_transaction(
2280 &mut self,
2281 txn_id: impl Into<String>,
2282 producer: &ProducerState,
2283 ) -> Result<()> {
2284 let request = Request::CommitTransaction {
2285 txn_id: txn_id.into(),
2286 producer_id: producer.producer_id,
2287 producer_epoch: producer.producer_epoch,
2288 };
2289
2290 let response = self.send_request(request).await?;
2291
2292 match response {
2293 Response::TransactionCommitted { .. } => Ok(()),
2294 Response::Error { message } => Err(Error::ServerError(message)),
2295 _ => Err(Error::InvalidResponse),
2296 }
2297 }
2298
2299 pub async fn abort_transaction(
2308 &mut self,
2309 txn_id: impl Into<String>,
2310 producer: &ProducerState,
2311 ) -> Result<()> {
2312 let request = Request::AbortTransaction {
2313 txn_id: txn_id.into(),
2314 producer_id: producer.producer_id,
2315 producer_epoch: producer.producer_epoch,
2316 };
2317
2318 let response = self.send_request(request).await?;
2319
2320 match response {
2321 Response::TransactionAborted { .. } => Ok(()),
2322 Response::Error { message } => Err(Error::ServerError(message)),
2323 _ => Err(Error::InvalidResponse),
2324 }
2325 }
2326}
2327
2328#[derive(Debug, Clone)]
2330pub struct ProducerState {
2331 pub producer_id: u64,
2333 pub producer_epoch: u16,
2335 pub partition_sequences: std::collections::HashMap<(String, u32), i32>,
2340 pub next_sequence: i32,
2343}
2344
2345impl ProducerState {
2346 pub fn next_sequence_for(&mut self, topic: &str, partition: u32) -> i32 {
2349 let seq = self
2350 .partition_sequences
2351 .entry((topic.to_string(), partition))
2352 .or_insert(1);
2353 let current = *seq;
2354 *seq = seq.wrapping_add(1);
2355 if *seq <= 0 {
2356 *seq = 1;
2357 }
2358 current
2359 }
2360}
2361
2362#[derive(Debug, Clone)]
2364pub struct AlterTopicConfigResult {
2365 pub topic: String,
2367 pub changed_count: usize,
2369}
2370
2371pub use rivven_protocol::DeleteRecordsResult;
2373
2374#[derive(Debug, Clone)]
2380pub struct AuthSession {
2381 pub session_id: String,
2383 pub expires_in: u64,
2385}
2386
2387pub(crate) fn generate_nonce() -> String {
2393 use rand::Rng;
2394 let mut rng = rand::thread_rng();
2395 let nonce_bytes: Vec<u8> = (0..24).map(|_| rng.gen()).collect();
2396 base64_encode(&nonce_bytes)
2397}
2398
2399pub(crate) fn escape_username(username: &str) -> String {
2401 username.replace('=', "=3D").replace(',', "=2C")
2402}
2403
2404pub(crate) fn parse_server_first(server_first: &str) -> Result<(String, String, u32)> {
2406 let mut nonce = None;
2407 let mut salt = None;
2408 let mut iterations = None;
2409
2410 for attr in server_first.split(',') {
2411 if let Some(value) = attr.strip_prefix("r=") {
2412 nonce = Some(value.to_string());
2413 } else if let Some(value) = attr.strip_prefix("s=") {
2414 salt = Some(value.to_string());
2415 } else if let Some(value) = attr.strip_prefix("i=") {
2416 iterations = Some(
2417 value
2418 .parse::<u32>()
2419 .map_err(|_| Error::AuthenticationFailed("Invalid iteration count".into()))?,
2420 );
2421 }
2422 }
2423
2424 let nonce = nonce.ok_or_else(|| Error::AuthenticationFailed("Missing nonce".into()))?;
2425 let salt = salt.ok_or_else(|| Error::AuthenticationFailed("Missing salt".into()))?;
2426 let iterations =
2427 iterations.ok_or_else(|| Error::AuthenticationFailed("Missing iterations".into()))?;
2428
2429 if iterations < 4096 {
2432 return Err(Error::AuthenticationFailed(format!(
2433 "SCRAM iteration count {} is below minimum 4096 (possible downgrade attack)",
2434 iterations
2435 )));
2436 }
2437
2438 Ok((nonce, salt, iterations))
2439}
2440
2441pub(crate) fn pbkdf2_sha256(password: &[u8], salt: &[u8], iterations: u32) -> Vec<u8> {
2443 let mut result = vec![0u8; 32];
2444
2445 let mut u = PasswordHash::hmac_sha256(password, &[salt, &1u32.to_be_bytes()].concat());
2447 result.copy_from_slice(&u);
2448
2449 for _ in 1..iterations {
2451 u = PasswordHash::hmac_sha256(password, &u);
2452 for (r, ui) in result.iter_mut().zip(u.iter()) {
2453 *r ^= ui;
2454 }
2455 }
2456
2457 result
2458}
2459
2460pub(crate) fn sha256(data: &[u8]) -> Vec<u8> {
2462 let mut hasher = Sha256::new();
2463 hasher.update(data);
2464 hasher.finalize().to_vec()
2465}
2466
2467pub(crate) fn xor_bytes(a: &[u8], b: &[u8]) -> Vec<u8> {
2469 a.iter().zip(b.iter()).map(|(x, y)| x ^ y).collect()
2470}
2471
2472pub(crate) fn base64_encode(data: &[u8]) -> String {
2474 use base64::{engine::general_purpose::STANDARD, Engine};
2475 STANDARD.encode(data)
2476}
2477
2478pub(crate) fn base64_decode(data: &str) -> std::result::Result<Vec<u8>, base64::DecodeError> {
2480 use base64::{engine::general_purpose::STANDARD, Engine};
2481 STANDARD.decode(data)
2482}
2483
2484#[cfg(test)]
2489mod tests {
2490 use super::*;
2491
2492 #[test]
2493 fn test_escape_username() {
2494 assert_eq!(escape_username("alice"), "alice");
2495 assert_eq!(escape_username("user=name"), "user=3Dname");
2496 assert_eq!(escape_username("user,name"), "user=2Cname");
2497 assert_eq!(escape_username("user=,name"), "user=3D=2Cname");
2498 }
2499
2500 #[test]
2501 fn test_parse_server_first() {
2502 let server_first = "r=clientnonce+servernonce,s=c2FsdA==,i=4096";
2503 let (nonce, salt, iterations) = parse_server_first(server_first).unwrap();
2504
2505 assert_eq!(nonce, "clientnonce+servernonce");
2506 assert_eq!(salt, "c2FsdA==");
2507 assert_eq!(iterations, 4096);
2508 }
2509
2510 #[test]
2511 fn test_parse_server_first_missing_nonce() {
2512 let server_first = "s=c2FsdA==,i=4096";
2513 assert!(parse_server_first(server_first).is_err());
2514 }
2515
2516 #[test]
2517 fn test_parse_server_first_missing_salt() {
2518 let server_first = "r=nonce,i=4096";
2519 assert!(parse_server_first(server_first).is_err());
2520 }
2521
2522 #[test]
2523 fn test_parse_server_first_missing_iterations() {
2524 let server_first = "r=nonce,s=c2FsdA==";
2525 assert!(parse_server_first(server_first).is_err());
2526 }
2527
2528 #[test]
2529 fn test_xor_bytes() {
2530 assert_eq!(xor_bytes(&[0xFF, 0x00], &[0xFF, 0xFF]), vec![0x00, 0xFF]);
2531 assert_eq!(xor_bytes(&[0x12, 0x34], &[0x12, 0x34]), vec![0x00, 0x00]);
2532 }
2533
2534 #[test]
2535 fn test_base64_roundtrip() {
2536 let data = b"hello world";
2537 let encoded = base64_encode(data);
2538 let decoded = base64_decode(&encoded).unwrap();
2539 assert_eq!(decoded, data);
2540 }
2541
2542 #[test]
2543 fn test_sha256() {
2544 let hash = sha256(b"");
2546 assert_eq!(hash.len(), 32);
2547 assert_eq!(
2549 hex::encode(&hash),
2550 "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"
2551 );
2552 }
2553
2554 #[test]
2555 fn test_pbkdf2_sha256() {
2556 let password = b"password";
2558 let salt = b"salt";
2559 let iterations = 1;
2560
2561 let result = pbkdf2_sha256(password, salt, iterations);
2562 assert_eq!(result.len(), 32);
2563 let result2 = pbkdf2_sha256(password, salt, iterations);
2565 assert_eq!(result, result2);
2566 }
2567
2568 #[test]
2569 fn test_generate_nonce() {
2570 let nonce1 = generate_nonce();
2571 let nonce2 = generate_nonce();
2572
2573 assert!(!nonce1.is_empty());
2575 assert!(!nonce2.is_empty());
2576
2577 assert_ne!(nonce1, nonce2);
2579
2580 assert!(base64_decode(&nonce1).is_ok());
2582 }
2583}