1use capnp::capability::Promise;
26use capnp::message::ReaderOptions;
27use futures::channel::oneshot;
28use futures::{AsyncRead, AsyncWrite, FutureExt, TryFutureExt};
29
30use std::cell::RefCell;
31use std::rc::{Rc, Weak};
32
33pub type VatId = crate::rpc_twoparty_capnp::Side;
34
35struct IncomingMessage {
36 message: ::capnp::message::Reader<capnp::serialize::OwnedSegments>,
37}
38
39impl IncomingMessage {
40 pub fn new(message: ::capnp::message::Reader<capnp::serialize::OwnedSegments>) -> Self {
41 Self { message }
42 }
43}
44
45impl crate::IncomingMessage for IncomingMessage {
46 fn get_body(&self) -> ::capnp::Result<::capnp::any_pointer::Reader<'_>> {
47 self.message.get_root()
48 }
49}
50
51struct OutgoingMessage {
52 message: ::capnp::message::Builder<::capnp::message::HeapAllocator>,
53 sender: ::capnp_futures::Sender<Rc<::capnp::message::Builder<::capnp::message::HeapAllocator>>>,
54}
55
56impl crate::OutgoingMessage for OutgoingMessage {
57 fn get_body(&mut self) -> ::capnp::Result<::capnp::any_pointer::Builder<'_>> {
58 self.message.get_root()
59 }
60
61 fn get_body_as_reader(&self) -> ::capnp::Result<::capnp::any_pointer::Reader<'_>> {
62 self.message.get_root_as_reader()
63 }
64
65 fn send(
66 self: Box<Self>,
67 ) -> (
68 Promise<(), ::capnp::Error>,
69 Rc<::capnp::message::Builder<::capnp::message::HeapAllocator>>,
70 ) {
71 let tmp = *self;
72 let Self {
73 message,
74 mut sender,
75 } = tmp;
76 let m = Rc::new(message);
77 (
78 Promise::from_future(sender.send(m.clone()).map_ok(|_| ())),
79 m,
80 )
81 }
82
83 fn take(self: Box<Self>) -> ::capnp::message::Builder<::capnp::message::HeapAllocator> {
84 self.message
85 }
86
87 fn size_in_words(&self) -> usize {
88 self.message.size_in_words()
89 }
90}
91
92struct ConnectionInner<T>
93where
94 T: AsyncRead + 'static,
95{
96 input_stream: Rc<RefCell<Option<T>>>,
97 sender: ::capnp_futures::Sender<Rc<::capnp::message::Builder<::capnp::message::HeapAllocator>>>,
98 side: crate::rpc_twoparty_capnp::Side,
99 receive_options: ReaderOptions,
100 on_disconnect_fulfiller: Option<oneshot::Sender<()>>,
101 window_size_in_bytes: usize,
102}
103
104struct Connection<T>
105where
106 T: AsyncRead + 'static,
107{
108 inner: Rc<RefCell<ConnectionInner<T>>>,
109}
110
111impl<T> Drop for ConnectionInner<T>
112where
113 T: AsyncRead,
114{
115 fn drop(&mut self) {
116 match self.on_disconnect_fulfiller.take() {
117 Some(fulfiller) => {
118 let _ = fulfiller.send(());
119 }
120 None => unreachable!(),
121 }
122 }
123}
124
125impl<T> Connection<T>
126where
127 T: AsyncRead,
128{
129 fn new(
130 input_stream: T,
131 sender: ::capnp_futures::Sender<
132 Rc<::capnp::message::Builder<::capnp::message::HeapAllocator>>,
133 >,
134 side: crate::rpc_twoparty_capnp::Side,
135 receive_options: ReaderOptions,
136 on_disconnect_fulfiller: oneshot::Sender<()>,
137 ) -> Self {
138 Self {
139 inner: Rc::new(RefCell::new(ConnectionInner {
140 input_stream: Rc::new(RefCell::new(Some(input_stream))),
141 sender,
142 side,
143 receive_options,
144 on_disconnect_fulfiller: Some(on_disconnect_fulfiller),
145 window_size_in_bytes: crate::flow_control::DEFAULT_WINDOW_SIZE,
146 })),
147 }
148 }
149}
150
151impl<T> crate::Connection<crate::rpc_twoparty_capnp::Side> for Connection<T>
152where
153 T: AsyncRead + Unpin,
154{
155 fn get_peer_vat_id(&self) -> crate::rpc_twoparty_capnp::Side {
156 self.inner.borrow().side
157 }
158
159 fn new_outgoing_message(
160 &mut self,
161 first_segment_word_size: u32,
162 ) -> Box<dyn crate::OutgoingMessage> {
163 let message = ::capnp::message::Builder::new(
164 ::capnp::message::HeapAllocator::new().first_segment_words(first_segment_word_size),
165 );
166 Box::new(OutgoingMessage {
167 message,
168 sender: self.inner.borrow().sender.clone(),
169 })
170 }
171
172 fn receive_incoming_message(
173 &mut self,
174 ) -> Promise<Option<Box<dyn crate::IncomingMessage + 'static>>, ::capnp::Error> {
175 let inner = self.inner.borrow_mut();
176
177 let maybe_input_stream = inner.input_stream.borrow_mut().take();
178 let return_it_here = inner.input_stream.clone();
179 match maybe_input_stream {
180 Some(mut s) => {
181 let receive_options = inner.receive_options;
182 Promise::from_future(async move {
183 let maybe_message =
184 ::capnp_futures::serialize::try_read_message(&mut s, receive_options)
185 .await?;
186 *return_it_here.borrow_mut() = Some(s);
187 Ok(maybe_message.map(|message| {
188 Box::new(IncomingMessage::new(message)) as Box<dyn crate::IncomingMessage>
189 }))
190 })
191 }
192 None => {
193 Promise::err(::capnp::Error::failed(
194 "this should not be possible".to_string(),
195 ))
196 }
198 }
199 }
200
201 fn new_stream(&mut self) -> (Box<dyn crate::FlowController>, Promise<(), capnp::Error>) {
202 let (fc, f) = crate::flow_control::FixedWindowFlowController::new(
203 self.inner.borrow().window_size_in_bytes,
204 );
205 (Box::new(fc), f)
206 }
207
208 fn shutdown(&mut self, result: ::capnp::Result<()>) -> Promise<(), ::capnp::Error> {
209 Promise::from_future(self.inner.borrow_mut().sender.terminate(result))
210 }
211}
212
213pub struct VatNetwork<T>
215where
216 T: AsyncRead + 'static + Unpin,
217{
218 connection: Option<Connection<T>>,
220
221 weak_connection_inner: Weak<RefCell<ConnectionInner<T>>>,
223
224 execution_driver: futures::future::Shared<Promise<(), ::capnp::Error>>,
225 side: crate::rpc_twoparty_capnp::Side,
226}
227
228impl<T> VatNetwork<T>
230where
231 T: AsyncRead + Unpin,
232{
233 pub fn new<U>(
244 input_stream: T,
245 output_stream: U,
246 side: crate::rpc_twoparty_capnp::Side,
247 receive_options: ReaderOptions,
248 ) -> Self
249 where
250 U: AsyncWrite + 'static + Unpin,
251 {
252 let (fulfiller, disconnect_promise) = oneshot::channel();
253 let disconnect_promise =
254 disconnect_promise.map_err(|_| ::capnp::Error::disconnected("disconnected".into()));
255
256 let (execution_driver, sender) = {
257 let (tx, write_queue) = ::capnp_futures::write_queue(output_stream);
258
259 (
262 Promise::from_future(write_queue.then(move |r| {
263 disconnect_promise
264 .then(move |_| futures::future::ready(r))
265 .map_ok(|_| ())
266 }))
267 .shared(),
268 tx,
269 )
270 };
271
272 let connection = Connection::new(input_stream, sender, side, receive_options, fulfiller);
273 let weak_inner = Rc::downgrade(&connection.inner);
274 Self {
275 connection: Some(connection),
276 weak_connection_inner: weak_inner,
277 execution_driver,
278 side,
279 }
280 }
281
282 pub fn set_window_size(&mut self, window_size: usize) {
285 if let Some(ref mut conn) = self.connection {
286 conn.inner.borrow_mut().window_size_in_bytes = window_size;
287 }
288 }
289}
290
291impl<T> crate::VatNetwork<VatId> for VatNetwork<T>
292where
293 T: AsyncRead + Unpin,
294{
295 fn connect(&mut self, host_id: VatId) -> Option<Box<dyn crate::Connection<VatId>>> {
296 if host_id == self.side {
297 None
298 } else {
299 match self.weak_connection_inner.upgrade() {
300 Some(connection_inner) => Some(Box::new(Connection {
301 inner: connection_inner,
302 })),
303 None => {
304 panic!("tried to reconnect a disconnected twoparty vat network.")
305 }
306 }
307 }
308 }
309
310 fn accept(&mut self) -> Promise<Box<dyn crate::Connection<VatId>>, ::capnp::Error> {
311 match self.connection.take() {
312 Some(c) => Promise::ok(Box::new(c) as Box<dyn crate::Connection<VatId>>),
313 None => Promise::from_future(::futures::future::pending()),
314 }
315 }
316
317 fn drive_until_shutdown(&mut self) -> Promise<(), ::capnp::Error> {
318 Promise::from_future(self.execution_driver.clone())
319 }
320}