use std::io::Read;
use std::net::Ipv4Addr;
use std::str::from_utf8;
use bytebuffer::ByteBuffer;
use sha2::{Digest, Sha256};
use crate::encode::VoltError;
use crate::node::ConnInfo;
pub const PROTOCOL_VERSION: u8 = 1;
pub const PING_HANDLE: i64 = 1 << (63 - 1);
pub fn build_auth_message(user: Option<&str>, pass: Option<&str>) -> Result<Vec<u8>, VoltError> {
let mut buffer = ByteBuffer::new();
let version = [PROTOCOL_VERSION; 1];
buffer.write_u32(0);
buffer.write_bytes(&version);
buffer.write_bytes(&version);
buffer.write_string("database");
match user {
None => buffer.write_string(""),
Some(u) => buffer.write_string(u),
}
let password_bytes = pass.map(|p| p.as_bytes()).unwrap_or(&[]);
let mut hasher: Sha256 = Sha256::new();
Digest::update(&mut hasher, password_bytes);
buffer.write_bytes(&hasher.finalize());
buffer.set_wpos(0);
buffer.write_u32((buffer.len() - 4) as u32);
Ok(buffer.into_vec())
}
pub fn parse_auth_response(data: &[u8]) -> Result<ConnInfo, VoltError> {
let mut res = ByteBuffer::from_bytes(data);
let _version = res.read_u8()?;
let auth = res.read_u8()?;
if auth != 0 {
return Err(VoltError::AuthFailed);
}
let host_id = res.read_i32()?;
let connection = res.read_i64()?;
let _ = res.read_i64()?; let leader = res.read_i32()?;
let bs = (leader as u32).to_be_bytes();
let leader_addr = Ipv4Addr::from(bs);
let length = res.read_i32()?;
let mut build = vec![0; length as usize];
res.read_exact(&mut build)?;
let b = from_utf8(&build)?;
Ok(ConnInfo {
host_id,
connection,
leader_addr,
build: String::from(b),
})
}
pub fn read_message<R: Read>(reader: &mut R) -> Result<Vec<u8>, VoltError> {
use byteorder::{BigEndian, ReadBytesExt};
let len = reader.read_u32::<BigEndian>()?;
if len == 0 {
return Ok(Vec::new());
}
let mut data = vec![0u8; len as usize];
reader.read_exact(&mut data)?;
Ok(data)
}
pub fn parse_response_handle(data: &[u8]) -> Result<i64, VoltError> {
if data.len() < 9 {
return Err(VoltError::Other("Response too short".to_string()));
}
let mut buffer = ByteBuffer::from_bytes(&data[1..9]);
Ok(buffer.read_i64()?)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_build_auth_message_no_credentials() {
let msg = build_auth_message(None, None).unwrap();
assert!(!msg.is_empty());
let len = u32::from_be_bytes([msg[0], msg[1], msg[2], msg[3]]);
assert_eq!(len as usize, msg.len() - 4);
}
#[test]
fn test_build_auth_message_with_credentials() {
let msg = build_auth_message(Some("admin"), Some("password")).unwrap();
assert!(!msg.is_empty());
let no_cred_msg = build_auth_message(None, None).unwrap();
assert!(msg.len() > no_cred_msg.len());
}
#[test]
fn test_build_auth_message_user_only() {
let msg = build_auth_message(Some("testuser"), None).unwrap();
assert!(!msg.is_empty());
let len = u32::from_be_bytes([msg[0], msg[1], msg[2], msg[3]]);
assert_eq!(len as usize, msg.len() - 4);
}
#[test]
fn test_build_auth_message_contains_protocol_version() {
let msg = build_auth_message(None, None).unwrap();
assert_eq!(msg[4], PROTOCOL_VERSION);
assert_eq!(msg[5], PROTOCOL_VERSION);
}
#[test]
fn test_build_auth_message_contains_database() {
let msg = build_auth_message(None, None).unwrap();
let msg_str = String::from_utf8_lossy(&msg);
assert!(msg_str.contains("database"));
}
#[test]
fn test_parse_response_handle_valid() {
let data = vec![0u8, 0, 0, 0, 0, 0, 0, 0, 42];
let handle = parse_response_handle(&data).unwrap();
assert_eq!(handle, 42);
}
#[test]
fn test_parse_response_handle_negative() {
let data = vec![0u8, 255, 255, 255, 255, 255, 255, 255, 255];
let handle = parse_response_handle(&data).unwrap();
assert_eq!(handle, -1);
}
#[test]
fn test_parse_response_handle_too_short() {
let data = vec![0u8, 1, 2, 3]; let result = parse_response_handle(&data);
assert!(result.is_err());
}
#[test]
fn test_parse_auth_response_invalid_auth() {
let mut data = vec![1u8, 1]; data.extend_from_slice(&[0u8; 50]);
let result = parse_auth_response(&data);
assert!(matches!(result, Err(VoltError::AuthFailed)));
}
#[test]
fn test_ping_handle_constant() {
assert_eq!(PING_HANDLE, 1i64 << 62);
}
}