commonware_sync/
protocol.rs

1//! Network protocol definitions for syncing a [commonware_storage::adb::any::Any] database.
2//!
3//! This module defines the network protocol used for syncing a [commonware_storage::adb::any::Any]
4//! database to a server's database state. It includes message types, error handling, and validation
5//! logic for safe network communication.
6//!
7//! The protocol supports:
8//! - Getting server metadata (database size, root digest, operation bounds)
9//! - Fetching operations with cryptographic proofs
10//! - Getting target updates for dynamic sync
11//! - Error handling
12
13use crate::Operation;
14use bytes::{Buf, BufMut};
15use commonware_codec::{
16    EncodeSize, Error as CodecError, RangeCfg, Read, ReadExt, ReadRangeExt as _, Write,
17};
18use commonware_cryptography::sha256::Digest;
19use commonware_storage::{adb::any::sync::SyncTarget, mmr::verification::Proof};
20use std::{
21    mem::size_of,
22    num::NonZeroU64,
23    sync::atomic::{AtomicU64, Ordering},
24};
25
26/// Maximum message size in bytes (10MB).
27pub const MAX_MESSAGE_SIZE: usize = 10 * 1024 * 1024;
28
29/// Maximum number of digests in a proof.
30const MAX_DIGESTS: usize = 10_000;
31
32/// Unique identifier for correlating requests with responses.
33#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
34pub struct RequestId(u64);
35
36impl Default for RequestId {
37    fn default() -> Self {
38        Self::new()
39    }
40}
41
42impl RequestId {
43    pub fn new() -> Self {
44        static COUNTER: AtomicU64 = AtomicU64::new(1);
45        RequestId(COUNTER.fetch_add(1, Ordering::Relaxed))
46    }
47
48    pub fn value(&self) -> u64 {
49        self.0
50    }
51}
52
53impl Write for RequestId {
54    fn write(&self, buf: &mut impl BufMut) {
55        self.0.write(buf);
56    }
57}
58
59impl EncodeSize for RequestId {
60    fn encode_size(&self) -> usize {
61        self.0.encode_size()
62    }
63}
64
65impl Read for RequestId {
66    type Cfg = ();
67
68    fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, CodecError> {
69        Ok(RequestId(u64::read(buf)?))
70    }
71}
72
73/// Network protocol messages for syncing a [commonware_storage::adb::any::Any] database.
74#[derive(Debug, Clone)]
75pub enum Message {
76    /// Request operations from the server.
77    GetOperationsRequest(GetOperationsRequest),
78    /// Response with operations and proof.
79    GetOperationsResponse(GetOperationsResponse),
80    /// Request sync target from server.
81    GetSyncTargetRequest(GetSyncTargetRequest),
82    /// Response with sync target.
83    GetSyncTargetResponse(GetSyncTargetResponse),
84    /// Error response.
85    /// Note that, in this example, the server sends an error response to the client in the event
86    /// of an invalid request or internal error. In a real-world application, this may be inadvisable.
87    /// A server may want to simply ignore the client's faulty request and close the connection
88    /// to the client. Similarly, a client may not care about the reason for the server's error.
89    Error(ErrorResponse),
90}
91
92impl Message {
93    pub fn request_id(&self) -> RequestId {
94        match self {
95            Message::GetOperationsRequest(req) => req.request_id,
96            Message::GetOperationsResponse(resp) => resp.request_id,
97            Message::GetSyncTargetRequest(req) => req.request_id,
98            Message::GetSyncTargetResponse(resp) => resp.request_id,
99            Message::Error(err) => err.request_id,
100        }
101    }
102}
103
104/// Request for operations from the server.
105#[derive(Debug, Clone)]
106pub struct GetOperationsRequest {
107    /// Unique identifier for this request.
108    pub request_id: RequestId,
109    /// Size of the database at the root we are syncing to.
110    pub size: u64,
111    /// Starting location for the operations.
112    pub start_loc: u64,
113    /// Maximum number of operations to return.
114    pub max_ops: NonZeroU64,
115}
116
117/// Response with operations and proof.
118#[derive(Debug, Clone)]
119pub struct GetOperationsResponse {
120    /// Unique identifier matching the original request.
121    pub request_id: RequestId,
122    /// Serialized proof that the operations were in the database.
123    pub proof: Proof<Digest>,
124    /// Serialized operations in the requested range.
125    pub operations: Vec<Operation>,
126}
127
128/// Request for sync target from server.
129#[derive(Debug, Clone)]
130pub struct GetSyncTargetRequest {
131    /// Unique identifier for this request.
132    pub request_id: RequestId,
133}
134
135/// Response with sync target.
136#[derive(Debug, Clone)]
137pub struct GetSyncTargetResponse {
138    /// Unique identifier matching the original request.
139    pub request_id: RequestId,
140    /// Sync target information.
141    pub target: SyncTarget<Digest>,
142}
143
144/// Error response.
145#[derive(Debug, Clone)]
146pub struct ErrorResponse {
147    /// Unique identifier matching the original request.
148    pub request_id: RequestId,
149    /// Error code.
150    pub error_code: ErrorCode,
151    /// Human-readable error message.
152    pub message: String,
153}
154
155/// Error codes for protocol errors.
156#[derive(Debug, Clone)]
157pub enum ErrorCode {
158    /// Invalid request parameters.
159    InvalidRequest,
160    /// Database error occurred.
161    DatabaseError,
162    /// Network error occurred.
163    NetworkError,
164    /// Request timeout.
165    Timeout,
166    /// Internal server error.
167    InternalError,
168}
169
170impl Write for Message {
171    fn write(&self, buf: &mut impl BufMut) {
172        match self {
173            Message::GetOperationsRequest(req) => {
174                0u8.write(buf);
175                req.write(buf);
176            }
177            Message::GetOperationsResponse(resp) => {
178                1u8.write(buf);
179                resp.write(buf);
180            }
181            Message::GetSyncTargetRequest(req) => {
182                2u8.write(buf);
183                req.write(buf);
184            }
185            Message::GetSyncTargetResponse(resp) => {
186                3u8.write(buf);
187                resp.write(buf);
188            }
189            Message::Error(err) => {
190                4u8.write(buf);
191                err.write(buf);
192            }
193        }
194    }
195}
196
197impl EncodeSize for Message {
198    fn encode_size(&self) -> usize {
199        // 1 byte for the discriminant
200        1 + match self {
201            Message::GetOperationsRequest(req) => req.encode_size(),
202            Message::GetOperationsResponse(resp) => resp.encode_size(),
203            Message::GetSyncTargetRequest(req) => req.encode_size(),
204            Message::GetSyncTargetResponse(resp) => resp.encode_size(),
205            Message::Error(err) => err.encode_size(),
206        }
207    }
208}
209
210impl Read for Message {
211    type Cfg = ();
212
213    fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, CodecError> {
214        let discriminant = u8::read(buf)?;
215        match discriminant {
216            0 => Ok(Message::GetOperationsRequest(GetOperationsRequest::read(
217                buf,
218            )?)),
219            1 => Ok(Message::GetOperationsResponse(GetOperationsResponse::read(
220                buf,
221            )?)),
222            2 => Ok(Message::GetSyncTargetRequest(GetSyncTargetRequest::read(
223                buf,
224            )?)),
225            3 => Ok(Message::GetSyncTargetResponse(GetSyncTargetResponse::read(
226                buf,
227            )?)),
228            4 => Ok(Message::Error(ErrorResponse::read(buf)?)),
229            _ => Err(CodecError::InvalidEnum(discriminant)),
230        }
231    }
232}
233
234impl Write for GetOperationsRequest {
235    fn write(&self, buf: &mut impl BufMut) {
236        self.request_id.write(buf);
237        self.size.write(buf);
238        self.start_loc.write(buf);
239        self.max_ops.get().write(buf);
240    }
241}
242
243impl EncodeSize for GetOperationsRequest {
244    fn encode_size(&self) -> usize {
245        self.request_id.encode_size()
246            + self.size.encode_size()
247            + self.start_loc.encode_size()
248            + self.max_ops.get().encode_size()
249    }
250}
251
252impl Read for GetOperationsRequest {
253    type Cfg = ();
254
255    fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, CodecError> {
256        let request_id = RequestId::read_cfg(buf, &())?;
257        let size = u64::read(buf)?;
258        let start_loc = u64::read(buf)?;
259        let max_ops_raw = u64::read(buf)?;
260        let max_ops = NonZeroU64::new(max_ops_raw)
261            .ok_or_else(|| CodecError::Invalid("GetOperationsRequest", "max_ops cannot be zero"))?;
262        Ok(Self {
263            request_id,
264            size,
265            start_loc,
266            max_ops,
267        })
268    }
269}
270
271impl Write for GetOperationsResponse {
272    fn write(&self, buf: &mut impl BufMut) {
273        self.request_id.write(buf);
274        self.proof.write(buf);
275        self.operations.write(buf);
276    }
277}
278
279impl EncodeSize for GetOperationsResponse {
280    fn encode_size(&self) -> usize {
281        self.request_id.encode_size() + self.proof.encode_size() + self.operations.encode_size()
282    }
283}
284
285impl Read for GetOperationsResponse {
286    type Cfg = ();
287
288    fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, CodecError> {
289        let request_id = RequestId::read_cfg(buf, &())?;
290        let proof = Proof::read_cfg(buf, &MAX_DIGESTS)?;
291        let operations = {
292            let range_cfg = RangeCfg::from(0..=MAX_DIGESTS);
293            Vec::<Operation>::read_cfg(buf, &(range_cfg, ()))?
294        };
295        Ok(Self {
296            request_id,
297            proof,
298            operations,
299        })
300    }
301}
302
303impl Write for GetSyncTargetRequest {
304    fn write(&self, buf: &mut impl BufMut) {
305        self.request_id.write(buf);
306    }
307}
308
309impl EncodeSize for GetSyncTargetRequest {
310    fn encode_size(&self) -> usize {
311        self.request_id.encode_size()
312    }
313}
314
315impl Read for GetSyncTargetRequest {
316    type Cfg = ();
317
318    fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, CodecError> {
319        let request_id = RequestId::read_cfg(buf, &())?;
320        Ok(Self { request_id })
321    }
322}
323
324impl Write for GetSyncTargetResponse {
325    fn write(&self, buf: &mut impl BufMut) {
326        self.request_id.write(buf);
327        self.target.write(buf);
328    }
329}
330
331impl EncodeSize for GetSyncTargetResponse {
332    fn encode_size(&self) -> usize {
333        self.request_id.encode_size() + self.target.encode_size()
334    }
335}
336
337impl Read for GetSyncTargetResponse {
338    type Cfg = ();
339
340    fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, CodecError> {
341        let request_id = RequestId::read_cfg(buf, &())?;
342        let target = SyncTarget::read_cfg(buf, &())?;
343        Ok(Self { request_id, target })
344    }
345}
346
347impl Write for ErrorResponse {
348    fn write(&self, buf: &mut impl BufMut) {
349        self.request_id.write(buf);
350        self.error_code.write(buf);
351        self.message.as_bytes().to_vec().write(buf);
352    }
353}
354
355impl EncodeSize for ErrorResponse {
356    fn encode_size(&self) -> usize {
357        self.request_id.encode_size()
358            + self.error_code.encode_size()
359            + self.message.as_bytes().to_vec().encode_size()
360    }
361}
362
363impl Read for ErrorResponse {
364    type Cfg = ();
365
366    fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, CodecError> {
367        let request_id = RequestId::read_cfg(buf, &())?;
368        let error_code = ErrorCode::read(buf)?;
369        // Read string as Vec<u8> and convert to String
370        let message_bytes = Vec::<u8>::read_range(buf, 0..=MAX_MESSAGE_SIZE)?;
371        let message = String::from_utf8(message_bytes)
372            .map_err(|_| CodecError::Invalid("ErrorResponse", "invalid UTF-8 in message"))?;
373        Ok(Self {
374            request_id,
375            error_code,
376            message,
377        })
378    }
379}
380
381impl Write for ErrorCode {
382    fn write(&self, buf: &mut impl BufMut) {
383        let discriminant = match self {
384            ErrorCode::InvalidRequest => 0u8,
385            ErrorCode::DatabaseError => 1u8,
386            ErrorCode::NetworkError => 2u8,
387            ErrorCode::Timeout => 3u8,
388            ErrorCode::InternalError => 4u8,
389        };
390        discriminant.write(buf);
391    }
392}
393
394impl EncodeSize for ErrorCode {
395    fn encode_size(&self) -> usize {
396        size_of::<u8>()
397    }
398}
399
400impl Read for ErrorCode {
401    type Cfg = ();
402
403    fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, CodecError> {
404        let discriminant = u8::read(buf)?;
405        match discriminant {
406            0 => Ok(ErrorCode::InvalidRequest),
407            1 => Ok(ErrorCode::DatabaseError),
408            2 => Ok(ErrorCode::NetworkError),
409            3 => Ok(ErrorCode::Timeout),
410            4 => Ok(ErrorCode::InternalError),
411            _ => Err(CodecError::InvalidEnum(discriminant)),
412        }
413    }
414}
415
416impl GetOperationsRequest {
417    /// Validate the request parameters.
418    pub fn validate(&self) -> Result<(), crate::Error> {
419        if self.start_loc >= self.size {
420            return Err(crate::Error::InvalidRequest(format!(
421                "start_loc >= size ({}) >= ({})",
422                self.start_loc, self.size
423            )));
424        }
425
426        if self.max_ops.get() == 0 {
427            return Err(crate::Error::InvalidRequest(
428                "max_ops cannot be zero".to_string(),
429            ));
430        }
431
432        Ok(())
433    }
434}
435
436#[cfg(test)]
437mod tests {
438    use super::*;
439    use commonware_utils::NZU64;
440
441    #[test]
442    fn test_request_id_generation() {
443        let id1 = RequestId::new();
444        let id2 = RequestId::new();
445        let id3 = RequestId::new();
446
447        // Request IDs should be incrementing
448        assert!(id2.value() > id1.value());
449        assert!(id3.value() > id2.value());
450
451        // Should be consecutive
452        assert_eq!(id2.value(), id1.value() + 1);
453        assert_eq!(id3.value(), id2.value() + 1);
454    }
455
456    #[test]
457    fn test_error_code_roundtrip_serialization() {
458        use commonware_codec::{DecodeExt, Encode};
459
460        let test_cases = vec![
461            ErrorCode::InvalidRequest,
462            ErrorCode::DatabaseError,
463            ErrorCode::NetworkError,
464            ErrorCode::Timeout,
465            ErrorCode::InternalError,
466        ];
467
468        for error_code in test_cases {
469            // Serialize
470            let encoded = error_code.encode().to_vec();
471
472            // Deserialize
473            let decoded = ErrorCode::decode(&encoded[..]).expect("Failed to decode ErrorCode");
474
475            // Verify they match
476            match (&error_code, &decoded) {
477                (ErrorCode::InvalidRequest, ErrorCode::InvalidRequest) => {}
478                (ErrorCode::DatabaseError, ErrorCode::DatabaseError) => {}
479                (ErrorCode::NetworkError, ErrorCode::NetworkError) => {}
480                (ErrorCode::Timeout, ErrorCode::Timeout) => {}
481                (ErrorCode::InternalError, ErrorCode::InternalError) => {}
482                _ => panic!("ErrorCode roundtrip failed: {error_code:?} != {decoded:?}"),
483            }
484        }
485    }
486
487    #[test]
488    fn test_get_operations_request_validation() {
489        // Valid request
490        let request = GetOperationsRequest {
491            request_id: RequestId::new(),
492            size: 100,
493            start_loc: 10,
494            max_ops: NZU64!(50),
495        };
496        assert!(request.validate().is_ok());
497
498        // Invalid start_loc
499        let request = GetOperationsRequest {
500            request_id: RequestId::new(),
501            size: 100,
502            start_loc: 100,
503            max_ops: NZU64!(50),
504        };
505        assert!(matches!(
506            request.validate(),
507            Err(crate::Error::InvalidRequest(_))
508        ));
509
510        // start_loc beyond size
511        let request = GetOperationsRequest {
512            request_id: RequestId::new(),
513            size: 100,
514            start_loc: 150,
515            max_ops: NZU64!(50),
516        };
517        assert!(matches!(
518            request.validate(),
519            Err(crate::Error::InvalidRequest(_))
520        ));
521    }
522}