sequoia-keystore 0.3.0

Sequoia's private key store server.
Documentation
use std::any::Any;
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::Mutex;
use std::thread;

use tokio::sync::mpsc;
use tokio::sync::mpsc::error::TryRecvError;

use capnp::any_pointer;
use capnp::capability;
use capnp::message;

use crate::keystore_protocol_capnp::keystore;

use sequoia_ipc as ipc;
use ipc::capnp_rpc::rpc_twoparty_capnp::Side;
use ipc::Descriptor;

use crate::Result;

// A capability reference.
//
// When this goes out of scope, we add the capability to the dead
// capability list so that it can be deallocated.
#[derive(Clone, Debug)]
struct CapRef {
    index: u64,
    dead_caps: Option<Arc<Mutex<mpsc::UnboundedSender<u64>>>>,
}

impl Drop for CapRef {
    fn drop(&mut self) {
        log::trace!("Dropping cap {}", self.index);
        if let Some(dead_caps) = self.dead_caps.take() {
            let _ = dead_caps.lock().unwrap().send(self.index);
        }
    }
}

/// A local capability.
///
/// We wrap capnproto capabilities, because they are not Send + Sync.
#[derive(Clone, Debug)]
pub struct Cap {
    // The index into CapTable::caps.
    index: u64,
    // We don't read from this, but we rely on its drop method to
    // clean up the actual capability.
    #[allow(unused)]
    cap_ref: Option<Arc<CapRef>>,
}

/// A capability table.
pub struct CapTable {
    next_index: u64,
    caps: HashMap<u64, capability::Client>,
    dead_caps: Arc<Mutex<mpsc::UnboundedSender<u64>>>,
}

impl CapTable {
    /// Returns a new capability table.
    ///
    /// Normally, there is one capability table per root capability.
    pub fn new() -> (Self, mpsc::UnboundedReceiver<u64>) {
        let (sender, receiver) = mpsc::unbounded_channel();

        (
            CapTable {
                next_index: 1,
                caps: HashMap::new(),
                dead_caps: Arc::new(Mutex::new(sender)),
            },
            receiver
        )
    }

    /// Adds the capnproto capability to the capability table.  Returns
    /// the local capability.
    pub fn insert(&mut self, cap: capability::Client) -> Cap {
        let i = self.next_index;
        self.next_index += 1;
        self.caps.insert(i, cap);

        Cap {
            index: i,
            cap_ref: Some(Arc::new(CapRef {
                index: i,
                dead_caps: Some(Arc::clone(&self.dead_caps)),
            })),
        }
    }

    /// Returns the capnproto capability.
    fn lookup<C>(&self, cap: C) -> Option<&capability::Client>
        where C: std::borrow::Borrow<Cap>
    {
        self.caps.get(&cap.borrow().index)
    }

    /// Sets the root capability.
    ///
    /// This may only be called once.
    fn set_root(&mut self, cap: capability::Client) {
        if let Some(_) = self.caps.insert(0, cap) {
            panic!("root capability already set");
        }
    }

    /// Returns the root capability.
    pub fn root() -> Cap {
        Cap {
            index: 0,
            // The root is always live.
            cap_ref: None,
        }
    }
}

type CapnProtoRelayResponse
    = Result<Box<dyn Any + Send + Sync>>;

type CapnProtoRelayRequest = (
    Cap, // capability in CapnProtoRelay::caps
    u64, // Interface ID.
    u16, // Method ID.
    message::Reader<message::Builder<message::HeapAllocator>>,
    // Response from the server.
    mpsc::Sender<CapnProtoRelayResponse>,
    // Function to extract the result.
    Box<dyn FnOnce(Arc<Mutex<CapnProtoRelay>>,
                   &mut CapTable,
                   capability::Response<any_pointer::Owned>)
                   -> Result<Box<dyn Any + Send + Sync>>
                             + Send + Sync>
);

/// A capnproto relay.
///
/// We want capabilities to be Send + Sync, but capnproto capabilities
/// aren't.  This relay makes that possible by wrapping capabilities
/// and forwarding messages to a dedicated thread that is responsible
/// for dispatching RPCs.
pub struct CapnProtoRelay {
    // The end point used to send a message to the worker thread.
    sender: mpsc::Sender<CapnProtoRelayRequest>,
}

