1use crate::Operation;
13use bytes::{Buf, BufMut};
14use commonware_codec::{
15 EncodeSize, Error as CodecError, RangeCfg, Read, ReadExt, ReadRangeExt as _, Write,
16};
17use commonware_cryptography::sha256::Digest;
18use commonware_storage::mmr::verification::Proof;
19use std::num::NonZeroU64;
20use thiserror::Error;
21
22pub const MAX_MESSAGE_SIZE: usize = 10 * 1024 * 1024;
24
25const MAX_DIGESTS: usize = 10_000;
27
28#[derive(Debug, Clone)]
30pub enum Message {
31 GetOperationsRequest(GetOperationsRequest),
33 GetOperationsResponse(GetOperationsResponse),
35 GetServerMetadataRequest,
37 GetServerMetadataResponse(GetServerMetadataResponse),
39 Error(ErrorResponse),
45}
46
47#[derive(Debug, Clone)]
49pub struct GetOperationsRequest {
50 pub size: u64,
52 pub start_loc: u64,
54 pub max_ops: NonZeroU64,
56}
57
58#[derive(Debug, Clone)]
60pub struct GetOperationsResponse {
61 pub proof: Proof<Digest>,
63 pub operations: Vec<Operation>,
65}
66
67#[derive(Debug, Clone)]
69pub struct GetServerMetadataResponse {
70 pub target_hash: Digest,
72 pub oldest_retained_loc: u64,
74 pub latest_op_loc: u64,
76}
77
78#[derive(Debug, Clone)]
80pub struct ErrorResponse {
81 pub error_code: ErrorCode,
83 pub message: String,
85}
86
87#[derive(Debug, Clone)]
89pub enum ErrorCode {
90 InvalidRequest,
92 DatabaseError,
94 NetworkError,
96 Timeout,
98 InternalError,
100}
101
102#[derive(Debug, Error)]
104pub enum ProtocolError {
105 #[error("Invalid request: {message}")]
106 InvalidRequest { message: String },
107
108 #[error("Database error: {0}")]
109 DatabaseError(#[from] commonware_storage::adb::Error),
110
111 #[error("Network error: {0}")]
112 NetworkError(String),
113}
114
115impl Write for Message {
116 fn write(&self, buf: &mut impl BufMut) {
117 match self {
118 Message::GetOperationsRequest(req) => {
119 0u8.write(buf);
120 req.write(buf);
121 }
122 Message::GetOperationsResponse(resp) => {
123 1u8.write(buf);
124 resp.write(buf);
125 }
126 Message::GetServerMetadataRequest => {
127 2u8.write(buf);
128 }
129 Message::GetServerMetadataResponse(resp) => {
130 3u8.write(buf);
131 resp.write(buf);
132 }
133 Message::Error(err) => {
134 4u8.write(buf);
135 err.write(buf);
136 }
137 }
138 }
139}
140
141impl EncodeSize for Message {
142 fn encode_size(&self) -> usize {
143 1 + match self {
145 Message::GetOperationsRequest(req) => req.encode_size(),
146 Message::GetOperationsResponse(resp) => resp.encode_size(),
147 Message::GetServerMetadataRequest => 0,
148 Message::GetServerMetadataResponse(resp) => resp.encode_size(),
149 Message::Error(err) => err.encode_size(),
150 }
151 }
152}
153
154impl Read for Message {
155 type Cfg = ();
156
157 fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, CodecError> {
158 let discriminant = u8::read(buf)?;
159 match discriminant {
160 0 => Ok(Message::GetOperationsRequest(GetOperationsRequest::read(
161 buf,
162 )?)),
163 1 => Ok(Message::GetOperationsResponse(GetOperationsResponse::read(
164 buf,
165 )?)),
166 2 => Ok(Message::GetServerMetadataRequest),
167 3 => Ok(Message::GetServerMetadataResponse(
168 GetServerMetadataResponse::read(buf)?,
169 )),
170 4 => Ok(Message::Error(ErrorResponse::read(buf)?)),
171 _ => Err(CodecError::InvalidEnum(discriminant)),
172 }
173 }
174}
175
176impl Write for GetOperationsRequest {
177 fn write(&self, buf: &mut impl BufMut) {
178 self.size.write(buf);
179 self.start_loc.write(buf);
180 self.max_ops.get().write(buf);
181 }
182}
183
184impl EncodeSize for GetOperationsRequest {
185 fn encode_size(&self) -> usize {
186 self.size.encode_size() + self.start_loc.encode_size() + self.max_ops.get().encode_size()
187 }
188}
189
190impl Read for GetOperationsRequest {
191 type Cfg = ();
192
193 fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, CodecError> {
194 let size = u64::read(buf)?;
195 let start_loc = u64::read(buf)?;
196 let max_ops_raw = u64::read(buf)?;
197 let max_ops = NonZeroU64::new(max_ops_raw)
198 .ok_or_else(|| CodecError::Invalid("GetOperationsRequest", "max_ops cannot be zero"))?;
199 Ok(Self {
200 size,
201 start_loc,
202 max_ops,
203 })
204 }
205}
206
207impl Write for GetOperationsResponse {
208 fn write(&self, buf: &mut impl BufMut) {
209 self.proof.write(buf);
210 self.operations.write(buf);
211 }
212}
213
214impl EncodeSize for GetOperationsResponse {
215 fn encode_size(&self) -> usize {
216 self.proof.encode_size() + self.operations.encode_size()
217 }
218}
219
220impl Read for GetOperationsResponse {
221 type Cfg = ();
222
223 fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, CodecError> {
224 let proof = Proof::read_cfg(buf, &MAX_DIGESTS)?;
225 let operations = {
226 let range_cfg = RangeCfg::from(0..=MAX_DIGESTS);
227 Vec::<Operation>::read_cfg(buf, &(range_cfg, ()))?
228 };
229
230 Ok(Self { proof, operations })
231 }
232}
233
234impl Write for GetServerMetadataResponse {
235 fn write(&self, buf: &mut impl BufMut) {
236 self.target_hash.write(buf);
237 self.oldest_retained_loc.write(buf);
238 self.latest_op_loc.write(buf);
239 }
240}
241
242impl EncodeSize for GetServerMetadataResponse {
243 fn encode_size(&self) -> usize {
244 self.target_hash.encode_size()
245 + self.oldest_retained_loc.encode_size()
246 + self.latest_op_loc.encode_size()
247 }
248}
249
250impl Read for GetServerMetadataResponse {
251 type Cfg = ();
252
253 fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, CodecError> {
254 let target_hash = Digest::read(buf)?;
255 let oldest_retained_loc = u64::read(buf)?;
256 let latest_op_loc = u64::read(buf)?;
257 Ok(Self {
258 target_hash,
259 oldest_retained_loc,
260 latest_op_loc,
261 })
262 }
263}
264
265impl Write for ErrorResponse {
266 fn write(&self, buf: &mut impl BufMut) {
267 self.error_code.write(buf);
268 self.message.as_bytes().to_vec().write(buf);
269 }
270}
271
272impl EncodeSize for ErrorResponse {
273 fn encode_size(&self) -> usize {
274 self.error_code.encode_size() + self.message.as_bytes().to_vec().encode_size()
275 }
276}
277
278impl Read for ErrorResponse {
279 type Cfg = ();
280
281 fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, CodecError> {
282 let error_code = ErrorCode::read(buf)?;
283 let message_bytes = Vec::<u8>::read_range(buf, 0..=MAX_MESSAGE_SIZE)?;
285 let message = String::from_utf8(message_bytes)
286 .map_err(|_| CodecError::Invalid("ErrorResponse", "invalid UTF-8 in message"))?;
287 Ok(Self {
288 error_code,
289 message,
290 })
291 }
292}
293
294impl Write for ErrorCode {
295 fn write(&self, buf: &mut impl BufMut) {
296 let discriminant = match self {
297 ErrorCode::InvalidRequest => 0u8,
298 ErrorCode::DatabaseError => 1u8,
299 ErrorCode::NetworkError => 2u8,
300 ErrorCode::Timeout => 3u8,
301 ErrorCode::InternalError => 4u8,
302 };
303 discriminant.write(buf);
304 }
305}
306
307impl EncodeSize for ErrorCode {
308 fn encode_size(&self) -> usize {
309 size_of::<u8>()
310 }
311}
312
313impl Read for ErrorCode {
314 type Cfg = ();
315
316 fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, CodecError> {
317 let discriminant = u8::read(buf)?;
318 match discriminant {
319 0 => Ok(ErrorCode::InvalidRequest),
320 1 => Ok(ErrorCode::DatabaseError),
321 2 => Ok(ErrorCode::NetworkError),
322 3 => Ok(ErrorCode::Timeout),
323 4 => Ok(ErrorCode::InternalError),
324 _ => Err(CodecError::InvalidEnum(discriminant)),
325 }
326 }
327}
328
329impl From<ProtocolError> for ErrorResponse {
330 fn from(error: ProtocolError) -> Self {
331 let (error_code, message) = match error {
332 ProtocolError::InvalidRequest { message } => (ErrorCode::InvalidRequest, message),
333 ProtocolError::DatabaseError(e) => (ErrorCode::DatabaseError, e.to_string()),
334 ProtocolError::NetworkError(e) => (ErrorCode::NetworkError, e),
335 };
336
337 ErrorResponse {
338 error_code,
339 message,
340 }
341 }
342}
343
344impl GetOperationsRequest {
345 pub fn validate(&self) -> Result<(), ProtocolError> {
347 if self.start_loc >= self.size {
348 return Err(ProtocolError::InvalidRequest {
349 message: format!("start_loc >= size ({}) >= ({})", self.start_loc, self.size),
350 });
351 }
352
353 if self.max_ops.get() == 0 {
354 return Err(ProtocolError::InvalidRequest {
355 message: "max_ops cannot be zero".to_string(),
356 });
357 }
358
359 Ok(())
360 }
361}
362
363#[cfg(test)]
364mod tests {
365 use super::*;
366 use commonware_utils::NZU64;
367
368 #[test]
369 fn test_get_operations_request_validation() {
370 let request = GetOperationsRequest {
372 size: 100,
373 start_loc: 10,
374 max_ops: NZU64!(50),
375 };
376 assert!(request.validate().is_ok());
377
378 let request = GetOperationsRequest {
380 size: 100,
381 start_loc: 100,
382 max_ops: NZU64!(50),
383 };
384 assert!(matches!(
385 request.validate(),
386 Err(ProtocolError::InvalidRequest { .. })
387 ));
388
389 let request = GetOperationsRequest {
391 size: 100,
392 start_loc: 150,
393 max_ops: NZU64!(50),
394 };
395 assert!(matches!(
396 request.validate(),
397 Err(ProtocolError::InvalidRequest { .. })
398 ));
399 }
400}