kitsune2_core/factories/
core_fetch.rs

1use back_off::BackOffList;
2use kitsune2_api::*;
3use message_handler::FetchMessageHandler;
4use std::collections::HashMap;
5use std::{
6    collections::HashSet,
7    sync::{Arc, Mutex},
8    time::Duration,
9};
10use tokio::{
11    sync::mpsc::{channel, Receiver, Sender},
12    task::JoinHandle,
13};
14
15mod back_off;
16mod message_handler;
17
18#[cfg(test)]
19mod test;
20
21/// CoreFetch module name.
22pub const MOD_NAME: &str = "Fetch";
23
24/// CoreFetch configuration types.
25mod config {
26    /// Configuration parameters for [CoreFetchFactory](super::CoreFetchFactory).
27    #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
28    #[serde(rename_all = "camelCase")]
29    pub struct CoreFetchConfig {
30        /// How many parallel op fetch requests can be made at once. Default: 2.
31        pub parallel_request_count: u8,
32        /// Delay before re-inserting ops to request back into the outgoing request queue.
33        /// Default: 30 s.
34        pub re_insert_outgoing_request_delay_ms: u32,
35        /// Duration of first interval to back off an unresponsive peer. Default: 20 s.
36        pub first_back_off_interval_ms: u32,
37        /// Duration of last interval to back off an unresponsive peer. Default: 10 min.
38        pub last_back_off_interval_ms: u32,
39        /// Number of back off intervals. Default: 4.
40        pub num_back_off_intervals: usize,
41    }
42
43    impl Default for CoreFetchConfig {
44        // Maximum back off is 11:40 min.
45        fn default() -> Self {
46            Self {
47                parallel_request_count: 2,
48                re_insert_outgoing_request_delay_ms: 30000,
49                first_back_off_interval_ms: 1000 * 20,
50                last_back_off_interval_ms: 1000 * 60 * 10,
51                num_back_off_intervals: 4,
52            }
53        }
54    }
55
56    /// Module-level configuration for CoreFetch.
57    #[derive(Debug, Default, Clone, serde::Serialize, serde::Deserialize)]
58    #[serde(rename_all = "camelCase")]
59    pub struct CoreFetchModConfig {
60        /// CoreFetch configuration.
61        pub core_fetch: CoreFetchConfig,
62    }
63}
64
65pub use config::*;
66
67/// A production-ready fetch module.
68#[derive(Debug)]
69pub struct CoreFetchFactory {}
70
71impl CoreFetchFactory {
72    /// Construct a new CoreFetchFactory.
73    pub fn create() -> DynFetchFactory {
74        Arc::new(Self {})
75    }
76}
77
78impl FetchFactory for CoreFetchFactory {
79    fn default_config(&self, config: &mut Config) -> K2Result<()> {
80        config.set_module_config(&CoreFetchModConfig::default())?;
81        Ok(())
82    }
83
84    fn validate_config(&self, _config: &Config) -> K2Result<()> {
85        Ok(())
86    }
87
88    fn create(
89        &self,
90        builder: Arc<Builder>,
91        space_id: SpaceId,
92        op_store: DynOpStore,
93        transport: DynTransport,
94    ) -> BoxFut<'static, K2Result<DynFetch>> {
95        Box::pin(async move {
96            let config: CoreFetchModConfig =
97                builder.config.get_module_config()?;
98            let out: DynFetch = Arc::new(CoreFetch::new(
99                config.core_fetch,
100                space_id,
101                op_store,
102                transport,
103            ));
104            Ok(out)
105        })
106    }
107}
108
109type OutgoingRequest = (OpId, Url);
110type IncomingRequest = (Vec<OpId>, Url);
111type IncomingResponse = Vec<Op>;
112
113#[derive(Debug)]
114struct State {
115    requests: HashSet<OutgoingRequest>,
116    back_off_list: BackOffList,
117    notify_when_drained_senders: Vec<futures::channel::oneshot::Sender<()>>,
118}
119
120impl State {
121    fn summary(&self) -> FetchStateSummary {
122        FetchStateSummary {
123            pending_requests: self.requests.iter().fold(
124                HashMap::new(),
125                |mut acc, (op_id, peer_url)| {
126                    acc.entry(op_id.clone())
127                        .or_default()
128                        .push(peer_url.clone());
129                    acc
130                },
131            ),
132            peers_on_backoff: self
133                .back_off_list
134                .state
135                .iter()
136                .map(|(peer_url, backoff)| {
137                    (peer_url.clone(), backoff.current_backoff_expiry())
138                })
139                .collect(),
140        }
141    }
142}
143
144#[derive(Debug)]
145struct CoreFetch {
146    state: Arc<Mutex<State>>,
147    outgoing_request_tx: Sender<OutgoingRequest>,
148    tasks: Vec<JoinHandle<()>>,
149    op_store: DynOpStore,
150    #[cfg(test)]
151    message_handler: DynTxModuleHandler,
152}
153
154impl CoreFetch {
155    fn new(
156        config: CoreFetchConfig,
157        space_id: SpaceId,
158        op_store: DynOpStore,
159        transport: DynTransport,
160    ) -> Self {
161        Self::spawn_tasks(config, space_id, op_store, transport)
162    }
163}
164
165impl Fetch for CoreFetch {
166    fn request_ops(
167        &self,
168        op_ids: Vec<OpId>,
169        source: Url,
170    ) -> BoxFut<'_, K2Result<()>> {
171        Box::pin(async move {
172            // Filter out requests for ops that are already in the op store.
173            let new_op_ids =
174                self.op_store.filter_out_existing_ops(op_ids).await?;
175
176            // Add requests to set.
177            {
178                let requests = &mut self.state.lock().unwrap().requests;
179                requests.extend(
180                    new_op_ids
181                        .clone()
182                        .into_iter()
183                        .map(|op_id| (op_id.clone(), source.clone())),
184                );
185            }
186            // Insert requests into fetch queue.
187            for op_id in new_op_ids {
188                if let Err(err) =
189                    self.outgoing_request_tx.send((op_id, source.clone())).await
190                {
191                    tracing::warn!(
192                        "could not insert fetch request into fetch queue: {err}"
193                    );
194                }
195            }
196
197            Ok(())
198        })
199    }
200
201    fn notify_on_drained(&self, notify: futures::channel::oneshot::Sender<()>) {
202        let mut lock = self.state.lock().expect("poisoned");
203        if lock.requests.is_empty() {
204            if let Err(err) = notify.send(()) {
205                tracing::warn!(?err, "Failed to send notification on drained");
206            }
207        } else {
208            lock.notify_when_drained_senders.push(notify);
209        }
210    }
211
212    fn get_state_summary(&self) -> BoxFut<'_, K2Result<FetchStateSummary>> {
213        Box::pin(async move { Ok(self.state.lock().unwrap().summary()) })
214    }
215}
216
217impl CoreFetch {
218    pub fn spawn_tasks(
219        config: CoreFetchConfig,
220        space_id: SpaceId,
221        op_store: DynOpStore,
222        transport: DynTransport,
223    ) -> Self {
224        // Create a queue to process outgoing op requests. Requests are sent to peers.
225        let (outgoing_request_tx, outgoing_request_rx) =
226            channel::<OutgoingRequest>(16_384);
227        let outgoing_request_rx =
228            Arc::new(tokio::sync::Mutex::new(outgoing_request_rx));
229
230        // Create a queue to process incoming op requests. Requested ops are retrieved from the
231        // store and returned to the requester.
232        let (incoming_request_tx, incoming_request_rx) =
233            channel::<IncomingRequest>(16_384);
234
235        // Create a queue to process incoming op responses. Ops are passed to the op store and op
236        // ids removed from the set of ops to fetch.
237        let (incoming_response_tx, incoming_response_rx) =
238            channel::<IncomingResponse>(16_384);
239
240        let state = Arc::new(Mutex::new(State {
241            requests: HashSet::new(),
242            back_off_list: BackOffList::new(
243                config.first_back_off_interval_ms,
244                config.last_back_off_interval_ms,
245                config.num_back_off_intervals,
246            ),
247            notify_when_drained_senders: vec![],
248        }));
249
250        let mut tasks =
251            Vec::with_capacity(config.parallel_request_count as usize);
252        // Spawn request tasks.
253        for _ in 0..config.parallel_request_count {
254            let request_task =
255                tokio::task::spawn(CoreFetch::outgoing_request_task(
256                    state.clone(),
257                    outgoing_request_tx.clone(),
258                    outgoing_request_rx.clone(),
259                    space_id.clone(),
260                    Arc::downgrade(&transport),
261                    config.re_insert_outgoing_request_delay_ms,
262                ));
263            tasks.push(request_task);
264        }
265
266        // Spawn incoming request task.
267        let incoming_request_task =
268            tokio::task::spawn(CoreFetch::incoming_request_task(
269                incoming_request_rx,
270                op_store.clone(),
271                Arc::downgrade(&transport),
272                space_id.clone(),
273            ));
274        tasks.push(incoming_request_task);
275
276        // Spawn incoming response task.
277        let incoming_response_task =
278            tokio::task::spawn(CoreFetch::incoming_response_task(
279                incoming_response_rx,
280                op_store.clone(),
281                state.clone(),
282            ));
283        tasks.push(incoming_response_task);
284
285        // Register transport module handler for incoming op requests and responses.
286        let message_handler = Arc::new(FetchMessageHandler {
287            incoming_request_tx,
288            incoming_response_tx,
289        });
290        transport.register_module_handler(
291            space_id.clone(),
292            MOD_NAME.to_string(),
293            message_handler.clone(),
294        );
295
296        Self {
297            state,
298            outgoing_request_tx,
299            tasks,
300            op_store,
301            #[cfg(test)]
302            message_handler,
303        }
304    }
305
306    async fn outgoing_request_task(
307        state: Arc<Mutex<State>>,
308        outgoing_request_tx: Sender<OutgoingRequest>,
309        outgoing_request_rx: Arc<tokio::sync::Mutex<Receiver<OutgoingRequest>>>,
310        space_id: SpaceId,
311        transport: WeakDynTransport,
312        re_insert_outgoing_request_delay: u32,
313    ) {
314        while let Some((op_id, peer_url)) =
315            outgoing_request_rx.lock().await.recv().await
316        {
317            let Some(transport) = transport.upgrade() else {
318                tracing::info!(
319                    "Transport dropped, stopping outgoing request task"
320                );
321                break;
322            };
323
324            let is_peer_on_back_off = {
325                let mut lock = state.lock().unwrap();
326
327                // Do nothing if op id is no longer in the set of requests to send.
328                //
329                // Note that because this request isn't in the state, it is safe to
330                // skip the requests empty check below.
331                if !lock.requests.contains(&(op_id.clone(), peer_url.clone())) {
332                    continue;
333                }
334
335                lock.back_off_list.is_peer_on_back_off(&peer_url)
336            };
337
338            // Send request if peer is not on back off list.
339            if !is_peer_on_back_off {
340                tracing::debug!(
341                    ?peer_url,
342                    ?space_id,
343                    ?op_id,
344                    "sending fetch request"
345                );
346
347                // Send fetch request to peer.
348                let data = serialize_request_message(vec![op_id.clone()]);
349                match transport
350                    .send_module(
351                        peer_url.clone(),
352                        space_id.clone(),
353                        MOD_NAME.to_string(),
354                        data,
355                    )
356                    .await
357                {
358                    Ok(()) => {
359                        // If peer was on back off list, remove them.
360                        state
361                            .lock()
362                            .unwrap()
363                            .back_off_list
364                            .remove_peer(&peer_url);
365                    }
366                    Err(err) => {
367                        tracing::warn!(
368                            ?op_id,
369                            ?peer_url,
370                            "could not send fetch request: {err}. Putting peer on back off list."
371                        );
372                        let mut lock = state.lock().unwrap();
373                        lock.back_off_list.back_off_peer(&peer_url);
374
375                        // If max back off interval has expired for the peer,
376                        // give up on requesting ops from them.
377                        if lock
378                            .back_off_list
379                            .has_last_back_off_expired(&peer_url)
380                        {
381                            lock.requests.retain(|(_, a)| *a != peer_url);
382                        }
383                    }
384                }
385            }
386
387            // After processing this request, check if the fetch queue is drained.
388            //
389            // Note that using flow control above could skip this step, so please only `continue`
390            // if it is safe to do so.
391            {
392                let mut lock = state.lock().expect("poisoned");
393                if lock.requests.is_empty() {
394                    // Notify all listeners that the fetch queue is drained.
395                    for notify in lock.notify_when_drained_senders.drain(..) {
396                        if notify.send(()).is_err() {
397                            tracing::warn!(
398                                "Failed to send notification on drained"
399                            );
400                        }
401                    }
402                }
403            }
404
405            // Re-insert the fetch request into the queue after a delay.
406            let outgoing_request_tx = outgoing_request_tx.clone();
407
408            tokio::task::spawn({
409                let state = state.clone();
410                async move {
411                    tokio::time::sleep(Duration::from_millis(
412                        re_insert_outgoing_request_delay as u64,
413                    ))
414                    .await;
415                    if let Err(err) = outgoing_request_tx
416                        .try_send((op_id.clone(), peer_url.clone()))
417                    {
418                        tracing::warn!(
419                        "could not re-insert fetch request for op {op_id} to peer {peer_url} into queue: {err}"
420                    );
421                        // Remove op id/peer url from set to prevent build-up of state.
422                        state
423                            .lock()
424                            .unwrap()
425                            .requests
426                            .remove(&(op_id, peer_url));
427                    }
428                }
429            });
430        }
431    }
432
433    async fn incoming_request_task(
434        mut response_rx: Receiver<IncomingRequest>,
435        op_store: DynOpStore,
436        transport: WeakDynTransport,
437        space_id: SpaceId,
438    ) {
439        while let Some((op_ids, peer)) = response_rx.recv().await {
440            tracing::debug!(?peer, ?op_ids, "incoming request");
441
442            let Some(transport) = transport.upgrade() else {
443                tracing::info!(
444                    "Transport dropped, stopping incoming request task"
445                );
446                break;
447            };
448
449            // Retrieve ops to send from store.
450            let ops = match op_store.retrieve_ops(op_ids.clone()).await {
451                Err(err) => {
452                    tracing::error!("could not read ops from store: {err}");
453                    continue;
454                }
455                Ok(ops) => {
456                    ops.into_iter().map(|op| op.op_data).collect::<Vec<_>>()
457                }
458            };
459
460            if ops.is_empty() {
461                tracing::info!(
462                    "none of the ops requested from {peer} found in store"
463                );
464                // Do not send a response when no ops could be retrieved.
465                continue;
466            }
467
468            let data = serialize_response_message(ops);
469            if let Err(err) = transport
470                .send_module(
471                    peer.clone(),
472                    space_id.clone(),
473                    MOD_NAME.to_string(),
474                    data,
475                )
476                .await
477            {
478                tracing::warn!(
479                    ?op_ids,
480                    ?peer,
481                    "could not send ops to requesting peer: {err}"
482                );
483            }
484        }
485    }
486
487    async fn incoming_response_task(
488        mut incoming_response_rx: Receiver<IncomingResponse>,
489        op_store: DynOpStore,
490        state: Arc<Mutex<State>>,
491    ) {
492        while let Some(ops) = incoming_response_rx.recv().await {
493            let op_count = ops.len();
494            tracing::debug!(?op_count, "incoming op response");
495            let ops_data = ops.clone().into_iter().map(|op| op.data).collect();
496            match op_store.process_incoming_ops(ops_data).await {
497                Err(err) => {
498                    tracing::error!("could not process incoming ops: {err}");
499                    // Ops could not be written to the op store. Their ids remain in the set of ops
500                    // to fetch.
501                    continue;
502                }
503                Ok(processed_op_ids) => {
504                    tracing::debug!(
505                        "processed incoming ops with op ids {processed_op_ids:?}"
506                    );
507                    // Ops were processed successfully by op store. Op ids are returned.
508                    // The op ids are removed from the set of ops to fetch.
509                    let mut lock = state.lock().unwrap();
510                    lock.requests
511                        .retain(|(op_id, _)| !processed_op_ids.contains(op_id));
512                }
513            }
514        }
515    }
516}
517
518impl Drop for CoreFetch {
519    fn drop(&mut self) {
520        for t in self.tasks.iter() {
521            t.abort();
522        }
523    }
524}