capnp_rpc/
flow_control.rs

1use capnp::capability::Promise;
2use capnp::Error;
3
4use futures::channel::oneshot;
5use futures::TryFutureExt;
6use std::cell::RefCell;
7use std::rc::Rc;
8
9use crate::task_set::{TaskReaper, TaskSet, TaskSetHandle};
10
11pub const DEFAULT_WINDOW_SIZE: usize = 65536;
12
13enum State {
14    Running(Vec<oneshot::Sender<Result<(), Error>>>),
15    Failed(Error),
16}
17
18struct FixedWindowFlowControllerInner {
19    window_size: usize,
20    in_flight: usize,
21    max_message_size: usize,
22    state: State,
23    empty_fulfiller: Option<oneshot::Sender<Promise<(), Error>>>,
24}
25
26impl FixedWindowFlowControllerInner {
27    fn is_ready(&self) -> bool {
28        // We extend the window by maxMessageSize to avoid a pathological situation when a message
29        // is larger than the window size. Otherwise, after sending that message, we would end up
30        // not sending any others until the ack was received, wasting a round trip's worth of
31        // bandwidth.
32
33        self.in_flight < self.window_size + self.max_message_size
34    }
35}
36
37pub struct FixedWindowFlowController {
38    inner: Rc<RefCell<FixedWindowFlowControllerInner>>,
39    tasks: TaskSetHandle<Error>,
40}
41
42struct Reaper {
43    inner: Rc<RefCell<FixedWindowFlowControllerInner>>,
44}
45
46impl TaskReaper<Error> for Reaper {
47    fn task_failed(&mut self, error: Error) {
48        let mut inner = self.inner.borrow_mut();
49        if let State::Running(ref mut blocked_sends) = &mut inner.state {
50            for s in std::mem::take(blocked_sends) {
51                let _ = s.send(Err(error.clone()));
52            }
53            inner.state = State::Failed(error)
54        }
55    }
56}
57
58impl FixedWindowFlowController {
59    pub fn new(window_size: usize) -> (Self, Promise<(), Error>) {
60        let inner = FixedWindowFlowControllerInner {
61            window_size,
62            in_flight: 0,
63            max_message_size: 0,
64            state: State::Running(vec![]),
65            empty_fulfiller: None,
66        };
67        let inner = Rc::new(RefCell::new(inner));
68        let (tasks, task_future) = TaskSet::new(Box::new(Reaper {
69            inner: inner.clone(),
70        }));
71        (Self { inner, tasks }, Promise::from_future(task_future))
72    }
73}
74
75impl crate::FlowController for FixedWindowFlowController {
76    fn send(
77        &mut self,
78        message: Box<dyn crate::OutgoingMessage>,
79        ack: Promise<(), Error>,
80    ) -> Promise<(), Error> {
81        let size = message.size_in_words() * 8;
82        {
83            let mut inner = self.inner.borrow_mut();
84            let prev_max_size = inner.max_message_size;
85            inner.max_message_size = usize::max(size, prev_max_size);
86
87            // We are REQUIRED to send the message NOW to maintain correct ordering.
88            let _ = message.send();
89
90            inner.in_flight += size;
91        }
92        let inner = self.inner.clone();
93        let mut tasks = self.tasks.clone();
94        self.tasks.add(async move {
95            ack.await?;
96            let mut inner = inner.borrow_mut();
97            inner.in_flight -= size;
98            let is_ready = inner.is_ready();
99            match inner.state {
100                State::Running(ref mut blocked_sends) => {
101                    if is_ready {
102                        for s in std::mem::take(blocked_sends) {
103                            let _ = s.send(Ok(()));
104                        }
105                    }
106
107                    if inner.in_flight == 0 {
108                        if let Some(f) = inner.empty_fulfiller.take() {
109                            let _ = f.send(Promise::from_future(
110                                tasks.on_empty().map_err(crate::canceled_to_error),
111                            ));
112                        }
113                    }
114                }
115                State::Failed(_) => {
116                    // A previous call failed, but this one -- which was already in-flight at the
117                    // time -- ended up succeeding. That may indicate that the server side is not
118                    // properly handling streaming error propagation. Nothing much we can do about
119                    // it here though.
120                }
121            }
122            Ok(())
123        });
124
125        let mut inner = self.inner.borrow_mut();
126        let is_ready = inner.is_ready();
127        match inner.state {
128            State::Running(ref mut blocked_sends) => {
129                if is_ready {
130                    Promise::ok(())
131                } else {
132                    let (snd, rcv) = oneshot::channel();
133                    blocked_sends.push(snd);
134                    Promise::from_future(async {
135                        match rcv.await {
136                            Ok(r) => r,
137                            Err(e) => Err(crate::canceled_to_error(e)),
138                        }
139                    })
140                }
141            }
142            State::Failed(ref e) => Promise::err(e.clone()),
143        }
144    }
145
146    fn wait_all_acked(&mut self) -> Promise<(), Error> {
147        let mut inner = self.inner.borrow_mut();
148        if let State::Running(ref blocked_sends) = inner.state {
149            if !blocked_sends.is_empty() {
150                let (snd, rcv) = oneshot::channel();
151                inner.empty_fulfiller = Some(snd);
152                return Promise::from_future(async move {
153                    match rcv.await {
154                        Ok(r) => r.await,
155                        Err(e) => Err(crate::canceled_to_error(e)),
156                    }
157                });
158            }
159        }
160        Promise::from_future(self.tasks.on_empty().map_err(crate::canceled_to_error))
161    }
162}