capnp_rpc/
flow_control.rs1use capnp::capability::Promise;
2use capnp::Error;
3
4use futures::channel::oneshot;
5use futures::TryFutureExt;
6use std::cell::RefCell;
7use std::rc::Rc;
8
9use crate::task_set::{TaskReaper, TaskSet, TaskSetHandle};
10
11pub const DEFAULT_WINDOW_SIZE: usize = 65536;
12
13enum State {
14 Running(Vec<oneshot::Sender<Result<(), Error>>>),
15 Failed(Error),
16}
17
18struct FixedWindowFlowControllerInner {
19 window_size: usize,
20 in_flight: usize,
21 max_message_size: usize,
22 state: State,
23 empty_fulfiller: Option<oneshot::Sender<Promise<(), Error>>>,
24}
25
26impl FixedWindowFlowControllerInner {
27 fn is_ready(&self) -> bool {
28 self.in_flight < self.window_size + self.max_message_size
34 }
35}
36
37pub struct FixedWindowFlowController {
38 inner: Rc<RefCell<FixedWindowFlowControllerInner>>,
39 tasks: TaskSetHandle<Error>,
40}
41
42struct Reaper {
43 inner: Rc<RefCell<FixedWindowFlowControllerInner>>,
44}
45
46impl TaskReaper<Error> for Reaper {
47 fn task_failed(&mut self, error: Error) {
48 let mut inner = self.inner.borrow_mut();
49 if let State::Running(ref mut blocked_sends) = &mut inner.state {
50 for s in std::mem::take(blocked_sends) {
51 let _ = s.send(Err(error.clone()));
52 }
53 inner.state = State::Failed(error)
54 }
55 }
56}
57
58impl FixedWindowFlowController {
59 pub fn new(window_size: usize) -> (Self, Promise<(), Error>) {
60 let inner = FixedWindowFlowControllerInner {
61 window_size,
62 in_flight: 0,
63 max_message_size: 0,
64 state: State::Running(vec![]),
65 empty_fulfiller: None,
66 };
67 let inner = Rc::new(RefCell::new(inner));
68 let (tasks, task_future) = TaskSet::new(Box::new(Reaper {
69 inner: inner.clone(),
70 }));
71 (Self { inner, tasks }, Promise::from_future(task_future))
72 }
73}
74
75impl crate::FlowController for FixedWindowFlowController {
76 fn send(
77 &mut self,
78 message: Box<dyn crate::OutgoingMessage>,
79 ack: Promise<(), Error>,
80 ) -> Promise<(), Error> {
81 let size = message.size_in_words() * 8;
82 {
83 let mut inner = self.inner.borrow_mut();
84 let prev_max_size = inner.max_message_size;
85 inner.max_message_size = usize::max(size, prev_max_size);
86
87 let _ = message.send();
89
90 inner.in_flight += size;
91 }
92 let inner = self.inner.clone();
93 let mut tasks = self.tasks.clone();
94 self.tasks.add(async move {
95 ack.await?;
96 let mut inner = inner.borrow_mut();
97 inner.in_flight -= size;
98 let is_ready = inner.is_ready();
99 match inner.state {
100 State::Running(ref mut blocked_sends) => {
101 if is_ready {
102 for s in std::mem::take(blocked_sends) {
103 let _ = s.send(Ok(()));
104 }
105 }
106
107 if inner.in_flight == 0 {
108 if let Some(f) = inner.empty_fulfiller.take() {
109 let _ = f.send(Promise::from_future(
110 tasks.on_empty().map_err(crate::canceled_to_error),
111 ));
112 }
113 }
114 }
115 State::Failed(_) => {
116 }
121 }
122 Ok(())
123 });
124
125 let mut inner = self.inner.borrow_mut();
126 let is_ready = inner.is_ready();
127 match inner.state {
128 State::Running(ref mut blocked_sends) => {
129 if is_ready {
130 Promise::ok(())
131 } else {
132 let (snd, rcv) = oneshot::channel();
133 blocked_sends.push(snd);
134 Promise::from_future(async {
135 match rcv.await {
136 Ok(r) => r,
137 Err(e) => Err(crate::canceled_to_error(e)),
138 }
139 })
140 }
141 }
142 State::Failed(ref e) => Promise::err(e.clone()),
143 }
144 }
145
146 fn wait_all_acked(&mut self) -> Promise<(), Error> {
147 let mut inner = self.inner.borrow_mut();
148 if let State::Running(ref blocked_sends) = inner.state {
149 if !blocked_sends.is_empty() {
150 let (snd, rcv) = oneshot::channel();
151 inner.empty_fulfiller = Some(snd);
152 return Promise::from_future(async move {
153 match rcv.await {
154 Ok(r) => r.await,
155 Err(e) => Err(crate::canceled_to_error(e)),
156 }
157 });
158 }
159 }
160 Promise::from_future(self.tasks.on_empty().map_err(crate::canceled_to_error))
161 }
162}