use crate::auth::types::AuthMessage;
#[cfg(feature = "http")]
use crate::auth::types::MessageType;
use crate::{Error, Result};
use async_trait::async_trait;
use std::sync::Arc;
use std::sync::RwLock as StdRwLock;
use tokio::sync::RwLock;
#[async_trait]
pub trait Transport: Send + Sync {
async fn send(&self, message: &AuthMessage) -> Result<()>;
fn set_callback(&self, callback: Box<TransportCallback>);
fn clear_callback(&self);
}
pub type TransportCallback = dyn Fn(AuthMessage) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<()>> + Send>>
+ Send
+ Sync;
pub mod headers {
pub const VERSION: &str = "x-bsv-auth-version";
pub const IDENTITY_KEY: &str = "x-bsv-auth-identity-key";
pub const NONCE: &str = "x-bsv-auth-nonce";
pub const YOUR_NONCE: &str = "x-bsv-auth-your-nonce";
pub const SIGNATURE: &str = "x-bsv-auth-signature";
pub const MESSAGE_TYPE: &str = "x-bsv-auth-message-type";
pub const REQUEST_ID: &str = "x-bsv-auth-request-id";
pub const REQUESTED_CERTIFICATES: &str = "x-bsv-auth-requested-certificates";
}
#[derive(Debug, Clone)]
pub struct HttpRequest {
pub request_id: [u8; 32],
pub method: String,
pub path: String,
pub search: String,
pub headers: Vec<(String, String)>,
pub body: Vec<u8>,
}
impl HttpRequest {
pub fn url_postfix(&self) -> String {
format!("{}{}", self.path, self.search)
}
}
impl HttpRequest {
pub fn from_payload(payload: &[u8]) -> Result<Self> {
let mut cursor = 0;
if payload.len() < 32 {
return Err(Error::AuthError("Payload too short for request ID".into()));
}
let mut request_id = [0u8; 32];
request_id.copy_from_slice(&payload[..32]);
cursor += 32;
let (method_len, bytes_read) = read_varint(&payload[cursor..])?;
cursor += bytes_read;
let method = if method_len > 0 {
let len = method_len as usize;
if cursor + len > payload.len() {
return Err(Error::AuthError("Payload too short for method".into()));
}
let s = String::from_utf8(payload[cursor..cursor + len].to_vec())
.map_err(|e| Error::AuthError(format!("Invalid method UTF-8: {}", e)))?;
cursor += len;
s
} else {
"GET".to_string()
};
let (path_len, bytes_read) = read_varint(&payload[cursor..])?;
cursor += bytes_read;
let path = if path_len > 0 {
let len = path_len as usize;
if cursor + len > payload.len() {
return Err(Error::AuthError("Payload too short for path".into()));
}
let s = String::from_utf8(payload[cursor..cursor + len].to_vec())
.map_err(|e| Error::AuthError(format!("Invalid path UTF-8: {}", e)))?;
cursor += len;
s
} else {
String::new()
};
let (search_len, bytes_read) = read_varint(&payload[cursor..])?;
cursor += bytes_read;
let search = if search_len > 0 {
let len = search_len as usize;
if cursor + len > payload.len() {
return Err(Error::AuthError("Payload too short for search".into()));
}
let s = String::from_utf8(payload[cursor..cursor + len].to_vec())
.map_err(|e| Error::AuthError(format!("Invalid search UTF-8: {}", e)))?;
cursor += len;
s
} else {
String::new()
};
let (header_count, bytes_read) = read_varint(&payload[cursor..])?;
cursor += bytes_read;
let count = if header_count > 0 {
header_count as usize
} else {
0
};
let mut headers = Vec::with_capacity(count);
for _ in 0..count {
let (key_len, bytes_read) = read_varint(&payload[cursor..])?;
cursor += bytes_read;
let klen = if key_len > 0 { key_len as usize } else { 0 };
if cursor + klen > payload.len() {
return Err(Error::AuthError("Payload too short for header key".into()));
}
let key = String::from_utf8(payload[cursor..cursor + klen].to_vec())
.map_err(|e| Error::AuthError(format!("Invalid header key UTF-8: {}", e)))?;
cursor += klen;
let (val_len, bytes_read) = read_varint(&payload[cursor..])?;
cursor += bytes_read;
let vlen = if val_len > 0 { val_len as usize } else { 0 };
if cursor + vlen > payload.len() {
return Err(Error::AuthError(
"Payload too short for header value".into(),
));
}
let value = String::from_utf8(payload[cursor..cursor + vlen].to_vec())
.map_err(|e| Error::AuthError(format!("Invalid header value UTF-8: {}", e)))?;
cursor += vlen;
headers.push((key, value));
}
let (body_len, bytes_read) = read_varint(&payload[cursor..])?;
cursor += bytes_read;
let body = if body_len > 0 {
let len = body_len as usize;
if cursor + len > payload.len() {
return Err(Error::AuthError("Payload too short for body".into()));
}
payload[cursor..cursor + len].to_vec()
} else {
Vec::new()
};
Ok(Self {
request_id,
method,
path,
search,
headers,
body,
})
}
pub fn to_payload(&self) -> Vec<u8> {
let mut payload = Vec::new();
payload.extend_from_slice(&self.request_id);
let method_bytes = self.method.as_bytes();
payload.extend(write_varint(method_bytes.len() as i64));
payload.extend_from_slice(method_bytes);
if self.path.is_empty() {
payload.extend(write_varint(-1));
} else {
let path_bytes = self.path.as_bytes();
payload.extend(write_varint(path_bytes.len() as i64));
payload.extend_from_slice(path_bytes);
}
if self.search.is_empty() {
payload.extend(write_varint(-1));
} else {
let search_bytes = self.search.as_bytes();
payload.extend(write_varint(search_bytes.len() as i64));
payload.extend_from_slice(search_bytes);
}
payload.extend(write_varint(self.headers.len() as i64));
for (key, value) in &self.headers {
let key_bytes = key.as_bytes();
payload.extend(write_varint(key_bytes.len() as i64));
payload.extend_from_slice(key_bytes);
let val_bytes = value.as_bytes();
payload.extend(write_varint(val_bytes.len() as i64));
payload.extend_from_slice(val_bytes);
}
if self.body.is_empty() {
payload.extend(write_varint(-1));
} else {
payload.extend(write_varint(self.body.len() as i64));
payload.extend_from_slice(&self.body);
}
payload
}
}
#[derive(Debug, Clone)]
pub struct HttpResponse {
pub request_id: [u8; 32],
pub status: u16,
pub headers: Vec<(String, String)>,
pub body: Vec<u8>,
}
impl HttpResponse {
pub fn to_payload(&self) -> Vec<u8> {
let mut payload = Vec::new();
payload.extend_from_slice(&self.request_id);
payload.extend(write_varint(self.status as i64));
payload.extend(write_varint(self.headers.len() as i64));
for (key, value) in &self.headers {
let key_bytes = key.as_bytes();
payload.extend(write_varint(key_bytes.len() as i64));
payload.extend_from_slice(key_bytes);
let val_bytes = value.as_bytes();
payload.extend(write_varint(val_bytes.len() as i64));
payload.extend_from_slice(val_bytes);
}
if self.body.is_empty() {
payload.extend(write_varint(-1));
} else {
payload.extend(write_varint(self.body.len() as i64));
payload.extend_from_slice(&self.body);
}
payload
}
pub fn from_payload(payload: &[u8]) -> Result<Self> {
let mut cursor = 0;
if payload.len() < 32 {
return Err(Error::AuthError(
"Response payload too short for request ID".into(),
));
}
let mut request_id = [0u8; 32];
request_id.copy_from_slice(&payload[..32]);
cursor += 32;
let (status, bytes_read) = read_varint(&payload[cursor..])?;
cursor += bytes_read;
let status = status as u16;
let (header_count, bytes_read) = read_varint(&payload[cursor..])?;
cursor += bytes_read;
let count = if header_count > 0 {
header_count as usize
} else {
0
};
let mut headers = Vec::with_capacity(count);
for _ in 0..count {
let (key_len, bytes_read) = read_varint(&payload[cursor..])?;
cursor += bytes_read;
let klen = if key_len > 0 { key_len as usize } else { 0 };
if cursor + klen > payload.len() {
return Err(Error::AuthError(
"Response payload too short for header key".into(),
));
}
let key = String::from_utf8(payload[cursor..cursor + klen].to_vec())
.map_err(|e| Error::AuthError(format!("Invalid header key UTF-8: {}", e)))?;
cursor += klen;
let (val_len, bytes_read) = read_varint(&payload[cursor..])?;
cursor += bytes_read;
let vlen = if val_len > 0 { val_len as usize } else { 0 };
if cursor + vlen > payload.len() {
return Err(Error::AuthError(
"Response payload too short for header value".into(),
));
}
let value = String::from_utf8(payload[cursor..cursor + vlen].to_vec())
.map_err(|e| Error::AuthError(format!("Invalid header value UTF-8: {}", e)))?;
cursor += vlen;
headers.push((key, value));
}
let (body_len, bytes_read) = read_varint(&payload[cursor..])?;
cursor += bytes_read;
let body = if body_len > 0 {
let len = body_len as usize;
if cursor + len <= payload.len() {
payload[cursor..cursor + len].to_vec()
} else {
Vec::new()
}
} else {
Vec::new()
};
Ok(Self {
request_id,
status,
headers,
body,
})
}
}
fn read_varint(bytes: &[u8]) -> Result<(i64, usize)> {
if bytes.is_empty() {
return Err(Error::AuthError("Empty varint".into()));
}
let first = bytes[0];
if first < 253 {
Ok((first as i64, 1))
} else if first == 253 {
if bytes.len() < 3 {
return Err(Error::AuthError("Incomplete varint (fd)".into()));
}
let value = u16::from_le_bytes([bytes[1], bytes[2]]);
Ok((value as i64, 3))
} else if first == 254 {
if bytes.len() < 5 {
return Err(Error::AuthError("Incomplete varint (fe)".into()));
}
let value = u32::from_le_bytes([bytes[1], bytes[2], bytes[3], bytes[4]]);
Ok((value as i64, 5))
} else {
if bytes.len() < 9 {
return Err(Error::AuthError("Incomplete varint (ff)".into()));
}
let value = u64::from_le_bytes([
bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7], bytes[8],
]);
if value == u64::MAX {
Ok((-1, 9))
} else {
Ok((value as i64, 9))
}
}
}
fn write_varint(value: i64) -> Vec<u8> {
if value < 0 {
vec![0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF]
} else if value < 253 {
vec![value as u8]
} else if value < 0x10000 {
let v = value as u16;
let bytes = v.to_le_bytes();
vec![0xFD, bytes[0], bytes[1]]
} else if value < 0x100000000 {
let v = value as u32;
let bytes = v.to_le_bytes();
vec![0xFE, bytes[0], bytes[1], bytes[2], bytes[3]]
} else {
let v = value as u64;
let bytes = v.to_le_bytes();
vec![
0xFF, bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7],
]
}
}
pub struct SimplifiedFetchTransport {
base_url: String,
#[cfg(feature = "http")]
client: reqwest::Client,
callback: Arc<StdRwLock<Option<Box<TransportCallback>>>>,
}
impl std::fmt::Debug for SimplifiedFetchTransport {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SimplifiedFetchTransport")
.field("base_url", &self.base_url)
.finish()
}
}
impl SimplifiedFetchTransport {
pub fn new(base_url: &str) -> Self {
Self {
base_url: base_url.trim_end_matches('/').to_string(),
#[cfg(feature = "http")]
client: reqwest::Client::new(),
callback: Arc::new(StdRwLock::new(None)),
}
}
pub fn base_url(&self) -> &str {
&self.base_url
}
#[cfg_attr(not(feature = "http"), allow(dead_code))]
fn auth_url(&self) -> String {
format!("{}/.well-known/auth", self.base_url)
}
pub fn message_to_headers(&self, message: &AuthMessage) -> Vec<(String, String)> {
let mut headers = Vec::new();
headers.push((headers::VERSION.to_string(), message.version.clone()));
headers.push((
headers::IDENTITY_KEY.to_string(),
message.identity_key.to_hex(),
));
headers.push((
headers::MESSAGE_TYPE.to_string(),
message.message_type.as_str().to_string(),
));
if let Some(ref nonce) = message.nonce {
headers.push((headers::NONCE.to_string(), nonce.clone()));
}
if let Some(ref your_nonce) = message.your_nonce {
headers.push((headers::YOUR_NONCE.to_string(), your_nonce.clone()));
}
if let Some(ref sig) = message.signature {
headers.push((
headers::SIGNATURE.to_string(),
crate::primitives::to_base64(sig),
));
}
if let Some(ref req_certs) = message.requested_certificates {
if let Ok(json) = serde_json::to_string(req_certs) {
headers.push((headers::REQUESTED_CERTIFICATES.to_string(), json));
}
}
headers
}
pub fn headers_to_message_fields(
&self,
header_map: &[(String, String)],
) -> Result<(Option<String>, Option<String>, Option<Vec<u8>>)> {
#![allow(clippy::type_complexity)]
let mut nonce = None;
let mut your_nonce = None;
let mut signature = None;
for (key, value) in header_map {
match key.to_lowercase().as_str() {
k if k == headers::NONCE.to_lowercase() => {
nonce = Some(value.clone());
}
k if k == headers::YOUR_NONCE.to_lowercase() => {
your_nonce = Some(value.clone());
}
k if k == headers::SIGNATURE.to_lowercase() => {
signature = Some(crate::primitives::from_base64(value)?);
}
_ => {}
}
}
Ok((nonce, your_nonce, signature))
}
#[cfg_attr(not(feature = "http"), allow(dead_code))]
async fn invoke_callback(&self, message: AuthMessage) -> Result<()> {
let future_opt = {
let guard = self
.callback
.read()
.map_err(|_| Error::AuthError("Failed to acquire callback lock".into()))?;
(*guard).as_ref().map(|cb| cb(message))
};
if let Some(future) = future_opt {
future.await?;
}
Ok(())
}
}
#[async_trait]
impl Transport for SimplifiedFetchTransport {
async fn send(&self, message: &AuthMessage) -> Result<()> {
#[cfg(not(feature = "http"))]
{
let _ = message;
return Err(Error::AuthError(
"HTTP transport requires the 'http' feature".into(),
));
}
#[cfg(feature = "http")]
{
match message.message_type {
MessageType::InitialRequest
| MessageType::InitialResponse
| MessageType::CertificateRequest
| MessageType::CertificateResponse => {
let response = self
.client
.post(self.auth_url())
.json(message)
.send()
.await
.map_err(|e| Error::AuthError(format!("HTTP request failed: {}", e)))?;
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
return Err(Error::AuthError(format!(
"Auth endpoint returned {}: {}",
status, body
)));
}
let response_text = response.text().await.map_err(|e| {
Error::AuthError(format!("Failed to read auth response: {}", e))
})?;
let response_message: AuthMessage = serde_json::from_str(&response_text)
.map_err(|e| {
Error::AuthError(format!(
"Failed to parse auth response: {} - body: {}",
e, response_text
))
})?;
self.invoke_callback(response_message).await?;
}
MessageType::General => {
let payload = message.payload.as_ref().ok_or_else(|| {
Error::AuthError("General message must have payload".into())
})?;
let http_request = HttpRequest::from_payload(payload)?;
let url = format!("{}{}", self.base_url, http_request.url_postfix());
let mut request_builder = match http_request.method.to_uppercase().as_str() {
"GET" => self.client.get(&url),
"POST" => self.client.post(&url),
"PUT" => self.client.put(&url),
"DELETE" => self.client.delete(&url),
"PATCH" => self.client.patch(&url),
"HEAD" => self.client.head(&url),
_ => self.client.request(
reqwest::Method::from_bytes(http_request.method.as_bytes())
.unwrap_or(reqwest::Method::GET),
&url,
),
};
request_builder = request_builder
.header(headers::VERSION, &message.version)
.header(headers::IDENTITY_KEY, message.identity_key.to_hex());
if let Some(ref nonce) = message.nonce {
request_builder = request_builder.header(headers::NONCE, nonce);
}
if let Some(ref your_nonce) = message.your_nonce {
request_builder = request_builder.header(headers::YOUR_NONCE, your_nonce);
}
if let Some(ref sig) = message.signature {
request_builder = request_builder
.header(headers::SIGNATURE, crate::primitives::to_hex(sig));
}
request_builder = request_builder.header(
headers::REQUEST_ID,
crate::primitives::to_base64(&http_request.request_id),
);
for (key, value) in &http_request.headers {
let lower_key = key.to_lowercase();
if lower_key.starts_with("x-bsv-")
|| lower_key == "authorization"
|| lower_key == "content-type"
{
if !lower_key.starts_with("x-bsv-auth") {
request_builder =
request_builder.header(key.as_str(), value.as_str());
}
}
}
if !http_request.body.is_empty() {
request_builder = request_builder.body(http_request.body.clone());
}
let response = request_builder
.send()
.await
.map_err(|e| Error::AuthError(format!("HTTP request failed: {}", e)))?;
let response_status = response.status().as_u16();
let response_headers = response.headers().clone();
let response_body = response
.bytes()
.await
.map_err(|e| Error::AuthError(format!("Failed to read response: {}", e)))?
.to_vec();
let response_request_id: [u8; 32] =
if let Some(rid) = response_headers.get(headers::REQUEST_ID) {
let rid_str = rid.to_str().unwrap_or_default();
crate::primitives::from_base64(rid_str)?
.try_into()
.map_err(|_| Error::AuthError("Invalid request ID length".into()))?
} else {
http_request.request_id };
let mut included_headers: Vec<(String, String)> = Vec::new();
for (key, value) in response_headers.iter() {
let key_str = key.as_str().to_lowercase();
if (key_str.starts_with("x-bsv-") || key_str == "authorization")
&& !key_str.starts_with("x-bsv-auth")
{
if let Ok(v) = value.to_str() {
included_headers.push((key_str, v.to_string()));
}
}
}
included_headers.sort_by(|a, b| a.0.cmp(&b.0));
let http_response = HttpResponse {
request_id: response_request_id,
status: response_status,
headers: included_headers,
body: response_body,
};
let response_identity = if let Some(resp_identity_key) = response_headers
.get(headers::IDENTITY_KEY)
.and_then(|v| v.to_str().ok())
{
crate::primitives::PublicKey::from_hex(resp_identity_key)?
} else {
message.identity_key.clone()
};
let mut response_message =
AuthMessage::new(MessageType::General, response_identity);
response_message.payload = Some(http_response.to_payload());
if let Some(nonce) = response_headers.get(headers::NONCE) {
response_message.nonce = nonce.to_str().ok().map(String::from);
}
if let Some(your_nonce) = response_headers.get(headers::YOUR_NONCE) {
response_message.your_nonce = your_nonce.to_str().ok().map(String::from);
}
if let Some(sig) = response_headers.get(headers::SIGNATURE) {
let sig_str = sig.to_str().unwrap_or_default();
response_message.signature = crate::primitives::from_hex(sig_str)
.or_else(|_| crate::primitives::from_base64(sig_str))
.ok();
}
if let Some(msg_type) = response_headers.get(headers::MESSAGE_TYPE) {
if msg_type.to_str().ok() == Some("certificateRequest") {
if let Some(req_certs) =
response_headers.get(headers::REQUESTED_CERTIFICATES)
{
if let Ok(requested) =
serde_json::from_str(req_certs.to_str().unwrap_or("{}"))
{
response_message.requested_certificates = Some(requested);
}
}
}
}
self.invoke_callback(response_message).await?;
}
}
Ok(())
}
}
fn set_callback(&self, callback: Box<TransportCallback>) {
if let Ok(mut cb) = self.callback.write() {
*cb = Some(callback);
}
}
fn clear_callback(&self) {
if let Ok(mut cb) = self.callback.write() {
*cb = None;
}
}
}
#[derive(Default)]
pub struct MockTransport {
sent_messages: Arc<RwLock<Vec<AuthMessage>>>,
response_queue: Arc<RwLock<Vec<AuthMessage>>>,
callback: Arc<RwLock<Option<Box<TransportCallback>>>>,
}
impl std::fmt::Debug for MockTransport {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MockTransport")
.field("sent_messages", &"<messages>")
.field("response_queue", &"<queue>")
.field("callback", &"<callback>")
.finish()
}
}
impl MockTransport {
pub fn new() -> Self {
Self::default()
}
pub async fn queue_response(&self, message: AuthMessage) {
let mut queue = self.response_queue.write().await;
queue.push(message);
}
pub async fn get_sent_messages(&self) -> Vec<AuthMessage> {
let sent = self.sent_messages.read().await;
sent.clone()
}
pub async fn clear_sent(&self) {
let mut sent = self.sent_messages.write().await;
sent.clear();
}
pub async fn receive_message(&self, message: AuthMessage) -> Result<()> {
let callback = self.callback.read().await;
if let Some(ref cb) = *callback {
cb(message).await?;
}
Ok(())
}
}
#[async_trait]
impl Transport for MockTransport {
async fn send(&self, message: &AuthMessage) -> Result<()> {
{
let mut sent = self.sent_messages.write().await;
sent.push(message.clone());
}
let response = {
let mut queue = self.response_queue.write().await;
if !queue.is_empty() {
Some(queue.remove(0))
} else {
None
}
};
if let Some(resp) = response {
let callback = self.callback.read().await;
if let Some(ref cb) = *callback {
cb(resp).await?;
}
}
Ok(())
}
fn set_callback(&self, callback: Box<TransportCallback>) {
let callback_store = self.callback.clone();
tokio::spawn(async move {
let mut cb = callback_store.write().await;
*cb = Some(callback);
});
}
fn clear_callback(&self) {
let callback_store = self.callback.clone();
tokio::spawn(async move {
let mut cb = callback_store.write().await;
*cb = None;
});
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::auth::types::MessageType;
#[test]
fn test_simplified_fetch_transport_new() {
let transport = SimplifiedFetchTransport::new("https://example.com/");
assert_eq!(transport.base_url(), "https://example.com");
assert_eq!(transport.auth_url(), "https://example.com/.well-known/auth");
}
#[test]
fn test_simplified_fetch_transport_trailing_slash() {
let transport = SimplifiedFetchTransport::new("https://example.com///");
assert_eq!(transport.base_url(), "https://example.com");
}
#[tokio::test]
async fn test_mock_transport_send_and_receive() {
let transport = MockTransport::new();
let received = Arc::new(RwLock::new(Vec::new()));
let received_clone = received.clone();
transport.set_callback(Box::new(move |msg| {
let received = received_clone.clone();
Box::pin(async move {
let mut r = received.write().await;
r.push(msg);
Ok(())
})
}));
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
let response = AuthMessage::new(
MessageType::InitialResponse,
crate::primitives::PrivateKey::random().public_key(),
);
transport.queue_response(response.clone()).await;
let request = AuthMessage::new(
MessageType::InitialRequest,
crate::primitives::PrivateKey::random().public_key(),
);
transport.send(&request).await.unwrap();
let sent = transport.get_sent_messages().await;
assert_eq!(sent.len(), 1);
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
let recv = received.read().await;
assert_eq!(recv.len(), 1);
}
#[test]
fn test_varint_roundtrip() {
let test_values: Vec<i64> = vec![0, 1, 127, 128, 252, 253, 255, 256, 65535, 65536, 1000000];
for value in test_values {
let encoded = write_varint(value);
let (decoded, _bytes_read) = read_varint(&encoded).unwrap();
assert_eq!(decoded, value, "Varint roundtrip failed for {}", value);
}
let encoded = write_varint(-1);
let (decoded, _bytes_read) = read_varint(&encoded).unwrap();
assert_eq!(decoded, -1, "Varint roundtrip failed for -1");
}
#[test]
fn test_varint_encoding_sizes() {
assert_eq!(write_varint(0).len(), 1);
assert_eq!(write_varint(127).len(), 1);
assert_eq!(write_varint(252).len(), 1);
assert_eq!(write_varint(253).len(), 3);
assert_eq!(write_varint(65535).len(), 3);
assert_eq!(write_varint(65536).len(), 5);
assert_eq!(write_varint(-1).len(), 9);
}
#[test]
fn test_varint_empty_error() {
let result = read_varint(&[]);
assert!(result.is_err());
}
#[test]
fn test_http_request_roundtrip() {
let request = HttpRequest {
request_id: [42u8; 32],
method: "POST".to_string(),
path: "/api/v1/users".to_string(),
search: "?foo=bar".to_string(),
headers: vec![
("content-type".to_string(), "application/json".to_string()),
("x-bsv-custom".to_string(), "value".to_string()),
],
body: b"hello world".to_vec(),
};
let payload = request.to_payload();
let decoded = HttpRequest::from_payload(&payload).unwrap();
assert_eq!(decoded.request_id, request.request_id);
assert_eq!(decoded.method, request.method);
assert_eq!(decoded.path, request.path);
assert_eq!(decoded.search, request.search);
assert_eq!(decoded.url_postfix(), "/api/v1/users?foo=bar");
assert_eq!(decoded.headers, request.headers);
assert_eq!(decoded.body, request.body);
}
#[test]
fn test_http_request_get_default() {
let mut payload = vec![0u8; 32]; payload.extend(write_varint(0)); payload.extend(write_varint(-1)); payload.extend(write_varint(-1)); payload.extend(write_varint(0)); payload.extend(write_varint(-1));
let request = HttpRequest::from_payload(&payload).unwrap();
assert_eq!(request.method, "GET");
}
#[test]
fn test_http_request_with_large_body() {
let request = HttpRequest {
request_id: [1u8; 32],
method: "PUT".to_string(),
path: "/data".to_string(),
search: String::new(),
headers: vec![],
body: vec![0xAB; 10000], };
let payload = request.to_payload();
let decoded = HttpRequest::from_payload(&payload).unwrap();
assert_eq!(decoded.body.len(), 10000);
assert_eq!(decoded.body[0], 0xAB);
}
#[test]
fn test_http_request_payload_too_short() {
let payload = vec![0u8; 10]; let result = HttpRequest::from_payload(&payload);
assert!(result.is_err());
}
#[test]
fn test_http_response_roundtrip() {
let response = HttpResponse {
request_id: [99u8; 32],
status: 200,
headers: vec![
("content-type".to_string(), "text/plain".to_string()),
("x-bsv-data".to_string(), "abc".to_string()),
],
body: b"OK".to_vec(),
};
let payload = response.to_payload();
let decoded = HttpResponse::from_payload(&payload).unwrap();
assert_eq!(decoded.request_id, response.request_id);
assert_eq!(decoded.status, response.status);
assert_eq!(decoded.headers, response.headers);
assert_eq!(decoded.body, response.body);
}
#[test]
fn test_http_response_status_codes() {
for status in [200, 201, 400, 401, 403, 404, 500, 503] {
let response = HttpResponse {
request_id: [0u8; 32],
status,
headers: vec![],
body: vec![],
};
let payload = response.to_payload();
let decoded = HttpResponse::from_payload(&payload).unwrap();
assert_eq!(decoded.status, status);
}
}
#[test]
fn test_http_response_empty_body() {
let response = HttpResponse {
request_id: [0u8; 32],
status: 204, headers: vec![],
body: vec![],
};
let payload = response.to_payload();
let decoded = HttpResponse::from_payload(&payload).unwrap();
assert!(decoded.body.is_empty());
}
#[test]
fn test_message_to_headers() {
let transport = SimplifiedFetchTransport::new("https://example.com");
let key = crate::primitives::PrivateKey::random().public_key();
let mut msg = AuthMessage::new(MessageType::General, key.clone());
msg.nonce = Some("test-nonce".to_string());
msg.your_nonce = Some("peer-nonce".to_string());
msg.signature = Some(vec![0x30, 0x44]);
let headers = transport.message_to_headers(&msg);
let headers_map: std::collections::HashMap<_, _> = headers.into_iter().collect();
assert_eq!(headers_map.get(headers::VERSION), Some(&"0.1".to_string()));
assert_eq!(headers_map.get(headers::IDENTITY_KEY), Some(&key.to_hex()));
assert_eq!(
headers_map.get(headers::NONCE),
Some(&"test-nonce".to_string())
);
assert_eq!(
headers_map.get(headers::YOUR_NONCE),
Some(&"peer-nonce".to_string())
);
assert!(headers_map.contains_key(headers::SIGNATURE));
}
#[test]
fn test_headers_to_message_fields() {
let transport = SimplifiedFetchTransport::new("https://example.com");
let headers = vec![
(headers::NONCE.to_string(), "test-nonce".to_string()),
(headers::YOUR_NONCE.to_string(), "peer-nonce".to_string()),
(
headers::SIGNATURE.to_string(),
crate::primitives::to_base64(&[0x30, 0x44]),
),
];
let (nonce, your_nonce, signature) = transport.headers_to_message_fields(&headers).unwrap();
assert_eq!(nonce, Some("test-nonce".to_string()));
assert_eq!(your_nonce, Some("peer-nonce".to_string()));
assert_eq!(signature, Some(vec![0x30, 0x44]));
}
}