1use commonware_codec::{DecodeExt, Encode, EncodeSize, Error, Read, ReadExt, ReadRangeExt, Write};
2use commonware_runtime::{Buf, BufMut};
3use std::mem::size_of;
4
5pub const MAX_MESSAGE_SIZE: u32 = 10 * 1024 * 1024;
7
8pub mod request_id;
9pub use request_id::RequestId;
10pub mod io;
11pub mod resolver;
12pub mod wire;
13pub use resolver::Resolver;
14
15pub(super) trait Message: Encode + DecodeExt<()> + Sized + Send + Sync + 'static {
17 fn request_id(&self) -> RequestId;
18}
19
20#[derive(Debug, Clone)]
22pub enum ErrorCode {
23 InvalidRequest,
25 DatabaseError,
27 NetworkError,
29 StaleTarget,
31 Timeout,
33 InternalError,
35}
36
37impl Write for ErrorCode {
38 fn write(&self, buf: &mut impl BufMut) {
39 let discriminant = match self {
40 Self::InvalidRequest => 0u8,
41 Self::DatabaseError => 1u8,
42 Self::NetworkError => 2u8,
43 Self::StaleTarget => 3u8,
44 Self::Timeout => 4u8,
45 Self::InternalError => 5u8,
46 };
47 discriminant.write(buf);
48 }
49}
50
51impl EncodeSize for ErrorCode {
52 fn encode_size(&self) -> usize {
53 size_of::<u8>()
54 }
55}
56
57impl Read for ErrorCode {
58 type Cfg = ();
59
60 fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, Error> {
61 let discriminant = u8::read(buf)?;
62 match discriminant {
63 0 => Ok(Self::InvalidRequest),
64 1 => Ok(Self::DatabaseError),
65 2 => Ok(Self::NetworkError),
66 3 => Ok(Self::StaleTarget),
67 4 => Ok(Self::Timeout),
68 5 => Ok(Self::InternalError),
69 _ => Err(Error::InvalidEnum(discriminant)),
70 }
71 }
72}
73
74#[derive(Debug, Clone)]
76pub struct ErrorResponse {
77 pub request_id: RequestId,
79 pub error_code: ErrorCode,
81 pub message: String,
83}
84
85impl Write for ErrorResponse {
86 fn write(&self, buf: &mut impl BufMut) {
87 self.request_id.write(buf);
88 self.error_code.write(buf);
89 self.message.as_bytes().to_vec().write(buf);
90 }
91}
92
93impl EncodeSize for ErrorResponse {
94 fn encode_size(&self) -> usize {
95 self.request_id.encode_size()
96 + self.error_code.encode_size()
97 + self.message.as_bytes().to_vec().encode_size()
98 }
99}
100
101impl Read for ErrorResponse {
102 type Cfg = ();
103
104 fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, Error> {
105 let request_id = RequestId::read_cfg(buf, &())?;
106 let error_code = ErrorCode::read(buf)?;
107 let message_bytes = Vec::<u8>::read_range(buf, 0..=MAX_MESSAGE_SIZE as usize)?;
108 let message = String::from_utf8(message_bytes)
109 .map_err(|_| Error::Invalid("ErrorResponse", "invalid UTF-8 in message"))?;
110 Ok(Self {
111 request_id,
112 error_code,
113 message,
114 })
115 }
116}
117
118#[cfg(test)]
119mod tests {
120 use crate::{
121 keyless_compact,
122 net::{request_id::Generator, wire, wire::GetOperationsRequest, ErrorCode},
123 };
124 use commonware_codec::{DecodeExt as _, Encode as _};
125 use commonware_cryptography::sha256;
126 use commonware_storage::{mmr::Location, qmdb::sync::compact::State};
127 use commonware_utils::NZU64;
128 use rstest::rstest;
129
130 #[rstest]
131 #[case(ErrorCode::InvalidRequest)]
132 #[case(ErrorCode::DatabaseError)]
133 #[case(ErrorCode::NetworkError)]
134 #[case(ErrorCode::StaleTarget)]
135 #[case(ErrorCode::Timeout)]
136 #[case(ErrorCode::InternalError)]
137 fn test_error_code_roundtrip_serialization(#[case] error_code: ErrorCode) {
138 let encoded = error_code.encode().to_vec();
140
141 let decoded = ErrorCode::decode(&encoded[..]).expect("Failed to decode ErrorCode");
143
144 match (&error_code, &decoded) {
146 (ErrorCode::InvalidRequest, ErrorCode::InvalidRequest) => {}
147 (ErrorCode::DatabaseError, ErrorCode::DatabaseError) => {}
148 (ErrorCode::NetworkError, ErrorCode::NetworkError) => {}
149 (ErrorCode::StaleTarget, ErrorCode::StaleTarget) => {}
150 (ErrorCode::Timeout, ErrorCode::Timeout) => {}
151 (ErrorCode::InternalError, ErrorCode::InternalError) => {}
152 _ => panic!("ErrorCode roundtrip failed: {error_code:?} != {decoded:?}"),
153 }
154 }
155
156 #[test]
157 fn test_get_operations_request_validation() {
158 let requester = Generator::new();
160 let request = GetOperationsRequest {
161 request_id: requester.next(),
162 op_count: Location::new(100),
163 start_loc: Location::new(10),
164 max_ops: NZU64!(50),
165 include_pinned_nodes: false,
166 };
167 assert!(request.validate().is_ok());
168
169 let request = GetOperationsRequest {
171 request_id: requester.next(),
172 op_count: Location::new(100),
173 start_loc: Location::new(100),
174 max_ops: NZU64!(50),
175 include_pinned_nodes: false,
176 };
177 assert!(matches!(
178 request.validate(),
179 Err(crate::Error::InvalidRequest(_))
180 ));
181
182 let request = GetOperationsRequest {
184 request_id: requester.next(),
185 op_count: Location::new(100),
186 start_loc: Location::new(150),
187 max_ops: NZU64!(50),
188 include_pinned_nodes: false,
189 };
190 assert!(matches!(
191 request.validate(),
192 Err(crate::Error::InvalidRequest(_))
193 ));
194 }
195
196 #[test]
197 fn test_get_compact_state_response_roundtrip() {
198 let request_id = Generator::new().next();
199 let digest_a = sha256::Digest::from([7; 32]);
200 let digest_b = sha256::Digest::from([8; 32]);
201 let digest_c = sha256::Digest::from([10; 32]);
202 let message = wire::Message::GetCompactStateResponse(wire::GetCompactStateResponse {
203 request_id,
204 state: State {
205 leaf_count: Location::new(11),
206 pinned_nodes: vec![digest_a, digest_b],
207 last_commit_op: keyless_compact::Operation::Commit(None, Location::new(0)),
208 last_commit_proof: commonware_storage::mmr::Proof {
209 leaves: Location::new(11),
210 inactive_peaks: 0,
211 digests: vec![digest_c],
212 },
213 },
214 });
215
216 let encoded = message.encode().to_vec();
217 let decoded = wire::Message::<
218 keyless_compact::Operation,
219 commonware_cryptography::sha256::Digest,
220 >::decode(&encoded[..])
221 .expect("failed to decode compact response");
222
223 match decoded {
224 wire::Message::GetCompactStateResponse(response) => {
225 assert_eq!(response.request_id, request_id);
226 assert_eq!(response.state.leaf_count, Location::new(11));
227 assert_eq!(response.state.pinned_nodes.len(), 2);
228 }
229 other => panic!("unexpected message variant: {other:?}"),
230 }
231 }
232}