Skip to main content

commonware_sync/net/
mod.rs

1use commonware_codec::{DecodeExt, Encode, EncodeSize, Error, Read, ReadExt, ReadRangeExt, Write};
2use commonware_runtime::{Buf, BufMut};
3use std::mem::size_of;
4
5/// Maximum message size in bytes (10MB).
6pub const MAX_MESSAGE_SIZE: u32 = 10 * 1024 * 1024;
7
8pub mod request_id;
9pub use request_id::RequestId;
10pub mod io;
11pub mod resolver;
12pub mod wire;
13pub use resolver::Resolver;
14
15/// A message that can be sent over the wire.
16pub(super) trait Message: Encode + DecodeExt<()> + Sized + Send + Sync + 'static {
17    fn request_id(&self) -> RequestId;
18}
19
20/// Error codes for protocol errors.
21#[derive(Debug, Clone)]
22pub enum ErrorCode {
23    /// Invalid request parameters.
24    InvalidRequest,
25    /// Database error occurred.
26    DatabaseError,
27    /// Network error occurred.
28    NetworkError,
29    /// Compact target went stale and should be retried.
30    StaleTarget,
31    /// Request timeout.
32    Timeout,
33    /// Internal server error.
34    InternalError,
35}
36
37impl Write for ErrorCode {
38    fn write(&self, buf: &mut impl BufMut) {
39        let discriminant = match self {
40            Self::InvalidRequest => 0u8,
41            Self::DatabaseError => 1u8,
42            Self::NetworkError => 2u8,
43            Self::StaleTarget => 3u8,
44            Self::Timeout => 4u8,
45            Self::InternalError => 5u8,
46        };
47        discriminant.write(buf);
48    }
49}
50
51impl EncodeSize for ErrorCode {
52    fn encode_size(&self) -> usize {
53        size_of::<u8>()
54    }
55}
56
57impl Read for ErrorCode {
58    type Cfg = ();
59
60    fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, Error> {
61        let discriminant = u8::read(buf)?;
62        match discriminant {
63            0 => Ok(Self::InvalidRequest),
64            1 => Ok(Self::DatabaseError),
65            2 => Ok(Self::NetworkError),
66            3 => Ok(Self::StaleTarget),
67            4 => Ok(Self::Timeout),
68            5 => Ok(Self::InternalError),
69            _ => Err(Error::InvalidEnum(discriminant)),
70        }
71    }
72}
73
74/// Error from the server.
75#[derive(Debug, Clone)]
76pub struct ErrorResponse {
77    /// Unique identifier matching the original request.
78    pub request_id: RequestId,
79    /// Error code.
80    pub error_code: ErrorCode,
81    /// Human-readable error message.
82    pub message: String,
83}
84
85impl Write for ErrorResponse {
86    fn write(&self, buf: &mut impl BufMut) {
87        self.request_id.write(buf);
88        self.error_code.write(buf);
89        self.message.as_bytes().to_vec().write(buf);
90    }
91}
92
93impl EncodeSize for ErrorResponse {
94    fn encode_size(&self) -> usize {
95        self.request_id.encode_size()
96            + self.error_code.encode_size()
97            + self.message.as_bytes().to_vec().encode_size()
98    }
99}
100
101impl Read for ErrorResponse {
102    type Cfg = ();
103
104    fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, Error> {
105        let request_id = RequestId::read_cfg(buf, &())?;
106        let error_code = ErrorCode::read(buf)?;
107        let message_bytes = Vec::<u8>::read_range(buf, 0..=MAX_MESSAGE_SIZE as usize)?;
108        let message = String::from_utf8(message_bytes)
109            .map_err(|_| Error::Invalid("ErrorResponse", "invalid UTF-8 in message"))?;
110        Ok(Self {
111            request_id,
112            error_code,
113            message,
114        })
115    }
116}
117
118#[cfg(test)]
119mod tests {
120    use crate::{
121        keyless_compact,
122        net::{request_id::Generator, wire, wire::GetOperationsRequest, ErrorCode},
123    };
124    use commonware_codec::{DecodeExt as _, Encode as _};
125    use commonware_cryptography::sha256;
126    use commonware_storage::{mmr::Location, qmdb::sync::compact::State};
127    use commonware_utils::NZU64;
128    use rstest::rstest;
129
130    #[rstest]
131    #[case(ErrorCode::InvalidRequest)]
132    #[case(ErrorCode::DatabaseError)]
133    #[case(ErrorCode::NetworkError)]
134    #[case(ErrorCode::StaleTarget)]
135    #[case(ErrorCode::Timeout)]
136    #[case(ErrorCode::InternalError)]
137    fn test_error_code_roundtrip_serialization(#[case] error_code: ErrorCode) {
138        // Serialize
139        let encoded = error_code.encode().to_vec();
140
141        // Deserialize
142        let decoded = ErrorCode::decode(&encoded[..]).expect("Failed to decode ErrorCode");
143
144        // Verify they match
145        match (&error_code, &decoded) {
146            (ErrorCode::InvalidRequest, ErrorCode::InvalidRequest) => {}
147            (ErrorCode::DatabaseError, ErrorCode::DatabaseError) => {}
148            (ErrorCode::NetworkError, ErrorCode::NetworkError) => {}
149            (ErrorCode::StaleTarget, ErrorCode::StaleTarget) => {}
150            (ErrorCode::Timeout, ErrorCode::Timeout) => {}
151            (ErrorCode::InternalError, ErrorCode::InternalError) => {}
152            _ => panic!("ErrorCode roundtrip failed: {error_code:?} != {decoded:?}"),
153        }
154    }
155
156    #[test]
157    fn test_get_operations_request_validation() {
158        // Valid request
159        let requester = Generator::new();
160        let request = GetOperationsRequest {
161            request_id: requester.next(),
162            op_count: Location::new(100),
163            start_loc: Location::new(10),
164            max_ops: NZU64!(50),
165            include_pinned_nodes: false,
166        };
167        assert!(request.validate().is_ok());
168
169        // Invalid start_loc
170        let request = GetOperationsRequest {
171            request_id: requester.next(),
172            op_count: Location::new(100),
173            start_loc: Location::new(100),
174            max_ops: NZU64!(50),
175            include_pinned_nodes: false,
176        };
177        assert!(matches!(
178            request.validate(),
179            Err(crate::Error::InvalidRequest(_))
180        ));
181
182        // start_loc beyond size
183        let request = GetOperationsRequest {
184            request_id: requester.next(),
185            op_count: Location::new(100),
186            start_loc: Location::new(150),
187            max_ops: NZU64!(50),
188            include_pinned_nodes: false,
189        };
190        assert!(matches!(
191            request.validate(),
192            Err(crate::Error::InvalidRequest(_))
193        ));
194    }
195
196    #[test]
197    fn test_get_compact_state_response_roundtrip() {
198        let request_id = Generator::new().next();
199        let digest_a = sha256::Digest::from([7; 32]);
200        let digest_b = sha256::Digest::from([8; 32]);
201        let digest_c = sha256::Digest::from([10; 32]);
202        let message = wire::Message::GetCompactStateResponse(wire::GetCompactStateResponse {
203            request_id,
204            state: State {
205                leaf_count: Location::new(11),
206                pinned_nodes: vec![digest_a, digest_b],
207                last_commit_op: keyless_compact::Operation::Commit(None, Location::new(0)),
208                last_commit_proof: commonware_storage::mmr::Proof {
209                    leaves: Location::new(11),
210                    inactive_peaks: 0,
211                    digests: vec![digest_c],
212                },
213            },
214        });
215
216        let encoded = message.encode().to_vec();
217        let decoded = wire::Message::<
218            keyless_compact::Operation,
219            commonware_cryptography::sha256::Digest,
220        >::decode(&encoded[..])
221        .expect("failed to decode compact response");
222
223        match decoded {
224            wire::Message::GetCompactStateResponse(response) => {
225                assert_eq!(response.request_id, request_id);
226                assert_eq!(response.state.leaf_count, Location::new(11));
227                assert_eq!(response.state.pinned_nodes.len(), 2);
228            }
229            other => panic!("unexpected message variant: {other:?}"),
230        }
231    }
232}