libp2p_bitswap/
query.rs

1use crate::stats::{REQUESTS_TOTAL, REQUEST_DURATION_SECONDS};
2use fnv::{FnvHashMap, FnvHashSet};
3use libipld::Cid;
4use libp2p::PeerId;
5use prometheus::HistogramTimer;
6use std::collections::VecDeque;
7
8/// Query id.
9#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
10pub struct QueryId(u64);
11
12impl std::fmt::Display for QueryId {
13    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
14        self.0.fmt(f)
15    }
16}
17
18/// Request.
19#[derive(Debug, Eq, PartialEq)]
20pub enum Request {
21    /// Have query.
22    Have(PeerId, Cid),
23    /// Block query.
24    Block(PeerId, Cid),
25    /// Missing blocks query.
26    MissingBlocks(Cid),
27}
28
29impl std::fmt::Display for Request {
30    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
31        match self {
32            Self::Have(_, _) => write!(f, "have"),
33            Self::Block(_, _) => write!(f, "block"),
34            Self::MissingBlocks(_) => write!(f, "missing-blocks"),
35        }
36    }
37}
38
39/// Response.
40#[derive(Debug)]
41pub enum Response {
42    /// Have query.
43    Have(PeerId, bool),
44    /// Block query.
45    Block(PeerId, bool),
46    /// Missing blocks query.
47    MissingBlocks(Vec<Cid>),
48}
49
50impl std::fmt::Display for Response {
51    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
52        match self {
53            Self::Have(_, have) => write!(f, "have {}", have),
54            Self::Block(_, block) => write!(f, "block {}", block),
55            Self::MissingBlocks(missing) => write!(f, "missing-blocks {}", missing.len()),
56        }
57    }
58}
59
60/// Event emitted by a query.
61#[derive(Debug)]
62pub enum QueryEvent {
63    /// A subquery to run.
64    Request(QueryId, Request),
65    /// A progress event.
66    Progress(QueryId, usize),
67    /// Complete event.
68    Complete(QueryId, Result<(), Cid>),
69}
70
71#[derive(Debug)]
72pub struct Header {
73    /// Query id.
74    pub id: QueryId,
75    /// Root query id.
76    pub root: QueryId,
77    /// Parent.
78    pub parent: Option<QueryId>,
79    /// Cid.
80    pub cid: Cid,
81    /// Timer.
82    pub timer: HistogramTimer,
83    /// Type.
84    pub label: &'static str,
85}
86
87impl Drop for Header {
88    fn drop(&mut self) {
89        REQUESTS_TOTAL.with_label_values(&[self.label]).inc();
90    }
91}
92
93/// Query.
94#[derive(Debug)]
95struct Query {
96    /// Header.
97    hdr: Header,
98    /// State.
99    state: State,
100}
101
102#[derive(Debug)]
103enum State {
104    None,
105    Get(GetState),
106    Sync(SyncState),
107}
108
109#[derive(Debug, Default)]
110struct GetState {
111    have: FnvHashSet<QueryId>,
112    block: Option<QueryId>,
113    providers: Vec<PeerId>,
114}
115
116#[derive(Debug, Default)]
117struct SyncState {
118    missing: FnvHashSet<QueryId>,
119    children: FnvHashSet<QueryId>,
120    providers: Vec<PeerId>,
121}
122
123enum Transition<S, C> {
124    Next(S),
125    Complete(C),
126}
127
128#[derive(Default)]
129pub struct QueryManager {
130    id_counter: u64,
131    queries: FnvHashMap<QueryId, Query>,
132    events: VecDeque<QueryEvent>,
133}
134
135impl QueryManager {
136    /// Start a new subquery.
137    fn start_query(
138        &mut self,
139        root: QueryId,
140        parent: Option<QueryId>,
141        cid: Cid,
142        req: Request,
143        label: &'static str,
144    ) -> QueryId {
145        let timer = REQUEST_DURATION_SECONDS
146            .with_label_values(&[label])
147            .start_timer();
148        let id = QueryId(self.id_counter);
149        self.id_counter += 1;
150        let query = Query {
151            hdr: Header {
152                id,
153                root,
154                parent,
155                cid,
156                timer,
157                label,
158            },
159            state: State::None,
160        };
161        self.queries.insert(id, query);
162        tracing::trace!("{} {} {}", root, id, req);
163        self.events.push_back(QueryEvent::Request(id, req));
164        id
165    }
166
167    /// Starts a new have query to ask a peer if it has a block.
168    fn have(&mut self, root: QueryId, parent: QueryId, peer_id: PeerId, cid: Cid) -> QueryId {
169        self.start_query(root, Some(parent), cid, Request::Have(peer_id, cid), "have")
170    }
171
172    /// Starts a new block query to request a block from a peer.
173    fn block(&mut self, root: QueryId, parent: QueryId, peer_id: PeerId, cid: Cid) -> QueryId {
174        self.start_query(
175            root,
176            Some(parent),
177            cid,
178            Request::Block(peer_id, cid),
179            "block",
180        )
181    }
182
183    /// Starts a query to determine the missing blocks of a dag.
184    fn missing_blocks(&mut self, parent: QueryId, cid: Cid) -> QueryId {
185        self.start_query(
186            parent,
187            Some(parent),
188            cid,
189            Request::MissingBlocks(cid),
190            "missing-blocks",
191        )
192    }
193
194    /// Starts a query to locate and retrieve a block. Panics if no providers are supplied.
195    pub fn get(
196        &mut self,
197        parent: Option<QueryId>,
198        cid: Cid,
199        providers: impl Iterator<Item = PeerId>,
200    ) -> QueryId {
201        let timer = REQUEST_DURATION_SECONDS
202            .with_label_values(&["get"])
203            .start_timer();
204        let id = QueryId(self.id_counter);
205        self.id_counter += 1;
206        let root = parent.unwrap_or(id);
207        tracing::trace!("{} {} get", root, id);
208        let mut state = GetState::default();
209        for peer in providers {
210            if state.block.is_none() {
211                state.block = Some(self.block(root, id, peer, cid));
212            } else {
213                state.have.insert(self.have(root, id, peer, cid));
214            }
215        }
216        assert!(state.block.is_some());
217        let query = Query {
218            hdr: Header {
219                id,
220                root,
221                parent,
222                cid,
223                timer,
224                label: "get",
225            },
226            state: State::Get(state),
227        };
228        self.queries.insert(id, query);
229        id
230    }
231
232    /// Starts a query to recursively retrieve a dag. The missing blocks are the first
233    /// blocks that need to be retrieved.
234    pub fn sync(
235        &mut self,
236        cid: Cid,
237        providers: Vec<PeerId>,
238        missing: impl Iterator<Item = Cid>,
239    ) -> QueryId {
240        let timer = REQUEST_DURATION_SECONDS
241            .with_label_values(&["sync"])
242            .start_timer();
243        let id = QueryId(self.id_counter);
244        self.id_counter += 1;
245        tracing::trace!("{} {} sync", id, id);
246        let mut state = SyncState::default();
247        for cid in missing {
248            state
249                .missing
250                .insert(self.get(Some(id), cid, providers.iter().copied()));
251        }
252        if state.missing.is_empty() {
253            state.children.insert(self.missing_blocks(id, cid));
254        }
255        state.providers = providers;
256        let query = Query {
257            hdr: Header {
258                id,
259                root: id,
260                parent: None,
261                cid,
262                timer,
263                label: "sync",
264            },
265            state: State::Sync(state),
266        };
267        self.queries.insert(id, query);
268        id
269    }
270
271    /// Cancels an in progress query.
272    pub fn cancel(&mut self, root: QueryId) -> bool {
273        let query = if let Some(query) = self.queries.remove(&root) {
274            query
275        } else {
276            return false;
277        };
278        let queries = &self.queries;
279        self.events.retain(|event| {
280            let (id, req) = match event {
281                QueryEvent::Request(id, req) => (id, req),
282                QueryEvent::Progress(id, _) => return *id != root,
283                QueryEvent::Complete(_, _) => return true,
284            };
285            if queries.get(id).map(|q| q.hdr.root) != Some(root) {
286                return true;
287            }
288            tracing::trace!("{} {} {} cancel", root, id, req);
289            false
290        });
291        match query.state {
292            State::Get(_) => {
293                tracing::trace!("{} {} get cancel", root, root);
294                true
295            }
296            State::Sync(state) => {
297                for id in state.missing {
298                    tracing::trace!("{} {} get cancel", root, id);
299                    self.queries.remove(&id);
300                }
301                tracing::trace!("{} {} sync cancel", root, root);
302                true
303            }
304            State::None => {
305                self.queries.insert(root, query);
306                false
307            }
308        }
309    }
310
311    /// Advances a get query state machine using a transition function.
312    fn get_query<F>(&mut self, id: QueryId, f: F)
313    where
314        F: FnOnce(&mut Self, &Header, GetState) -> Transition<GetState, Result<(), Cid>>,
315    {
316        if let Some(mut parent) = self.queries.remove(&id) {
317            let state = if let State::Get(state) = parent.state {
318                state
319            } else {
320                return;
321            };
322            match f(self, &parent.hdr, state) {
323                Transition::Next(state) => {
324                    parent.state = State::Get(state);
325                    self.queries.insert(id, parent);
326                }
327                Transition::Complete(res) => {
328                    match res {
329                        Ok(()) => tracing::trace!("{} {} get ok", parent.hdr.root, parent.hdr.id),
330                        Err(_) => tracing::trace!("{} {} get err", parent.hdr.root, parent.hdr.id),
331                    }
332                    self.recv_get(parent.hdr, res);
333                }
334            }
335        }
336    }
337
338    /// Advances a sync query state machine using a transition function.
339    fn sync_query<F>(&mut self, id: QueryId, f: F)
340    where
341        F: FnOnce(&mut Self, &Header, SyncState) -> Transition<SyncState, Result<(), Cid>>,
342    {
343        if let Some(mut parent) = self.queries.remove(&id) {
344            let state = if let State::Sync(state) = parent.state {
345                state
346            } else {
347                return;
348            };
349            match f(self, &parent.hdr, state) {
350                Transition::Next(state) => {
351                    parent.state = State::Sync(state);
352                    self.queries.insert(id, parent);
353                }
354                Transition::Complete(res) => {
355                    if res.is_ok() {
356                        tracing::trace!("{} {} sync ok", parent.hdr.root, parent.hdr.id);
357                    } else {
358                        tracing::trace!("{} {} sync err", parent.hdr.root, parent.hdr.id);
359                    }
360                    self.recv_sync(parent.hdr, res);
361                }
362            }
363        }
364    }
365
366    /// Processes the response of a have query.
367    ///
368    /// Marks the in progress query as complete and updates the set of peers that have
369    /// a block. If there isn't an in progress block query a new block query will be
370    /// started. If no block query can be started either a provider query is started or
371    /// the get query is marked as complete with a block-not-found error.
372    fn recv_have(&mut self, query: Header, peer_id: PeerId, have: bool) {
373        self.get_query(query.parent.unwrap(), |mgr, parent, mut state| {
374            state.have.remove(&query.id);
375            if state.block == Some(query.id) {
376                state.block = None;
377            }
378            if have {
379                state.providers.push(peer_id);
380            }
381            if state.block.is_none() && !state.providers.is_empty() {
382                state.block = Some(mgr.block(
383                    parent.root,
384                    parent.id,
385                    state.providers.pop().unwrap(),
386                    query.cid,
387                ));
388            }
389            if state.have.is_empty() && state.block.is_none() && state.providers.is_empty() {
390                if state.providers.is_empty() {
391                    return Transition::Complete(Err(query.cid));
392                } else {
393                    return Transition::Complete(Ok(()));
394                }
395            }
396            Transition::Next(state)
397        });
398    }
399
400    /// Processes the response of a block query.
401    ///
402    /// Either completes the get query or processes it like a have query response.
403    fn recv_block(&mut self, query: Header, peer_id: PeerId, block: bool) {
404        if block {
405            self.get_query(query.parent.unwrap(), |_mgr, _parent, mut state| {
406                state.providers.push(peer_id);
407                Transition::Complete(Ok(()))
408            });
409        } else {
410            self.recv_have(query, peer_id, block);
411        }
412    }
413
414    /// Processes the response of a missing blocks query.
415    ///
416    /// Starts a get query for each missing block. If there are no in progress queries
417    /// the sync query is marked as complete.
418    fn recv_missing_blocks(&mut self, query: Header, missing: Vec<Cid>) {
419        let mut num_missing = 0;
420        let num_missing_ref = &mut num_missing;
421        self.sync_query(query.parent.unwrap(), |mgr, parent, mut state| {
422            state.children.remove(&query.id);
423            for cid in missing {
424                state.missing.insert(mgr.get(
425                    Some(parent.root),
426                    cid,
427                    state.providers.iter().copied(),
428                ));
429            }
430            *num_missing_ref = state.missing.len();
431            if state.missing.is_empty() && state.children.is_empty() {
432                Transition::Complete(Ok(()))
433            } else {
434                Transition::Next(state)
435            }
436        });
437        if num_missing != 0 {
438            self.events
439                .push_back(QueryEvent::Progress(query.root, num_missing));
440        }
441    }
442
443    /// Processes the response of a get query.
444    ///
445    /// If it is part of a sync query a new missing blocks query is started. Otherwise
446    /// the get query emits a `complete` event.
447    fn recv_get(&mut self, query: Header, res: Result<(), Cid>) {
448        if let Some(id) = query.parent {
449            self.sync_query(id, |mgr, parent, mut state| {
450                state.missing.remove(&query.id);
451                if res.is_err() {
452                    Transition::Complete(res)
453                } else {
454                    state
455                        .children
456                        .insert(mgr.missing_blocks(parent.root, query.cid));
457                    Transition::Next(state)
458                }
459            });
460        } else {
461            self.events.push_back(QueryEvent::Complete(query.id, res));
462        }
463    }
464
465    /// Processes the response of a sync query.
466    ///
467    /// The sync query emits a `complete` event.
468    fn recv_sync(&mut self, query: Header, res: Result<(), Cid>) {
469        self.events.push_back(QueryEvent::Complete(query.id, res));
470    }
471
472    /// Dispatches the response to a query handler.
473    pub fn inject_response(&mut self, id: QueryId, res: Response) {
474        let query = if let Some(query) = self.queries.remove(&id) {
475            query.hdr
476        } else {
477            return;
478        };
479        tracing::trace!("{} {} {}", query.root, query.id, res);
480        match res {
481            Response::Have(peer, have) => {
482                self.recv_have(query, peer, have);
483            }
484            Response::Block(peer, block) => {
485                self.recv_block(query, peer, block);
486            }
487            Response::MissingBlocks(cids) => {
488                self.recv_missing_blocks(query, cids);
489            }
490        }
491    }
492
493    /// Returns the header of a query.
494    pub fn query_info(&self, id: QueryId) -> Option<&Header> {
495        self.queries.get(&id).map(|q| &q.hdr)
496    }
497
498    /// Retrieves the next query event.
499    pub fn next(&mut self) -> Option<QueryEvent> {
500        self.events.pop_front()
501    }
502}
503
504#[cfg(test)]
505mod tests {
506    use super::*;
507
508    fn tracing_try_init() {
509        tracing_subscriber::fmt()
510            .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
511            .try_init()
512            .ok();
513    }
514
515    fn gen_peers(n: usize) -> Vec<PeerId> {
516        let mut peers = Vec::with_capacity(n);
517        for _ in 0..n {
518            peers.push(PeerId::random());
519        }
520        peers
521    }
522
523    fn assert_request(event: Option<QueryEvent>, req: Request) -> QueryId {
524        if let Some(QueryEvent::Request(id, req2)) = event {
525            assert_eq!(req2, req);
526            id
527        } else {
528            panic!("{:?} is not a request", event);
529        }
530    }
531
532    fn assert_complete(event: Option<QueryEvent>, id: QueryId, res: Result<(), Cid>) {
533        if let Some(QueryEvent::Complete(id2, res2)) = event {
534            assert_eq!(id, id2);
535            assert_eq!(res, res2);
536        } else {
537            panic!("{:?} is not a complete event", event);
538        }
539    }
540
541    #[test]
542    fn test_get_query_block_not_found() {
543        let mut mgr = QueryManager::default();
544        let initial_set = gen_peers(3);
545        let cid = Cid::default();
546
547        let id = mgr.get(None, cid, initial_set.iter().copied());
548
549        let id1 = assert_request(mgr.next(), Request::Block(initial_set[0], cid));
550        let id2 = assert_request(mgr.next(), Request::Have(initial_set[1], cid));
551        let id3 = assert_request(mgr.next(), Request::Have(initial_set[2], cid));
552
553        mgr.inject_response(id1, Response::Have(initial_set[0], false));
554        mgr.inject_response(id2, Response::Have(initial_set[1], false));
555        mgr.inject_response(id3, Response::Have(initial_set[2], false));
556
557        assert_complete(mgr.next(), id, Err(cid));
558    }
559
560    #[test]
561    fn test_cid_query_block_found() {
562        let mut mgr = QueryManager::default();
563        let initial_set = gen_peers(3);
564        let cid = Cid::default();
565
566        let id = mgr.get(None, cid, initial_set.iter().copied());
567
568        let id1 = assert_request(mgr.next(), Request::Block(initial_set[0], cid));
569        let id2 = assert_request(mgr.next(), Request::Have(initial_set[1], cid));
570        let id3 = assert_request(mgr.next(), Request::Have(initial_set[2], cid));
571
572        mgr.inject_response(id1, Response::Block(initial_set[0], true));
573        mgr.inject_response(id2, Response::Have(initial_set[1], false));
574        mgr.inject_response(id3, Response::Have(initial_set[2], false));
575
576        assert_complete(mgr.next(), id, Ok(()));
577    }
578
579    #[test]
580    fn test_get_query_gets_from_spare_if_block_request_fails() {
581        let mut mgr = QueryManager::default();
582        let initial_set = gen_peers(3);
583        let cid = Cid::default();
584
585        let id = mgr.get(None, cid, initial_set.iter().copied());
586
587        let id1 = assert_request(mgr.next(), Request::Block(initial_set[0], cid));
588        let id2 = assert_request(mgr.next(), Request::Have(initial_set[1], cid));
589        let id3 = assert_request(mgr.next(), Request::Have(initial_set[2], cid));
590
591        mgr.inject_response(id1, Response::Block(initial_set[0], false));
592        mgr.inject_response(id2, Response::Have(initial_set[1], true));
593        mgr.inject_response(id3, Response::Have(initial_set[2], false));
594
595        let id1 = assert_request(mgr.next(), Request::Block(initial_set[1], cid));
596        mgr.inject_response(id1, Response::Block(initial_set[1], true));
597
598        assert_complete(mgr.next(), id, Ok(()));
599    }
600
601    #[test]
602    fn test_get_query_gets_from_spare_if_block_request_fails_after_have_is_received() {
603        let mut mgr = QueryManager::default();
604        let initial_set = gen_peers(3);
605        let cid = Cid::default();
606
607        let id = mgr.get(None, cid, initial_set.iter().copied());
608
609        let id1 = assert_request(mgr.next(), Request::Block(initial_set[0], cid));
610        let id2 = assert_request(mgr.next(), Request::Have(initial_set[1], cid));
611        let id3 = assert_request(mgr.next(), Request::Have(initial_set[2], cid));
612
613        mgr.inject_response(id1, Response::Block(initial_set[0], false));
614        mgr.inject_response(id2, Response::Have(initial_set[1], true));
615        mgr.inject_response(id3, Response::Have(initial_set[2], true));
616
617        let id1 = assert_request(mgr.next(), Request::Block(initial_set[1], cid));
618        mgr.inject_response(id1, Response::Block(initial_set[1], false));
619
620        let id1 = assert_request(mgr.next(), Request::Block(initial_set[2], cid));
621        mgr.inject_response(id1, Response::Block(initial_set[2], true));
622
623        assert_complete(mgr.next(), id, Ok(()));
624    }
625
626    #[test]
627    fn test_sync_query() {
628        tracing_try_init();
629        let mut mgr = QueryManager::default();
630        let providers = gen_peers(3);
631        let cid = Cid::default();
632
633        let id = mgr.sync(cid, providers.clone(), std::iter::once(cid));
634
635        let id1 = assert_request(mgr.next(), Request::Block(providers[0], cid));
636        let id2 = assert_request(mgr.next(), Request::Have(providers[1], cid));
637        let id3 = assert_request(mgr.next(), Request::Have(providers[2], cid));
638
639        mgr.inject_response(id1, Response::Block(providers[0], true));
640        mgr.inject_response(id2, Response::Have(providers[1], false));
641        mgr.inject_response(id3, Response::Have(providers[2], false));
642
643        let id1 = assert_request(mgr.next(), Request::MissingBlocks(cid));
644        mgr.inject_response(id1, Response::MissingBlocks(vec![]));
645
646        assert_complete(mgr.next(), id, Ok(()));
647    }
648
649    #[test]
650    fn test_sync_query_empty() {
651        tracing_try_init();
652        let mut mgr = QueryManager::default();
653        let cid = Cid::default();
654        let id = mgr.sync(cid, vec![], std::iter::empty());
655        let id1 = assert_request(mgr.next(), Request::MissingBlocks(cid));
656        mgr.inject_response(id1, Response::MissingBlocks(vec![]));
657        assert_complete(mgr.next(), id, Ok(()));
658    }
659}