Skip to main content

commonware_sync/net/
wire.rs

1use crate::net::{ErrorResponse, RequestId};
2use commonware_codec::{
3    Encode, EncodeSize, Error as CodecError, IsUnit, RangeCfg, Read, ReadExt as _, Write,
4};
5use commonware_cryptography::Digest;
6use commonware_runtime::{Buf, BufMut};
7use commonware_storage::{
8    mmr::{Location, Proof},
9    qmdb::sync::Target,
10};
11use std::num::NonZeroU64;
12
13/// Maximum number of digests in a proof.
14pub const MAX_DIGESTS: usize = 10_000;
15
16/// Maximum number of pinned nodes (one per MMR peak, bounded by max tree height).
17pub const MAX_PINNED_NODES: usize = 64;
18
19/// Request for operations from the server.
20#[derive(Debug)]
21pub struct GetOperationsRequest {
22    pub request_id: RequestId,
23    pub op_count: Location,
24    pub start_loc: Location,
25    pub max_ops: NonZeroU64,
26    pub include_pinned_nodes: bool,
27}
28
29/// Response with operations and proof.
30#[derive(Debug)]
31pub struct GetOperationsResponse<Op, D>
32where
33    D: Digest,
34{
35    pub request_id: RequestId,
36    pub proof: Proof<D>,
37    pub operations: Vec<Op>,
38    pub pinned_nodes: Option<Vec<D>>,
39}
40
41/// Request for sync target from server.
42#[derive(Debug)]
43pub struct GetSyncTargetRequest {
44    pub request_id: RequestId,
45}
46
47/// Response with sync target.
48#[derive(Debug)]
49pub struct GetSyncTargetResponse<D>
50where
51    D: Digest,
52{
53    pub request_id: RequestId,
54    pub target: Target<D>,
55}
56
57/// Messages that can be sent over the wire.
58#[derive(Debug)]
59pub enum Message<Op, D>
60where
61    D: Digest,
62{
63    GetOperationsRequest(GetOperationsRequest),
64    GetOperationsResponse(GetOperationsResponse<Op, D>),
65    GetSyncTargetRequest(GetSyncTargetRequest),
66    GetSyncTargetResponse(GetSyncTargetResponse<D>),
67    Error(ErrorResponse),
68}
69
70impl<Op, D> Message<Op, D>
71where
72    D: Digest,
73{
74    pub const fn request_id(&self) -> RequestId {
75        match self {
76            Self::GetOperationsRequest(r) => r.request_id,
77            Self::GetOperationsResponse(r) => r.request_id,
78            Self::GetSyncTargetRequest(r) => r.request_id,
79            Self::GetSyncTargetResponse(r) => r.request_id,
80            Self::Error(e) => e.request_id,
81        }
82    }
83}
84
85impl<Op, D> super::Message for Message<Op, D>
86where
87    Op: Encode + Read + Send + Sync + 'static,
88    Op::Cfg: IsUnit,
89    D: Digest,
90{
91    fn request_id(&self) -> RequestId {
92        self.request_id()
93    }
94}
95
96impl<Op, D> Write for Message<Op, D>
97where
98    Op: Write,
99    D: Digest,
100{
101    fn write(&self, buf: &mut impl BufMut) {
102        match self {
103            Self::GetOperationsRequest(req) => {
104                0u8.write(buf);
105                req.write(buf);
106            }
107            Self::GetOperationsResponse(resp) => {
108                1u8.write(buf);
109                resp.write(buf);
110            }
111            Self::GetSyncTargetRequest(req) => {
112                2u8.write(buf);
113                req.write(buf);
114            }
115            Self::GetSyncTargetResponse(resp) => {
116                3u8.write(buf);
117                resp.write(buf);
118            }
119            Self::Error(err) => {
120                4u8.write(buf);
121                err.write(buf);
122            }
123        }
124    }
125}
126
127impl<Op, D> EncodeSize for Message<Op, D>
128where
129    Op: EncodeSize,
130    D: Digest,
131{
132    fn encode_size(&self) -> usize {
133        1 + match self {
134            Self::GetOperationsRequest(req) => req.encode_size(),
135            Self::GetOperationsResponse(resp) => resp.encode_size(),
136            Self::GetSyncTargetRequest(req) => req.encode_size(),
137            Self::GetSyncTargetResponse(resp) => resp.encode_size(),
138            Self::Error(err) => err.encode_size(),
139        }
140    }
141}
142
143impl<Op, D> Read for Message<Op, D>
144where
145    Op: Read,
146    Op::Cfg: IsUnit,
147    D: Digest,
148{
149    type Cfg = ();
150    fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, CodecError> {
151        let tag = u8::read(buf)?;
152        match tag {
153            0 => Ok(Self::GetOperationsRequest(GetOperationsRequest::read(buf)?)),
154            1 => Ok(Self::GetOperationsResponse(GetOperationsResponse::read(
155                buf,
156            )?)),
157            2 => Ok(Self::GetSyncTargetRequest(GetSyncTargetRequest::read(buf)?)),
158            3 => Ok(Self::GetSyncTargetResponse(GetSyncTargetResponse::read(
159                buf,
160            )?)),
161            4 => Ok(Self::Error(ErrorResponse::read(buf)?)),
162            d => Err(CodecError::InvalidEnum(d)),
163        }
164    }
165}
166
167impl Write for GetOperationsRequest {
168    fn write(&self, buf: &mut impl BufMut) {
169        self.request_id.write(buf);
170        self.op_count.write(buf);
171        self.start_loc.write(buf);
172        self.max_ops.get().write(buf);
173        (self.include_pinned_nodes as u8).write(buf);
174    }
175}
176
177impl EncodeSize for GetOperationsRequest {
178    fn encode_size(&self) -> usize {
179        self.request_id.encode_size()
180            + self.op_count.encode_size()
181            + self.start_loc.encode_size()
182            + self.max_ops.get().encode_size()
183            + 1u8.encode_size()
184    }
185}
186
187impl Read for GetOperationsRequest {
188    type Cfg = ();
189    fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, CodecError> {
190        let request_id = RequestId::read_cfg(buf, &())?;
191        let op_count = Location::read(buf)?;
192        let start_loc = Location::read(buf)?;
193        let max_ops = u64::read(buf)?;
194        let Some(max_ops) = NonZeroU64::new(max_ops) else {
195            return Err(CodecError::Invalid(
196                "GetOperationsRequest",
197                "max_ops cannot be zero",
198            ));
199        };
200        let include_pinned_nodes = u8::read(buf)? != 0;
201        Ok(Self {
202            request_id,
203            op_count,
204            start_loc,
205            max_ops,
206            include_pinned_nodes,
207        })
208    }
209}
210
211impl GetOperationsRequest {
212    pub fn validate(&self) -> Result<(), crate::Error> {
213        if self.start_loc >= self.op_count {
214            return Err(crate::Error::InvalidRequest(format!(
215                "start_loc >= size ({}) >= ({})",
216                self.start_loc, self.op_count
217            )));
218        }
219        Ok(())
220    }
221}
222
223impl<Op, D> Write for GetOperationsResponse<Op, D>
224where
225    Op: Write,
226    D: Digest,
227{
228    fn write(&self, buf: &mut impl BufMut) {
229        self.request_id.write(buf);
230        self.proof.write(buf);
231        self.operations.write(buf);
232        match &self.pinned_nodes {
233            Some(nodes) => {
234                1u8.write(buf);
235                nodes.write(buf);
236            }
237            None => {
238                0u8.write(buf);
239            }
240        }
241    }
242}
243
244impl<Op, D> EncodeSize for GetOperationsResponse<Op, D>
245where
246    Op: EncodeSize,
247    D: Digest,
248{
249    fn encode_size(&self) -> usize {
250        self.request_id.encode_size()
251            + self.proof.encode_size()
252            + self.operations.encode_size()
253            + 1u8.encode_size()
254            + self
255                .pinned_nodes
256                .as_ref()
257                .map_or(0, |nodes| nodes.encode_size())
258    }
259}
260
261impl<Op, D> Read for GetOperationsResponse<Op, D>
262where
263    Op: Read,
264    Op::Cfg: IsUnit,
265    D: Digest,
266{
267    type Cfg = ();
268    fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, CodecError> {
269        let request_id = RequestId::read_cfg(buf, &())?;
270        let proof = Proof::<D>::read_cfg(buf, &MAX_DIGESTS)?;
271        let operations = {
272            let range_cfg = RangeCfg::from(0..=MAX_DIGESTS);
273            Vec::<Op>::read_cfg(buf, &(range_cfg, Op::Cfg::default()))?
274        };
275        let has_pinned_nodes = u8::read(buf)? != 0;
276        let pinned_nodes = if has_pinned_nodes {
277            let range_cfg = RangeCfg::from(0..=MAX_PINNED_NODES);
278            Some(Vec::<D>::read_cfg(buf, &(range_cfg, ()))?)
279        } else {
280            None
281        };
282        Ok(Self {
283            request_id,
284            proof,
285            operations,
286            pinned_nodes,
287        })
288    }
289}
290
291impl Write for GetSyncTargetRequest {
292    fn write(&self, buf: &mut impl BufMut) {
293        self.request_id.write(buf);
294    }
295}
296
297impl EncodeSize for GetSyncTargetRequest {
298    fn encode_size(&self) -> usize {
299        self.request_id.encode_size()
300    }
301}
302
303impl Read for GetSyncTargetRequest {
304    type Cfg = ();
305    fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, CodecError> {
306        let request_id = RequestId::read_cfg(buf, &())?;
307        Ok(Self { request_id })
308    }
309}
310
311impl<D> Write for GetSyncTargetResponse<D>
312where
313    D: Digest,
314{
315    fn write(&self, buf: &mut impl BufMut) {
316        self.request_id.write(buf);
317        self.target.write(buf);
318    }
319}
320
321impl<D> EncodeSize for GetSyncTargetResponse<D>
322where
323    D: Digest,
324{
325    fn encode_size(&self) -> usize {
326        self.request_id.encode_size() + self.target.encode_size()
327    }
328}
329
330impl<D> Read for GetSyncTargetResponse<D>
331where
332    D: Digest,
333{
334    type Cfg = ();
335    fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, CodecError> {
336        let request_id = RequestId::read_cfg(buf, &())?;
337        let target = Target::<D>::read_cfg(buf, &())?;
338        Ok(Self { request_id, target })
339    }
340}