1use back_off::BackOffList;
2use kitsune2_api::*;
3use message_handler::FetchMessageHandler;
4use std::collections::HashMap;
5use std::{
6 collections::HashSet,
7 sync::{Arc, Mutex},
8 time::Duration,
9};
10use tokio::{
11 sync::mpsc::{channel, Receiver, Sender},
12 task::JoinHandle,
13};
14
15mod back_off;
16mod message_handler;
17
18#[cfg(test)]
19mod test;
20
21pub const MOD_NAME: &str = "Fetch";
23
24mod config {
26 #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
28 #[serde(rename_all = "camelCase")]
29 pub struct CoreFetchConfig {
30 pub parallel_request_count: u8,
32 pub re_insert_outgoing_request_delay_ms: u32,
35 pub first_back_off_interval_ms: u32,
37 pub last_back_off_interval_ms: u32,
39 pub num_back_off_intervals: usize,
41 }
42
43 impl Default for CoreFetchConfig {
44 fn default() -> Self {
46 Self {
47 parallel_request_count: 2,
48 re_insert_outgoing_request_delay_ms: 30000,
49 first_back_off_interval_ms: 1000 * 20,
50 last_back_off_interval_ms: 1000 * 60 * 10,
51 num_back_off_intervals: 4,
52 }
53 }
54 }
55
56 #[derive(Debug, Default, Clone, serde::Serialize, serde::Deserialize)]
58 #[serde(rename_all = "camelCase")]
59 pub struct CoreFetchModConfig {
60 pub core_fetch: CoreFetchConfig,
62 }
63}
64
65pub use config::*;
66
67#[derive(Debug)]
69pub struct CoreFetchFactory {}
70
71impl CoreFetchFactory {
72 pub fn create() -> DynFetchFactory {
74 Arc::new(Self {})
75 }
76}
77
78impl FetchFactory for CoreFetchFactory {
79 fn default_config(&self, config: &mut Config) -> K2Result<()> {
80 config.set_module_config(&CoreFetchModConfig::default())?;
81 Ok(())
82 }
83
84 fn validate_config(&self, _config: &Config) -> K2Result<()> {
85 Ok(())
86 }
87
88 fn create(
89 &self,
90 builder: Arc<Builder>,
91 space_id: SpaceId,
92 op_store: DynOpStore,
93 transport: DynTransport,
94 ) -> BoxFut<'static, K2Result<DynFetch>> {
95 Box::pin(async move {
96 let config: CoreFetchModConfig =
97 builder.config.get_module_config()?;
98 let out: DynFetch = Arc::new(CoreFetch::new(
99 config.core_fetch,
100 space_id,
101 op_store,
102 transport,
103 ));
104 Ok(out)
105 })
106 }
107}
108
109type OutgoingRequest = (OpId, Url);
110type IncomingRequest = (Vec<OpId>, Url);
111type IncomingResponse = Vec<Op>;
112
113#[derive(Debug)]
114struct State {
115 requests: HashSet<OutgoingRequest>,
116 back_off_list: BackOffList,
117 notify_when_drained_senders: Vec<futures::channel::oneshot::Sender<()>>,
118}
119
120impl State {
121 fn summary(&self) -> FetchStateSummary {
122 FetchStateSummary {
123 pending_requests: self.requests.iter().fold(
124 HashMap::new(),
125 |mut acc, (op_id, peer_url)| {
126 acc.entry(op_id.clone())
127 .or_default()
128 .push(peer_url.clone());
129 acc
130 },
131 ),
132 peers_on_backoff: self
133 .back_off_list
134 .state
135 .iter()
136 .map(|(peer_url, backoff)| {
137 (peer_url.clone(), backoff.current_backoff_expiry())
138 })
139 .collect(),
140 }
141 }
142}
143
144#[derive(Debug)]
145struct CoreFetch {
146 state: Arc<Mutex<State>>,
147 outgoing_request_tx: Sender<OutgoingRequest>,
148 tasks: Vec<JoinHandle<()>>,
149 op_store: DynOpStore,
150 #[cfg(test)]
151 message_handler: DynTxModuleHandler,
152}
153
154impl CoreFetch {
155 fn new(
156 config: CoreFetchConfig,
157 space_id: SpaceId,
158 op_store: DynOpStore,
159 transport: DynTransport,
160 ) -> Self {
161 Self::spawn_tasks(config, space_id, op_store, transport)
162 }
163}
164
165impl Fetch for CoreFetch {
166 fn request_ops(
167 &self,
168 op_ids: Vec<OpId>,
169 source: Url,
170 ) -> BoxFut<'_, K2Result<()>> {
171 Box::pin(async move {
172 let new_op_ids =
174 self.op_store.filter_out_existing_ops(op_ids).await?;
175
176 {
178 let requests = &mut self.state.lock().unwrap().requests;
179 requests.extend(
180 new_op_ids
181 .clone()
182 .into_iter()
183 .map(|op_id| (op_id.clone(), source.clone())),
184 );
185 }
186 for op_id in new_op_ids {
188 if let Err(err) =
189 self.outgoing_request_tx.send((op_id, source.clone())).await
190 {
191 tracing::warn!(
192 "could not insert fetch request into fetch queue: {err}"
193 );
194 }
195 }
196
197 Ok(())
198 })
199 }
200
201 fn notify_on_drained(&self, notify: futures::channel::oneshot::Sender<()>) {
202 let mut lock = self.state.lock().expect("poisoned");
203 if lock.requests.is_empty() {
204 if let Err(err) = notify.send(()) {
205 tracing::warn!(?err, "Failed to send notification on drained");
206 }
207 } else {
208 lock.notify_when_drained_senders.push(notify);
209 }
210 }
211
212 fn get_state_summary(&self) -> BoxFut<'_, K2Result<FetchStateSummary>> {
213 Box::pin(async move { Ok(self.state.lock().unwrap().summary()) })
214 }
215}
216
217impl CoreFetch {
218 pub fn spawn_tasks(
219 config: CoreFetchConfig,
220 space_id: SpaceId,
221 op_store: DynOpStore,
222 transport: DynTransport,
223 ) -> Self {
224 let (outgoing_request_tx, outgoing_request_rx) =
226 channel::<OutgoingRequest>(16_384);
227 let outgoing_request_rx =
228 Arc::new(tokio::sync::Mutex::new(outgoing_request_rx));
229
230 let (incoming_request_tx, incoming_request_rx) =
233 channel::<IncomingRequest>(16_384);
234
235 let (incoming_response_tx, incoming_response_rx) =
238 channel::<IncomingResponse>(16_384);
239
240 let state = Arc::new(Mutex::new(State {
241 requests: HashSet::new(),
242 back_off_list: BackOffList::new(
243 config.first_back_off_interval_ms,
244 config.last_back_off_interval_ms,
245 config.num_back_off_intervals,
246 ),
247 notify_when_drained_senders: vec![],
248 }));
249
250 let mut tasks =
251 Vec::with_capacity(config.parallel_request_count as usize);
252 for _ in 0..config.parallel_request_count {
254 let request_task =
255 tokio::task::spawn(CoreFetch::outgoing_request_task(
256 state.clone(),
257 outgoing_request_tx.clone(),
258 outgoing_request_rx.clone(),
259 space_id.clone(),
260 Arc::downgrade(&transport),
261 config.re_insert_outgoing_request_delay_ms,
262 ));
263 tasks.push(request_task);
264 }
265
266 let incoming_request_task =
268 tokio::task::spawn(CoreFetch::incoming_request_task(
269 incoming_request_rx,
270 op_store.clone(),
271 Arc::downgrade(&transport),
272 space_id.clone(),
273 ));
274 tasks.push(incoming_request_task);
275
276 let incoming_response_task =
278 tokio::task::spawn(CoreFetch::incoming_response_task(
279 incoming_response_rx,
280 op_store.clone(),
281 state.clone(),
282 ));
283 tasks.push(incoming_response_task);
284
285 let message_handler = Arc::new(FetchMessageHandler {
287 incoming_request_tx,
288 incoming_response_tx,
289 });
290 transport.register_module_handler(
291 space_id.clone(),
292 MOD_NAME.to_string(),
293 message_handler.clone(),
294 );
295
296 Self {
297 state,
298 outgoing_request_tx,
299 tasks,
300 op_store,
301 #[cfg(test)]
302 message_handler,
303 }
304 }
305
306 async fn outgoing_request_task(
307 state: Arc<Mutex<State>>,
308 outgoing_request_tx: Sender<OutgoingRequest>,
309 outgoing_request_rx: Arc<tokio::sync::Mutex<Receiver<OutgoingRequest>>>,
310 space_id: SpaceId,
311 transport: WeakDynTransport,
312 re_insert_outgoing_request_delay: u32,
313 ) {
314 while let Some((op_id, peer_url)) =
315 outgoing_request_rx.lock().await.recv().await
316 {
317 let Some(transport) = transport.upgrade() else {
318 tracing::info!(
319 "Transport dropped, stopping outgoing request task"
320 );
321 break;
322 };
323
324 let is_peer_on_back_off = {
325 let mut lock = state.lock().unwrap();
326
327 if !lock.requests.contains(&(op_id.clone(), peer_url.clone())) {
332 continue;
333 }
334
335 lock.back_off_list.is_peer_on_back_off(&peer_url)
336 };
337
338 if !is_peer_on_back_off {
340 tracing::debug!(
341 ?peer_url,
342 ?space_id,
343 ?op_id,
344 "sending fetch request"
345 );
346
347 let data = serialize_request_message(vec![op_id.clone()]);
349 match transport
350 .send_module(
351 peer_url.clone(),
352 space_id.clone(),
353 MOD_NAME.to_string(),
354 data,
355 )
356 .await
357 {
358 Ok(()) => {
359 state
361 .lock()
362 .unwrap()
363 .back_off_list
364 .remove_peer(&peer_url);
365 }
366 Err(err) => {
367 tracing::warn!(
368 ?op_id,
369 ?peer_url,
370 "could not send fetch request: {err}. Putting peer on back off list."
371 );
372 let mut lock = state.lock().unwrap();
373 lock.back_off_list.back_off_peer(&peer_url);
374
375 if lock
378 .back_off_list
379 .has_last_back_off_expired(&peer_url)
380 {
381 lock.requests.retain(|(_, a)| *a != peer_url);
382 }
383 }
384 }
385 }
386
387 {
392 let mut lock = state.lock().expect("poisoned");
393 if lock.requests.is_empty() {
394 for notify in lock.notify_when_drained_senders.drain(..) {
396 if notify.send(()).is_err() {
397 tracing::warn!(
398 "Failed to send notification on drained"
399 );
400 }
401 }
402 }
403 }
404
405 let outgoing_request_tx = outgoing_request_tx.clone();
407
408 tokio::task::spawn({
409 let state = state.clone();
410 async move {
411 tokio::time::sleep(Duration::from_millis(
412 re_insert_outgoing_request_delay as u64,
413 ))
414 .await;
415 if let Err(err) = outgoing_request_tx
416 .try_send((op_id.clone(), peer_url.clone()))
417 {
418 tracing::warn!(
419 "could not re-insert fetch request for op {op_id} to peer {peer_url} into queue: {err}"
420 );
421 state
423 .lock()
424 .unwrap()
425 .requests
426 .remove(&(op_id, peer_url));
427 }
428 }
429 });
430 }
431 }
432
433 async fn incoming_request_task(
434 mut response_rx: Receiver<IncomingRequest>,
435 op_store: DynOpStore,
436 transport: WeakDynTransport,
437 space_id: SpaceId,
438 ) {
439 while let Some((op_ids, peer)) = response_rx.recv().await {
440 tracing::debug!(?peer, ?op_ids, "incoming request");
441
442 let Some(transport) = transport.upgrade() else {
443 tracing::info!(
444 "Transport dropped, stopping incoming request task"
445 );
446 break;
447 };
448
449 let ops = match op_store.retrieve_ops(op_ids.clone()).await {
451 Err(err) => {
452 tracing::error!("could not read ops from store: {err}");
453 continue;
454 }
455 Ok(ops) => {
456 ops.into_iter().map(|op| op.op_data).collect::<Vec<_>>()
457 }
458 };
459
460 if ops.is_empty() {
461 tracing::info!(
462 "none of the ops requested from {peer} found in store"
463 );
464 continue;
466 }
467
468 let data = serialize_response_message(ops);
469 if let Err(err) = transport
470 .send_module(
471 peer.clone(),
472 space_id.clone(),
473 MOD_NAME.to_string(),
474 data,
475 )
476 .await
477 {
478 tracing::warn!(
479 ?op_ids,
480 ?peer,
481 "could not send ops to requesting peer: {err}"
482 );
483 }
484 }
485 }
486
487 async fn incoming_response_task(
488 mut incoming_response_rx: Receiver<IncomingResponse>,
489 op_store: DynOpStore,
490 state: Arc<Mutex<State>>,
491 ) {
492 while let Some(ops) = incoming_response_rx.recv().await {
493 let op_count = ops.len();
494 tracing::debug!(?op_count, "incoming op response");
495 let ops_data = ops.clone().into_iter().map(|op| op.data).collect();
496 match op_store.process_incoming_ops(ops_data).await {
497 Err(err) => {
498 tracing::error!("could not process incoming ops: {err}");
499 continue;
502 }
503 Ok(processed_op_ids) => {
504 tracing::debug!(
505 "processed incoming ops with op ids {processed_op_ids:?}"
506 );
507 let mut lock = state.lock().unwrap();
510 lock.requests
511 .retain(|(op_id, _)| !processed_op_ids.contains(op_id));
512 }
513 }
514 }
515 }
516}
517
518impl Drop for CoreFetch {
519 fn drop(&mut self) {
520 for t in self.tasks.iter() {
521 t.abort();
522 }
523 }
524}