use crate::buffer::{Reader, Writer};
use crate::error::Error;
use crate::types::RefFlag;
use std::any::Any;
use std::collections::HashMap;
use std::rc::Rc;
use std::sync::Arc;
#[derive(Default)]
pub struct RefWriter {
refs: HashMap<usize, u32>,
next_ref_id: u32,
}
type UpdateCallback = Box<dyn FnOnce(&RefReader)>;
impl RefWriter {
pub fn new() -> Self {
Self::default()
}
#[inline]
pub fn try_write_rc_ref<T: ?Sized>(&mut self, writer: &mut Writer, rc: &Rc<T>) -> bool {
let ptr_addr = Rc::as_ptr(rc) as *const () as usize;
if let Some(&ref_id) = self.refs.get(&ptr_addr) {
writer.write_i8(RefFlag::Ref as i8);
writer.write_var_uint32(ref_id);
true
} else {
let ref_id = self.next_ref_id;
self.next_ref_id += 1;
self.refs.insert(ptr_addr, ref_id);
writer.write_i8(RefFlag::RefValue as i8);
false
}
}
#[inline]
pub fn try_write_arc_ref<T: ?Sized>(&mut self, writer: &mut Writer, arc: &Arc<T>) -> bool {
let ptr_addr = Arc::as_ptr(arc) as *const () as usize;
if let Some(&ref_id) = self.refs.get(&ptr_addr) {
writer.write_i8(RefFlag::Ref as i8);
writer.write_var_uint32(ref_id);
true
} else {
let ref_id = self.next_ref_id;
self.next_ref_id += 1;
self.refs.insert(ptr_addr, ref_id);
writer.write_i8(RefFlag::RefValue as i8);
false
}
}
#[inline(always)]
pub fn reserve_ref_id(&mut self) -> u32 {
let ref_id = self.next_ref_id;
self.next_ref_id += 1;
ref_id
}
#[inline(always)]
pub fn reset(&mut self) {
self.refs.clear();
self.next_ref_id = 0;
}
}
#[derive(Default)]
pub struct RefReader {
refs: Vec<Box<dyn Any>>,
callbacks: Vec<UpdateCallback>,
}
unsafe impl Send for RefReader {}
unsafe impl Sync for RefReader {}
impl RefReader {
pub fn new() -> Self {
Self::default()
}
#[inline(always)]
pub fn reserve_ref_id(&mut self) -> u32 {
let ref_id = self.refs.len() as u32;
self.refs.push(Box::new(()));
ref_id
}
#[inline(always)]
pub fn store_rc_ref_at<T: 'static + ?Sized>(&mut self, ref_id: u32, rc: Rc<T>) {
self.refs[ref_id as usize] = Box::new(rc);
}
#[inline(always)]
pub fn store_rc_ref<T: 'static + ?Sized>(&mut self, rc: Rc<T>) -> u32 {
let ref_id = self.refs.len() as u32;
self.refs.push(Box::new(rc));
ref_id
}
pub fn store_arc_ref_at<T: 'static + ?Sized>(&mut self, ref_id: u32, arc: Arc<T>) {
self.refs[ref_id as usize] = Box::new(arc);
}
#[inline(always)]
pub fn store_arc_ref<T: 'static + ?Sized>(&mut self, arc: Arc<T>) -> u32 {
let ref_id = self.refs.len() as u32;
self.refs.push(Box::new(arc));
ref_id
}
#[inline(always)]
pub fn get_rc_ref<T: 'static + ?Sized>(&self, ref_id: u32) -> Option<Rc<T>> {
let any_box = self.refs.get(ref_id as usize)?;
any_box.downcast_ref::<Rc<T>>().cloned()
}
#[inline(always)]
pub fn get_arc_ref<T: 'static + ?Sized>(&self, ref_id: u32) -> Option<Arc<T>> {
let any_box = self.refs.get(ref_id as usize)?;
any_box.downcast_ref::<Arc<T>>().cloned()
}
#[inline(always)]
pub fn add_callback(&mut self, callback: UpdateCallback) {
self.callbacks.push(callback);
}
#[inline(always)]
pub fn read_ref_flag(&self, reader: &mut Reader) -> Result<RefFlag, Error> {
let flag_value = reader.read_i8()?;
Ok(match flag_value {
-3 => RefFlag::Null,
-2 => RefFlag::Ref,
-1 => RefFlag::NotNullValue,
0 => RefFlag::RefValue,
_ => Err(Error::invalid_ref(format!(
"Invalid reference flag: {}",
flag_value
)))?,
})
}
#[inline(always)]
pub fn read_ref_id(&self, reader: &mut Reader) -> Result<u32, Error> {
reader.read_varuint32()
}
#[inline(always)]
pub fn resolve_callbacks(&mut self) {
let callbacks = std::mem::take(&mut self.callbacks);
for callback in callbacks {
callback(self);
}
}
#[inline(always)]
pub fn reset(&mut self) {
self.resolve_callbacks();
self.refs.clear();
self.callbacks.clear();
}
}