use crate::{Error, MessageData, Request, Response, Result};
use bytes::Bytes;
use rivven_core::PasswordHash;
use sha2::{Digest, Sha256};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::net::TcpStream;
use tracing::{debug, info};
#[cfg(feature = "tls")]
use std::net::SocketAddr;
#[cfg(feature = "tls")]
use rivven_core::tls::{TlsClientStream, TlsConfig, TlsConnector};
const DEFAULT_MAX_RESPONSE_SIZE: usize = 100 * 1024 * 1024;
#[allow(clippy::large_enum_variant)]
enum ClientStream {
Plaintext(TcpStream),
#[cfg(feature = "tls")]
Tls(TlsClientStream<TcpStream>),
}
impl AsyncRead for ClientStream {
fn poll_read(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &mut tokio::io::ReadBuf<'_>,
) -> std::task::Poll<std::io::Result<()>> {
match self.get_mut() {
ClientStream::Plaintext(s) => std::pin::Pin::new(s).poll_read(cx, buf),
#[cfg(feature = "tls")]
ClientStream::Tls(s) => std::pin::Pin::new(s).poll_read(cx, buf),
}
}
}
impl AsyncWrite for ClientStream {
fn poll_write(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
buf: &[u8],
) -> std::task::Poll<std::io::Result<usize>> {
match self.get_mut() {
ClientStream::Plaintext(s) => std::pin::Pin::new(s).poll_write(cx, buf),
#[cfg(feature = "tls")]
ClientStream::Tls(s) => std::pin::Pin::new(s).poll_write(cx, buf),
}
}
fn poll_flush(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<std::io::Result<()>> {
match self.get_mut() {
ClientStream::Plaintext(s) => std::pin::Pin::new(s).poll_flush(cx),
#[cfg(feature = "tls")]
ClientStream::Tls(s) => std::pin::Pin::new(s).poll_flush(cx),
}
}
fn poll_shutdown(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<std::io::Result<()>> {
match self.get_mut() {
ClientStream::Plaintext(s) => std::pin::Pin::new(s).poll_shutdown(cx),
#[cfg(feature = "tls")]
ClientStream::Tls(s) => std::pin::Pin::new(s).poll_shutdown(cx),
}
}
}
pub struct Client {
stream: ClientStream,
}
impl Client {
pub async fn connect(addr: &str) -> Result<Self> {
info!("Connecting to Rivven server at {}", addr);
let stream = TcpStream::connect(addr)
.await
.map_err(|e| Error::ConnectionError(e.to_string()))?;
Ok(Self {
stream: ClientStream::Plaintext(stream),
})
}
#[cfg(feature = "tls")]
pub async fn connect_tls(
addr: &str,
tls_config: &TlsConfig,
server_name: &str,
) -> Result<Self> {
info!("Connecting to Rivven server at {} with TLS", addr);
let socket_addr: SocketAddr = addr
.parse()
.map_err(|e| Error::ConnectionError(format!("Invalid address: {}", e)))?;
let connector = TlsConnector::new(tls_config)
.map_err(|e| Error::ConnectionError(format!("TLS config error: {}", e)))?;
let tls_stream = connector
.connect_tcp(socket_addr, server_name)
.await
.map_err(|e| Error::ConnectionError(format!("TLS connection error: {}", e)))?;
info!("TLS connection established to {} ({})", addr, server_name);
Ok(Self {
stream: ClientStream::Tls(tls_stream),
})
}
#[cfg(feature = "tls")]
pub async fn connect_mtls(
addr: &str,
cert_path: impl Into<std::path::PathBuf>,
key_path: impl Into<std::path::PathBuf>,
ca_path: impl Into<std::path::PathBuf> + Clone,
server_name: &str,
) -> Result<Self> {
let tls_config = TlsConfig::mtls_from_pem_files(cert_path, key_path, ca_path);
Self::connect_tls(addr, &tls_config, server_name).await
}
pub async fn authenticate(&mut self, username: &str, password: &str) -> Result<AuthSession> {
let request = Request::Authenticate {
username: username.to_string(),
password: password.to_string(),
};
let response = self.send_request(request).await?;
match response {
Response::Authenticated {
session_id,
expires_in,
} => {
info!("Authenticated as '{}'", username);
Ok(AuthSession {
session_id,
expires_in,
})
}
Response::Error { message } => Err(Error::AuthenticationFailed(message)),
_ => Err(Error::InvalidResponse),
}
}
pub async fn authenticate_scram(
&mut self,
username: &str,
password: &str,
) -> Result<AuthSession> {
let client_nonce = generate_nonce();
let client_first_bare = format!("n={},r={}", escape_username(username), client_nonce);
let client_first = format!("n,,{}", client_first_bare);
debug!("SCRAM: Sending client-first");
let request = Request::ScramClientFirst {
message: Bytes::from(client_first.clone()),
};
let response = self.send_request(request).await?;
let server_first = match response {
Response::ScramServerFirst { message } => String::from_utf8(message.to_vec())
.map_err(|_| Error::AuthenticationFailed("Invalid server-first encoding".into()))?,
Response::Error { message } => return Err(Error::AuthenticationFailed(message)),
_ => return Err(Error::InvalidResponse),
};
debug!("SCRAM: Received server-first");
let (combined_nonce, salt_b64, iterations) = parse_server_first(&server_first)?;
if !combined_nonce.starts_with(&client_nonce) {
return Err(Error::AuthenticationFailed("Server nonce mismatch".into()));
}
let salt = base64_decode(&salt_b64)
.map_err(|_| Error::AuthenticationFailed("Invalid salt encoding".into()))?;
let salted_password = pbkdf2_sha256(password.as_bytes(), &salt, iterations);
let client_key = PasswordHash::hmac_sha256(&salted_password, b"Client Key");
let stored_key = sha256(&client_key);
let client_final_without_proof = format!("c=biws,r={}", combined_nonce);
let auth_message = format!(
"{},{},{}",
client_first_bare, server_first, client_final_without_proof
);
let client_signature = PasswordHash::hmac_sha256(&stored_key, auth_message.as_bytes());
let client_proof = xor_bytes(&client_key, &client_signature);
let client_proof_b64 = base64_encode(&client_proof);
let client_final = format!("{},p={}", client_final_without_proof, client_proof_b64);
debug!("SCRAM: Sending client-final");
let request = Request::ScramClientFinal {
message: Bytes::from(client_final),
};
let response = self.send_request(request).await?;
match response {
Response::ScramServerFinal {
message,
session_id,
expires_in,
} => {
let server_final = String::from_utf8(message.to_vec()).map_err(|_| {
Error::AuthenticationFailed("Invalid server-final encoding".into())
})?;
if let Some(error_msg) = server_final.strip_prefix("e=") {
return Err(Error::AuthenticationFailed(error_msg.to_string()));
}
if let Some(verifier_b64) = server_final.strip_prefix("v=") {
let server_key = PasswordHash::hmac_sha256(&salted_password, b"Server Key");
let expected_server_sig =
PasswordHash::hmac_sha256(&server_key, auth_message.as_bytes());
let expected_verifier = base64_encode(&expected_server_sig);
if verifier_b64 != expected_verifier {
return Err(Error::AuthenticationFailed(
"Server verification failed".into(),
));
}
}
let session_id = session_id.ok_or_else(|| {
Error::AuthenticationFailed("No session ID in response".into())
})?;
let expires_in = expires_in
.ok_or_else(|| Error::AuthenticationFailed("No expiry in response".into()))?;
info!("SCRAM authentication successful for '{}'", username);
Ok(AuthSession {
session_id,
expires_in,
})
}
Response::Error { message } => Err(Error::AuthenticationFailed(message)),
_ => Err(Error::InvalidResponse),
}
}
async fn send_request(&mut self, request: Request) -> Result<Response> {
let request_bytes = request.to_bytes()?;
let len = request_bytes.len() as u32;
self.stream.write_all(&len.to_be_bytes()).await?;
self.stream.write_all(&request_bytes).await?;
self.stream.flush().await?;
let mut len_buf = [0u8; 4];
self.stream.read_exact(&mut len_buf).await?;
let msg_len = u32::from_be_bytes(len_buf) as usize;
if msg_len > DEFAULT_MAX_RESPONSE_SIZE {
return Err(Error::ResponseTooLarge(msg_len, DEFAULT_MAX_RESPONSE_SIZE));
}
let mut response_buf = vec![0u8; msg_len];
self.stream.read_exact(&mut response_buf).await?;
let response = Response::from_bytes(&response_buf)?;
Ok(response)
}
pub async fn publish(
&mut self,
topic: impl Into<String>,
value: impl Into<Bytes>,
) -> Result<u64> {
self.publish_with_key(topic, None::<Bytes>, value).await
}
pub async fn publish_with_key(
&mut self,
topic: impl Into<String>,
key: Option<impl Into<Bytes>>,
value: impl Into<Bytes>,
) -> Result<u64> {
let request = Request::Publish {
topic: topic.into(),
partition: None,
key: key.map(|k| k.into()),
value: value.into(),
};
let response = self.send_request(request).await?;
match response {
Response::Published { offset, .. } => Ok(offset),
Response::Error { message } => Err(Error::ServerError(message)),
_ => Err(Error::InvalidResponse),
}
}
pub async fn publish_to_partition(
&mut self,
topic: impl Into<String>,
partition: u32,
key: Option<impl Into<Bytes>>,
value: impl Into<Bytes>,
) -> Result<u64> {
let request = Request::Publish {
topic: topic.into(),
partition: Some(partition),
key: key.map(|k| k.into()),
value: value.into(),
};
let response = self.send_request(request).await?;
match response {
Response::Published { offset, .. } => Ok(offset),
Response::Error { message } => Err(Error::ServerError(message)),
_ => Err(Error::InvalidResponse),
}
}
pub async fn consume(
&mut self,
topic: impl Into<String>,
partition: u32,
offset: u64,
max_messages: usize,
) -> Result<Vec<MessageData>> {
self.consume_with_isolation(topic, partition, offset, max_messages, None)
.await
}
pub async fn consume_with_isolation(
&mut self,
topic: impl Into<String>,
partition: u32,
offset: u64,
max_messages: usize,
isolation_level: Option<u8>,
) -> Result<Vec<MessageData>> {
let request = Request::Consume {
topic: topic.into(),
partition,
offset,
max_messages,
isolation_level,
};
let response = self.send_request(request).await?;
match response {
Response::Messages { messages } => Ok(messages),
Response::Error { message } => Err(Error::ServerError(message)),
_ => Err(Error::InvalidResponse),
}
}
pub async fn consume_read_committed(
&mut self,
topic: impl Into<String>,
partition: u32,
offset: u64,
max_messages: usize,
) -> Result<Vec<MessageData>> {
self.consume_with_isolation(topic, partition, offset, max_messages, Some(1))
.await
}
pub async fn create_topic(
&mut self,
name: impl Into<String>,
partitions: Option<u32>,
) -> Result<u32> {
let name = name.into();
let request = Request::CreateTopic {
name: name.clone(),
partitions,
};
let response = self.send_request(request).await?;
match response {
Response::TopicCreated { partitions, .. } => Ok(partitions),
Response::Error { message } => Err(Error::ServerError(message)),
_ => Err(Error::InvalidResponse),
}
}
pub async fn list_topics(&mut self) -> Result<Vec<String>> {
let request = Request::ListTopics;
let response = self.send_request(request).await?;
match response {
Response::Topics { topics } => Ok(topics),
Response::Error { message } => Err(Error::ServerError(message)),
_ => Err(Error::InvalidResponse),
}
}
pub async fn delete_topic(&mut self, name: impl Into<String>) -> Result<()> {
let request = Request::DeleteTopic { name: name.into() };
let response = self.send_request(request).await?;
match response {
Response::TopicDeleted => Ok(()),
Response::Error { message } => Err(Error::ServerError(message)),
_ => Err(Error::InvalidResponse),
}
}
pub async fn commit_offset(
&mut self,
consumer_group: impl Into<String>,
topic: impl Into<String>,
partition: u32,
offset: u64,
) -> Result<()> {
let request = Request::CommitOffset {
consumer_group: consumer_group.into(),
topic: topic.into(),
partition,
offset,
};
let response = self.send_request(request).await?;
match response {
Response::OffsetCommitted => Ok(()),
Response::Error { message } => Err(Error::ServerError(message)),
_ => Err(Error::InvalidResponse),
}
}
pub async fn get_offset(
&mut self,
consumer_group: impl Into<String>,
topic: impl Into<String>,
partition: u32,
) -> Result<Option<u64>> {
let request = Request::GetOffset {
consumer_group: consumer_group.into(),
topic: topic.into(),
partition,
};
let response = self.send_request(request).await?;
match response {
Response::Offset { offset } => Ok(offset),
Response::Error { message } => Err(Error::ServerError(message)),
_ => Err(Error::InvalidResponse),
}
}
pub async fn get_offset_bounds(
&mut self,
topic: impl Into<String>,
partition: u32,
) -> Result<(u64, u64)> {
let request = Request::GetOffsetBounds {
topic: topic.into(),
partition,
};
let response = self.send_request(request).await?;
match response {
Response::OffsetBounds { earliest, latest } => Ok((earliest, latest)),
Response::Error { message } => Err(Error::ServerError(message)),
_ => Err(Error::InvalidResponse),
}
}
pub async fn get_metadata(&mut self, topic: impl Into<String>) -> Result<(String, u32)> {
let request = Request::GetMetadata {
topic: topic.into(),
};
let response = self.send_request(request).await?;
match response {
Response::Metadata { name, partitions } => Ok((name, partitions)),
Response::Error { message } => Err(Error::ServerError(message)),
_ => Err(Error::InvalidResponse),
}
}
pub async fn ping(&mut self) -> Result<()> {
let request = Request::Ping;
let response = self.send_request(request).await?;
match response {
Response::Pong => Ok(()),
Response::Error { message } => Err(Error::ServerError(message)),
_ => Err(Error::InvalidResponse),
}
}
pub async fn list_groups(&mut self) -> Result<Vec<String>> {
let request = Request::ListGroups;
let response = self.send_request(request).await?;
match response {
Response::Groups { groups } => Ok(groups),
Response::Error { message } => Err(Error::ServerError(message)),
_ => Err(Error::InvalidResponse),
}
}
pub async fn describe_group(
&mut self,
consumer_group: impl Into<String>,
) -> Result<std::collections::HashMap<String, std::collections::HashMap<u32, u64>>> {
let request = Request::DescribeGroup {
consumer_group: consumer_group.into(),
};
let response = self.send_request(request).await?;
match response {
Response::GroupDescription { offsets, .. } => Ok(offsets),
Response::Error { message } => Err(Error::ServerError(message)),
_ => Err(Error::InvalidResponse),
}
}
pub async fn delete_group(&mut self, consumer_group: impl Into<String>) -> Result<()> {
let request = Request::DeleteGroup {
consumer_group: consumer_group.into(),
};
let response = self.send_request(request).await?;
match response {
Response::GroupDeleted => Ok(()),
Response::Error { message } => Err(Error::ServerError(message)),
_ => Err(Error::InvalidResponse),
}
}
pub async fn get_offset_for_timestamp(
&mut self,
topic: impl Into<String>,
partition: u32,
timestamp_ms: i64,
) -> Result<Option<u64>> {
let request = Request::GetOffsetForTimestamp {
topic: topic.into(),
partition,
timestamp_ms,
};
let response = self.send_request(request).await?;
match response {
Response::OffsetForTimestamp { offset } => Ok(offset),
Response::Error { message } => Err(Error::ServerError(message)),
_ => Err(Error::InvalidResponse),
}
}
pub async fn describe_topic_configs(
&mut self,
topics: &[&str],
) -> Result<std::collections::HashMap<String, std::collections::HashMap<String, String>>> {
let request = Request::DescribeTopicConfigs {
topics: topics.iter().map(|s| s.to_string()).collect(),
};
let response = self.send_request(request).await?;
match response {
Response::TopicConfigsDescribed { configs } => {
let mut result = std::collections::HashMap::new();
for desc in configs {
let mut topic_configs = std::collections::HashMap::new();
for (key, value) in desc.configs {
topic_configs.insert(key, value.value);
}
result.insert(desc.topic, topic_configs);
}
Ok(result)
}
Response::Error { message } => Err(Error::ServerError(message)),
_ => Err(Error::InvalidResponse),
}
}
pub async fn alter_topic_config(
&mut self,
topic: impl Into<String>,
configs: &[(&str, Option<&str>)],
) -> Result<AlterTopicConfigResult> {
use rivven_protocol::TopicConfigEntry;
let request = Request::AlterTopicConfig {
topic: topic.into(),
configs: configs
.iter()
.map(|(k, v)| TopicConfigEntry {
key: k.to_string(),
value: v.map(|s| s.to_string()),
})
.collect(),
};
let response = self.send_request(request).await?;
match response {
Response::TopicConfigAltered {
topic,
changed_count,
} => Ok(AlterTopicConfigResult {
topic,
changed_count,
}),
Response::Error { message } => Err(Error::ServerError(message)),
_ => Err(Error::InvalidResponse),
}
}
pub async fn create_partitions(
&mut self,
topic: impl Into<String>,
new_partition_count: u32,
) -> Result<u32> {
let request = Request::CreatePartitions {
topic: topic.into(),
new_partition_count,
assignments: vec![], };
let response = self.send_request(request).await?;
match response {
Response::PartitionsCreated {
new_partition_count,
..
} => Ok(new_partition_count),
Response::Error { message } => Err(Error::ServerError(message)),
_ => Err(Error::InvalidResponse),
}
}
pub async fn delete_records(
&mut self,
topic: impl Into<String>,
partition_offsets: &[(u32, u64)],
) -> Result<Vec<DeleteRecordsResult>> {
let request = Request::DeleteRecords {
topic: topic.into(),
partition_offsets: partition_offsets.to_vec(),
};
let response = self.send_request(request).await?;
match response {
Response::RecordsDeleted { results, .. } => Ok(results),
Response::Error { message } => Err(Error::ServerError(message)),
_ => Err(Error::InvalidResponse),
}
}
pub async fn init_producer_id(
&mut self,
previous_producer_id: Option<u64>,
) -> Result<ProducerState> {
let request = Request::InitProducerId {
producer_id: previous_producer_id,
};
let response = self.send_request(request).await?;
match response {
Response::ProducerIdInitialized {
producer_id,
producer_epoch,
} => Ok(ProducerState {
producer_id,
producer_epoch,
next_sequence: 0,
}),
Response::Error { message } => Err(Error::ServerError(message)),
_ => Err(Error::InvalidResponse),
}
}
pub async fn publish_idempotent(
&mut self,
topic: impl Into<String>,
key: Option<impl Into<Bytes>>,
value: impl Into<Bytes>,
producer: &mut ProducerState,
) -> Result<(u64, u32, bool)> {
let sequence = producer.next_sequence;
producer.next_sequence += 1;
let request = Request::IdempotentPublish {
topic: topic.into(),
partition: None,
key: key.map(|k| k.into()),
value: value.into(),
producer_id: producer.producer_id,
producer_epoch: producer.producer_epoch,
sequence,
};
let response = self.send_request(request).await?;
match response {
Response::IdempotentPublished {
offset,
partition,
duplicate,
} => Ok((offset, partition, duplicate)),
Response::Error { message } => Err(Error::ServerError(message)),
_ => Err(Error::InvalidResponse),
}
}
pub async fn begin_transaction(
&mut self,
txn_id: impl Into<String>,
producer: &ProducerState,
timeout_ms: Option<u64>,
) -> Result<()> {
let request = Request::BeginTransaction {
txn_id: txn_id.into(),
producer_id: producer.producer_id,
producer_epoch: producer.producer_epoch,
timeout_ms,
};
let response = self.send_request(request).await?;
match response {
Response::TransactionStarted { .. } => Ok(()),
Response::Error { message } => Err(Error::ServerError(message)),
_ => Err(Error::InvalidResponse),
}
}
pub async fn add_partitions_to_txn(
&mut self,
txn_id: impl Into<String>,
producer: &ProducerState,
partitions: &[(&str, u32)],
) -> Result<usize> {
let request = Request::AddPartitionsToTxn {
txn_id: txn_id.into(),
producer_id: producer.producer_id,
producer_epoch: producer.producer_epoch,
partitions: partitions
.iter()
.map(|(t, p)| (t.to_string(), *p))
.collect(),
};
let response = self.send_request(request).await?;
match response {
Response::PartitionsAddedToTxn {
partition_count, ..
} => Ok(partition_count),
Response::Error { message } => Err(Error::ServerError(message)),
_ => Err(Error::InvalidResponse),
}
}
pub async fn publish_transactional(
&mut self,
txn_id: impl Into<String>,
topic: impl Into<String>,
key: Option<impl Into<Bytes>>,
value: impl Into<Bytes>,
producer: &mut ProducerState,
) -> Result<(u64, u32, i32)> {
let sequence = producer.next_sequence;
producer.next_sequence += 1;
let request = Request::TransactionalPublish {
txn_id: txn_id.into(),
topic: topic.into(),
partition: None,
key: key.map(|k| k.into()),
value: value.into(),
producer_id: producer.producer_id,
producer_epoch: producer.producer_epoch,
sequence,
};
let response = self.send_request(request).await?;
match response {
Response::TransactionalPublished {
offset,
partition,
sequence,
} => Ok((offset, partition, sequence)),
Response::Error { message } => Err(Error::ServerError(message)),
_ => Err(Error::InvalidResponse),
}
}
pub async fn add_offsets_to_txn(
&mut self,
txn_id: impl Into<String>,
producer: &ProducerState,
group_id: impl Into<String>,
offsets: &[(&str, u32, i64)],
) -> Result<()> {
let request = Request::AddOffsetsToTxn {
txn_id: txn_id.into(),
producer_id: producer.producer_id,
producer_epoch: producer.producer_epoch,
group_id: group_id.into(),
offsets: offsets
.iter()
.map(|(t, p, o)| (t.to_string(), *p, *o))
.collect(),
};
let response = self.send_request(request).await?;
match response {
Response::OffsetsAddedToTxn { .. } => Ok(()),
Response::Error { message } => Err(Error::ServerError(message)),
_ => Err(Error::InvalidResponse),
}
}
pub async fn commit_transaction(
&mut self,
txn_id: impl Into<String>,
producer: &ProducerState,
) -> Result<()> {
let request = Request::CommitTransaction {
txn_id: txn_id.into(),
producer_id: producer.producer_id,
producer_epoch: producer.producer_epoch,
};
let response = self.send_request(request).await?;
match response {
Response::TransactionCommitted { .. } => Ok(()),
Response::Error { message } => Err(Error::ServerError(message)),
_ => Err(Error::InvalidResponse),
}
}
pub async fn abort_transaction(
&mut self,
txn_id: impl Into<String>,
producer: &ProducerState,
) -> Result<()> {
let request = Request::AbortTransaction {
txn_id: txn_id.into(),
producer_id: producer.producer_id,
producer_epoch: producer.producer_epoch,
};
let response = self.send_request(request).await?;
match response {
Response::TransactionAborted { .. } => Ok(()),
Response::Error { message } => Err(Error::ServerError(message)),
_ => Err(Error::InvalidResponse),
}
}
}
#[derive(Debug, Clone)]
pub struct ProducerState {
pub producer_id: u64,
pub producer_epoch: u16,
pub next_sequence: i32,
}
#[derive(Debug, Clone)]
pub struct AlterTopicConfigResult {
pub topic: String,
pub changed_count: usize,
}
pub use rivven_protocol::DeleteRecordsResult;
#[derive(Debug, Clone)]
pub struct AuthSession {
pub session_id: String,
pub expires_in: u64,
}
fn generate_nonce() -> String {
use rand::Rng;
let mut rng = rand::thread_rng();
let nonce_bytes: Vec<u8> = (0..24).map(|_| rng.gen()).collect();
base64_encode(&nonce_bytes)
}
fn escape_username(username: &str) -> String {
username.replace('=', "=3D").replace(',', "=2C")
}
fn parse_server_first(server_first: &str) -> Result<(String, String, u32)> {
let mut nonce = None;
let mut salt = None;
let mut iterations = None;
for attr in server_first.split(',') {
if let Some(value) = attr.strip_prefix("r=") {
nonce = Some(value.to_string());
} else if let Some(value) = attr.strip_prefix("s=") {
salt = Some(value.to_string());
} else if let Some(value) = attr.strip_prefix("i=") {
iterations = Some(
value
.parse::<u32>()
.map_err(|_| Error::AuthenticationFailed("Invalid iteration count".into()))?,
);
}
}
let nonce = nonce.ok_or_else(|| Error::AuthenticationFailed("Missing nonce".into()))?;
let salt = salt.ok_or_else(|| Error::AuthenticationFailed("Missing salt".into()))?;
let iterations =
iterations.ok_or_else(|| Error::AuthenticationFailed("Missing iterations".into()))?;
Ok((nonce, salt, iterations))
}
fn pbkdf2_sha256(password: &[u8], salt: &[u8], iterations: u32) -> Vec<u8> {
let mut result = vec![0u8; 32];
let mut u = PasswordHash::hmac_sha256(password, &[salt, &1u32.to_be_bytes()].concat());
result.copy_from_slice(&u);
for _ in 1..iterations {
u = PasswordHash::hmac_sha256(password, &u);
for (r, ui) in result.iter_mut().zip(u.iter()) {
*r ^= ui;
}
}
result
}
fn sha256(data: &[u8]) -> Vec<u8> {
let mut hasher = Sha256::new();
hasher.update(data);
hasher.finalize().to_vec()
}
fn xor_bytes(a: &[u8], b: &[u8]) -> Vec<u8> {
a.iter().zip(b.iter()).map(|(x, y)| x ^ y).collect()
}
fn base64_encode(data: &[u8]) -> String {
use base64::{engine::general_purpose::STANDARD, Engine};
STANDARD.encode(data)
}
fn base64_decode(data: &str) -> std::result::Result<Vec<u8>, base64::DecodeError> {
use base64::{engine::general_purpose::STANDARD, Engine};
STANDARD.decode(data)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_escape_username() {
assert_eq!(escape_username("alice"), "alice");
assert_eq!(escape_username("user=name"), "user=3Dname");
assert_eq!(escape_username("user,name"), "user=2Cname");
assert_eq!(escape_username("user=,name"), "user=3D=2Cname");
}
#[test]
fn test_parse_server_first() {
let server_first = "r=clientnonce+servernonce,s=c2FsdA==,i=4096";
let (nonce, salt, iterations) = parse_server_first(server_first).unwrap();
assert_eq!(nonce, "clientnonce+servernonce");
assert_eq!(salt, "c2FsdA==");
assert_eq!(iterations, 4096);
}
#[test]
fn test_parse_server_first_missing_nonce() {
let server_first = "s=c2FsdA==,i=4096";
assert!(parse_server_first(server_first).is_err());
}
#[test]
fn test_parse_server_first_missing_salt() {
let server_first = "r=nonce,i=4096";
assert!(parse_server_first(server_first).is_err());
}
#[test]
fn test_parse_server_first_missing_iterations() {
let server_first = "r=nonce,s=c2FsdA==";
assert!(parse_server_first(server_first).is_err());
}
#[test]
fn test_xor_bytes() {
assert_eq!(xor_bytes(&[0xFF, 0x00], &[0xFF, 0xFF]), vec![0x00, 0xFF]);
assert_eq!(xor_bytes(&[0x12, 0x34], &[0x12, 0x34]), vec![0x00, 0x00]);
}
#[test]
fn test_base64_roundtrip() {
let data = b"hello world";
let encoded = base64_encode(data);
let decoded = base64_decode(&encoded).unwrap();
assert_eq!(decoded, data);
}
#[test]
fn test_sha256() {
let hash = sha256(b"");
assert_eq!(hash.len(), 32);
assert_eq!(
hex::encode(&hash),
"e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"
);
}
#[test]
fn test_pbkdf2_sha256() {
let password = b"password";
let salt = b"salt";
let iterations = 1;
let result = pbkdf2_sha256(password, salt, iterations);
assert_eq!(result.len(), 32);
let result2 = pbkdf2_sha256(password, salt, iterations);
assert_eq!(result, result2);
}
#[test]
fn test_generate_nonce() {
let nonce1 = generate_nonce();
let nonce2 = generate_nonce();
assert!(!nonce1.is_empty());
assert!(!nonce2.is_empty());
assert_ne!(nonce1, nonce2);
assert!(base64_decode(&nonce1).is_ok());
}
}