1use crate::{Error, MessageData, Request, Response, Result};
2use bytes::Bytes;
3use rivven_core::PasswordHash;
4use sha2::{Digest, Sha256};
5use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
6use tokio::net::TcpStream;
7use tracing::{debug, info};
8
9#[cfg(feature = "tls")]
10use std::net::SocketAddr;
11
12#[cfg(feature = "tls")]
13use rivven_core::tls::{TlsClientStream, TlsConfig, TlsConnector};
14
15const DEFAULT_MAX_RESPONSE_SIZE: usize = 100 * 1024 * 1024;
17
18#[allow(clippy::large_enum_variant)]
26enum ClientStream {
27 Plaintext(TcpStream),
28 #[cfg(feature = "tls")]
29 Tls(TlsClientStream<TcpStream>),
30}
31
32impl AsyncRead for ClientStream {
33 fn poll_read(
34 self: std::pin::Pin<&mut Self>,
35 cx: &mut std::task::Context<'_>,
36 buf: &mut tokio::io::ReadBuf<'_>,
37 ) -> std::task::Poll<std::io::Result<()>> {
38 match self.get_mut() {
39 ClientStream::Plaintext(s) => std::pin::Pin::new(s).poll_read(cx, buf),
40 #[cfg(feature = "tls")]
41 ClientStream::Tls(s) => std::pin::Pin::new(s).poll_read(cx, buf),
42 }
43 }
44}
45
46impl AsyncWrite for ClientStream {
47 fn poll_write(
48 self: std::pin::Pin<&mut Self>,
49 cx: &mut std::task::Context<'_>,
50 buf: &[u8],
51 ) -> std::task::Poll<std::io::Result<usize>> {
52 match self.get_mut() {
53 ClientStream::Plaintext(s) => std::pin::Pin::new(s).poll_write(cx, buf),
54 #[cfg(feature = "tls")]
55 ClientStream::Tls(s) => std::pin::Pin::new(s).poll_write(cx, buf),
56 }
57 }
58
59 fn poll_flush(
60 self: std::pin::Pin<&mut Self>,
61 cx: &mut std::task::Context<'_>,
62 ) -> std::task::Poll<std::io::Result<()>> {
63 match self.get_mut() {
64 ClientStream::Plaintext(s) => std::pin::Pin::new(s).poll_flush(cx),
65 #[cfg(feature = "tls")]
66 ClientStream::Tls(s) => std::pin::Pin::new(s).poll_flush(cx),
67 }
68 }
69
70 fn poll_shutdown(
71 self: std::pin::Pin<&mut Self>,
72 cx: &mut std::task::Context<'_>,
73 ) -> std::task::Poll<std::io::Result<()>> {
74 match self.get_mut() {
75 ClientStream::Plaintext(s) => std::pin::Pin::new(s).poll_shutdown(cx),
76 #[cfg(feature = "tls")]
77 ClientStream::Tls(s) => std::pin::Pin::new(s).poll_shutdown(cx),
78 }
79 }
80}
81
82pub struct Client {
88 stream: ClientStream,
89}
90
91impl Client {
92 pub async fn connect(addr: &str) -> Result<Self> {
94 info!("Connecting to Rivven server at {}", addr);
95 let stream = TcpStream::connect(addr)
96 .await
97 .map_err(|e| Error::ConnectionError(e.to_string()))?;
98
99 Ok(Self {
100 stream: ClientStream::Plaintext(stream),
101 })
102 }
103
104 #[cfg(feature = "tls")]
106 pub async fn connect_tls(
107 addr: &str,
108 tls_config: &TlsConfig,
109 server_name: &str,
110 ) -> Result<Self> {
111 info!("Connecting to Rivven server at {} with TLS", addr);
112
113 let socket_addr: SocketAddr = addr
115 .parse()
116 .map_err(|e| Error::ConnectionError(format!("Invalid address: {}", e)))?;
117
118 let connector = TlsConnector::new(tls_config)
120 .map_err(|e| Error::ConnectionError(format!("TLS config error: {}", e)))?;
121
122 let tls_stream = connector
124 .connect_tcp(socket_addr, server_name)
125 .await
126 .map_err(|e| Error::ConnectionError(format!("TLS connection error: {}", e)))?;
127
128 info!("TLS connection established to {} ({})", addr, server_name);
129
130 Ok(Self {
131 stream: ClientStream::Tls(tls_stream),
132 })
133 }
134
135 #[cfg(feature = "tls")]
137 pub async fn connect_mtls(
138 addr: &str,
139 cert_path: impl Into<std::path::PathBuf>,
140 key_path: impl Into<std::path::PathBuf>,
141 ca_path: impl Into<std::path::PathBuf> + Clone,
142 server_name: &str,
143 ) -> Result<Self> {
144 let tls_config = TlsConfig::mtls_from_pem_files(cert_path, key_path, ca_path);
145 Self::connect_tls(addr, &tls_config, server_name).await
146 }
147
148 pub async fn authenticate(&mut self, username: &str, password: &str) -> Result<AuthSession> {
157 let request = Request::Authenticate {
158 username: username.to_string(),
159 password: password.to_string(),
160 };
161
162 let response = self.send_request(request).await?;
163
164 match response {
165 Response::Authenticated {
166 session_id,
167 expires_in,
168 } => {
169 info!("Authenticated as '{}'", username);
170 Ok(AuthSession {
171 session_id,
172 expires_in,
173 })
174 }
175 Response::Error { message } => Err(Error::AuthenticationFailed(message)),
176 _ => Err(Error::InvalidResponse),
177 }
178 }
179
180 pub async fn authenticate_scram(
198 &mut self,
199 username: &str,
200 password: &str,
201 ) -> Result<AuthSession> {
202 let client_nonce = generate_nonce();
204 let client_first_bare = format!("n={},r={}", escape_username(username), client_nonce);
205 let client_first = format!("n,,{}", client_first_bare);
206
207 debug!("SCRAM: Sending client-first");
208 let request = Request::ScramClientFirst {
209 message: Bytes::from(client_first.clone()),
210 };
211
212 let response = self.send_request(request).await?;
213
214 let server_first = match response {
216 Response::ScramServerFirst { message } => String::from_utf8(message.to_vec())
217 .map_err(|_| Error::AuthenticationFailed("Invalid server-first encoding".into()))?,
218 Response::Error { message } => return Err(Error::AuthenticationFailed(message)),
219 _ => return Err(Error::InvalidResponse),
220 };
221
222 debug!("SCRAM: Received server-first");
223
224 let (combined_nonce, salt_b64, iterations) = parse_server_first(&server_first)?;
226
227 if !combined_nonce.starts_with(&client_nonce) {
229 return Err(Error::AuthenticationFailed("Server nonce mismatch".into()));
230 }
231
232 let salt = base64_decode(&salt_b64)
234 .map_err(|_| Error::AuthenticationFailed("Invalid salt encoding".into()))?;
235
236 let salted_password = pbkdf2_sha256(password.as_bytes(), &salt, iterations);
238 let client_key = PasswordHash::hmac_sha256(&salted_password, b"Client Key");
239 let stored_key = sha256(&client_key);
240
241 let client_final_without_proof = format!("c=biws,r={}", combined_nonce);
242 let auth_message = format!(
243 "{},{},{}",
244 client_first_bare, server_first, client_final_without_proof
245 );
246
247 let client_signature = PasswordHash::hmac_sha256(&stored_key, auth_message.as_bytes());
248 let client_proof = xor_bytes(&client_key, &client_signature);
249 let client_proof_b64 = base64_encode(&client_proof);
250
251 let client_final = format!("{},p={}", client_final_without_proof, client_proof_b64);
253
254 debug!("SCRAM: Sending client-final");
255 let request = Request::ScramClientFinal {
256 message: Bytes::from(client_final),
257 };
258
259 let response = self.send_request(request).await?;
260
261 match response {
263 Response::ScramServerFinal {
264 message,
265 session_id,
266 expires_in,
267 } => {
268 let server_final = String::from_utf8(message.to_vec()).map_err(|_| {
269 Error::AuthenticationFailed("Invalid server-final encoding".into())
270 })?;
271
272 if let Some(error_msg) = server_final.strip_prefix("e=") {
274 return Err(Error::AuthenticationFailed(error_msg.to_string()));
275 }
276
277 if let Some(verifier_b64) = server_final.strip_prefix("v=") {
279 let server_key = PasswordHash::hmac_sha256(&salted_password, b"Server Key");
280 let expected_server_sig =
281 PasswordHash::hmac_sha256(&server_key, auth_message.as_bytes());
282 let expected_verifier = base64_encode(&expected_server_sig);
283
284 if verifier_b64 != expected_verifier {
285 return Err(Error::AuthenticationFailed(
286 "Server verification failed".into(),
287 ));
288 }
289 }
290
291 let session_id = session_id.ok_or_else(|| {
292 Error::AuthenticationFailed("No session ID in response".into())
293 })?;
294 let expires_in = expires_in
295 .ok_or_else(|| Error::AuthenticationFailed("No expiry in response".into()))?;
296
297 info!("SCRAM authentication successful for '{}'", username);
298 Ok(AuthSession {
299 session_id,
300 expires_in,
301 })
302 }
303 Response::Error { message } => Err(Error::AuthenticationFailed(message)),
304 _ => Err(Error::InvalidResponse),
305 }
306 }
307
308 async fn send_request(&mut self, request: Request) -> Result<Response> {
314 let request_bytes = request.to_bytes()?;
316
317 let len = request_bytes.len() as u32;
319 self.stream.write_all(&len.to_be_bytes()).await?;
320 self.stream.write_all(&request_bytes).await?;
321 self.stream.flush().await?;
322
323 let mut len_buf = [0u8; 4];
325 self.stream.read_exact(&mut len_buf).await?;
326 let msg_len = u32::from_be_bytes(len_buf) as usize;
327
328 if msg_len > DEFAULT_MAX_RESPONSE_SIZE {
330 return Err(Error::ResponseTooLarge(msg_len, DEFAULT_MAX_RESPONSE_SIZE));
331 }
332
333 let mut response_buf = vec![0u8; msg_len];
335 self.stream.read_exact(&mut response_buf).await?;
336
337 let response = Response::from_bytes(&response_buf)?;
339
340 Ok(response)
341 }
342
343 pub async fn publish(
345 &mut self,
346 topic: impl Into<String>,
347 value: impl Into<Bytes>,
348 ) -> Result<u64> {
349 self.publish_with_key(topic, None::<Bytes>, value).await
350 }
351
352 pub async fn publish_with_key(
354 &mut self,
355 topic: impl Into<String>,
356 key: Option<impl Into<Bytes>>,
357 value: impl Into<Bytes>,
358 ) -> Result<u64> {
359 let request = Request::Publish {
360 topic: topic.into(),
361 partition: None,
362 key: key.map(|k| k.into()),
363 value: value.into(),
364 };
365
366 let response = self.send_request(request).await?;
367
368 match response {
369 Response::Published { offset, .. } => Ok(offset),
370 Response::Error { message } => Err(Error::ServerError(message)),
371 _ => Err(Error::InvalidResponse),
372 }
373 }
374
375 pub async fn publish_to_partition(
377 &mut self,
378 topic: impl Into<String>,
379 partition: u32,
380 key: Option<impl Into<Bytes>>,
381 value: impl Into<Bytes>,
382 ) -> Result<u64> {
383 let request = Request::Publish {
384 topic: topic.into(),
385 partition: Some(partition),
386 key: key.map(|k| k.into()),
387 value: value.into(),
388 };
389
390 let response = self.send_request(request).await?;
391
392 match response {
393 Response::Published { offset, .. } => Ok(offset),
394 Response::Error { message } => Err(Error::ServerError(message)),
395 _ => Err(Error::InvalidResponse),
396 }
397 }
398
399 pub async fn consume(
401 &mut self,
402 topic: impl Into<String>,
403 partition: u32,
404 offset: u64,
405 max_messages: usize,
406 ) -> Result<Vec<MessageData>> {
407 let request = Request::Consume {
408 topic: topic.into(),
409 partition,
410 offset,
411 max_messages,
412 };
413
414 let response = self.send_request(request).await?;
415
416 match response {
417 Response::Messages { messages } => Ok(messages),
418 Response::Error { message } => Err(Error::ServerError(message)),
419 _ => Err(Error::InvalidResponse),
420 }
421 }
422
423 pub async fn create_topic(
425 &mut self,
426 name: impl Into<String>,
427 partitions: Option<u32>,
428 ) -> Result<u32> {
429 let name = name.into();
430 let request = Request::CreateTopic {
431 name: name.clone(),
432 partitions,
433 };
434
435 let response = self.send_request(request).await?;
436
437 match response {
438 Response::TopicCreated { partitions, .. } => Ok(partitions),
439 Response::Error { message } => Err(Error::ServerError(message)),
440 _ => Err(Error::InvalidResponse),
441 }
442 }
443
444 pub async fn list_topics(&mut self) -> Result<Vec<String>> {
446 let request = Request::ListTopics;
447 let response = self.send_request(request).await?;
448
449 match response {
450 Response::Topics { topics } => Ok(topics),
451 Response::Error { message } => Err(Error::ServerError(message)),
452 _ => Err(Error::InvalidResponse),
453 }
454 }
455
456 pub async fn delete_topic(&mut self, name: impl Into<String>) -> Result<()> {
458 let request = Request::DeleteTopic { name: name.into() };
459 let response = self.send_request(request).await?;
460
461 match response {
462 Response::TopicDeleted => Ok(()),
463 Response::Error { message } => Err(Error::ServerError(message)),
464 _ => Err(Error::InvalidResponse),
465 }
466 }
467
468 pub async fn commit_offset(
470 &mut self,
471 consumer_group: impl Into<String>,
472 topic: impl Into<String>,
473 partition: u32,
474 offset: u64,
475 ) -> Result<()> {
476 let request = Request::CommitOffset {
477 consumer_group: consumer_group.into(),
478 topic: topic.into(),
479 partition,
480 offset,
481 };
482
483 let response = self.send_request(request).await?;
484
485 match response {
486 Response::OffsetCommitted => Ok(()),
487 Response::Error { message } => Err(Error::ServerError(message)),
488 _ => Err(Error::InvalidResponse),
489 }
490 }
491
492 pub async fn get_offset(
494 &mut self,
495 consumer_group: impl Into<String>,
496 topic: impl Into<String>,
497 partition: u32,
498 ) -> Result<Option<u64>> {
499 let request = Request::GetOffset {
500 consumer_group: consumer_group.into(),
501 topic: topic.into(),
502 partition,
503 };
504
505 let response = self.send_request(request).await?;
506
507 match response {
508 Response::Offset { offset } => Ok(offset),
509 Response::Error { message } => Err(Error::ServerError(message)),
510 _ => Err(Error::InvalidResponse),
511 }
512 }
513
514 pub async fn get_offset_bounds(
520 &mut self,
521 topic: impl Into<String>,
522 partition: u32,
523 ) -> Result<(u64, u64)> {
524 let request = Request::GetOffsetBounds {
525 topic: topic.into(),
526 partition,
527 };
528
529 let response = self.send_request(request).await?;
530
531 match response {
532 Response::OffsetBounds { earliest, latest } => Ok((earliest, latest)),
533 Response::Error { message } => Err(Error::ServerError(message)),
534 _ => Err(Error::InvalidResponse),
535 }
536 }
537
538 pub async fn get_metadata(&mut self, topic: impl Into<String>) -> Result<(String, u32)> {
540 let request = Request::GetMetadata {
541 topic: topic.into(),
542 };
543
544 let response = self.send_request(request).await?;
545
546 match response {
547 Response::Metadata { name, partitions } => Ok((name, partitions)),
548 Response::Error { message } => Err(Error::ServerError(message)),
549 _ => Err(Error::InvalidResponse),
550 }
551 }
552
553 pub async fn ping(&mut self) -> Result<()> {
555 let request = Request::Ping;
556 let response = self.send_request(request).await?;
557
558 match response {
559 Response::Pong => Ok(()),
560 Response::Error { message } => Err(Error::ServerError(message)),
561 _ => Err(Error::InvalidResponse),
562 }
563 }
564
565 pub async fn register_schema(
567 &mut self,
568 subject: impl Into<String>,
569 schema: impl Into<String>,
570 ) -> Result<i32> {
571 let request = Request::RegisterSchema {
572 subject: subject.into(),
573 schema: schema.into(),
574 };
575
576 let response = self.send_request(request).await?;
577
578 match response {
579 Response::SchemaRegistered { id } => Ok(id),
580 Response::Error { message } => Err(Error::ServerError(message)),
581 _ => Err(Error::InvalidResponse),
582 }
583 }
584
585 pub async fn get_schema(&mut self, id: i32) -> Result<String> {
587 let request = Request::GetSchema { id };
588
589 let response = self.send_request(request).await?;
590
591 match response {
592 Response::Schema { id: _, schema } => Ok(schema),
593 Response::Error { message } => Err(Error::ServerError(message)),
594 _ => Err(Error::InvalidResponse),
595 }
596 }
597
598 pub async fn list_groups(&mut self) -> Result<Vec<String>> {
600 let request = Request::ListGroups;
601
602 let response = self.send_request(request).await?;
603
604 match response {
605 Response::Groups { groups } => Ok(groups),
606 Response::Error { message } => Err(Error::ServerError(message)),
607 _ => Err(Error::InvalidResponse),
608 }
609 }
610
611 pub async fn describe_group(
613 &mut self,
614 consumer_group: impl Into<String>,
615 ) -> Result<std::collections::HashMap<String, std::collections::HashMap<u32, u64>>> {
616 let request = Request::DescribeGroup {
617 consumer_group: consumer_group.into(),
618 };
619
620 let response = self.send_request(request).await?;
621
622 match response {
623 Response::GroupDescription { offsets, .. } => Ok(offsets),
624 Response::Error { message } => Err(Error::ServerError(message)),
625 _ => Err(Error::InvalidResponse),
626 }
627 }
628
629 pub async fn delete_group(&mut self, consumer_group: impl Into<String>) -> Result<()> {
631 let request = Request::DeleteGroup {
632 consumer_group: consumer_group.into(),
633 };
634
635 let response = self.send_request(request).await?;
636
637 match response {
638 Response::GroupDeleted => Ok(()),
639 Response::Error { message } => Err(Error::ServerError(message)),
640 _ => Err(Error::InvalidResponse),
641 }
642 }
643
644 pub async fn get_offset_for_timestamp(
655 &mut self,
656 topic: impl Into<String>,
657 partition: u32,
658 timestamp_ms: i64,
659 ) -> Result<Option<u64>> {
660 let request = Request::GetOffsetForTimestamp {
661 topic: topic.into(),
662 partition,
663 timestamp_ms,
664 };
665
666 let response = self.send_request(request).await?;
667
668 match response {
669 Response::OffsetForTimestamp { offset } => Ok(offset),
670 Response::Error { message } => Err(Error::ServerError(message)),
671 _ => Err(Error::InvalidResponse),
672 }
673 }
674}
675
676#[derive(Debug, Clone)]
682pub struct AuthSession {
683 pub session_id: String,
685 pub expires_in: u64,
687}
688
689fn generate_nonce() -> String {
695 use rand::Rng;
696 let mut rng = rand::thread_rng();
697 let nonce_bytes: Vec<u8> = (0..24).map(|_| rng.gen()).collect();
698 base64_encode(&nonce_bytes)
699}
700
701fn escape_username(username: &str) -> String {
703 username.replace('=', "=3D").replace(',', "=2C")
704}
705
706fn parse_server_first(server_first: &str) -> Result<(String, String, u32)> {
708 let mut nonce = None;
709 let mut salt = None;
710 let mut iterations = None;
711
712 for attr in server_first.split(',') {
713 if let Some(value) = attr.strip_prefix("r=") {
714 nonce = Some(value.to_string());
715 } else if let Some(value) = attr.strip_prefix("s=") {
716 salt = Some(value.to_string());
717 } else if let Some(value) = attr.strip_prefix("i=") {
718 iterations = Some(
719 value
720 .parse::<u32>()
721 .map_err(|_| Error::AuthenticationFailed("Invalid iteration count".into()))?,
722 );
723 }
724 }
725
726 let nonce = nonce.ok_or_else(|| Error::AuthenticationFailed("Missing nonce".into()))?;
727 let salt = salt.ok_or_else(|| Error::AuthenticationFailed("Missing salt".into()))?;
728 let iterations =
729 iterations.ok_or_else(|| Error::AuthenticationFailed("Missing iterations".into()))?;
730
731 Ok((nonce, salt, iterations))
732}
733
734fn pbkdf2_sha256(password: &[u8], salt: &[u8], iterations: u32) -> Vec<u8> {
736 let mut result = vec![0u8; 32];
737
738 let mut u = PasswordHash::hmac_sha256(password, &[salt, &1u32.to_be_bytes()].concat());
740 result.copy_from_slice(&u);
741
742 for _ in 1..iterations {
744 u = PasswordHash::hmac_sha256(password, &u);
745 for (r, ui) in result.iter_mut().zip(u.iter()) {
746 *r ^= ui;
747 }
748 }
749
750 result
751}
752
753fn sha256(data: &[u8]) -> Vec<u8> {
755 let mut hasher = Sha256::new();
756 hasher.update(data);
757 hasher.finalize().to_vec()
758}
759
760fn xor_bytes(a: &[u8], b: &[u8]) -> Vec<u8> {
762 a.iter().zip(b.iter()).map(|(x, y)| x ^ y).collect()
763}
764
765fn base64_encode(data: &[u8]) -> String {
767 use base64::{engine::general_purpose::STANDARD, Engine};
768 STANDARD.encode(data)
769}
770
771fn base64_decode(data: &str) -> std::result::Result<Vec<u8>, base64::DecodeError> {
773 use base64::{engine::general_purpose::STANDARD, Engine};
774 STANDARD.decode(data)
775}
776
777#[cfg(test)]
782mod tests {
783 use super::*;
784
785 #[test]
786 fn test_escape_username() {
787 assert_eq!(escape_username("alice"), "alice");
788 assert_eq!(escape_username("user=name"), "user=3Dname");
789 assert_eq!(escape_username("user,name"), "user=2Cname");
790 assert_eq!(escape_username("user=,name"), "user=3D=2Cname");
791 }
792
793 #[test]
794 fn test_parse_server_first() {
795 let server_first = "r=clientnonce+servernonce,s=c2FsdA==,i=4096";
796 let (nonce, salt, iterations) = parse_server_first(server_first).unwrap();
797
798 assert_eq!(nonce, "clientnonce+servernonce");
799 assert_eq!(salt, "c2FsdA==");
800 assert_eq!(iterations, 4096);
801 }
802
803 #[test]
804 fn test_parse_server_first_missing_nonce() {
805 let server_first = "s=c2FsdA==,i=4096";
806 assert!(parse_server_first(server_first).is_err());
807 }
808
809 #[test]
810 fn test_parse_server_first_missing_salt() {
811 let server_first = "r=nonce,i=4096";
812 assert!(parse_server_first(server_first).is_err());
813 }
814
815 #[test]
816 fn test_parse_server_first_missing_iterations() {
817 let server_first = "r=nonce,s=c2FsdA==";
818 assert!(parse_server_first(server_first).is_err());
819 }
820
821 #[test]
822 fn test_xor_bytes() {
823 assert_eq!(xor_bytes(&[0xFF, 0x00], &[0xFF, 0xFF]), vec![0x00, 0xFF]);
824 assert_eq!(xor_bytes(&[0x12, 0x34], &[0x12, 0x34]), vec![0x00, 0x00]);
825 }
826
827 #[test]
828 fn test_base64_roundtrip() {
829 let data = b"hello world";
830 let encoded = base64_encode(data);
831 let decoded = base64_decode(&encoded).unwrap();
832 assert_eq!(decoded, data);
833 }
834
835 #[test]
836 fn test_sha256() {
837 let hash = sha256(b"");
839 assert_eq!(hash.len(), 32);
840 assert_eq!(
842 hex::encode(&hash),
843 "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"
844 );
845 }
846
847 #[test]
848 fn test_pbkdf2_sha256() {
849 let password = b"password";
851 let salt = b"salt";
852 let iterations = 1;
853
854 let result = pbkdf2_sha256(password, salt, iterations);
855 assert_eq!(result.len(), 32);
856 let result2 = pbkdf2_sha256(password, salt, iterations);
858 assert_eq!(result, result2);
859 }
860
861 #[test]
862 fn test_generate_nonce() {
863 let nonce1 = generate_nonce();
864 let nonce2 = generate_nonce();
865
866 assert!(!nonce1.is_empty());
868 assert!(!nonce2.is_empty());
869
870 assert_ne!(nonce1, nonce2);
872
873 assert!(base64_decode(&nonce1).is_ok());
875 }
876}