commonware_sync/net/
mod.rs

1use bytes::{Buf, BufMut};
2use commonware_codec::{DecodeExt, Encode, EncodeSize, Error, Read, ReadExt, ReadRangeExt, Write};
3use std::mem::size_of;
4
5/// Maximum message size in bytes (10MB).
6pub const MAX_MESSAGE_SIZE: usize = 10 * 1024 * 1024;
7
8pub mod request_id;
9pub use request_id::RequestId;
10pub mod io;
11pub mod resolver;
12pub mod wire;
13pub use resolver::Resolver;
14
15/// A message that can be sent over the wire.
16pub(super) trait Message: Encode + DecodeExt<()> + Sized + Send + Sync + 'static {
17    fn request_id(&self) -> RequestId;
18}
19
20/// Error codes for protocol errors.
21#[derive(Debug, Clone)]
22pub enum ErrorCode {
23    /// Invalid request parameters.
24    InvalidRequest,
25    /// Database error occurred.
26    DatabaseError,
27    /// Network error occurred.
28    NetworkError,
29    /// Request timeout.
30    Timeout,
31    /// Internal server error.
32    InternalError,
33}
34
35impl Write for ErrorCode {
36    fn write(&self, buf: &mut impl BufMut) {
37        let discriminant = match self {
38            ErrorCode::InvalidRequest => 0u8,
39            ErrorCode::DatabaseError => 1u8,
40            ErrorCode::NetworkError => 2u8,
41            ErrorCode::Timeout => 3u8,
42            ErrorCode::InternalError => 4u8,
43        };
44        discriminant.write(buf);
45    }
46}
47
48impl EncodeSize for ErrorCode {
49    fn encode_size(&self) -> usize {
50        size_of::<u8>()
51    }
52}
53
54impl Read for ErrorCode {
55    type Cfg = ();
56
57    fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, Error> {
58        let discriminant = u8::read(buf)?;
59        match discriminant {
60            0 => Ok(ErrorCode::InvalidRequest),
61            1 => Ok(ErrorCode::DatabaseError),
62            2 => Ok(ErrorCode::NetworkError),
63            3 => Ok(ErrorCode::Timeout),
64            4 => Ok(ErrorCode::InternalError),
65            _ => Err(Error::InvalidEnum(discriminant)),
66        }
67    }
68}
69
70/// Error from the server.
71#[derive(Debug, Clone)]
72pub struct ErrorResponse {
73    /// Unique identifier matching the original request.
74    pub request_id: RequestId,
75    /// Error code.
76    pub error_code: ErrorCode,
77    /// Human-readable error message.
78    pub message: String,
79}
80
81impl Write for ErrorResponse {
82    fn write(&self, buf: &mut impl BufMut) {
83        self.request_id.write(buf);
84        self.error_code.write(buf);
85        self.message.as_bytes().to_vec().write(buf);
86    }
87}
88
89impl EncodeSize for ErrorResponse {
90    fn encode_size(&self) -> usize {
91        self.request_id.encode_size()
92            + self.error_code.encode_size()
93            + self.message.as_bytes().to_vec().encode_size()
94    }
95}
96
97impl Read for ErrorResponse {
98    type Cfg = ();
99
100    fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, Error> {
101        let request_id = RequestId::read_cfg(buf, &())?;
102        let error_code = ErrorCode::read(buf)?;
103        let message_bytes = Vec::<u8>::read_range(buf, 0..=MAX_MESSAGE_SIZE)?;
104        let message = String::from_utf8(message_bytes)
105            .map_err(|_| Error::Invalid("ErrorResponse", "invalid UTF-8 in message"))?;
106        Ok(Self {
107            request_id,
108            error_code,
109            message,
110        })
111    }
112}
113
114#[cfg(test)]
115mod tests {
116    use crate::net::{request_id::Generator, wire::GetOperationsRequest, ErrorCode};
117    use commonware_codec::{DecodeExt as _, Encode as _};
118    use commonware_storage::mmr::Location;
119    use commonware_utils::NZU64;
120
121    #[test]
122    fn test_error_code_roundtrip_serialization() {
123        let test_cases = vec![
124            ErrorCode::InvalidRequest,
125            ErrorCode::DatabaseError,
126            ErrorCode::NetworkError,
127            ErrorCode::Timeout,
128            ErrorCode::InternalError,
129        ];
130
131        for error_code in test_cases {
132            // Serialize
133            let encoded = error_code.encode().to_vec();
134
135            // Deserialize
136            let decoded = ErrorCode::decode(&encoded[..]).expect("Failed to decode ErrorCode");
137
138            // Verify they match
139            match (&error_code, &decoded) {
140                (ErrorCode::InvalidRequest, ErrorCode::InvalidRequest) => {}
141                (ErrorCode::DatabaseError, ErrorCode::DatabaseError) => {}
142                (ErrorCode::NetworkError, ErrorCode::NetworkError) => {}
143                (ErrorCode::Timeout, ErrorCode::Timeout) => {}
144                (ErrorCode::InternalError, ErrorCode::InternalError) => {}
145                _ => panic!("ErrorCode roundtrip failed: {error_code:?} != {decoded:?}"),
146            }
147        }
148    }
149
150    #[test]
151    fn test_get_operations_request_validation() {
152        // Valid request
153        let requester = Generator::new();
154        let request = GetOperationsRequest {
155            request_id: requester.next(),
156            op_count: Location::new(100).unwrap(),
157            start_loc: Location::new(10).unwrap(),
158            max_ops: NZU64!(50),
159        };
160        assert!(request.validate().is_ok());
161
162        // Invalid start_loc
163        let request = GetOperationsRequest {
164            request_id: requester.next(),
165            op_count: Location::new(100).unwrap(),
166            start_loc: Location::new(100).unwrap(),
167            max_ops: NZU64!(50),
168        };
169        assert!(matches!(
170            request.validate(),
171            Err(crate::Error::InvalidRequest(_))
172        ));
173
174        // start_loc beyond size
175        let request = GetOperationsRequest {
176            request_id: requester.next(),
177            op_count: Location::new(100).unwrap(),
178            start_loc: Location::new(150).unwrap(),
179            max_ops: NZU64!(50),
180        };
181        assert!(matches!(
182            request.validate(),
183            Err(crate::Error::InvalidRequest(_))
184        ));
185    }
186}