use commonware_codec::{DecodeExt, Encode, EncodeSize, Error, Read, ReadExt, ReadRangeExt, Write};
use commonware_runtime::{Buf, BufMut};
use std::mem::size_of;
pub const MAX_MESSAGE_SIZE: u32 = 10 * 1024 * 1024;
pub mod request_id;
pub use request_id::RequestId;
pub mod io;
pub mod resolver;
pub mod wire;
pub use resolver::Resolver;
pub(super) trait Message: Encode + DecodeExt<()> + Sized + Send + Sync + 'static {
fn request_id(&self) -> RequestId;
}
#[derive(Debug, Clone)]
pub enum ErrorCode {
InvalidRequest,
DatabaseError,
NetworkError,
Timeout,
InternalError,
}
impl Write for ErrorCode {
fn write(&self, buf: &mut impl BufMut) {
let discriminant = match self {
Self::InvalidRequest => 0u8,
Self::DatabaseError => 1u8,
Self::NetworkError => 2u8,
Self::Timeout => 3u8,
Self::InternalError => 4u8,
};
discriminant.write(buf);
}
}
impl EncodeSize for ErrorCode {
fn encode_size(&self) -> usize {
size_of::<u8>()
}
}
impl Read for ErrorCode {
type Cfg = ();
fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, Error> {
let discriminant = u8::read(buf)?;
match discriminant {
0 => Ok(Self::InvalidRequest),
1 => Ok(Self::DatabaseError),
2 => Ok(Self::NetworkError),
3 => Ok(Self::Timeout),
4 => Ok(Self::InternalError),
_ => Err(Error::InvalidEnum(discriminant)),
}
}
}
#[derive(Debug, Clone)]
pub struct ErrorResponse {
pub request_id: RequestId,
pub error_code: ErrorCode,
pub message: String,
}
impl Write for ErrorResponse {
fn write(&self, buf: &mut impl BufMut) {
self.request_id.write(buf);
self.error_code.write(buf);
self.message.as_bytes().to_vec().write(buf);
}
}
impl EncodeSize for ErrorResponse {
fn encode_size(&self) -> usize {
self.request_id.encode_size()
+ self.error_code.encode_size()
+ self.message.as_bytes().to_vec().encode_size()
}
}
impl Read for ErrorResponse {
type Cfg = ();
fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, Error> {
let request_id = RequestId::read_cfg(buf, &())?;
let error_code = ErrorCode::read(buf)?;
let message_bytes = Vec::<u8>::read_range(buf, 0..=MAX_MESSAGE_SIZE as usize)?;
let message = String::from_utf8(message_bytes)
.map_err(|_| Error::Invalid("ErrorResponse", "invalid UTF-8 in message"))?;
Ok(Self {
request_id,
error_code,
message,
})
}
}
#[cfg(test)]
mod tests {
use crate::net::{request_id::Generator, wire::GetOperationsRequest, ErrorCode};
use commonware_codec::{DecodeExt as _, Encode as _};
use commonware_storage::mmr::Location;
use commonware_utils::NZU64;
use rstest::rstest;
#[rstest]
#[case(ErrorCode::InvalidRequest)]
#[case(ErrorCode::DatabaseError)]
#[case(ErrorCode::NetworkError)]
#[case(ErrorCode::Timeout)]
#[case(ErrorCode::InternalError)]
fn test_error_code_roundtrip_serialization(#[case] error_code: ErrorCode) {
let encoded = error_code.encode().to_vec();
let decoded = ErrorCode::decode(&encoded[..]).expect("Failed to decode ErrorCode");
match (&error_code, &decoded) {
(ErrorCode::InvalidRequest, ErrorCode::InvalidRequest) => {}
(ErrorCode::DatabaseError, ErrorCode::DatabaseError) => {}
(ErrorCode::NetworkError, ErrorCode::NetworkError) => {}
(ErrorCode::Timeout, ErrorCode::Timeout) => {}
(ErrorCode::InternalError, ErrorCode::InternalError) => {}
_ => panic!("ErrorCode roundtrip failed: {error_code:?} != {decoded:?}"),
}
}
#[test]
fn test_get_operations_request_validation() {
let requester = Generator::new();
let request = GetOperationsRequest {
request_id: requester.next(),
op_count: Location::new(100),
start_loc: Location::new(10),
max_ops: NZU64!(50),
include_pinned_nodes: false,
};
assert!(request.validate().is_ok());
let request = GetOperationsRequest {
request_id: requester.next(),
op_count: Location::new(100),
start_loc: Location::new(100),
max_ops: NZU64!(50),
include_pinned_nodes: false,
};
assert!(matches!(
request.validate(),
Err(crate::Error::InvalidRequest(_))
));
let request = GetOperationsRequest {
request_id: requester.next(),
op_count: Location::new(100),
start_loc: Location::new(150),
max_ops: NZU64!(50),
include_pinned_nodes: false,
};
assert!(matches!(
request.validate(),
Err(crate::Error::InvalidRequest(_))
));
}
}