use std::sync::{Arc, Mutex, RwLock, Weak};
use rand::Rng;
use rpc::InvalidRpcIdentifier;
use tor_rpcbase as rpc;
use tracing::warn;
use crate::{
RpcAuthentication,
connection::{Connection, ConnectionId},
globalid::{GlobalId, MacKey},
};
type WeakValueHashMap<K, V> = weak_table::WeakValueHashMap<K, V, std::hash::RandomState>;
pub struct RpcMgr {
global_id_mac_key: MacKey,
dispatch_table: Arc<RwLock<rpc::DispatchTable>>,
inner: Mutex<Inner>,
}
pub(crate) struct Inner {
connections: WeakValueHashMap<ConnectionId, Weak<Connection>>,
}
#[derive(Clone, Debug, thiserror::Error)]
#[non_exhaustive]
pub enum RpcMgrError {
#[error("Method {1} had an invalid name")]
InvalidMethodName(#[source] InvalidRpcIdentifier, String),
}
type ObjectWithContext = (Arc<dyn rpc::Context>, Arc<dyn rpc::Object>);
impl RpcMgr {
pub fn new() -> Result<Arc<Self>, RpcMgrError> {
let problems = rpc::check_method_names([]);
for (m, err) in &problems {
warn!("Internal issue: Invalid RPC method name {m:?}: {err}");
}
let fatal_problem = problems
.into_iter()
.find(|(_, err)| !matches!(err, InvalidRpcIdentifier::UnrecognizedNamespace));
if let Some((name, err)) = fatal_problem {
return Err(RpcMgrError::InvalidMethodName(err, name.to_owned()));
}
Ok(Arc::new(RpcMgr {
global_id_mac_key: MacKey::new(&mut rand::rng()),
dispatch_table: Arc::new(RwLock::new(rpc::DispatchTable::from_inventory())),
inner: Mutex::new(Inner {
connections: WeakValueHashMap::new(),
}),
}))
}
pub fn register_rpc_methods<I>(&self, entries: I)
where
I: IntoIterator<Item = rpc::dispatch::InvokerEnt>,
{
self.with_dispatch_table(|table| table.extend(entries));
}
pub fn with_dispatch_table<F, T>(&self, func: F) -> T
where
F: FnOnce(&mut rpc::DispatchTable) -> T,
{
let mut table = self.dispatch_table.write().expect("poisoned lock");
func(&mut table)
}
pub fn new_connection<F>(
self: &Arc<Self>,
require_auth: tor_rpc_connect::auth::RpcAuth,
create_session: F,
) -> Arc<Connection>
where
F: Fn(&RpcAuthentication) -> Arc<dyn rpc::Object> + Send + Sync + 'static,
{
let connection_id = ConnectionId::from(rand::rng().random::<[u8; 16]>());
let connection = Connection::new(
connection_id,
self.dispatch_table.clone(),
self.global_id_mac_key.clone(),
require_auth,
Box::new(create_session) as _,
);
let mut inner = self.inner.lock().expect("poisoned lock");
let old = inner.connections.insert(connection_id, connection.clone());
assert!(
old.is_none(),
"connection ID collision detected; this is phenomenally unlikely!",
);
connection
}
pub fn lookup_object(&self, id: &rpc::ObjectId) -> Result<ObjectWithContext, rpc::LookupError> {
GlobalId::try_decode(&self.global_id_mac_key, id)?
.and_then(|global_id| self.lookup_by_global_id(&global_id))
.ok_or_else(|| rpc::LookupError::NoObject(id.clone()))
}
pub(crate) fn lookup_by_global_id(&self, id: &GlobalId) -> Option<ObjectWithContext> {
let connection = {
let inner = self.inner.lock().expect("lock poisoned");
let connection = inner.connections.get(&id.connection)?;
drop(inner);
connection
};
let obj = connection.lookup_by_idx(id.local_id)?;
Some((connection, obj))
}
}