use std::sync::{Arc, Weak};
use rand::RngExt;
use slotmap_careful::{Key as _, KeyData, SlotMap};
use tor_rpcbase as rpc;
pub(crate) mod methods;
slotmap_careful::new_key_type! {
pub(crate) struct GenIdx;
}
#[derive(Clone, derive_more::From)]
enum ObjectRef {
Strong(Arc<dyn rpc::Object>),
Weak(Weak<dyn rpc::Object>),
}
impl ObjectRef {
fn get(&self) -> Option<Arc<dyn rpc::Object>> {
match self {
ObjectRef::Strong(s) => Some(Arc::clone(s)),
ObjectRef::Weak(w) => w.upgrade(),
}
}
}
#[derive(Default)]
pub(crate) struct ObjMap {
arena: SlotMap<GenIdx, ObjectRef>,
}
impl GenIdx {
pub(crate) const BYTE_LEN: usize = 16;
pub(crate) fn encode(self) -> rpc::ObjectId {
self.encode_with_rng(&mut rand::rng())
}
fn encode_with_rng<R: rand::Rng>(self, rng: &mut R) -> rpc::ObjectId {
use base64ct::Encoding;
let bytes = self.to_bytes(rng);
rpc::ObjectId::from(base64ct::Base64UrlUnpadded::encode_string(&bytes[..]))
}
pub(crate) fn to_bytes<R: rand::Rng>(self, rng: &mut R) -> [u8; Self::BYTE_LEN] {
use tor_bytes::Writer;
let ffi_idx = self.data().as_ffi();
let x = rng.random::<u64>();
let mut bytes = Vec::with_capacity(Self::BYTE_LEN);
bytes.write_u64(x);
bytes.write_u64(ffi_idx.wrapping_add(x));
bytes.try_into().expect("Length was wrong!")
}
pub(crate) fn try_decode(id: &rpc::ObjectId) -> Result<Self, rpc::LookupError> {
use base64ct::Encoding;
let bytes = base64ct::Base64UrlUnpadded::decode_vec(id.as_ref())
.map_err(|_| rpc::LookupError::NoObject(id.clone()))?;
Self::from_bytes(&bytes).ok_or_else(|| rpc::LookupError::NoObject(id.clone()))
}
pub(crate) fn from_bytes(bytes: &[u8]) -> Option<Self> {
use tor_bytes::Reader;
let mut r = Reader::from_slice(bytes);
let x = r.take_u64().ok()?;
let ffi_idx = r.take_u64().ok()?;
r.should_be_exhausted().ok()?;
let ffi_idx = ffi_idx.wrapping_sub(x);
Some(GenIdx::from(KeyData::from_ffi(ffi_idx)))
}
}
impl ObjMap {
pub(crate) fn new() -> Self {
Self::default()
}
pub(crate) fn insert_strong(&mut self, value: Arc<dyn rpc::Object>) -> GenIdx {
self.arena.insert(ObjectRef::Strong(value))
}
pub(crate) fn insert_weak(&mut self, value: &Arc<dyn rpc::Object>) -> GenIdx {
self.arena.insert(ObjectRef::Weak(Arc::downgrade(value)))
}
pub(crate) fn lookup(&self, idx: GenIdx) -> Result<Arc<dyn rpc::Object>, LookupError> {
self.arena
.get(idx)
.ok_or(LookupError::NoObject)?
.get()
.ok_or(LookupError::Expired)
}
pub(crate) fn remove(&mut self, idx: GenIdx) -> bool {
self.arena.remove(idx).is_some()
}
#[cfg(test)]
fn assert_okay(&self) {}
}
#[derive(Clone, Debug, thiserror::Error)]
pub(crate) enum LookupError {
#[error("Object not found")]
NoObject,
#[error("Object expired")]
Expired,
}
impl LookupError {
pub(crate) fn to_rpc_lookup_error(&self, id: rpc::ObjectId) -> rpc::LookupError {
match self {
LookupError::NoObject => rpc::LookupError::NoObject(id),
LookupError::Expired => rpc::LookupError::Expired(id),
}
}
}
#[cfg(test)]
mod test {
#![allow(clippy::bool_assert_comparison)]
#![allow(clippy::clone_on_copy)]
#![allow(clippy::dbg_macro)]
#![allow(clippy::mixed_attributes_style)]
#![allow(clippy::print_stderr)]
#![allow(clippy::print_stdout)]
#![allow(clippy::single_char_pattern)]
#![allow(clippy::unwrap_used)]
#![allow(clippy::unchecked_time_subtraction)]
#![allow(clippy::useless_vec)]
#![allow(clippy::needless_pass_by_value)]
#![allow(clippy::string_slice)]
use super::*;
use derive_deftly::Deftly;
use tor_rpcbase::templates::*;
#[derive(Clone, Debug, Deftly)]
#[derive_deftly(Object)]
struct ExampleObject(#[allow(unused)] String);
#[test]
fn map_basics() {
let obj1 = Arc::new(ExampleObject("abcdef".to_string()));
let mut map = ObjMap::new();
map.assert_okay();
let id1 = map.insert_strong(obj1.clone());
let id2 = map.insert_strong(obj1.clone());
assert_ne!(id1, id2);
let obj1: Arc<dyn rpc::Object> = obj1;
let obj_out1 = map.lookup(id1).unwrap();
let obj_out2 = map.lookup(id2).unwrap();
assert!(Arc::ptr_eq(&obj1, &obj_out1));
assert!(Arc::ptr_eq(&obj1, &obj_out2));
map.assert_okay();
map.remove(id1);
assert!(map.lookup(id1).is_err());
let obj_out2b = map.lookup(id2).unwrap();
assert!(Arc::ptr_eq(&obj_out2, &obj_out2b));
map.assert_okay();
}
#[test]
fn strong_and_weak() {
let obj1: Arc<dyn rpc::Object> = Arc::new(ExampleObject("hello".to_string()));
let obj2: Arc<dyn rpc::Object> = Arc::new(ExampleObject("world".to_string()));
let mut map = ObjMap::new();
let id1 = map.insert_strong(obj1.clone());
let id2 = map.insert_weak(&obj2);
{
let out1 = map.lookup(id1).unwrap();
let out2 = map.lookup(id2).unwrap();
assert!(Arc::ptr_eq(&obj1, &out1));
assert!(Arc::ptr_eq(&obj2, &out2));
}
map.assert_okay();
drop(obj1);
drop(obj2);
{
let out1 = map.lookup(id1);
let out2 = map.lookup(id2);
assert!(out1.is_ok());
assert!(out2.is_err());
}
map.assert_okay();
}
#[test]
fn remove() {
let obj1: Arc<dyn rpc::Object> = Arc::new(ExampleObject("hello".to_string()));
let obj2: Arc<dyn rpc::Object> = Arc::new(ExampleObject("world".to_string()));
let mut map = ObjMap::new();
let id1 = map.insert_strong(obj1.clone());
let id2 = map.insert_weak(&obj2);
map.assert_okay();
map.remove(id1);
map.assert_okay();
assert!(map.lookup(id1).is_err());
assert!(map.lookup(id2).is_ok());
map.remove(id2);
map.assert_okay();
assert!(map.lookup(id1).is_err());
assert!(map.lookup(id2).is_err());
}
#[test]
fn duplicates() {
let obj1: Arc<dyn rpc::Object> = Arc::new(ExampleObject("hello".to_string()));
let obj2: Arc<dyn rpc::Object> = Arc::new(ExampleObject("world".to_string()));
let mut map = ObjMap::new();
let id1 = map.insert_strong(obj1.clone());
let id2 = map.insert_weak(&obj2);
{
assert_ne!(id2, map.insert_weak(&obj1));
assert_ne!(id2, map.insert_weak(&obj2));
}
{
assert_ne!(id1, map.insert_strong(obj1.clone()));
assert_ne!(id2, map.insert_strong(obj2.clone()));
}
}
#[test]
fn objid_encoding() {
fn test_roundtrip(a: u32, b: u32, rng: &mut tor_basic_utils::test_rng::TestingRng) {
let a: u64 = a.into();
let b: u64 = b.into();
let data = KeyData::from_ffi((a << 33) | (1_u64 << 32) | b);
let idx = GenIdx::from(data);
let s1 = idx.encode_with_rng(rng);
let s2 = idx.encode_with_rng(rng);
assert_ne!(s1, s2);
assert_eq!(idx, GenIdx::try_decode(&s1).unwrap());
assert_eq!(idx, GenIdx::try_decode(&s2).unwrap());
}
let mut rng = tor_basic_utils::test_rng::testing_rng();
test_roundtrip(0, 1, &mut rng);
test_roundtrip(0, 2, &mut rng);
test_roundtrip(1, 1, &mut rng);
test_roundtrip(0xffffffff, 0xffffffff, &mut rng);
for _ in 0..256 {
test_roundtrip(rng.random(), rng.random(), &mut rng);
}
}
}