libp2p_bitswap_next/
query.rs

1use crate::stats::{REQUESTS_TOTAL, REQUEST_DURATION_SECONDS};
2use fnv::{FnvHashMap, FnvHashSet};
3use libipld::Cid;
4use libp2p::PeerId;
5use prometheus::HistogramTimer;
6use std::collections::{HashSet, VecDeque};
7
8/// Query id.
9#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)]
10pub struct QueryId(u64);
11
12impl std::fmt::Display for QueryId {
13    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
14        self.0.fmt(f)
15    }
16}
17
18/// Request.
19#[derive(Debug, Eq, PartialEq)]
20pub enum Request {
21    /// Have query.
22    Have(PeerId, Cid),
23    /// Block query.
24    Block(PeerId, Cid),
25    /// Missing blocks query.
26    MissingBlocks(Cid),
27}
28
29impl std::fmt::Display for Request {
30    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
31        match self {
32            Self::Have(_, _) => write!(f, "have"),
33            Self::Block(_, _) => write!(f, "block"),
34            Self::MissingBlocks(_) => write!(f, "missing-blocks"),
35        }
36    }
37}
38
39/// Response.
40#[derive(Debug)]
41pub enum Response {
42    /// Have query.
43    Have(PeerId, bool),
44    /// Block query.
45    Block(PeerId, bool),
46    /// Missing blocks query.
47    MissingBlocks(Vec<Cid>),
48}
49
50impl std::fmt::Display for Response {
51    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
52        match self {
53            Self::Have(_, have) => write!(f, "have {}", have),
54            Self::Block(_, block) => write!(f, "block {}", block),
55            Self::MissingBlocks(missing) => write!(f, "missing-blocks {}", missing.len()),
56        }
57    }
58}
59
60/// Event emitted by a query.
61#[derive(Debug)]
62pub enum QueryEvent {
63    /// A subquery to run.
64    Request(QueryId, Request),
65    /// A progress event.
66    Progress(QueryId, usize),
67    /// Complete event.
68    Complete(QueryId, Result<(), Cid>),
69}
70
71#[derive(Debug)]
72pub struct Header {
73    /// Query id.
74    pub id: QueryId,
75    /// Root query id.
76    pub root: QueryId,
77    /// Parent.
78    pub parent: Option<QueryId>,
79    /// Cid.
80    pub cid: Cid,
81    /// Timer.
82    pub timer: HistogramTimer,
83    /// Type.
84    pub label: &'static str,
85}
86
87impl Drop for Header {
88    fn drop(&mut self) {
89        REQUESTS_TOTAL.with_label_values(&[self.label]).inc();
90    }
91}
92
93/// Query.
94#[derive(Debug)]
95struct Query {
96    /// Header.
97    hdr: Header,
98    /// State.
99    state: State,
100}
101
102#[derive(Debug)]
103enum State {
104    None,
105    Get(GetState),
106    Sync(SyncState),
107}
108
109#[derive(Debug, Default)]
110struct GetState {
111    have: FnvHashSet<QueryId>,
112    block: Option<QueryId>,
113    providers: Vec<PeerId>,
114}
115
116#[derive(Debug, Default)]
117struct SyncState {
118    missing: FnvHashSet<QueryId>,
119    children: FnvHashSet<QueryId>,
120    providers: Vec<PeerId>,
121}
122
123enum Transition<S, C> {
124    Next(S),
125    Complete(C),
126}
127
128#[derive(Default)]
129pub struct QueryManager {
130    id_counter: u64,
131    peers: HashSet<PeerId>,
132    queries: FnvHashMap<QueryId, Query>,
133    events: VecDeque<QueryEvent>,
134}
135
136impl QueryManager {
137    /// Start a new subquery.
138    fn start_query(
139        &mut self,
140        root: QueryId,
141        parent: Option<QueryId>,
142        cid: Cid,
143        req: Request,
144        label: &'static str,
145    ) -> QueryId {
146        let timer = REQUEST_DURATION_SECONDS
147            .with_label_values(&[label])
148            .start_timer();
149        let id = QueryId(self.id_counter);
150        self.id_counter += 1;
151        let query = Query {
152            hdr: Header {
153                id,
154                root,
155                parent,
156                cid,
157                timer,
158                label,
159            },
160            state: State::None,
161        };
162        self.queries.insert(id, query);
163        tracing::trace!("{} {} {}", root, id, req);
164        self.events.push_back(QueryEvent::Request(id, req));
165        id
166    }
167
168    /// Starts a new have query to ask a peer if it has a block.
169    fn have(&mut self, root: QueryId, parent: QueryId, peer_id: PeerId, cid: Cid) -> QueryId {
170        self.start_query(root, Some(parent), cid, Request::Have(peer_id, cid), "have")
171    }
172
173    /// Starts a new block query to request a block from a peer.
174    fn block(&mut self, root: QueryId, parent: QueryId, peer_id: PeerId, cid: Cid) -> QueryId {
175        self.start_query(
176            root,
177            Some(parent),
178            cid,
179            Request::Block(peer_id, cid),
180            "block",
181        )
182    }
183
184    /// Starts a query to determine the missing blocks of a dag.
185    fn missing_blocks(&mut self, parent: QueryId, cid: Cid) -> QueryId {
186        self.start_query(
187            parent,
188            Some(parent),
189            cid,
190            Request::MissingBlocks(cid),
191            "missing-blocks",
192        )
193    }
194
195    pub fn add_peer(&mut self, peer_id: &PeerId) {
196        self.peers.insert(*peer_id);
197    }
198
199    pub fn remove_peer(&mut self, peer_id: &PeerId) {
200        self.peers.remove(peer_id);
201    }
202
203    /// Starts a query to locate and retrieve a block. Panics if no providers are supplied.
204    pub fn get(
205        &mut self,
206        parent: Option<QueryId>,
207        cid: Cid,
208        providers: impl Iterator<Item = PeerId>,
209    ) -> QueryId {
210        let timer = REQUEST_DURATION_SECONDS
211            .with_label_values(&["get"])
212            .start_timer();
213        let id = QueryId(self.id_counter);
214        self.id_counter += 1;
215        let root = parent.unwrap_or(id);
216        tracing::trace!("{} {} get", root, id);
217        let mut state = GetState::default();
218
219        for peer in providers {
220            if state.block.is_none() {
221                state.block = Some(self.block(root, id, peer, cid));
222            } else {
223                state.have.insert(self.have(root, id, peer, cid));
224            }
225        }
226
227        if state.block.is_none() && !self.peers.is_empty() {
228            let peers = self.peers.clone();
229            for peer in peers {
230                if state.block.is_none() {
231                    state.block = Some(self.block(root, id, peer, cid));
232                } else {
233                    state.have.insert(self.have(root, id, peer, cid));
234                }
235            }
236        }
237
238        assert!(state.block.is_some());
239        let query = Query {
240            hdr: Header {
241                id,
242                root,
243                parent,
244                cid,
245                timer,
246                label: "get",
247            },
248            state: State::Get(state),
249        };
250        self.queries.insert(id, query);
251        id
252    }
253
254    /// Starts a query to recursively retrieve a dag. The missing blocks are the first
255    /// blocks that need to be retrieved.
256    pub fn sync(
257        &mut self,
258        cid: Cid,
259        providers: Vec<PeerId>,
260        missing: impl Iterator<Item = Cid>,
261    ) -> QueryId {
262        let timer = REQUEST_DURATION_SECONDS
263            .with_label_values(&["sync"])
264            .start_timer();
265        let id = QueryId(self.id_counter);
266        self.id_counter += 1;
267        tracing::trace!("{} {} sync", id, id);
268        let mut state = SyncState::default();
269        for cid in missing {
270            state
271                .missing
272                .insert(self.get(Some(id), cid, providers.iter().copied()));
273        }
274        if state.missing.is_empty() {
275            state.children.insert(self.missing_blocks(id, cid));
276        }
277        state.providers = providers;
278        let query = Query {
279            hdr: Header {
280                id,
281                root: id,
282                parent: None,
283                cid,
284                timer,
285                label: "sync",
286            },
287            state: State::Sync(state),
288        };
289        self.queries.insert(id, query);
290        id
291    }
292
293    /// Cancels an in progress query.
294    pub fn cancel(&mut self, root: QueryId) -> bool {
295        let query = if let Some(query) = self.queries.remove(&root) {
296            query
297        } else {
298            return false;
299        };
300        let queries = &self.queries;
301        self.events.retain(|event| {
302            let (id, req) = match event {
303                QueryEvent::Request(id, req) => (id, req),
304                QueryEvent::Progress(id, _) => return *id != root,
305                QueryEvent::Complete(_, _) => return true,
306            };
307            if queries.get(id).map(|q| q.hdr.root) != Some(root) {
308                return true;
309            }
310            tracing::trace!("{} {} {} cancel", root, id, req);
311            false
312        });
313        match query.state {
314            State::Get(_) => {
315                tracing::trace!("{} {} get cancel", root, root);
316                true
317            }
318            State::Sync(state) => {
319                for id in state.missing {
320                    tracing::trace!("{} {} get cancel", root, id);
321                    self.queries.remove(&id);
322                }
323                tracing::trace!("{} {} sync cancel", root, root);
324                true
325            }
326            State::None => {
327                self.queries.insert(root, query);
328                false
329            }
330        }
331    }
332
333    /// Advances a get query state machine using a transition function.
334    fn get_query<F>(&mut self, id: QueryId, f: F)
335    where
336        F: FnOnce(&mut Self, &Header, GetState) -> Transition<GetState, Result<(), Cid>>,
337    {
338        if let Some(mut parent) = self.queries.remove(&id) {
339            let state = if let State::Get(state) = parent.state {
340                state
341            } else {
342                return;
343            };
344            match f(self, &parent.hdr, state) {
345                Transition::Next(state) => {
346                    parent.state = State::Get(state);
347                    self.queries.insert(id, parent);
348                }
349                Transition::Complete(res) => {
350                    match res {
351                        Ok(()) => tracing::trace!("{} {} get ok", parent.hdr.root, parent.hdr.id),
352                        Err(_) => tracing::trace!("{} {} get err", parent.hdr.root, parent.hdr.id),
353                    }
354                    self.recv_get(parent.hdr, res);
355                }
356            }
357        }
358    }
359
360    /// Advances a sync query state machine using a transition function.
361    fn sync_query<F>(&mut self, id: QueryId, f: F)
362    where
363        F: FnOnce(&mut Self, &Header, SyncState) -> Transition<SyncState, Result<(), Cid>>,
364    {
365        if let Some(mut parent) = self.queries.remove(&id) {
366            let state = if let State::Sync(state) = parent.state {
367                state
368            } else {
369                return;
370            };
371            match f(self, &parent.hdr, state) {
372                Transition::Next(state) => {
373                    parent.state = State::Sync(state);
374                    self.queries.insert(id, parent);
375                }
376                Transition::Complete(res) => {
377                    if res.is_ok() {
378                        tracing::trace!("{} {} sync ok", parent.hdr.root, parent.hdr.id);
379                    } else {
380                        tracing::trace!("{} {} sync err", parent.hdr.root, parent.hdr.id);
381                    }
382                    self.recv_sync(parent.hdr, res);
383                }
384            }
385        }
386    }
387
388    /// Processes the response of a have query.
389    ///
390    /// Marks the in progress query as complete and updates the set of peers that have
391    /// a block. If there isn't an in progress block query a new block query will be
392    /// started. If no block query can be started either a provider query is started or
393    /// the get query is marked as complete with a block-not-found error.
394    fn recv_have(&mut self, query: Header, peer_id: PeerId, have: bool) {
395        self.get_query(query.parent.unwrap(), |mgr, parent, mut state| {
396            state.have.remove(&query.id);
397            if state.block == Some(query.id) {
398                state.block = None;
399            }
400            if have {
401                state.providers.push(peer_id);
402            }
403            if state.block.is_none() && !state.providers.is_empty() {
404                state.block = Some(mgr.block(
405                    parent.root,
406                    parent.id,
407                    state.providers.pop().unwrap(),
408                    query.cid,
409                ));
410            }
411            if state.have.is_empty() && state.block.is_none() && state.providers.is_empty() {
412                if state.providers.is_empty() {
413                    return Transition::Complete(Err(query.cid));
414                } else {
415                    return Transition::Complete(Ok(()));
416                }
417            }
418            Transition::Next(state)
419        });
420    }
421
422    /// Processes the response of a block query.
423    ///
424    /// Either completes the get query or processes it like a have query response.
425    fn recv_block(&mut self, query: Header, peer_id: PeerId, block: bool) {
426        if block {
427            self.get_query(query.parent.unwrap(), |_mgr, _parent, mut state| {
428                state.providers.push(peer_id);
429                Transition::Complete(Ok(()))
430            });
431        } else {
432            self.recv_have(query, peer_id, block);
433        }
434    }
435
436    /// Processes the response of a missing blocks query.
437    ///
438    /// Starts a get query for each missing block. If there are no in progress queries
439    /// the sync query is marked as complete.
440    fn recv_missing_blocks(&mut self, query: Header, missing: Vec<Cid>) {
441        let mut num_missing = 0;
442        let num_missing_ref = &mut num_missing;
443        self.sync_query(query.parent.unwrap(), |mgr, parent, mut state| {
444            state.children.remove(&query.id);
445            for cid in missing {
446                state.missing.insert(mgr.get(
447                    Some(parent.root),
448                    cid,
449                    state.providers.iter().copied(),
450                ));
451            }
452            *num_missing_ref = state.missing.len();
453            if state.missing.is_empty() && state.children.is_empty() {
454                Transition::Complete(Ok(()))
455            } else {
456                Transition::Next(state)
457            }
458        });
459        if num_missing != 0 {
460            self.events
461                .push_back(QueryEvent::Progress(query.root, num_missing));
462        }
463    }
464
465    /// Processes the response of a get query.
466    ///
467    /// If it is part of a sync query a new missing blocks query is started. Otherwise
468    /// the get query emits a `complete` event.
469    fn recv_get(&mut self, query: Header, res: Result<(), Cid>) {
470        if let Some(id) = query.parent {
471            self.sync_query(id, |mgr, parent, mut state| {
472                state.missing.remove(&query.id);
473                if res.is_err() {
474                    Transition::Complete(res)
475                } else {
476                    state
477                        .children
478                        .insert(mgr.missing_blocks(parent.root, query.cid));
479                    Transition::Next(state)
480                }
481            });
482        } else {
483            self.events.push_back(QueryEvent::Complete(query.id, res));
484        }
485    }
486
487    /// Processes the response of a sync query.
488    ///
489    /// The sync query emits a `complete` event.
490    fn recv_sync(&mut self, query: Header, res: Result<(), Cid>) {
491        self.events.push_back(QueryEvent::Complete(query.id, res));
492    }
493
494    /// Dispatches the response to a query handler.
495    pub fn inject_response(&mut self, id: QueryId, res: Response) {
496        let query = if let Some(query) = self.queries.remove(&id) {
497            query.hdr
498        } else {
499            return;
500        };
501        tracing::trace!("{} {} {}", query.root, query.id, res);
502        match res {
503            Response::Have(peer, have) => {
504                self.recv_have(query, peer, have);
505            }
506            Response::Block(peer, block) => {
507                self.recv_block(query, peer, block);
508            }
509            Response::MissingBlocks(cids) => {
510                self.recv_missing_blocks(query, cids);
511            }
512        }
513    }
514
515    /// Returns the header of a query.
516    pub fn query_info(&self, id: QueryId) -> Option<&Header> {
517        self.queries.get(&id).map(|q| &q.hdr)
518    }
519
520    /// Retrieves the next query event.
521    pub fn next(&mut self) -> Option<QueryEvent> {
522        self.events.pop_front()
523    }
524}
525
526#[cfg(test)]
527mod tests {
528    use super::*;
529    use tracing_subscriber::fmt::TestWriter;
530
531    fn tracing_try_init() {
532        tracing_subscriber::fmt()
533            .with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
534            .with_writer(TestWriter::new())
535            .try_init()
536            .ok();
537    }
538
539    fn gen_peers(n: usize) -> Vec<PeerId> {
540        let mut peers = Vec::with_capacity(n);
541        for _ in 0..n {
542            peers.push(PeerId::random());
543        }
544        peers
545    }
546
547    fn assert_request(event: Option<QueryEvent>, req: Request) -> QueryId {
548        if let Some(QueryEvent::Request(id, req2)) = event {
549            assert_eq!(req2, req);
550            id
551        } else {
552            panic!("{:?} is not a request", event);
553        }
554    }
555
556    fn assert_complete(event: Option<QueryEvent>, id: QueryId, res: Result<(), Cid>) {
557        if let Some(QueryEvent::Complete(id2, res2)) = event {
558            assert_eq!(id, id2);
559            assert_eq!(res, res2);
560        } else {
561            panic!("{:?} is not a complete event", event);
562        }
563    }
564
565    #[test]
566    fn test_get_query_block_not_found() {
567        let mut mgr = QueryManager::default();
568        let initial_set = gen_peers(3);
569        let cid = Cid::default();
570
571        let id = mgr.get(None, cid, initial_set.iter().copied());
572
573        let id1 = assert_request(mgr.next(), Request::Block(initial_set[0], cid));
574        let id2 = assert_request(mgr.next(), Request::Have(initial_set[1], cid));
575        let id3 = assert_request(mgr.next(), Request::Have(initial_set[2], cid));
576
577        mgr.inject_response(id1, Response::Have(initial_set[0], false));
578        mgr.inject_response(id2, Response::Have(initial_set[1], false));
579        mgr.inject_response(id3, Response::Have(initial_set[2], false));
580
581        assert_complete(mgr.next(), id, Err(cid));
582    }
583
584    #[test]
585    fn test_cid_query_block_found() {
586        let mut mgr = QueryManager::default();
587        let initial_set = gen_peers(3);
588        let cid = Cid::default();
589
590        let id = mgr.get(None, cid, initial_set.iter().copied());
591
592        let id1 = assert_request(mgr.next(), Request::Block(initial_set[0], cid));
593        let id2 = assert_request(mgr.next(), Request::Have(initial_set[1], cid));
594        let id3 = assert_request(mgr.next(), Request::Have(initial_set[2], cid));
595
596        mgr.inject_response(id1, Response::Block(initial_set[0], true));
597        mgr.inject_response(id2, Response::Have(initial_set[1], false));
598        mgr.inject_response(id3, Response::Have(initial_set[2], false));
599
600        assert_complete(mgr.next(), id, Ok(()));
601    }
602
603    #[test]
604    fn test_get_query_gets_from_spare_if_block_request_fails() {
605        let mut mgr = QueryManager::default();
606        let initial_set = gen_peers(3);
607        let cid = Cid::default();
608
609        let id = mgr.get(None, cid, initial_set.iter().copied());
610
611        let id1 = assert_request(mgr.next(), Request::Block(initial_set[0], cid));
612        let id2 = assert_request(mgr.next(), Request::Have(initial_set[1], cid));
613        let id3 = assert_request(mgr.next(), Request::Have(initial_set[2], cid));
614
615        mgr.inject_response(id1, Response::Block(initial_set[0], false));
616        mgr.inject_response(id2, Response::Have(initial_set[1], true));
617        mgr.inject_response(id3, Response::Have(initial_set[2], false));
618
619        let id1 = assert_request(mgr.next(), Request::Block(initial_set[1], cid));
620        mgr.inject_response(id1, Response::Block(initial_set[1], true));
621
622        assert_complete(mgr.next(), id, Ok(()));
623    }
624
625    #[test]
626    fn test_get_query_gets_from_spare_if_block_request_fails_after_have_is_received() {
627        let mut mgr = QueryManager::default();
628        let initial_set = gen_peers(3);
629        let cid = Cid::default();
630
631        let id = mgr.get(None, cid, initial_set.iter().copied());
632
633        let id1 = assert_request(mgr.next(), Request::Block(initial_set[0], cid));
634        let id2 = assert_request(mgr.next(), Request::Have(initial_set[1], cid));
635        let id3 = assert_request(mgr.next(), Request::Have(initial_set[2], cid));
636
637        mgr.inject_response(id1, Response::Block(initial_set[0], false));
638        mgr.inject_response(id2, Response::Have(initial_set[1], true));
639        mgr.inject_response(id3, Response::Have(initial_set[2], true));
640
641        let id1 = assert_request(mgr.next(), Request::Block(initial_set[1], cid));
642        mgr.inject_response(id1, Response::Block(initial_set[1], false));
643
644        let id1 = assert_request(mgr.next(), Request::Block(initial_set[2], cid));
645        mgr.inject_response(id1, Response::Block(initial_set[2], true));
646
647        assert_complete(mgr.next(), id, Ok(()));
648    }
649
650    #[test]
651    fn test_sync_query() {
652        tracing_try_init();
653        let mut mgr = QueryManager::default();
654        let providers = gen_peers(3);
655        let cid = Cid::default();
656
657        let id = mgr.sync(cid, providers.clone(), std::iter::once(cid));
658
659        let id1 = assert_request(mgr.next(), Request::Block(providers[0], cid));
660        let id2 = assert_request(mgr.next(), Request::Have(providers[1], cid));
661        let id3 = assert_request(mgr.next(), Request::Have(providers[2], cid));
662
663        mgr.inject_response(id1, Response::Block(providers[0], true));
664        mgr.inject_response(id2, Response::Have(providers[1], false));
665        mgr.inject_response(id3, Response::Have(providers[2], false));
666
667        let id1 = assert_request(mgr.next(), Request::MissingBlocks(cid));
668        mgr.inject_response(id1, Response::MissingBlocks(vec![]));
669
670        assert_complete(mgr.next(), id, Ok(()));
671    }
672
673    #[test]
674    fn test_sync_query_empty() {
675        tracing_try_init();
676        let mut mgr = QueryManager::default();
677        let cid = Cid::default();
678        let id = mgr.sync(cid, vec![], std::iter::empty());
679        let id1 = assert_request(mgr.next(), Request::MissingBlocks(cid));
680        mgr.inject_response(id1, Response::MissingBlocks(vec![]));
681        assert_complete(mgr.next(), id, Ok(()));
682    }
683}