Skip to main content

rivven_client/
client.rs

1use crate::{Error, MessageData, Request, Response, Result};
2use bytes::Bytes;
3use rivven_core::PasswordHash;
4use sha2::{Digest, Sha256};
5use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
6use tokio::net::TcpStream;
7use tracing::{debug, info};
8
9#[cfg(feature = "tls")]
10use std::net::SocketAddr;
11
12#[cfg(feature = "tls")]
13use rivven_core::tls::{TlsClientStream, TlsConfig, TlsConnector};
14
15// Default maximum response size (100 MB) - prevents malicious server from exhausting client memory
16const DEFAULT_MAX_RESPONSE_SIZE: usize = 100 * 1024 * 1024;
17
18// ============================================================================
19// Stream Wrapper
20// ============================================================================
21
22/// Wrapper for either plaintext or TLS streams
23/// Note: TLS variant is significantly larger due to TLS state, but boxing
24/// would add indirection overhead for every I/O operation
25#[allow(clippy::large_enum_variant)]
26enum ClientStream {
27    Plaintext(TcpStream),
28    #[cfg(feature = "tls")]
29    Tls(TlsClientStream<TcpStream>),
30}
31
32impl AsyncRead for ClientStream {
33    fn poll_read(
34        self: std::pin::Pin<&mut Self>,
35        cx: &mut std::task::Context<'_>,
36        buf: &mut tokio::io::ReadBuf<'_>,
37    ) -> std::task::Poll<std::io::Result<()>> {
38        match self.get_mut() {
39            ClientStream::Plaintext(s) => std::pin::Pin::new(s).poll_read(cx, buf),
40            #[cfg(feature = "tls")]
41            ClientStream::Tls(s) => std::pin::Pin::new(s).poll_read(cx, buf),
42        }
43    }
44}
45
46impl AsyncWrite for ClientStream {
47    fn poll_write(
48        self: std::pin::Pin<&mut Self>,
49        cx: &mut std::task::Context<'_>,
50        buf: &[u8],
51    ) -> std::task::Poll<std::io::Result<usize>> {
52        match self.get_mut() {
53            ClientStream::Plaintext(s) => std::pin::Pin::new(s).poll_write(cx, buf),
54            #[cfg(feature = "tls")]
55            ClientStream::Tls(s) => std::pin::Pin::new(s).poll_write(cx, buf),
56        }
57    }
58
59    fn poll_flush(
60        self: std::pin::Pin<&mut Self>,
61        cx: &mut std::task::Context<'_>,
62    ) -> std::task::Poll<std::io::Result<()>> {
63        match self.get_mut() {
64            ClientStream::Plaintext(s) => std::pin::Pin::new(s).poll_flush(cx),
65            #[cfg(feature = "tls")]
66            ClientStream::Tls(s) => std::pin::Pin::new(s).poll_flush(cx),
67        }
68    }
69
70    fn poll_shutdown(
71        self: std::pin::Pin<&mut Self>,
72        cx: &mut std::task::Context<'_>,
73    ) -> std::task::Poll<std::io::Result<()>> {
74        match self.get_mut() {
75            ClientStream::Plaintext(s) => std::pin::Pin::new(s).poll_shutdown(cx),
76            #[cfg(feature = "tls")]
77            ClientStream::Tls(s) => std::pin::Pin::new(s).poll_shutdown(cx),
78        }
79    }
80}
81
82// ============================================================================
83// Client
84// ============================================================================
85
86/// Rivven client for connecting to a Rivven server
87pub struct Client {
88    stream: ClientStream,
89}
90
91impl Client {
92    /// Connect to a Rivven server (plaintext)
93    pub async fn connect(addr: &str) -> Result<Self> {
94        info!("Connecting to Rivven server at {}", addr);
95        let stream = TcpStream::connect(addr)
96            .await
97            .map_err(|e| Error::ConnectionError(e.to_string()))?;
98
99        Ok(Self {
100            stream: ClientStream::Plaintext(stream),
101        })
102    }
103
104    /// Connect to a Rivven server with TLS
105    #[cfg(feature = "tls")]
106    pub async fn connect_tls(
107        addr: &str,
108        tls_config: &TlsConfig,
109        server_name: &str,
110    ) -> Result<Self> {
111        info!("Connecting to Rivven server at {} with TLS", addr);
112
113        // Parse address
114        let socket_addr: SocketAddr = addr
115            .parse()
116            .map_err(|e| Error::ConnectionError(format!("Invalid address: {}", e)))?;
117
118        // Create TLS connector
119        let connector = TlsConnector::new(tls_config)
120            .map_err(|e| Error::ConnectionError(format!("TLS config error: {}", e)))?;
121
122        // Connect with TLS
123        let tls_stream = connector
124            .connect_tcp(socket_addr, server_name)
125            .await
126            .map_err(|e| Error::ConnectionError(format!("TLS connection error: {}", e)))?;
127
128        info!("TLS connection established to {} ({})", addr, server_name);
129
130        Ok(Self {
131            stream: ClientStream::Tls(tls_stream),
132        })
133    }
134
135    /// Connect with mTLS (mutual TLS) using client certificate
136    #[cfg(feature = "tls")]
137    pub async fn connect_mtls(
138        addr: &str,
139        cert_path: impl Into<std::path::PathBuf>,
140        key_path: impl Into<std::path::PathBuf>,
141        ca_path: impl Into<std::path::PathBuf> + Clone,
142        server_name: &str,
143    ) -> Result<Self> {
144        let tls_config = TlsConfig::mtls_from_pem_files(cert_path, key_path, ca_path);
145        Self::connect_tls(addr, &tls_config, server_name).await
146    }
147
148    // ========================================================================
149    // Authentication Methods
150    // ========================================================================
151
152    /// Authenticate with simple username/password
153    ///
154    /// This uses a simple plaintext password protocol. For production use over
155    /// untrusted networks, prefer `authenticate_scram()` or use TLS.
156    pub async fn authenticate(&mut self, username: &str, password: &str) -> Result<AuthSession> {
157        let request = Request::Authenticate {
158            username: username.to_string(),
159            password: password.to_string(),
160        };
161
162        let response = self.send_request(request).await?;
163
164        match response {
165            Response::Authenticated {
166                session_id,
167                expires_in,
168            } => {
169                info!("Authenticated as '{}'", username);
170                Ok(AuthSession {
171                    session_id,
172                    expires_in,
173                })
174            }
175            Response::Error { message } => Err(Error::AuthenticationFailed(message)),
176            _ => Err(Error::InvalidResponse),
177        }
178    }
179
180    /// Authenticate using SCRAM-SHA-256 (secure challenge-response)
181    ///
182    /// SCRAM-SHA-256 (RFC 5802/7677) provides:
183    /// - Password never sent over the wire
184    /// - Mutual authentication (server proves it knows password too)
185    /// - Protection against replay attacks
186    ///
187    /// # Example
188    /// ```no_run
189    /// # use rivven_client::Client;
190    /// # async fn example() -> rivven_client::Result<()> {
191    /// let mut client = Client::connect("127.0.0.1:9092").await?;
192    /// let session = client.authenticate_scram("alice", "password123").await?;
193    /// println!("Session: {} (expires in {}s)", session.session_id, session.expires_in);
194    /// # Ok(())
195    /// # }
196    /// ```
197    pub async fn authenticate_scram(
198        &mut self,
199        username: &str,
200        password: &str,
201    ) -> Result<AuthSession> {
202        // Step 1: Generate client nonce and send client-first message
203        let client_nonce = generate_nonce();
204        let client_first_bare = format!("n={},r={}", escape_username(username), client_nonce);
205        let client_first = format!("n,,{}", client_first_bare);
206
207        debug!("SCRAM: Sending client-first");
208        let request = Request::ScramClientFirst {
209            message: Bytes::from(client_first.clone()),
210        };
211
212        let response = self.send_request(request).await?;
213
214        // Step 2: Parse server-first message
215        let server_first = match response {
216            Response::ScramServerFirst { message } => String::from_utf8(message.to_vec())
217                .map_err(|_| Error::AuthenticationFailed("Invalid server-first encoding".into()))?,
218            Response::Error { message } => return Err(Error::AuthenticationFailed(message)),
219            _ => return Err(Error::InvalidResponse),
220        };
221
222        debug!("SCRAM: Received server-first");
223
224        // Parse server-first: r=<nonce>,s=<salt>,i=<iterations>
225        let (combined_nonce, salt_b64, iterations) = parse_server_first(&server_first)?;
226
227        // Verify server nonce starts with our client nonce
228        if !combined_nonce.starts_with(&client_nonce) {
229            return Err(Error::AuthenticationFailed("Server nonce mismatch".into()));
230        }
231
232        // Decode salt
233        let salt = base64_decode(&salt_b64)
234            .map_err(|_| Error::AuthenticationFailed("Invalid salt encoding".into()))?;
235
236        // Step 3: Compute client proof
237        let salted_password = pbkdf2_sha256(password.as_bytes(), &salt, iterations);
238        let client_key = PasswordHash::hmac_sha256(&salted_password, b"Client Key");
239        let stored_key = sha256(&client_key);
240
241        let client_final_without_proof = format!("c=biws,r={}", combined_nonce);
242        let auth_message = format!(
243            "{},{},{}",
244            client_first_bare, server_first, client_final_without_proof
245        );
246
247        let client_signature = PasswordHash::hmac_sha256(&stored_key, auth_message.as_bytes());
248        let client_proof = xor_bytes(&client_key, &client_signature);
249        let client_proof_b64 = base64_encode(&client_proof);
250
251        // Step 4: Send client-final message
252        let client_final = format!("{},p={}", client_final_without_proof, client_proof_b64);
253
254        debug!("SCRAM: Sending client-final");
255        let request = Request::ScramClientFinal {
256            message: Bytes::from(client_final),
257        };
258
259        let response = self.send_request(request).await?;
260
261        // Step 5: Verify server-final and get session
262        match response {
263            Response::ScramServerFinal {
264                message,
265                session_id,
266                expires_in,
267            } => {
268                let server_final = String::from_utf8(message.to_vec()).map_err(|_| {
269                    Error::AuthenticationFailed("Invalid server-final encoding".into())
270                })?;
271
272                // Check for error response
273                if let Some(error_msg) = server_final.strip_prefix("e=") {
274                    return Err(Error::AuthenticationFailed(error_msg.to_string()));
275                }
276
277                // Verify server signature (mutual authentication)
278                if let Some(verifier_b64) = server_final.strip_prefix("v=") {
279                    let server_key = PasswordHash::hmac_sha256(&salted_password, b"Server Key");
280                    let expected_server_sig =
281                        PasswordHash::hmac_sha256(&server_key, auth_message.as_bytes());
282                    let expected_verifier = base64_encode(&expected_server_sig);
283
284                    if verifier_b64 != expected_verifier {
285                        return Err(Error::AuthenticationFailed(
286                            "Server verification failed".into(),
287                        ));
288                    }
289                }
290
291                let session_id = session_id.ok_or_else(|| {
292                    Error::AuthenticationFailed("No session ID in response".into())
293                })?;
294                let expires_in = expires_in
295                    .ok_or_else(|| Error::AuthenticationFailed("No expiry in response".into()))?;
296
297                info!("SCRAM authentication successful for '{}'", username);
298                Ok(AuthSession {
299                    session_id,
300                    expires_in,
301                })
302            }
303            Response::Error { message } => Err(Error::AuthenticationFailed(message)),
304            _ => Err(Error::InvalidResponse),
305        }
306    }
307
308    // ========================================================================
309    // Request/Response Handling
310    // ========================================================================
311
312    /// Send a request and receive a response
313    async fn send_request(&mut self, request: Request) -> Result<Response> {
314        // Serialize request
315        let request_bytes = request.to_bytes()?;
316
317        // Write length prefix + request
318        let len = request_bytes.len() as u32;
319        self.stream.write_all(&len.to_be_bytes()).await?;
320        self.stream.write_all(&request_bytes).await?;
321        self.stream.flush().await?;
322
323        // Read length prefix
324        let mut len_buf = [0u8; 4];
325        self.stream.read_exact(&mut len_buf).await?;
326        let msg_len = u32::from_be_bytes(len_buf) as usize;
327
328        // Validate response size to prevent memory exhaustion from malicious server
329        if msg_len > DEFAULT_MAX_RESPONSE_SIZE {
330            return Err(Error::ResponseTooLarge(msg_len, DEFAULT_MAX_RESPONSE_SIZE));
331        }
332
333        // Read response
334        let mut response_buf = vec![0u8; msg_len];
335        self.stream.read_exact(&mut response_buf).await?;
336
337        // Deserialize response
338        let response = Response::from_bytes(&response_buf)?;
339
340        Ok(response)
341    }
342
343    /// Publish a message to a topic
344    pub async fn publish(
345        &mut self,
346        topic: impl Into<String>,
347        value: impl Into<Bytes>,
348    ) -> Result<u64> {
349        self.publish_with_key(topic, None::<Bytes>, value).await
350    }
351
352    /// Publish a message with a key to a topic
353    pub async fn publish_with_key(
354        &mut self,
355        topic: impl Into<String>,
356        key: Option<impl Into<Bytes>>,
357        value: impl Into<Bytes>,
358    ) -> Result<u64> {
359        let request = Request::Publish {
360            topic: topic.into(),
361            partition: None,
362            key: key.map(|k| k.into()),
363            value: value.into(),
364        };
365
366        let response = self.send_request(request).await?;
367
368        match response {
369            Response::Published { offset, .. } => Ok(offset),
370            Response::Error { message } => Err(Error::ServerError(message)),
371            _ => Err(Error::InvalidResponse),
372        }
373    }
374
375    /// Publish a message to a specific partition
376    pub async fn publish_to_partition(
377        &mut self,
378        topic: impl Into<String>,
379        partition: u32,
380        key: Option<impl Into<Bytes>>,
381        value: impl Into<Bytes>,
382    ) -> Result<u64> {
383        let request = Request::Publish {
384            topic: topic.into(),
385            partition: Some(partition),
386            key: key.map(|k| k.into()),
387            value: value.into(),
388        };
389
390        let response = self.send_request(request).await?;
391
392        match response {
393            Response::Published { offset, .. } => Ok(offset),
394            Response::Error { message } => Err(Error::ServerError(message)),
395            _ => Err(Error::InvalidResponse),
396        }
397    }
398
399    /// Consume messages from a topic partition
400    pub async fn consume(
401        &mut self,
402        topic: impl Into<String>,
403        partition: u32,
404        offset: u64,
405        max_messages: usize,
406    ) -> Result<Vec<MessageData>> {
407        let request = Request::Consume {
408            topic: topic.into(),
409            partition,
410            offset,
411            max_messages,
412        };
413
414        let response = self.send_request(request).await?;
415
416        match response {
417            Response::Messages { messages } => Ok(messages),
418            Response::Error { message } => Err(Error::ServerError(message)),
419            _ => Err(Error::InvalidResponse),
420        }
421    }
422
423    /// Create a new topic
424    pub async fn create_topic(
425        &mut self,
426        name: impl Into<String>,
427        partitions: Option<u32>,
428    ) -> Result<u32> {
429        let name = name.into();
430        let request = Request::CreateTopic {
431            name: name.clone(),
432            partitions,
433        };
434
435        let response = self.send_request(request).await?;
436
437        match response {
438            Response::TopicCreated { partitions, .. } => Ok(partitions),
439            Response::Error { message } => Err(Error::ServerError(message)),
440            _ => Err(Error::InvalidResponse),
441        }
442    }
443
444    /// List all topics
445    pub async fn list_topics(&mut self) -> Result<Vec<String>> {
446        let request = Request::ListTopics;
447        let response = self.send_request(request).await?;
448
449        match response {
450            Response::Topics { topics } => Ok(topics),
451            Response::Error { message } => Err(Error::ServerError(message)),
452            _ => Err(Error::InvalidResponse),
453        }
454    }
455
456    /// Delete a topic
457    pub async fn delete_topic(&mut self, name: impl Into<String>) -> Result<()> {
458        let request = Request::DeleteTopic { name: name.into() };
459        let response = self.send_request(request).await?;
460
461        match response {
462            Response::TopicDeleted => Ok(()),
463            Response::Error { message } => Err(Error::ServerError(message)),
464            _ => Err(Error::InvalidResponse),
465        }
466    }
467
468    /// Commit consumer offset
469    pub async fn commit_offset(
470        &mut self,
471        consumer_group: impl Into<String>,
472        topic: impl Into<String>,
473        partition: u32,
474        offset: u64,
475    ) -> Result<()> {
476        let request = Request::CommitOffset {
477            consumer_group: consumer_group.into(),
478            topic: topic.into(),
479            partition,
480            offset,
481        };
482
483        let response = self.send_request(request).await?;
484
485        match response {
486            Response::OffsetCommitted => Ok(()),
487            Response::Error { message } => Err(Error::ServerError(message)),
488            _ => Err(Error::InvalidResponse),
489        }
490    }
491
492    /// Get consumer offset
493    pub async fn get_offset(
494        &mut self,
495        consumer_group: impl Into<String>,
496        topic: impl Into<String>,
497        partition: u32,
498    ) -> Result<Option<u64>> {
499        let request = Request::GetOffset {
500            consumer_group: consumer_group.into(),
501            topic: topic.into(),
502            partition,
503        };
504
505        let response = self.send_request(request).await?;
506
507        match response {
508            Response::Offset { offset } => Ok(offset),
509            Response::Error { message } => Err(Error::ServerError(message)),
510            _ => Err(Error::InvalidResponse),
511        }
512    }
513
514    /// Get earliest and latest offsets for a topic partition
515    ///
516    /// Returns (earliest, latest) where:
517    /// - earliest: First available offset (messages before this are deleted/compacted)
518    /// - latest: Next offset to be assigned (one past the last message)
519    pub async fn get_offset_bounds(
520        &mut self,
521        topic: impl Into<String>,
522        partition: u32,
523    ) -> Result<(u64, u64)> {
524        let request = Request::GetOffsetBounds {
525            topic: topic.into(),
526            partition,
527        };
528
529        let response = self.send_request(request).await?;
530
531        match response {
532            Response::OffsetBounds { earliest, latest } => Ok((earliest, latest)),
533            Response::Error { message } => Err(Error::ServerError(message)),
534            _ => Err(Error::InvalidResponse),
535        }
536    }
537
538    /// Get topic metadata
539    pub async fn get_metadata(&mut self, topic: impl Into<String>) -> Result<(String, u32)> {
540        let request = Request::GetMetadata {
541            topic: topic.into(),
542        };
543
544        let response = self.send_request(request).await?;
545
546        match response {
547            Response::Metadata { name, partitions } => Ok((name, partitions)),
548            Response::Error { message } => Err(Error::ServerError(message)),
549            _ => Err(Error::InvalidResponse),
550        }
551    }
552
553    /// Ping the server
554    pub async fn ping(&mut self) -> Result<()> {
555        let request = Request::Ping;
556        let response = self.send_request(request).await?;
557
558        match response {
559            Response::Pong => Ok(()),
560            Response::Error { message } => Err(Error::ServerError(message)),
561            _ => Err(Error::InvalidResponse),
562        }
563    }
564
565    /// Register a schema
566    pub async fn register_schema(
567        &mut self,
568        subject: impl Into<String>,
569        schema: impl Into<String>,
570    ) -> Result<i32> {
571        let request = Request::RegisterSchema {
572            subject: subject.into(),
573            schema: schema.into(),
574        };
575
576        let response = self.send_request(request).await?;
577
578        match response {
579            Response::SchemaRegistered { id } => Ok(id),
580            Response::Error { message } => Err(Error::ServerError(message)),
581            _ => Err(Error::InvalidResponse),
582        }
583    }
584
585    /// Get a schema
586    pub async fn get_schema(&mut self, id: i32) -> Result<String> {
587        let request = Request::GetSchema { id };
588
589        let response = self.send_request(request).await?;
590
591        match response {
592            Response::Schema { id: _, schema } => Ok(schema),
593            Response::Error { message } => Err(Error::ServerError(message)),
594            _ => Err(Error::InvalidResponse),
595        }
596    }
597
598    /// List all consumer groups
599    pub async fn list_groups(&mut self) -> Result<Vec<String>> {
600        let request = Request::ListGroups;
601
602        let response = self.send_request(request).await?;
603
604        match response {
605            Response::Groups { groups } => Ok(groups),
606            Response::Error { message } => Err(Error::ServerError(message)),
607            _ => Err(Error::InvalidResponse),
608        }
609    }
610
611    /// Describe a consumer group (get all committed offsets)
612    pub async fn describe_group(
613        &mut self,
614        consumer_group: impl Into<String>,
615    ) -> Result<std::collections::HashMap<String, std::collections::HashMap<u32, u64>>> {
616        let request = Request::DescribeGroup {
617            consumer_group: consumer_group.into(),
618        };
619
620        let response = self.send_request(request).await?;
621
622        match response {
623            Response::GroupDescription { offsets, .. } => Ok(offsets),
624            Response::Error { message } => Err(Error::ServerError(message)),
625            _ => Err(Error::InvalidResponse),
626        }
627    }
628
629    /// Delete a consumer group
630    pub async fn delete_group(&mut self, consumer_group: impl Into<String>) -> Result<()> {
631        let request = Request::DeleteGroup {
632            consumer_group: consumer_group.into(),
633        };
634
635        let response = self.send_request(request).await?;
636
637        match response {
638            Response::GroupDeleted => Ok(()),
639            Response::Error { message } => Err(Error::ServerError(message)),
640            _ => Err(Error::InvalidResponse),
641        }
642    }
643
644    /// Get the first offset with timestamp >= the given timestamp
645    ///
646    /// # Arguments
647    /// * `topic` - The topic name
648    /// * `partition` - The partition number
649    /// * `timestamp_ms` - Timestamp in milliseconds since Unix epoch
650    ///
651    /// # Returns
652    /// * `Some(offset)` - The first offset with message timestamp >= timestamp_ms
653    /// * `None` - No messages found with timestamp >= timestamp_ms
654    pub async fn get_offset_for_timestamp(
655        &mut self,
656        topic: impl Into<String>,
657        partition: u32,
658        timestamp_ms: i64,
659    ) -> Result<Option<u64>> {
660        let request = Request::GetOffsetForTimestamp {
661            topic: topic.into(),
662            partition,
663            timestamp_ms,
664        };
665
666        let response = self.send_request(request).await?;
667
668        match response {
669            Response::OffsetForTimestamp { offset } => Ok(offset),
670            Response::Error { message } => Err(Error::ServerError(message)),
671            _ => Err(Error::InvalidResponse),
672        }
673    }
674}
675
676// ============================================================================
677// Authentication Session
678// ============================================================================
679
680/// Authentication session information
681#[derive(Debug, Clone)]
682pub struct AuthSession {
683    /// Session ID for subsequent requests
684    pub session_id: String,
685    /// Session timeout in seconds
686    pub expires_in: u64,
687}
688
689// ============================================================================
690// SCRAM Helper Functions
691// ============================================================================
692
693/// Generate a random nonce for SCRAM authentication
694fn generate_nonce() -> String {
695    use rand::Rng;
696    let mut rng = rand::thread_rng();
697    let nonce_bytes: Vec<u8> = (0..24).map(|_| rng.gen()).collect();
698    base64_encode(&nonce_bytes)
699}
700
701/// Escape username for SCRAM (RFC 5802)
702fn escape_username(username: &str) -> String {
703    username.replace('=', "=3D").replace(',', "=2C")
704}
705
706/// Parse server-first message
707fn parse_server_first(server_first: &str) -> Result<(String, String, u32)> {
708    let mut nonce = None;
709    let mut salt = None;
710    let mut iterations = None;
711
712    for attr in server_first.split(',') {
713        if let Some(value) = attr.strip_prefix("r=") {
714            nonce = Some(value.to_string());
715        } else if let Some(value) = attr.strip_prefix("s=") {
716            salt = Some(value.to_string());
717        } else if let Some(value) = attr.strip_prefix("i=") {
718            iterations = Some(
719                value
720                    .parse::<u32>()
721                    .map_err(|_| Error::AuthenticationFailed("Invalid iteration count".into()))?,
722            );
723        }
724    }
725
726    let nonce = nonce.ok_or_else(|| Error::AuthenticationFailed("Missing nonce".into()))?;
727    let salt = salt.ok_or_else(|| Error::AuthenticationFailed("Missing salt".into()))?;
728    let iterations =
729        iterations.ok_or_else(|| Error::AuthenticationFailed("Missing iterations".into()))?;
730
731    Ok((nonce, salt, iterations))
732}
733
734/// PBKDF2-HMAC-SHA256 key derivation
735fn pbkdf2_sha256(password: &[u8], salt: &[u8], iterations: u32) -> Vec<u8> {
736    let mut result = vec![0u8; 32];
737
738    // U1 = PRF(Password, Salt || INT(1))
739    let mut u = PasswordHash::hmac_sha256(password, &[salt, &1u32.to_be_bytes()].concat());
740    result.copy_from_slice(&u);
741
742    // Ui = PRF(Password, Ui-1)
743    for _ in 1..iterations {
744        u = PasswordHash::hmac_sha256(password, &u);
745        for (r, ui) in result.iter_mut().zip(u.iter()) {
746            *r ^= ui;
747        }
748    }
749
750    result
751}
752
753/// SHA-256 hash
754fn sha256(data: &[u8]) -> Vec<u8> {
755    let mut hasher = Sha256::new();
756    hasher.update(data);
757    hasher.finalize().to_vec()
758}
759
760/// XOR two byte arrays
761fn xor_bytes(a: &[u8], b: &[u8]) -> Vec<u8> {
762    a.iter().zip(b.iter()).map(|(x, y)| x ^ y).collect()
763}
764
765/// Base64 encode
766fn base64_encode(data: &[u8]) -> String {
767    use base64::{engine::general_purpose::STANDARD, Engine};
768    STANDARD.encode(data)
769}
770
771/// Base64 decode
772fn base64_decode(data: &str) -> std::result::Result<Vec<u8>, base64::DecodeError> {
773    use base64::{engine::general_purpose::STANDARD, Engine};
774    STANDARD.decode(data)
775}
776
777// ============================================================================
778// Tests
779// ============================================================================
780
781#[cfg(test)]
782mod tests {
783    use super::*;
784
785    #[test]
786    fn test_escape_username() {
787        assert_eq!(escape_username("alice"), "alice");
788        assert_eq!(escape_username("user=name"), "user=3Dname");
789        assert_eq!(escape_username("user,name"), "user=2Cname");
790        assert_eq!(escape_username("user=,name"), "user=3D=2Cname");
791    }
792
793    #[test]
794    fn test_parse_server_first() {
795        let server_first = "r=clientnonce+servernonce,s=c2FsdA==,i=4096";
796        let (nonce, salt, iterations) = parse_server_first(server_first).unwrap();
797
798        assert_eq!(nonce, "clientnonce+servernonce");
799        assert_eq!(salt, "c2FsdA==");
800        assert_eq!(iterations, 4096);
801    }
802
803    #[test]
804    fn test_parse_server_first_missing_nonce() {
805        let server_first = "s=c2FsdA==,i=4096";
806        assert!(parse_server_first(server_first).is_err());
807    }
808
809    #[test]
810    fn test_parse_server_first_missing_salt() {
811        let server_first = "r=nonce,i=4096";
812        assert!(parse_server_first(server_first).is_err());
813    }
814
815    #[test]
816    fn test_parse_server_first_missing_iterations() {
817        let server_first = "r=nonce,s=c2FsdA==";
818        assert!(parse_server_first(server_first).is_err());
819    }
820
821    #[test]
822    fn test_xor_bytes() {
823        assert_eq!(xor_bytes(&[0xFF, 0x00], &[0xFF, 0xFF]), vec![0x00, 0xFF]);
824        assert_eq!(xor_bytes(&[0x12, 0x34], &[0x12, 0x34]), vec![0x00, 0x00]);
825    }
826
827    #[test]
828    fn test_base64_roundtrip() {
829        let data = b"hello world";
830        let encoded = base64_encode(data);
831        let decoded = base64_decode(&encoded).unwrap();
832        assert_eq!(decoded, data);
833    }
834
835    #[test]
836    fn test_sha256() {
837        // SHA-256 of empty string
838        let hash = sha256(b"");
839        assert_eq!(hash.len(), 32);
840        // Known hash value
841        assert_eq!(
842            hex::encode(&hash),
843            "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"
844        );
845    }
846
847    #[test]
848    fn test_pbkdf2_sha256() {
849        // Test vector from RFC 7914 (derived from RFC 6070)
850        let password = b"password";
851        let salt = b"salt";
852        let iterations = 1;
853
854        let result = pbkdf2_sha256(password, salt, iterations);
855        assert_eq!(result.len(), 32);
856        // The result should be deterministic
857        let result2 = pbkdf2_sha256(password, salt, iterations);
858        assert_eq!(result, result2);
859    }
860
861    #[test]
862    fn test_generate_nonce() {
863        let nonce1 = generate_nonce();
864        let nonce2 = generate_nonce();
865
866        // Nonces should be non-empty
867        assert!(!nonce1.is_empty());
868        assert!(!nonce2.is_empty());
869
870        // Nonces should be different (with overwhelming probability)
871        assert_ne!(nonce1, nonce2);
872
873        // Should be valid base64
874        assert!(base64_decode(&nonce1).is_ok());
875    }
876}