impl CapnProtoRelay {
    /// Instantiates a new relay.
    pub fn new(descriptor: Descriptor) -> Result<Arc<Mutex<Self>>> {
        log::trace!("CapnProtoRelay::new");

        // To get any errors regarding the initialization of the
        // worker thread.
        let (init_sender, init_receiver): (std::sync::mpsc::Sender<Result<()>>,
                                           std::sync::mpsc::Receiver<Result<()>>)
            = std::sync::mpsc::channel();

        // We don't need much capacity: the sender will always
        // immediately follow up by waiting for a reply.  So if the
        // sender's message is queued, and then the sender waits for a
        // reply, or the sender blocks until there is space, and then
        // waits for a reply, it doesn't matter.
        let (rpc_sender, rpc_receiver): (mpsc::Sender<CapnProtoRelayRequest>,
                                         mpsc::Receiver<CapnProtoRelayRequest>)
            = mpsc::channel(8);

        let relay = Arc::new(Mutex::new(CapnProtoRelay {
            sender: rpc_sender,
        }));

        let relay_ref = Arc::downgrade(&relay);
        thread::spawn(move || {
            CapnProtoRelay::worker(
                relay_ref, init_sender, descriptor, rpc_receiver);
        });

        // Wait for the worker thread to signal us that startup went
        // ok.
        let () = init_receiver.recv()??;

        Ok(relay)
    }

    /// Returns the root capability.
    pub fn root(&self) -> Cap {
        CapTable::root()
    }

