commonware_sync/
protocol.rs

1//! Network protocol definitions for syncing a [commonware_storage::adb::any::Any] database.
2//!
3//! This module defines the network protocol used for syncing a [commonware_storage::adb::any::Any]
4//! database to a server's database state. It includes message types, error handling, and validation
5//! logic for safe network communication.
6//!
7//! The protocol supports:
8//! - Getting server metadata (database size, target hash, operation bounds)
9//! - Fetching operations with cryptographic proofs
10//! - Error handling
11
12use crate::Operation;
13use bytes::{Buf, BufMut};
14use commonware_codec::{
15    EncodeSize, Error as CodecError, RangeCfg, Read, ReadExt, ReadRangeExt as _, Write,
16};
17use commonware_cryptography::sha256::Digest;
18use commonware_storage::mmr::verification::Proof;
19use std::num::NonZeroU64;
20use thiserror::Error;
21
22/// Maximum message size in bytes (10MB).
23pub const MAX_MESSAGE_SIZE: usize = 10 * 1024 * 1024;
24
25/// Maximum number of digests in a proof.
26const MAX_DIGESTS: usize = 10_000;
27
28/// Network protocol messages for syncing a [commonware_storage::adb::any::Any] database.
29#[derive(Debug, Clone)]
30pub enum Message {
31    /// Request operations from the server.
32    GetOperationsRequest(GetOperationsRequest),
33    /// Response with operations and proof.
34    GetOperationsResponse(GetOperationsResponse),
35    /// Request server metadata (target hash, bounds, etc.).
36    GetServerMetadataRequest,
37    /// Response with server metadata.
38    GetServerMetadataResponse(GetServerMetadataResponse),
39    /// Error response.
40    /// Note that, in this example, the server sends an error response to the client in the event
41    /// of an invalid request or internal error. In a real-world application, this may be inadvisable.
42    /// A server may want to simply ignore the client's faulty request and close the connection
43    /// to the client. Similarly, a client may not care about the reason for the server's error.
44    Error(ErrorResponse),
45}
46
47/// Request for operations from the server.
48#[derive(Debug, Clone)]
49pub struct GetOperationsRequest {
50    /// Size of the database at the root we are syncing to.
51    pub size: u64,
52    /// Starting location for the operations.
53    pub start_loc: u64,
54    /// Maximum number of operations to return.
55    pub max_ops: NonZeroU64,
56}
57
58/// Response with operations and proof.
59#[derive(Debug, Clone)]
60pub struct GetOperationsResponse {
61    /// Serialized proof that the operations were in the database.
62    pub proof: Proof<Digest>,
63    /// Serialized operations in the requested range.
64    pub operations: Vec<Operation>,
65}
66
67/// Response with server metadata.
68#[derive(Debug, Clone)]
69pub struct GetServerMetadataResponse {
70    /// Target hash of the database.
71    pub target_hash: Digest,
72    /// Oldest retained operation location.
73    pub oldest_retained_loc: u64,
74    /// Latest operation location.
75    pub latest_op_loc: u64,
76}
77
78/// Error response.
79#[derive(Debug, Clone)]
80pub struct ErrorResponse {
81    /// Error code.
82    pub error_code: ErrorCode,
83    /// Human-readable error message.
84    pub message: String,
85}
86
87/// Error codes for protocol errors.
88#[derive(Debug, Clone)]
89pub enum ErrorCode {
90    /// Invalid request parameters.
91    InvalidRequest,
92    /// Database error occurred.
93    DatabaseError,
94    /// Network error occurred.
95    NetworkError,
96    /// Request timeout.
97    Timeout,
98    /// Internal server error.
99    InternalError,
100}
101
102/// Errors that can occur during protocol operations.
103#[derive(Debug, Error)]
104pub enum ProtocolError {
105    #[error("Invalid request: {message}")]
106    InvalidRequest { message: String },
107
108    #[error("Database error: {0}")]
109    DatabaseError(#[from] commonware_storage::adb::Error),
110
111    #[error("Network error: {0}")]
112    NetworkError(String),
113}
114
115impl Write for Message {
116    fn write(&self, buf: &mut impl BufMut) {
117        match self {
118            Message::GetOperationsRequest(req) => {
119                0u8.write(buf);
120                req.write(buf);
121            }
122            Message::GetOperationsResponse(resp) => {
123                1u8.write(buf);
124                resp.write(buf);
125            }
126            Message::GetServerMetadataRequest => {
127                2u8.write(buf);
128            }
129            Message::GetServerMetadataResponse(resp) => {
130                3u8.write(buf);
131                resp.write(buf);
132            }
133            Message::Error(err) => {
134                4u8.write(buf);
135                err.write(buf);
136            }
137        }
138    }
139}
140
141impl EncodeSize for Message {
142    fn encode_size(&self) -> usize {
143        // 1 byte for the discriminant
144        1 + match self {
145            Message::GetOperationsRequest(req) => req.encode_size(),
146            Message::GetOperationsResponse(resp) => resp.encode_size(),
147            Message::GetServerMetadataRequest => 0,
148            Message::GetServerMetadataResponse(resp) => resp.encode_size(),
149            Message::Error(err) => err.encode_size(),
150        }
151    }
152}
153
154impl Read for Message {
155    type Cfg = ();
156
157    fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, CodecError> {
158        let discriminant = u8::read(buf)?;
159        match discriminant {
160            0 => Ok(Message::GetOperationsRequest(GetOperationsRequest::read(
161                buf,
162            )?)),
163            1 => Ok(Message::GetOperationsResponse(GetOperationsResponse::read(
164                buf,
165            )?)),
166            2 => Ok(Message::GetServerMetadataRequest),
167            3 => Ok(Message::GetServerMetadataResponse(
168                GetServerMetadataResponse::read(buf)?,
169            )),
170            4 => Ok(Message::Error(ErrorResponse::read(buf)?)),
171            _ => Err(CodecError::InvalidEnum(discriminant)),
172        }
173    }
174}
175
176impl Write for GetOperationsRequest {
177    fn write(&self, buf: &mut impl BufMut) {
178        self.size.write(buf);
179        self.start_loc.write(buf);
180        self.max_ops.get().write(buf);
181    }
182}
183
184impl EncodeSize for GetOperationsRequest {
185    fn encode_size(&self) -> usize {
186        self.size.encode_size() + self.start_loc.encode_size() + self.max_ops.get().encode_size()
187    }
188}
189
190impl Read for GetOperationsRequest {
191    type Cfg = ();
192
193    fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, CodecError> {
194        let size = u64::read(buf)?;
195        let start_loc = u64::read(buf)?;
196        let max_ops_raw = u64::read(buf)?;
197        let max_ops = NonZeroU64::new(max_ops_raw)
198            .ok_or_else(|| CodecError::Invalid("GetOperationsRequest", "max_ops cannot be zero"))?;
199        Ok(Self {
200            size,
201            start_loc,
202            max_ops,
203        })
204    }
205}
206
207impl Write for GetOperationsResponse {
208    fn write(&self, buf: &mut impl BufMut) {
209        self.proof.write(buf);
210        self.operations.write(buf);
211    }
212}
213
214impl EncodeSize for GetOperationsResponse {
215    fn encode_size(&self) -> usize {
216        self.proof.encode_size() + self.operations.encode_size()
217    }
218}
219
220impl Read for GetOperationsResponse {
221    type Cfg = ();
222
223    fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, CodecError> {
224        let proof = Proof::read_cfg(buf, &MAX_DIGESTS)?;
225        let operations = {
226            let range_cfg = RangeCfg::from(0..=MAX_DIGESTS);
227            Vec::<Operation>::read_cfg(buf, &(range_cfg, ()))?
228        };
229
230        Ok(Self { proof, operations })
231    }
232}
233
234impl Write for GetServerMetadataResponse {
235    fn write(&self, buf: &mut impl BufMut) {
236        self.target_hash.write(buf);
237        self.oldest_retained_loc.write(buf);
238        self.latest_op_loc.write(buf);
239    }
240}
241
242impl EncodeSize for GetServerMetadataResponse {
243    fn encode_size(&self) -> usize {
244        self.target_hash.encode_size()
245            + self.oldest_retained_loc.encode_size()
246            + self.latest_op_loc.encode_size()
247    }
248}
249
250impl Read for GetServerMetadataResponse {
251    type Cfg = ();
252
253    fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, CodecError> {
254        let target_hash = Digest::read(buf)?;
255        let oldest_retained_loc = u64::read(buf)?;
256        let latest_op_loc = u64::read(buf)?;
257        Ok(Self {
258            target_hash,
259            oldest_retained_loc,
260            latest_op_loc,
261        })
262    }
263}
264
265impl Write for ErrorResponse {
266    fn write(&self, buf: &mut impl BufMut) {
267        self.error_code.write(buf);
268        self.message.as_bytes().to_vec().write(buf);
269    }
270}
271
272impl EncodeSize for ErrorResponse {
273    fn encode_size(&self) -> usize {
274        self.error_code.encode_size() + self.message.as_bytes().to_vec().encode_size()
275    }
276}
277
278impl Read for ErrorResponse {
279    type Cfg = ();
280
281    fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, CodecError> {
282        let error_code = ErrorCode::read(buf)?;
283        // Read string as Vec<u8> and convert to String
284        let message_bytes = Vec::<u8>::read_range(buf, 0..=MAX_MESSAGE_SIZE)?;
285        let message = String::from_utf8(message_bytes)
286            .map_err(|_| CodecError::Invalid("ErrorResponse", "invalid UTF-8 in message"))?;
287        Ok(Self {
288            error_code,
289            message,
290        })
291    }
292}
293
294impl Write for ErrorCode {
295    fn write(&self, buf: &mut impl BufMut) {
296        let discriminant = match self {
297            ErrorCode::InvalidRequest => 0u8,
298            ErrorCode::DatabaseError => 1u8,
299            ErrorCode::NetworkError => 2u8,
300            ErrorCode::Timeout => 3u8,
301            ErrorCode::InternalError => 4u8,
302        };
303        discriminant.write(buf);
304    }
305}
306
307impl EncodeSize for ErrorCode {
308    fn encode_size(&self) -> usize {
309        size_of::<u8>()
310    }
311}
312
313impl Read for ErrorCode {
314    type Cfg = ();
315
316    fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, CodecError> {
317        let discriminant = u8::read(buf)?;
318        match discriminant {
319            0 => Ok(ErrorCode::InvalidRequest),
320            1 => Ok(ErrorCode::DatabaseError),
321            2 => Ok(ErrorCode::NetworkError),
322            3 => Ok(ErrorCode::Timeout),
323            4 => Ok(ErrorCode::InternalError),
324            _ => Err(CodecError::InvalidEnum(discriminant)),
325        }
326    }
327}
328
329impl From<ProtocolError> for ErrorResponse {
330    fn from(error: ProtocolError) -> Self {
331        let (error_code, message) = match error {
332            ProtocolError::InvalidRequest { message } => (ErrorCode::InvalidRequest, message),
333            ProtocolError::DatabaseError(e) => (ErrorCode::DatabaseError, e.to_string()),
334            ProtocolError::NetworkError(e) => (ErrorCode::NetworkError, e),
335        };
336
337        ErrorResponse {
338            error_code,
339            message,
340        }
341    }
342}
343
344impl GetOperationsRequest {
345    /// Validate the request parameters.
346    pub fn validate(&self) -> Result<(), ProtocolError> {
347        if self.start_loc >= self.size {
348            return Err(ProtocolError::InvalidRequest {
349                message: format!("start_loc >= size ({}) >= ({})", self.start_loc, self.size),
350            });
351        }
352
353        if self.max_ops.get() == 0 {
354            return Err(ProtocolError::InvalidRequest {
355                message: "max_ops cannot be zero".to_string(),
356            });
357        }
358
359        Ok(())
360    }
361}
362
363#[cfg(test)]
364mod tests {
365    use super::*;
366    use commonware_utils::NZU64;
367
368    #[test]
369    fn test_get_operations_request_validation() {
370        // Valid request
371        let request = GetOperationsRequest {
372            size: 100,
373            start_loc: 10,
374            max_ops: NZU64!(50),
375        };
376        assert!(request.validate().is_ok());
377
378        // Invalid start_loc
379        let request = GetOperationsRequest {
380            size: 100,
381            start_loc: 100,
382            max_ops: NZU64!(50),
383        };
384        assert!(matches!(
385            request.validate(),
386            Err(ProtocolError::InvalidRequest { .. })
387        ));
388
389        // start_loc beyond size
390        let request = GetOperationsRequest {
391            size: 100,
392            start_loc: 150,
393            max_ops: NZU64!(50),
394        };
395        assert!(matches!(
396            request.validate(),
397            Err(ProtocolError::InvalidRequest { .. })
398        ));
399    }
400}