aranya_runtime/sync/
requester.rs

1use alloc::vec;
2
3use aranya_crypto::Csprng;
4use buggy::BugExt as _;
5use heapless::Vec;
6use serde::{Deserialize, Serialize, de::DeserializeOwned};
7
8use super::{
9    COMMAND_RESPONSE_MAX, COMMAND_SAMPLE_MAX, PEER_HEAD_MAX, PeerCache, REQUEST_MISSING_MAX,
10    SyncCommand, SyncError, dispatcher::SyncType, responder::SyncResponseMessage,
11};
12use crate::{
13    Address, Command as _, GraphId, Location,
14    storage::{Segment as _, Storage as _, StorageError, StorageProvider},
15};
16
17// TODO: Use compile-time args. This initial definition results in this clippy warning:
18// https://rust-lang.github.io/rust-clippy/master/index.html#large_enum_variant.
19// As the buffer consts will be compile-time variables in the future, we will be
20// able to tune these buffers for smaller footprints. Right now, this enum is not
21// suitable for small devices (`SyncRequest` is 4080 bytes).
22/// Messages sent from the requester to the responder.
23#[derive(Serialize, Deserialize, Debug)]
24#[allow(clippy::large_enum_variant)]
25pub enum SyncRequestMessage {
26    /// Initiate a new Sync
27    SyncRequest {
28        /// A new random value produced by a cryptographically secure RNG.
29        session_id: u128,
30        /// Specifies the graph to be synced.
31        storage_id: GraphId,
32        /// Specifies the maximum number of bytes worth of commands that
33        /// the requester wishes to receive.
34        max_bytes: u64,
35        /// Sample of the commands held by the requester. The responder should
36        /// respond with any commands that the requester may not have based on
37        /// the provided sample. When sending commands ancestors must be sent
38        /// before descendents.
39        commands: Vec<Address, COMMAND_SAMPLE_MAX>,
40    },
41
42    /// Sent by the requester if it deduces a `SyncResponse` message has been
43    /// dropped.
44    RequestMissing {
45        /// A random-value produced by a cryptographically secure RNG
46        /// corresponding to the `session_id` in the initial `SyncRequest`.
47        session_id: u128,
48        /// `SyncResponse` indexes that the requester has not received.
49        indexes: Vec<u64, REQUEST_MISSING_MAX>,
50    },
51
52    /// Message to request the responder resumes sending `SyncResponse`s
53    /// following the specified message. This may be sent after a requester
54    /// timeout or after a `SyncEnd` has been sent.
55    SyncResume {
56        /// A random-value produced by a cryptographically secure RNG
57        /// corresponding to the `session_id` in the initial `SyncRequest`.
58        session_id: u128,
59        /// Indicates the last response message the requester received.
60        response_index: u64,
61        /// Updates the maximum number of bytes worth of commands that
62        /// the requester wishes to receive.
63        max_bytes: u64,
64    },
65
66    /// Message sent by either requester or responder to indicate the session
67    /// has been terminated or the `session_id` is no longer valid.
68    EndSession { session_id: u128 },
69}
70
71impl SyncRequestMessage {
72    pub fn session_id(&self) -> u128 {
73        match self {
74            Self::SyncRequest { session_id, .. } => *session_id,
75            Self::RequestMissing { session_id, .. } => *session_id,
76            Self::SyncResume { session_id, .. } => *session_id,
77            Self::EndSession { session_id, .. } => *session_id,
78        }
79    }
80}
81
82#[derive(Clone, Debug, PartialEq, Eq)]
83enum SyncRequesterState {
84    /// Object initialized; no messages sent
85    New,
86    /// Have sent a start message but have received no response
87    Start,
88    /// Received some response messages; session not yet complete
89    Waiting,
90    /// Sync is complete; byte count has not yet been reached
91    Idle,
92    /// Session is done and no more messages should be processed
93    Closed,
94    /// Sync response received out of order
95    Resync,
96    /// SyncEnd received but not everything has been received
97    PartialSync,
98    /// Session will be closed
99    Reset,
100}
101
102pub struct SyncRequester<A> {
103    session_id: u128,
104    storage_id: GraphId,
105    state: SyncRequesterState,
106    max_bytes: u64,
107    next_message_index: u64,
108    server_address: A,
109}
110
111impl<A: DeserializeOwned + Serialize + Clone> SyncRequester<A> {
112    /// Create a new [`SyncRequester`] with a random session ID.
113    pub fn new<R: Csprng>(storage_id: GraphId, rng: &mut R, server_address: A) -> Self {
114        // Randomly generate session id.
115        let mut dst = [0u8; 16];
116        rng.fill_bytes(&mut dst);
117        let session_id = u128::from_le_bytes(dst);
118
119        Self {
120            session_id,
121            storage_id,
122            state: SyncRequesterState::New,
123            max_bytes: 0,
124            next_message_index: 0,
125            server_address,
126        }
127    }
128
129    /// Create a new [`SyncRequester`] for an existing session.
130    pub fn new_session_id(storage_id: GraphId, session_id: u128, server_address: A) -> Self {
131        Self {
132            session_id,
133            storage_id,
134            state: SyncRequesterState::Waiting,
135            max_bytes: 0,
136            next_message_index: 0,
137            server_address,
138        }
139    }
140
141    /// Returns the server address.
142    pub fn server_addr(&self) -> A {
143        self.server_address.clone()
144    }
145
146    /// Returns true if [`Self::poll`] would produce a message.
147    pub fn ready(&self) -> bool {
148        use SyncRequesterState as S;
149        match self.state {
150            S::New | S::Resync | S::Reset => true,
151            S::Start | S::Waiting | S::Idle | S::Closed | S::PartialSync => false,
152        }
153    }
154
155    /// Write a sync message in to the target buffer. Returns the number
156    /// of bytes written and the number of commands sent in the sample.
157    pub fn poll(
158        &mut self,
159        target: &mut [u8],
160        provider: &mut impl StorageProvider,
161        heads: &mut PeerCache,
162    ) -> Result<(usize, usize), SyncError> {
163        use SyncRequesterState as S;
164        let result = match self.state {
165            S::Start | S::Waiting | S::Idle | S::Closed | S::PartialSync => {
166                return Err(SyncError::NotReady);
167            }
168            S::New => {
169                self.state = S::Start;
170                self.start(self.max_bytes, target, provider, heads)?
171            }
172            S::Resync => self.resume(self.max_bytes, target)?,
173            S::Reset => {
174                self.state = S::Closed;
175                self.end_session(target)?
176            }
177        };
178
179        Ok(result)
180    }
181
182    /// Receive a sync message. Returns parsed sync commands.
183    pub fn receive<'a>(
184        &mut self,
185        data: &'a [u8],
186    ) -> Result<Option<Vec<SyncCommand<'a>, COMMAND_RESPONSE_MAX>>, SyncError> {
187        let (message, remaining): (SyncResponseMessage, &'a [u8]) =
188            postcard::take_from_bytes(data)?;
189
190        self.get_sync_commands(message, remaining)
191    }
192
193    /// Extract SyncCommands from a SyncResponseMessage and remaining bytes.
194    pub fn get_sync_commands<'a>(
195        &mut self,
196        message: SyncResponseMessage,
197        remaining: &'a [u8],
198    ) -> Result<Option<Vec<SyncCommand<'a>, COMMAND_RESPONSE_MAX>>, SyncError> {
199        if message.session_id() != self.session_id {
200            return Err(SyncError::SessionMismatch);
201        }
202
203        let result = match message {
204            SyncResponseMessage::SyncResponse {
205                response_index,
206                commands,
207                ..
208            } => {
209                if !matches!(
210                    self.state,
211                    SyncRequesterState::Start | SyncRequesterState::Waiting
212                ) {
213                    return Err(SyncError::SessionState);
214                }
215
216                if response_index != self.next_message_index {
217                    self.state = SyncRequesterState::Resync;
218                    return Err(SyncError::MissingSyncResponse);
219                }
220                self.next_message_index = self
221                    .next_message_index
222                    .checked_add(1)
223                    .assume("next_message_index + 1 mustn't overflow")?;
224                self.state = SyncRequesterState::Waiting;
225
226                let mut result = Vec::new();
227                let mut start: usize = 0;
228                for meta in commands {
229                    let policy_len = meta.policy_length as usize;
230
231                    let policy = match policy_len == 0 {
232                        true => None,
233                        false => {
234                            let end = start
235                                .checked_add(policy_len)
236                                .assume("start + policy_len mustn't overflow")?;
237                            let policy = &remaining[start..end];
238                            start = end;
239                            Some(policy)
240                        }
241                    };
242
243                    let len = meta.length as usize;
244                    let end = start
245                        .checked_add(len)
246                        .assume("start + len mustn't overflow")?;
247                    let payload = &remaining[start..end];
248                    start = end;
249
250                    let command = SyncCommand {
251                        id: meta.id,
252                        priority: meta.priority,
253                        parent: meta.parent,
254                        policy,
255                        data: payload,
256                        max_cut: meta.max_cut,
257                    };
258
259                    result
260                        .push(command)
261                        .ok()
262                        .assume("commands is not larger than result")?;
263                }
264
265                Some(result)
266            }
267
268            SyncResponseMessage::SyncEnd { max_index, .. } => {
269                if !matches!(
270                    self.state,
271                    SyncRequesterState::Start | SyncRequesterState::Waiting
272                ) {
273                    return Err(SyncError::SessionState);
274                }
275
276                if max_index != self.next_message_index {
277                    self.state = SyncRequesterState::Resync;
278                    return Err(SyncError::MissingSyncResponse);
279                }
280
281                self.state = SyncRequesterState::PartialSync;
282
283                None
284            }
285
286            SyncResponseMessage::Offer { .. } => {
287                if self.state != SyncRequesterState::Idle {
288                    return Err(SyncError::SessionState);
289                }
290                self.state = SyncRequesterState::Resync;
291
292                None
293            }
294
295            SyncResponseMessage::EndSession { .. } => {
296                self.state = SyncRequesterState::Closed;
297                None
298            }
299        };
300
301        Ok(result)
302    }
303
304    fn write(target: &mut [u8], msg: SyncType<A>) -> Result<usize, SyncError> {
305        Ok(postcard::to_slice(&msg, target)?.len())
306    }
307
308    fn end_session(&mut self, target: &mut [u8]) -> Result<(usize, usize), SyncError> {
309        Ok((
310            Self::write(
311                target,
312                SyncType::Poll {
313                    request: SyncRequestMessage::EndSession {
314                        session_id: self.session_id,
315                    },
316                    address: self.server_address.clone(),
317                },
318            )?,
319            0,
320        ))
321    }
322
323    fn resume(&mut self, max_bytes: u64, target: &mut [u8]) -> Result<(usize, usize), SyncError> {
324        if !matches!(
325            self.state,
326            SyncRequesterState::Resync | SyncRequesterState::Idle
327        ) {
328            return Err(SyncError::SessionState);
329        }
330
331        self.state = SyncRequesterState::Waiting;
332        let message = SyncType::Poll {
333            request: SyncRequestMessage::SyncResume {
334                session_id: self.session_id,
335                response_index: self
336                    .next_message_index
337                    .checked_sub(1)
338                    .assume("next_index must be positive")?,
339                max_bytes,
340            },
341            address: self.server_address.clone(),
342        };
343
344        Ok((Self::write(target, message)?, 0))
345    }
346
347    fn get_commands(
348        &self,
349        provider: &mut impl StorageProvider,
350        peer_cache: &mut PeerCache,
351    ) -> Result<Vec<Address, COMMAND_SAMPLE_MAX>, SyncError> {
352        let mut commands: Vec<Address, COMMAND_SAMPLE_MAX> = Vec::new();
353
354        match provider.get_storage(self.storage_id) {
355            Err(StorageError::NoSuchStorage) => (),
356            Err(err) => {
357                return Err(SyncError::Storage(err));
358            }
359            Ok(storage) => {
360                let mut command_locations: Vec<Location, PEER_HEAD_MAX> = Vec::new();
361                for address in peer_cache.heads() {
362                    command_locations
363                        .push(
364                            storage
365                                .get_location(*address)?
366                                .assume("location must exist")?,
367                        )
368                        .ok()
369                        .assume("command locations should not be full")?;
370                    if commands.len() < COMMAND_SAMPLE_MAX {
371                        commands
372                            .push(*address)
373                            .map_err(|_| SyncError::CommandOverflow)?;
374                    }
375                }
376                let head = storage.get_head()?;
377
378                let mut current = vec![head];
379
380                // Here we just get the first command from the most reaseant
381                // COMMAND_SAMPLE_MAX segments in the graph. This is probbly
382                // not the best strategy as if you are far enough ahead of
383                // the other client they will just send you everything they have.
384                while commands.len() < COMMAND_SAMPLE_MAX && !current.is_empty() {
385                    let mut next = vec::Vec::new(); //BUG not constant memory
386
387                    'current: for &location in &current {
388                        let segment = storage.get_segment(location)?;
389
390                        let head = segment.head()?;
391                        let head_address = head.address()?;
392                        for loc in &command_locations {
393                            if loc.segment == location.segment {
394                                continue 'current;
395                            }
396                        }
397                        // TODO(chip): check that this is not an ancestor of a head in the PeerCache
398                        commands
399                            .push(head_address)
400                            .map_err(|_| SyncError::CommandOverflow)?;
401                        next.extend(segment.prior());
402                        if commands.len() >= COMMAND_SAMPLE_MAX {
403                            break 'current;
404                        }
405                    }
406
407                    current.clone_from(&next);
408                }
409            }
410        }
411        Ok(commands)
412    }
413
414    /// Writes a Subscribe message to target.
415    pub fn subscribe(
416        &mut self,
417        target: &mut [u8],
418        provider: &mut impl StorageProvider,
419        heads: &mut PeerCache,
420        remain_open: u64,
421        max_bytes: u64,
422    ) -> Result<usize, SyncError> {
423        let commands = self.get_commands(provider, heads)?;
424        let message = SyncType::Subscribe {
425            remain_open,
426            max_bytes,
427            commands,
428            address: self.server_address.clone(),
429            storage_id: self.storage_id,
430        };
431
432        Self::write(target, message)
433    }
434
435    /// Writes an Unsubscribe message to target.
436    pub fn unsubscribe(&mut self, target: &mut [u8]) -> Result<usize, SyncError> {
437        let message = SyncType::Unsubscribe {
438            address: self.server_address.clone(),
439        };
440
441        Self::write(target, message)
442    }
443
444    fn start(
445        &mut self,
446        max_bytes: u64,
447        target: &mut [u8],
448        provider: &mut impl StorageProvider,
449        heads: &mut PeerCache,
450    ) -> Result<(usize, usize), SyncError> {
451        if !matches!(
452            self.state,
453            SyncRequesterState::Start | SyncRequesterState::New
454        ) {
455            self.state = SyncRequesterState::Reset;
456            return Err(SyncError::SessionState);
457        }
458
459        self.state = SyncRequesterState::Start;
460        self.max_bytes = max_bytes;
461
462        let command_sample = self.get_commands(provider, heads)?;
463
464        let sent = command_sample.len();
465        let message = SyncType::Poll {
466            request: SyncRequestMessage::SyncRequest {
467                session_id: self.session_id,
468                storage_id: self.storage_id,
469                max_bytes,
470                commands: command_sample,
471            },
472            address: self.server_address.clone(),
473        };
474
475        Ok((Self::write(target, message)?, sent))
476    }
477}