use std::any::Any;
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::Mutex;
use std::thread;
use tokio::sync::mpsc;
use tokio::sync::mpsc::error::TryRecvError;
use capnp::any_pointer;
use capnp::capability;
use capnp::message;
use crate::keystore_protocol_capnp::keystore;
use sequoia_ipc as ipc;
use ipc::capnp_rpc::rpc_twoparty_capnp::Side;
use ipc::Descriptor;
use crate::Result;
#[derive(Clone, Debug)]
struct CapRef {
index: u64,
dead_caps: Option<Arc<Mutex<mpsc::UnboundedSender<u64>>>>,
}
impl Drop for CapRef {
fn drop(&mut self) {
log::trace!("Dropping cap {}", self.index);
if let Some(dead_caps) = self.dead_caps.take() {
let _ = dead_caps.lock().unwrap().send(self.index);
}
}
}
#[derive(Clone, Debug)]
pub struct Cap {
index: u64,
#[allow(unused)]
cap_ref: Option<Arc<CapRef>>,
}
pub struct CapTable {
next_index: u64,
caps: HashMap<u64, capability::Client>,
dead_caps: Arc<Mutex<mpsc::UnboundedSender<u64>>>,
}
impl CapTable {
pub fn new() -> (Self, mpsc::UnboundedReceiver<u64>) {
let (sender, receiver) = mpsc::unbounded_channel();
(
CapTable {
next_index: 1,
caps: HashMap::new(),
dead_caps: Arc::new(Mutex::new(sender)),
},
receiver
)
}
pub fn insert(&mut self, cap: capability::Client) -> Cap {
let i = self.next_index;
self.next_index += 1;
self.caps.insert(i, cap);
Cap {
index: i,
cap_ref: Some(Arc::new(CapRef {
index: i,
dead_caps: Some(Arc::clone(&self.dead_caps)),
})),
}
}
fn lookup<C>(&self, cap: C) -> Option<&capability::Client>
where C: std::borrow::Borrow<Cap>
{
self.caps.get(&cap.borrow().index)
}
fn set_root(&mut self, cap: capability::Client) {
if let Some(_) = self.caps.insert(0, cap) {
panic!("root capability already set");
}
}
pub fn root() -> Cap {
Cap {
index: 0,
cap_ref: None,
}
}
}
type CapnProtoRelayResponse
= Result<Box<dyn Any + Send + Sync>>;
type CapnProtoRelayRequest = (
Cap, u64, u16, message::Reader<message::Builder<message::HeapAllocator>>,
mpsc::Sender<CapnProtoRelayResponse>,
Box<dyn FnOnce(Arc<Mutex<CapnProtoRelay>>,
&mut CapTable,
capability::Response<any_pointer::Owned>)
-> Result<Box<dyn Any + Send + Sync>>
+ Send + Sync>
);
pub struct CapnProtoRelay {
sender: mpsc::Sender<CapnProtoRelayRequest>,
}
impl CapnProtoRelay {
pub fn new(descriptor: Descriptor) -> Result<Arc<Mutex<Self>>> {
log::trace!("CapnProtoRelay::new");
let (init_sender, init_receiver): (std::sync::mpsc::Sender<Result<()>>,
std::sync::mpsc::Receiver<Result<()>>)
= std::sync::mpsc::channel();
let (rpc_sender, rpc_receiver): (mpsc::Sender<CapnProtoRelayRequest>,
mpsc::Receiver<CapnProtoRelayRequest>)
= mpsc::channel(8);
let relay = Arc::new(Mutex::new(CapnProtoRelay {
sender: rpc_sender,
}));
let relay_ref = Arc::downgrade(&relay);
thread::spawn(move || {
CapnProtoRelay::worker(
relay_ref, init_sender, descriptor, rpc_receiver);
});
let () = init_receiver.recv()??;
Ok(relay)
}
pub fn root(&self) -> Cap {
CapTable::root()
}
fn worker(relay: std::sync::Weak<std::sync::Mutex<crate::capnp_relay::CapnProtoRelay>>,
init_sender: std::sync::mpsc::Sender<Result<()>>,
descriptor: ipc::Descriptor,
mut rpc_receiver: mpsc::Receiver<CapnProtoRelayRequest>) {
log::trace!("CapnProtoRelay::worker");
let rt = tokio::runtime::Builder::new_current_thread()
.enable_io()
.build();
let rt = match rt {
Ok(rt) => rt,
Err(err) => {
let _ = init_sender.send(Err(err.into()));
return;
}
};
let rpc_system = rt.block_on(async { descriptor.connect() });
let mut rpc_system = match rpc_system {
Ok(rpc_system) => rpc_system,
Err(err) => {
let _ = init_sender.send(Err(err));
return;
}
};
init_sender.send(Ok(())).expect("CapnProtoRelay::new waiting");
drop(init_sender);
let root: keystore::Client = rpc_system.bootstrap(Side::Server);
let (mut caps, dead_caps_receiver) = CapTable::new();
caps.set_root(root.client);
let rpc_task = tokio::task::LocalSet::new();
rpc_task.spawn_local(rpc_system);
rpc_task.block_on(&rt, async move {
let mut dead_caps_receiver = Some(dead_caps_receiver);
'message: loop {
if let Some(recv) = dead_caps_receiver.as_mut() {
'dead_caps: loop {
match recv.try_recv() {
Ok(i) => {
log::trace!("CapnProtoRelay::worker: \
Removing dead cap {}", i);
caps.caps.remove(&i).expect("valid capability");
}
Err(TryRecvError::Empty) => {
break 'dead_caps;
}
Err(err) => {
log::trace!("CapnProtoRelay::worker: \
dead caps receiver is... dead ({})",
err);
dead_caps_receiver = None;
break 'dead_caps;
}
}
}
}
log::trace!("CapnProtoRelay::worker: waiting for a message");
let (cap, interface_id, method_id, params, reply, f) =
match rpc_receiver.recv().await {
None => {
log::debug!("CapnProtoRelay::worker: \
no senders, exiting");
break 'message;
}
Some(r) => r,
};
log::trace!("CapnProtoRelay::worker: processing message");
let relay = if let Some(relay) = relay.upgrade() {
relay
} else {
log::debug!("CapnProtoRelay::worker: \
Aborting. Capability relay dropped.");
break 'message;
};
let r = Self::relay(
relay,
&mut caps, cap,
interface_id, method_id,
params, f).await;
if let Err(ref err) = r {
log::debug!("CapnProtoRelay::worker: \
error relaying rpc {:x}/{}: {}",
interface_id, method_id, err);
}
if let Err(err) = reply.send(r).await {
log::debug!("CapnProtoRelay::worker: \
error forwarding reply {:x}/{}: {}",
interface_id, method_id, err);
} else {
log::trace!("CapnProtoRelay::worker: \
forwarded reply: {:x}/{}",
interface_id, method_id);
}
}
log::trace!("CapnProtoRelay::worker: exited message processing loop");
});
log::trace!("CapnProtoRelay::worker: exiting function");
}
async fn relay<F>(
relay: std::sync::Arc<std::sync::Mutex<crate::capnp_relay::CapnProtoRelay>>,
caps: &mut CapTable,
cap: Cap,
interface_id: u64, method_id: u16,
params: message::Reader<message::Builder<message::HeapAllocator>>,
f: F)
-> Result<Box<dyn Any + Send + Sync>>
where F: FnOnce(std::sync::Arc<std::sync::Mutex<crate::capnp_relay::CapnProtoRelay>>,
&mut CapTable,
capability::Response<any_pointer::Owned>)
-> Result<Box<dyn Any + Send + Sync>>
{
log::trace!("CapnProtoRelay::relay(interface: {}, method: {})",
interface_id, method_id);
let cap = caps.lookup(&cap)
.ok_or(anyhow::anyhow!("Invalid capability ({})", cap.index))?;
let mut request: capability::Request<any_pointer::Owned,
any_pointer::Owned>
= cap.new_call(interface_id, method_id, None);
request.set(params.get_root()?)?;
log::trace!("CapnProtoRelay::relay: sending RPC");
let response: Result<capability::Response<any_pointer::Owned>, _>
= request.send().promise.await;
log::trace!("CapnProtoRelay::relay: got response");
f(relay, caps, response?)
}
pub fn send_rpc<F>(&self,
cap: Cap,
interface_id: u64,
method_id: u16,
message: message::Reader<
message::Builder<message::HeapAllocator>>,
f: F)
-> Result<mpsc::Receiver<CapnProtoRelayResponse>>
where F: 'static
+ FnOnce(std::sync::Arc<std::sync::Mutex<crate::capnp_relay::CapnProtoRelay>>,
&mut CapTable,
capability::Response<any_pointer::Owned>)
-> Result<Box<dyn Any + Send + Sync + 'static>> + Send + Sync
{
log::trace!("CapnProtoRelay::send_rpc");
let (sender, receiver) = mpsc::channel(1);
log::trace!("CapnProtoRelay::send_rpc: forwarding to relay thread");
if let Err(_) = self.sender.blocking_send((cap,
interface_id, method_id,
message, sender,
Box::new(f)))
{
panic!("capnproto worker threader died");
}
Ok(receiver)
}
pub fn await_reply(mut handle: mpsc::Receiver<CapnProtoRelayResponse>)
-> Result<Box<dyn Any + Send + Sync>>
{
log::trace!("CapnProtoRelay::await_reply: waiting for relay thread");
let response = handle.blocking_recv().expect("worker is alive")?;
log::trace!("CapnProtoRelay::await_reply: relay thread responded");
Ok(response)
}
}