1use crate::Operation;
14use bytes::{Buf, BufMut};
15use commonware_codec::{
16 EncodeSize, Error as CodecError, RangeCfg, Read, ReadExt, ReadRangeExt as _, Write,
17};
18use commonware_cryptography::sha256::Digest;
19use commonware_storage::{adb::any::sync::SyncTarget, mmr::verification::Proof};
20use std::{
21 mem::size_of,
22 num::NonZeroU64,
23 sync::atomic::{AtomicU64, Ordering},
24};
25
26pub const MAX_MESSAGE_SIZE: usize = 10 * 1024 * 1024;
28
29const MAX_DIGESTS: usize = 10_000;
31
32#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
34pub struct RequestId(u64);
35
36impl Default for RequestId {
37 fn default() -> Self {
38 Self::new()
39 }
40}
41
42impl RequestId {
43 pub fn new() -> Self {
44 static COUNTER: AtomicU64 = AtomicU64::new(1);
45 RequestId(COUNTER.fetch_add(1, Ordering::Relaxed))
46 }
47
48 pub fn value(&self) -> u64 {
49 self.0
50 }
51}
52
53impl Write for RequestId {
54 fn write(&self, buf: &mut impl BufMut) {
55 self.0.write(buf);
56 }
57}
58
59impl EncodeSize for RequestId {
60 fn encode_size(&self) -> usize {
61 self.0.encode_size()
62 }
63}
64
65impl Read for RequestId {
66 type Cfg = ();
67
68 fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, CodecError> {
69 Ok(RequestId(u64::read(buf)?))
70 }
71}
72
73#[derive(Debug, Clone)]
75pub enum Message {
76 GetOperationsRequest(GetOperationsRequest),
78 GetOperationsResponse(GetOperationsResponse),
80 GetSyncTargetRequest(GetSyncTargetRequest),
82 GetSyncTargetResponse(GetSyncTargetResponse),
84 Error(ErrorResponse),
90}
91
92impl Message {
93 pub fn request_id(&self) -> RequestId {
94 match self {
95 Message::GetOperationsRequest(req) => req.request_id,
96 Message::GetOperationsResponse(resp) => resp.request_id,
97 Message::GetSyncTargetRequest(req) => req.request_id,
98 Message::GetSyncTargetResponse(resp) => resp.request_id,
99 Message::Error(err) => err.request_id,
100 }
101 }
102}
103
104#[derive(Debug, Clone)]
106pub struct GetOperationsRequest {
107 pub request_id: RequestId,
109 pub size: u64,
111 pub start_loc: u64,
113 pub max_ops: NonZeroU64,
115}
116
117#[derive(Debug, Clone)]
119pub struct GetOperationsResponse {
120 pub request_id: RequestId,
122 pub proof: Proof<Digest>,
124 pub operations: Vec<Operation>,
126}
127
128#[derive(Debug, Clone)]
130pub struct GetSyncTargetRequest {
131 pub request_id: RequestId,
133}
134
135#[derive(Debug, Clone)]
137pub struct GetSyncTargetResponse {
138 pub request_id: RequestId,
140 pub target: SyncTarget<Digest>,
142}
143
144#[derive(Debug, Clone)]
146pub struct ErrorResponse {
147 pub request_id: RequestId,
149 pub error_code: ErrorCode,
151 pub message: String,
153}
154
155#[derive(Debug, Clone)]
157pub enum ErrorCode {
158 InvalidRequest,
160 DatabaseError,
162 NetworkError,
164 Timeout,
166 InternalError,
168}
169
170impl Write for Message {
171 fn write(&self, buf: &mut impl BufMut) {
172 match self {
173 Message::GetOperationsRequest(req) => {
174 0u8.write(buf);
175 req.write(buf);
176 }
177 Message::GetOperationsResponse(resp) => {
178 1u8.write(buf);
179 resp.write(buf);
180 }
181 Message::GetSyncTargetRequest(req) => {
182 2u8.write(buf);
183 req.write(buf);
184 }
185 Message::GetSyncTargetResponse(resp) => {
186 3u8.write(buf);
187 resp.write(buf);
188 }
189 Message::Error(err) => {
190 4u8.write(buf);
191 err.write(buf);
192 }
193 }
194 }
195}
196
197impl EncodeSize for Message {
198 fn encode_size(&self) -> usize {
199 1 + match self {
201 Message::GetOperationsRequest(req) => req.encode_size(),
202 Message::GetOperationsResponse(resp) => resp.encode_size(),
203 Message::GetSyncTargetRequest(req) => req.encode_size(),
204 Message::GetSyncTargetResponse(resp) => resp.encode_size(),
205 Message::Error(err) => err.encode_size(),
206 }
207 }
208}
209
210impl Read for Message {
211 type Cfg = ();
212
213 fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, CodecError> {
214 let discriminant = u8::read(buf)?;
215 match discriminant {
216 0 => Ok(Message::GetOperationsRequest(GetOperationsRequest::read(
217 buf,
218 )?)),
219 1 => Ok(Message::GetOperationsResponse(GetOperationsResponse::read(
220 buf,
221 )?)),
222 2 => Ok(Message::GetSyncTargetRequest(GetSyncTargetRequest::read(
223 buf,
224 )?)),
225 3 => Ok(Message::GetSyncTargetResponse(GetSyncTargetResponse::read(
226 buf,
227 )?)),
228 4 => Ok(Message::Error(ErrorResponse::read(buf)?)),
229 _ => Err(CodecError::InvalidEnum(discriminant)),
230 }
231 }
232}
233
234impl Write for GetOperationsRequest {
235 fn write(&self, buf: &mut impl BufMut) {
236 self.request_id.write(buf);
237 self.size.write(buf);
238 self.start_loc.write(buf);
239 self.max_ops.get().write(buf);
240 }
241}
242
243impl EncodeSize for GetOperationsRequest {
244 fn encode_size(&self) -> usize {
245 self.request_id.encode_size()
246 + self.size.encode_size()
247 + self.start_loc.encode_size()
248 + self.max_ops.get().encode_size()
249 }
250}
251
252impl Read for GetOperationsRequest {
253 type Cfg = ();
254
255 fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, CodecError> {
256 let request_id = RequestId::read_cfg(buf, &())?;
257 let size = u64::read(buf)?;
258 let start_loc = u64::read(buf)?;
259 let max_ops_raw = u64::read(buf)?;
260 let max_ops = NonZeroU64::new(max_ops_raw)
261 .ok_or_else(|| CodecError::Invalid("GetOperationsRequest", "max_ops cannot be zero"))?;
262 Ok(Self {
263 request_id,
264 size,
265 start_loc,
266 max_ops,
267 })
268 }
269}
270
271impl Write for GetOperationsResponse {
272 fn write(&self, buf: &mut impl BufMut) {
273 self.request_id.write(buf);
274 self.proof.write(buf);
275 self.operations.write(buf);
276 }
277}
278
279impl EncodeSize for GetOperationsResponse {
280 fn encode_size(&self) -> usize {
281 self.request_id.encode_size() + self.proof.encode_size() + self.operations.encode_size()
282 }
283}
284
285impl Read for GetOperationsResponse {
286 type Cfg = ();
287
288 fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, CodecError> {
289 let request_id = RequestId::read_cfg(buf, &())?;
290 let proof = Proof::read_cfg(buf, &MAX_DIGESTS)?;
291 let operations = {
292 let range_cfg = RangeCfg::from(0..=MAX_DIGESTS);
293 Vec::<Operation>::read_cfg(buf, &(range_cfg, ()))?
294 };
295 Ok(Self {
296 request_id,
297 proof,
298 operations,
299 })
300 }
301}
302
303impl Write for GetSyncTargetRequest {
304 fn write(&self, buf: &mut impl BufMut) {
305 self.request_id.write(buf);
306 }
307}
308
309impl EncodeSize for GetSyncTargetRequest {
310 fn encode_size(&self) -> usize {
311 self.request_id.encode_size()
312 }
313}
314
315impl Read for GetSyncTargetRequest {
316 type Cfg = ();
317
318 fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, CodecError> {
319 let request_id = RequestId::read_cfg(buf, &())?;
320 Ok(Self { request_id })
321 }
322}
323
324impl Write for GetSyncTargetResponse {
325 fn write(&self, buf: &mut impl BufMut) {
326 self.request_id.write(buf);
327 self.target.write(buf);
328 }
329}
330
331impl EncodeSize for GetSyncTargetResponse {
332 fn encode_size(&self) -> usize {
333 self.request_id.encode_size() + self.target.encode_size()
334 }
335}
336
337impl Read for GetSyncTargetResponse {
338 type Cfg = ();
339
340 fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, CodecError> {
341 let request_id = RequestId::read_cfg(buf, &())?;
342 let target = SyncTarget::read_cfg(buf, &())?;
343 Ok(Self { request_id, target })
344 }
345}
346
347impl Write for ErrorResponse {
348 fn write(&self, buf: &mut impl BufMut) {
349 self.request_id.write(buf);
350 self.error_code.write(buf);
351 self.message.as_bytes().to_vec().write(buf);
352 }
353}
354
355impl EncodeSize for ErrorResponse {
356 fn encode_size(&self) -> usize {
357 self.request_id.encode_size()
358 + self.error_code.encode_size()
359 + self.message.as_bytes().to_vec().encode_size()
360 }
361}
362
363impl Read for ErrorResponse {
364 type Cfg = ();
365
366 fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, CodecError> {
367 let request_id = RequestId::read_cfg(buf, &())?;
368 let error_code = ErrorCode::read(buf)?;
369 let message_bytes = Vec::<u8>::read_range(buf, 0..=MAX_MESSAGE_SIZE)?;
371 let message = String::from_utf8(message_bytes)
372 .map_err(|_| CodecError::Invalid("ErrorResponse", "invalid UTF-8 in message"))?;
373 Ok(Self {
374 request_id,
375 error_code,
376 message,
377 })
378 }
379}
380
381impl Write for ErrorCode {
382 fn write(&self, buf: &mut impl BufMut) {
383 let discriminant = match self {
384 ErrorCode::InvalidRequest => 0u8,
385 ErrorCode::DatabaseError => 1u8,
386 ErrorCode::NetworkError => 2u8,
387 ErrorCode::Timeout => 3u8,
388 ErrorCode::InternalError => 4u8,
389 };
390 discriminant.write(buf);
391 }
392}
393
394impl EncodeSize for ErrorCode {
395 fn encode_size(&self) -> usize {
396 size_of::<u8>()
397 }
398}
399
400impl Read for ErrorCode {
401 type Cfg = ();
402
403 fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, CodecError> {
404 let discriminant = u8::read(buf)?;
405 match discriminant {
406 0 => Ok(ErrorCode::InvalidRequest),
407 1 => Ok(ErrorCode::DatabaseError),
408 2 => Ok(ErrorCode::NetworkError),
409 3 => Ok(ErrorCode::Timeout),
410 4 => Ok(ErrorCode::InternalError),
411 _ => Err(CodecError::InvalidEnum(discriminant)),
412 }
413 }
414}
415
416impl GetOperationsRequest {
417 pub fn validate(&self) -> Result<(), crate::Error> {
419 if self.start_loc >= self.size {
420 return Err(crate::Error::InvalidRequest(format!(
421 "start_loc >= size ({}) >= ({})",
422 self.start_loc, self.size
423 )));
424 }
425
426 if self.max_ops.get() == 0 {
427 return Err(crate::Error::InvalidRequest(
428 "max_ops cannot be zero".to_string(),
429 ));
430 }
431
432 Ok(())
433 }
434}
435
436#[cfg(test)]
437mod tests {
438 use super::*;
439 use commonware_utils::NZU64;
440
441 #[test]
442 fn test_request_id_generation() {
443 let id1 = RequestId::new();
444 let id2 = RequestId::new();
445 let id3 = RequestId::new();
446
447 assert!(id2.value() > id1.value());
449 assert!(id3.value() > id2.value());
450
451 assert_eq!(id2.value(), id1.value() + 1);
453 assert_eq!(id3.value(), id2.value() + 1);
454 }
455
456 #[test]
457 fn test_error_code_roundtrip_serialization() {
458 use commonware_codec::{DecodeExt, Encode};
459
460 let test_cases = vec![
461 ErrorCode::InvalidRequest,
462 ErrorCode::DatabaseError,
463 ErrorCode::NetworkError,
464 ErrorCode::Timeout,
465 ErrorCode::InternalError,
466 ];
467
468 for error_code in test_cases {
469 let encoded = error_code.encode().to_vec();
471
472 let decoded = ErrorCode::decode(&encoded[..]).expect("Failed to decode ErrorCode");
474
475 match (&error_code, &decoded) {
477 (ErrorCode::InvalidRequest, ErrorCode::InvalidRequest) => {}
478 (ErrorCode::DatabaseError, ErrorCode::DatabaseError) => {}
479 (ErrorCode::NetworkError, ErrorCode::NetworkError) => {}
480 (ErrorCode::Timeout, ErrorCode::Timeout) => {}
481 (ErrorCode::InternalError, ErrorCode::InternalError) => {}
482 _ => panic!("ErrorCode roundtrip failed: {error_code:?} != {decoded:?}"),
483 }
484 }
485 }
486
487 #[test]
488 fn test_get_operations_request_validation() {
489 let request = GetOperationsRequest {
491 request_id: RequestId::new(),
492 size: 100,
493 start_loc: 10,
494 max_ops: NZU64!(50),
495 };
496 assert!(request.validate().is_ok());
497
498 let request = GetOperationsRequest {
500 request_id: RequestId::new(),
501 size: 100,
502 start_loc: 100,
503 max_ops: NZU64!(50),
504 };
505 assert!(matches!(
506 request.validate(),
507 Err(crate::Error::InvalidRequest(_))
508 ));
509
510 let request = GetOperationsRequest {
512 request_id: RequestId::new(),
513 size: 100,
514 start_loc: 150,
515 max_ops: NZU64!(50),
516 };
517 assert!(matches!(
518 request.validate(),
519 Err(crate::Error::InvalidRequest(_))
520 ));
521 }
522}