commonware_sync/net/
mod.rs1use bytes::{Buf, BufMut};
2use commonware_codec::{DecodeExt, Encode, EncodeSize, Error, Read, ReadExt, ReadRangeExt, Write};
3use std::mem::size_of;
4
5pub const MAX_MESSAGE_SIZE: u32 = 10 * 1024 * 1024;
7
8pub mod request_id;
9pub use request_id::RequestId;
10pub mod io;
11pub mod resolver;
12pub mod wire;
13pub use resolver::Resolver;
14
15pub(super) trait Message: Encode + DecodeExt<()> + Sized + Send + Sync + 'static {
17 fn request_id(&self) -> RequestId;
18}
19
20#[derive(Debug, Clone)]
22pub enum ErrorCode {
23 InvalidRequest,
25 DatabaseError,
27 NetworkError,
29 Timeout,
31 InternalError,
33}
34
35impl Write for ErrorCode {
36 fn write(&self, buf: &mut impl BufMut) {
37 let discriminant = match self {
38 Self::InvalidRequest => 0u8,
39 Self::DatabaseError => 1u8,
40 Self::NetworkError => 2u8,
41 Self::Timeout => 3u8,
42 Self::InternalError => 4u8,
43 };
44 discriminant.write(buf);
45 }
46}
47
48impl EncodeSize for ErrorCode {
49 fn encode_size(&self) -> usize {
50 size_of::<u8>()
51 }
52}
53
54impl Read for ErrorCode {
55 type Cfg = ();
56
57 fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, Error> {
58 let discriminant = u8::read(buf)?;
59 match discriminant {
60 0 => Ok(Self::InvalidRequest),
61 1 => Ok(Self::DatabaseError),
62 2 => Ok(Self::NetworkError),
63 3 => Ok(Self::Timeout),
64 4 => Ok(Self::InternalError),
65 _ => Err(Error::InvalidEnum(discriminant)),
66 }
67 }
68}
69
70#[derive(Debug, Clone)]
72pub struct ErrorResponse {
73 pub request_id: RequestId,
75 pub error_code: ErrorCode,
77 pub message: String,
79}
80
81impl Write for ErrorResponse {
82 fn write(&self, buf: &mut impl BufMut) {
83 self.request_id.write(buf);
84 self.error_code.write(buf);
85 self.message.as_bytes().to_vec().write(buf);
86 }
87}
88
89impl EncodeSize for ErrorResponse {
90 fn encode_size(&self) -> usize {
91 self.request_id.encode_size()
92 + self.error_code.encode_size()
93 + self.message.as_bytes().to_vec().encode_size()
94 }
95}
96
97impl Read for ErrorResponse {
98 type Cfg = ();
99
100 fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, Error> {
101 let request_id = RequestId::read_cfg(buf, &())?;
102 let error_code = ErrorCode::read(buf)?;
103 let message_bytes = Vec::<u8>::read_range(buf, 0..=MAX_MESSAGE_SIZE as usize)?;
104 let message = String::from_utf8(message_bytes)
105 .map_err(|_| Error::Invalid("ErrorResponse", "invalid UTF-8 in message"))?;
106 Ok(Self {
107 request_id,
108 error_code,
109 message,
110 })
111 }
112}
113
114#[cfg(test)]
115mod tests {
116 use crate::net::{request_id::Generator, wire::GetOperationsRequest, ErrorCode};
117 use commonware_codec::{DecodeExt as _, Encode as _};
118 use commonware_storage::mmr::Location;
119 use commonware_utils::NZU64;
120 use rstest::rstest;
121
122 #[rstest]
123 #[case(ErrorCode::InvalidRequest)]
124 #[case(ErrorCode::DatabaseError)]
125 #[case(ErrorCode::NetworkError)]
126 #[case(ErrorCode::Timeout)]
127 #[case(ErrorCode::InternalError)]
128 fn test_error_code_roundtrip_serialization(#[case] error_code: ErrorCode) {
129 let encoded = error_code.encode().to_vec();
131
132 let decoded = ErrorCode::decode(&encoded[..]).expect("Failed to decode ErrorCode");
134
135 match (&error_code, &decoded) {
137 (ErrorCode::InvalidRequest, ErrorCode::InvalidRequest) => {}
138 (ErrorCode::DatabaseError, ErrorCode::DatabaseError) => {}
139 (ErrorCode::NetworkError, ErrorCode::NetworkError) => {}
140 (ErrorCode::Timeout, ErrorCode::Timeout) => {}
141 (ErrorCode::InternalError, ErrorCode::InternalError) => {}
142 _ => panic!("ErrorCode roundtrip failed: {error_code:?} != {decoded:?}"),
143 }
144 }
145
146 #[test]
147 fn test_get_operations_request_validation() {
148 let requester = Generator::new();
150 let request = GetOperationsRequest {
151 request_id: requester.next(),
152 op_count: Location::new(100).unwrap(),
153 start_loc: Location::new(10).unwrap(),
154 max_ops: NZU64!(50),
155 };
156 assert!(request.validate().is_ok());
157
158 let request = GetOperationsRequest {
160 request_id: requester.next(),
161 op_count: Location::new(100).unwrap(),
162 start_loc: Location::new(100).unwrap(),
163 max_ops: NZU64!(50),
164 };
165 assert!(matches!(
166 request.validate(),
167 Err(crate::Error::InvalidRequest(_))
168 ));
169
170 let request = GetOperationsRequest {
172 request_id: requester.next(),
173 op_count: Location::new(100).unwrap(),
174 start_loc: Location::new(150).unwrap(),
175 max_ops: NZU64!(50),
176 };
177 assert!(matches!(
178 request.validate(),
179 Err(crate::Error::InvalidRequest(_))
180 ));
181 }
182}