    fn worker(relay: std::sync::Weak<std::sync::Mutex<crate::capnp_relay::CapnProtoRelay>>,
              init_sender: std::sync::mpsc::Sender<Result<()>>,
              descriptor: ipc::Descriptor,
              mut rpc_receiver: mpsc::Receiver<CapnProtoRelayRequest>) {
        log::trace!("CapnProtoRelay::worker");

        let rt = tokio::runtime::Builder::new_current_thread()
            .enable_io()
            .build();
        let rt = match rt {
            Ok(rt) => rt,
            Err(err) => {
                let _ = init_sender.send(Err(err.into()));
                return;
            }
        };

        // Need to enter the Tokio context due to Tokio TcpStream
        // creation binding eagerly to an I/O reactor.
        let rpc_system = rt.block_on(async { descriptor.connect() });
        let mut rpc_system = match rpc_system {
            Ok(rpc_system) => rpc_system,
            Err(err) =>  {
                let _ = init_sender.send(Err(err));
                return;
            }
        };

        // Start up went well.
        init_sender.send(Ok(())).expect("CapnProtoRelay::new waiting");
        drop(init_sender);

        let root: keystore::Client = rpc_system.bootstrap(Side::Server);

        let (mut caps, dead_caps_receiver) = CapTable::new();
        caps.set_root(root.client);

        // Since RpcSystem is explicitly `!Send`, we need to spawn it
        // on a `LocalSet`.
        let rpc_task = tokio::task::LocalSet::new();
        rpc_task.spawn_local(rpc_system);

        rpc_task.block_on(&rt, async move {
            let mut dead_caps_receiver = Some(dead_caps_receiver);

            'message: loop {
                // Clean up any dead capabilities.
                if let Some(recv) = dead_caps_receiver.as_mut() {
                    'dead_caps: loop {
                        match recv.try_recv() {
                            Ok(i) => {
                                log::trace!("CapnProtoRelay::worker: \
                                             Removing dead cap {}", i);
                                caps.caps.remove(&i).expect("valid capability");
                            }
                            Err(TryRecvError::Empty) => {
                                // No pending messages.
                                break 'dead_caps;
                            }
                            Err(err) => {
                                log::trace!("CapnProtoRelay::worker: \
                                             dead caps receiver is... dead ({})",
                                            err);
                                dead_caps_receiver = None;
                                break 'dead_caps;
                            }
                        }
                    }
                }

                // Process a message.
                log::trace!("CapnProtoRelay::worker: waiting for a message");
                let (cap, interface_id, method_id, params, reply, f) =
                    match rpc_receiver.recv().await {
                        None => {
                            // This only happens if there are no
                            // senders.  In that case, CapnProtoRelay
                            // has been dropped and we should exit.
                            log::debug!("CapnProtoRelay::worker: \
                                         no senders, exiting");
                            break 'message;
                        }
                        Some(r) => r,
                    };

                log::trace!("CapnProtoRelay::worker: processing message");
                let relay = if let Some(relay) = relay.upgrade() {
                    relay
                } else {
                    log::debug!("CapnProtoRelay::worker: \
                                 Aborting.  Capability relay dropped.");
                    break 'message;
                };

                let r = Self::relay(
                    relay,
                    &mut caps, cap,
                    interface_id, method_id,
                    params, f).await;
                if let Err(ref err) = r {
                    log::debug!("CapnProtoRelay::worker: \
                                 error relaying rpc {:x}/{}: {}",
                                interface_id, method_id, err);
                }
                if let Err(err) = reply.send(r).await {
                    log::debug!("CapnProtoRelay::worker: \
                                 error forwarding reply {:x}/{}: {}",
                                interface_id, method_id, err);
                } else {
                    log::trace!("CapnProtoRelay::worker: \
                                 forwarded reply: {:x}/{}",
                                interface_id, method_id);
                }
            }

            log::trace!("CapnProtoRelay::worker: exited message processing loop");
        });

        log::trace!("CapnProtoRelay::worker: exiting function");
    }

    async fn relay<F>(
        relay: std::sync::Arc<std::sync::Mutex<crate::capnp_relay::CapnProtoRelay>>,
        caps: &mut CapTable,
        cap: Cap,
        interface_id: u64, method_id: u16,
        params: message::Reader<message::Builder<message::HeapAllocator>>,
        f: F)
        -> Result<Box<dyn Any + Send + Sync>>
        where F: FnOnce(std::sync::Arc<std::sync::Mutex<crate::capnp_relay::CapnProtoRelay>>,
                        &mut CapTable,
                        capability::Response<any_pointer::Owned>)
                     -> Result<Box<dyn Any + Send + Sync>>
    {
        log::trace!("CapnProtoRelay::relay(interface: {}, method: {})",
                    interface_id, method_id);

        let cap = caps.lookup(&cap)
            // XXX: Use an Error variant.
            .ok_or(anyhow::anyhow!("Invalid capability ({})", cap.index))?;

        let mut request: capability::Request<any_pointer::Owned,
                                             any_pointer::Owned>
            = cap.new_call(interface_id, method_id, None);
        request.set(params.get_root()?)?;

        log::trace!("CapnProtoRelay::relay: sending RPC");
        let response: Result<capability::Response<any_pointer::Owned>, _>
            = request.send().promise.await;
        log::trace!("CapnProtoRelay::relay: got response");

        f(relay, caps, response?)
    }

    /// Sends an RPC and returns a handle that can be waited on.
    ///
    /// Use `await_reply` to wait on the returned handle.
    pub fn send_rpc<F>(&self,
                   cap: Cap,
                   interface_id: u64,
                   method_id: u16,
                   message: message::Reader<
                           message::Builder<message::HeapAllocator>>,
                   f: F)
        -> Result<mpsc::Receiver<CapnProtoRelayResponse>>
        where F: 'static
            + FnOnce(std::sync::Arc<std::sync::Mutex<crate::capnp_relay::CapnProtoRelay>>,
                     &mut CapTable,
                     capability::Response<any_pointer::Owned>)
                     -> Result<Box<dyn Any + Send + Sync + 'static>> + Send + Sync
    {
        log::trace!("CapnProtoRelay::send_rpc");

        let (sender, receiver) = mpsc::channel(1);

        log::trace!("CapnProtoRelay::send_rpc: forwarding to relay thread");

        if let Err(_) = self.sender.blocking_send((cap,
                                                   interface_id, method_id,
                                                   message, sender,
                                                   Box::new(f)))
        {
            // sending can only fail if the receiver disconnected.  We
            // can't get here if the relay thread is not running.
            panic!("capnproto worker threader died");
        }

        Ok(receiver)
    }

    /// Returns the response to an RPC.
    ///
    /// The handle is as returned by `send_rpc`.
    pub fn await_reply(mut handle: mpsc::Receiver<CapnProtoRelayResponse>)
        -> Result<Box<dyn Any + Send + Sync>>
    {
        log::trace!("CapnProtoRelay::await_reply: waiting for relay thread");

        let response = handle.blocking_recv().expect("worker is alive")?;

        log::trace!("CapnProtoRelay::await_reply: relay thread responded");

        Ok(response)
    }
}