capnp_rpc/
twoparty.rs

1// Copyright (c) 2015 Sandstorm Development Group, Inc. and contributors
2// Licensed under the MIT License:
3//
4// Permission is hereby granted, free of charge, to any person obtaining a copy
5// of this software and associated documentation files (the "Software"), to deal
6// in the Software without restriction, including without limitation the rights
7// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8// copies of the Software, and to permit persons to whom the Software is
9// furnished to do so, subject to the following conditions:
10//
11// The above copyright notice and this permission notice shall be included in
12// all copies or substantial portions of the Software.
13//
14// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
17// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
20// THE SOFTWARE.
21
22//! An implementation of [`VatNetwork`](crate::VatNetwork) for the common case
23//! of a client-server connection.
24
25use capnp::capability::Promise;
26use capnp::message::ReaderOptions;
27use futures::channel::oneshot;
28use futures::{AsyncRead, AsyncWrite, FutureExt, TryFutureExt};
29
30use std::cell::RefCell;
31use std::rc::{Rc, Weak};
32
33pub type VatId = crate::rpc_twoparty_capnp::Side;
34
35struct IncomingMessage {
36    message: ::capnp::message::Reader<capnp::serialize::OwnedSegments>,
37}
38
39impl IncomingMessage {
40    pub fn new(message: ::capnp::message::Reader<capnp::serialize::OwnedSegments>) -> Self {
41        Self { message }
42    }
43}
44
45impl crate::IncomingMessage for IncomingMessage {
46    fn get_body(&self) -> ::capnp::Result<::capnp::any_pointer::Reader<'_>> {
47        self.message.get_root()
48    }
49}
50
51struct OutgoingMessage {
52    message: ::capnp::message::Builder<::capnp::message::HeapAllocator>,
53    sender: ::capnp_futures::Sender<Rc<::capnp::message::Builder<::capnp::message::HeapAllocator>>>,
54}
55
56impl crate::OutgoingMessage for OutgoingMessage {
57    fn get_body(&mut self) -> ::capnp::Result<::capnp::any_pointer::Builder<'_>> {
58        self.message.get_root()
59    }
60
61    fn get_body_as_reader(&self) -> ::capnp::Result<::capnp::any_pointer::Reader<'_>> {
62        self.message.get_root_as_reader()
63    }
64
65    fn send(
66        self: Box<Self>,
67    ) -> (
68        Promise<(), ::capnp::Error>,
69        Rc<::capnp::message::Builder<::capnp::message::HeapAllocator>>,
70    ) {
71        let tmp = *self;
72        let Self {
73            message,
74            mut sender,
75        } = tmp;
76        let m = Rc::new(message);
77        (
78            Promise::from_future(sender.send(m.clone()).map_ok(|_| ())),
79            m,
80        )
81    }
82
83    fn take(self: Box<Self>) -> ::capnp::message::Builder<::capnp::message::HeapAllocator> {
84        self.message
85    }
86
87    fn size_in_words(&self) -> usize {
88        self.message.size_in_words()
89    }
90}
91
92struct ConnectionInner<T>
93where
94    T: AsyncRead + 'static,
95{
96    input_stream: Rc<RefCell<Option<T>>>,
97    sender: ::capnp_futures::Sender<Rc<::capnp::message::Builder<::capnp::message::HeapAllocator>>>,
98    side: crate::rpc_twoparty_capnp::Side,
99    receive_options: ReaderOptions,
100    on_disconnect_fulfiller: Option<oneshot::Sender<()>>,
101    window_size_in_bytes: usize,
102}
103
104struct Connection<T>
105where
106    T: AsyncRead + 'static,
107{
108    inner: Rc<RefCell<ConnectionInner<T>>>,
109}
110
111impl<T> Drop for ConnectionInner<T>
112where
113    T: AsyncRead,
114{
115    fn drop(&mut self) {
116        match self.on_disconnect_fulfiller.take() {
117            Some(fulfiller) => {
118                let _ = fulfiller.send(());
119            }
120            None => unreachable!(),
121        }
122    }
123}
124
125impl<T> Connection<T>
126where
127    T: AsyncRead,
128{
129    fn new(
130        input_stream: T,
131        sender: ::capnp_futures::Sender<
132            Rc<::capnp::message::Builder<::capnp::message::HeapAllocator>>,
133        >,
134        side: crate::rpc_twoparty_capnp::Side,
135        receive_options: ReaderOptions,
136        on_disconnect_fulfiller: oneshot::Sender<()>,
137    ) -> Self {
138        Self {
139            inner: Rc::new(RefCell::new(ConnectionInner {
140                input_stream: Rc::new(RefCell::new(Some(input_stream))),
141                sender,
142                side,
143                receive_options,
144                on_disconnect_fulfiller: Some(on_disconnect_fulfiller),
145                window_size_in_bytes: crate::flow_control::DEFAULT_WINDOW_SIZE,
146            })),
147        }
148    }
149}
150
151impl<T> crate::Connection<crate::rpc_twoparty_capnp::Side> for Connection<T>
152where
153    T: AsyncRead + Unpin,
154{
155    fn get_peer_vat_id(&self) -> crate::rpc_twoparty_capnp::Side {
156        self.inner.borrow().side
157    }
158
159    fn new_outgoing_message(
160        &mut self,
161        first_segment_word_size: u32,
162    ) -> Box<dyn crate::OutgoingMessage> {
163        let message = ::capnp::message::Builder::new(
164            ::capnp::message::HeapAllocator::new().first_segment_words(first_segment_word_size),
165        );
166        Box::new(OutgoingMessage {
167            message,
168            sender: self.inner.borrow().sender.clone(),
169        })
170    }
171
172    fn receive_incoming_message(
173        &mut self,
174    ) -> Promise<Option<Box<dyn crate::IncomingMessage + 'static>>, ::capnp::Error> {
175        let inner = self.inner.borrow_mut();
176
177        let maybe_input_stream = inner.input_stream.borrow_mut().take();
178        let return_it_here = inner.input_stream.clone();
179        match maybe_input_stream {
180            Some(mut s) => {
181                let receive_options = inner.receive_options;
182                Promise::from_future(async move {
183                    let maybe_message =
184                        ::capnp_futures::serialize::try_read_message(&mut s, receive_options)
185                            .await?;
186                    *return_it_here.borrow_mut() = Some(s);
187                    Ok(maybe_message.map(|message| {
188                        Box::new(IncomingMessage::new(message)) as Box<dyn crate::IncomingMessage>
189                    }))
190                })
191            }
192            None => {
193                Promise::err(::capnp::Error::failed(
194                    "this should not be possible".to_string(),
195                ))
196                //   unreachable!(),
197            }
198        }
199    }
200
201    fn new_stream(&mut self) -> (Box<dyn crate::FlowController>, Promise<(), capnp::Error>) {
202        let (fc, f) = crate::flow_control::FixedWindowFlowController::new(
203            self.inner.borrow().window_size_in_bytes,
204        );
205        (Box::new(fc), f)
206    }
207
208    fn shutdown(&mut self, result: ::capnp::Result<()>) -> Promise<(), ::capnp::Error> {
209        Promise::from_future(self.inner.borrow_mut().sender.terminate(result))
210    }
211}
212
213/// A vat network with two parties, the client and the server.
214pub struct VatNetwork<T>
215where
216    T: AsyncRead + 'static + Unpin,
217{
218    // connection handle that we will return on accept()
219    connection: Option<Connection<T>>,
220
221    // connection handle that we will return on connect()
222    weak_connection_inner: Weak<RefCell<ConnectionInner<T>>>,
223
224    execution_driver: futures::future::Shared<Promise<(), ::capnp::Error>>,
225    side: crate::rpc_twoparty_capnp::Side,
226}
227
228/// A two-party vat `VatNetwork` implementation.
229impl<T> VatNetwork<T>
230where
231    T: AsyncRead + Unpin,
232{
233    /// Creates a new two-party vat network that will receive data on `input_stream` and send data on
234    /// `output_stream`. (Typically, performance is best if these streams are buffered, possibly via
235    /// `futures::io::BufReader` and `futures::io::BufWriter`.)
236    ///
237    /// `side` indicates whether this is the client or the server side of the connection. This has no
238    /// effect on the data sent over the connection; it merely exists so that `RpcNetwork::bootstrap` knows
239    /// whether to return the local or the remote bootstrap capability. `VatId` parameters like this one
240    /// will make more sense once we have vat networks with more than two parties.
241    ///
242    /// The options in `receive_options` will be used when reading the messages that come in on `input_stream`.
243    pub fn new<U>(
244        input_stream: T,
245        output_stream: U,
246        side: crate::rpc_twoparty_capnp::Side,
247        receive_options: ReaderOptions,
248    ) -> Self
249    where
250        U: AsyncWrite + 'static + Unpin,
251    {
252        let (fulfiller, disconnect_promise) = oneshot::channel();
253        let disconnect_promise =
254            disconnect_promise.map_err(|_| ::capnp::Error::disconnected("disconnected".into()));
255
256        let (execution_driver, sender) = {
257            let (tx, write_queue) = ::capnp_futures::write_queue(output_stream);
258
259            // Don't use `.join()` here because we need to make sure to wait for `disconnect_promise` to
260            // resolve even if `write_queue` resolves to an error.
261            (
262                Promise::from_future(write_queue.then(move |r| {
263                    disconnect_promise
264                        .then(move |_| futures::future::ready(r))
265                        .map_ok(|_| ())
266                }))
267                .shared(),
268                tx,
269            )
270        };
271
272        let connection = Connection::new(input_stream, sender, side, receive_options, fulfiller);
273        let weak_inner = Rc::downgrade(&connection.inner);
274        Self {
275            connection: Some(connection),
276            weak_connection_inner: weak_inner,
277            execution_driver,
278            side,
279        }
280    }
281
282    /// Set the number of bytes in the flow control window for each stream created
283    /// on this connection.
284    pub fn set_window_size(&mut self, window_size: usize) {
285        if let Some(ref mut conn) = self.connection {
286            conn.inner.borrow_mut().window_size_in_bytes = window_size;
287        }
288    }
289}
290
291impl<T> crate::VatNetwork<VatId> for VatNetwork<T>
292where
293    T: AsyncRead + Unpin,
294{
295    fn connect(&mut self, host_id: VatId) -> Option<Box<dyn crate::Connection<VatId>>> {
296        if host_id == self.side {
297            None
298        } else {
299            match self.weak_connection_inner.upgrade() {
300                Some(connection_inner) => Some(Box::new(Connection {
301                    inner: connection_inner,
302                })),
303                None => {
304                    panic!("tried to reconnect a disconnected twoparty vat network.")
305                }
306            }
307        }
308    }
309
310    fn accept(&mut self) -> Promise<Box<dyn crate::Connection<VatId>>, ::capnp::Error> {
311        match self.connection.take() {
312            Some(c) => Promise::ok(Box::new(c) as Box<dyn crate::Connection<VatId>>),
313            None => Promise::from_future(::futures::future::pending()),
314        }
315    }
316
317    fn drive_until_shutdown(&mut self) -> Promise<(), ::capnp::Error> {
318        Promise::from_future(self.execution_driver.clone())
319    }
320}