use std::rc::Rc;
use std::fmt::{self,Debug};
use std::ops::Deref;
use serde::de::*;
use serde::ser::*;
use std::hash::{Hash,Hasher};
use std::collections::hash_map::DefaultHasher;
use std::collections::HashMap;
use std::cell::RefCell;
use core::any::Any;
pub type Id = u64;
#[derive(Clone)]
pub struct Shared<T> {
ptr:SharedPtr<T>
}
#[derive(Serialize, Deserialize, Clone)]
enum SharedPtr<T> {
Rc(Id, Rc<T>),
Copy(Id),
}
impl<T:Hash+'static> Shared<T> {
pub fn id(&self) -> Id {
match self.ptr {
SharedPtr::Rc(id, _) => id.clone(),
SharedPtr::Copy(id) => id.clone(),
}
}
pub fn new(t:T) -> Shared<T> {
let mut hasher = DefaultHasher::new();
t.hash(&mut hasher);
let id = hasher.finish();
Shared{ptr:SharedPtr::Rc(id, Rc::new(t))}
}
pub fn from_rc(rc:Rc<T>) -> Shared<T> {
let mut hasher = DefaultHasher::new();
rc.hash(&mut hasher);
let id = hasher.finish();
Shared{ptr:SharedPtr::Rc(id, rc)}
}
}
impl<T:PartialEq+'static> PartialEq for Shared<T> {
fn eq(&self, other:&Self) -> bool {
match (&self.ptr, &other.ptr) {
(&SharedPtr::Rc(ref id1, ref rc1),
&SharedPtr::Rc(ref id2, ref rc2)) => {
if true {
id1 == id2
} else {
rc1 == rc2
}
},
_ => unreachable!()
}
}
}
impl<T:PartialEq+'static> Eq for Shared<T> { }
impl<T:'static+Hash> Hash for Shared<T> {
fn hash<H>(&self, state: &mut H) where H: Hasher {
self.id().hash(state)
}
}
impl<T:Debug> fmt::Debug for Shared<T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self.ptr {
SharedPtr::Rc(_, ref rc) => rc.fmt(f),
SharedPtr::Copy(_) => unreachable!()
}
}
}
impl<T> Deref for Shared<T> {
type Target = T;
fn deref(&self) -> &T {
match self.ptr {
SharedPtr::Rc(_, ref rc) => &*rc,
SharedPtr::Copy(_) => unreachable!(),
}
}
}
impl<T:Serialize+Hash+'static> Serialize for Shared<T> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let orc : Option<Rc<T>> = table_get(&self.id());
match (&self.ptr, orc) {
(&SharedPtr::Copy(_), _) => unreachable!(),
(&SharedPtr::Rc(ref id, ref rc), None) => {
table_put(id.clone(), rc.clone());
self.ptr.serialize(serializer)
}
(&SharedPtr::Rc(ref id, ref _rc1), Some(ref _rc2)) => {
table_inc_copy_count();
let ptr_copy:SharedPtr<T> = SharedPtr::Copy(id.clone());
ptr_copy.serialize(serializer)
}
}
}
}
impl<'de,T:Deserialize<'de>+Hash+'static> Deserialize<'de> for Shared<T> {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
match SharedPtr::<T>::deserialize(deserializer) {
Ok(SharedPtr::Copy(id)) => {
match table_get(&id) {
None => unreachable!(),
Some(rc) => {
table_inc_copy_count();
Ok(Shared{ptr:SharedPtr::Rc(id, rc)})
}
}
}
Ok(SharedPtr::Rc(id, rc)) => {
table_put(id.clone(), rc.clone());
Ok(Shared{ptr:SharedPtr::Rc(id, rc)})
},
Err(err) => Err(err),
}
}
}
struct Table {
copy_count:usize,
table:HashMap<Id,Box<Rc<Any>>>
}
thread_local!(static TABLE:
RefCell<Table> =
RefCell::new(Table{
copy_count:0,
table:HashMap::new()
}));
fn table_put<T:Any+'static>(id:Id, x:Rc<T>) {
TABLE.with(|t| {
drop(t.borrow_mut().table.insert(id, Box::new(x)))
})
}
fn table_inc_copy_count() {
TABLE.with(|t| {
t.borrow_mut().copy_count += 1;
})
}
fn table_get<T:'static>(id:&Id) -> Option<Rc<T>> {
TABLE.with(|t| {
match t.borrow().table.get(id) {
Some(ref brc) => {
let x : &Rc<Any> = &**brc;
let y : Result<Rc<T>, Rc<Any>> = (x.clone()).downcast::<T>();
match y {
Err(_) => {
panic!("downcast failed for id {:?}", id)
}
Ok(ref rc) => Some((*rc).clone())
}
}
None => None,
}
})
}
pub fn clear() -> usize {
let copy_count =
TABLE.with(|t| {
let c = t.borrow().copy_count;
t.borrow_mut().table.clear();
t.borrow_mut().copy_count = 0;
c
});
copy_count
}
mod list_example {
use super::Shared;
#[derive(Hash,Clone,Debug,PartialEq,Eq,Serialize,Deserialize)]
enum List {
Nil,
Cons(usize, Shared<List>)
}
fn nil() -> List {
List::Nil
}
fn cons(h:usize, t:List) -> List {
List::Cons(h, Shared::new(t))
}
#[allow(unused)]
fn sum(l:&List) -> usize {
match *l {
List::Nil => 0,
List::Cons(ref h, ref t) => {
h + sum(&*t)
}
}
}
#[allow(unused)]
fn from_vec(v:&Vec<usize>) -> List {
let mut l = nil();
for x in v.iter() {
l = cons(*x, l);
}
return l
}
#[test]
fn test_elim_forms() {
let x = from_vec(&vec![1,2,3]);
assert_eq!(1+2+3, sum(&x))
}
#[test]
fn test_intro_forms() {
let x = nil();
let x = cons(1, x);
let y = cons(2, x.clone());
let z = cons(3, x.clone());
drop((x,y,z))
}
#[test]
fn test_serde() {
use serde_json;
let (value, expected_copy_count) = {
let x = nil();
let x = cons(1, x);
let y = cons(2, x.clone());
let z = cons(3, x.clone());
((x,y,z), 2)
};
let serialized = serde_json::to_string(&value).unwrap();
let copy_count1 = super::clear();
println!("serialized = {}", serialized);
println!("copy_count1 = {}", copy_count1);
assert_eq!(copy_count1, expected_copy_count);
let deserialized: (List,List,List) =
serde_json::from_str(&serialized[..]).unwrap();
let copy_count2 = super::clear();
println!("copy_count2 = {}", copy_count2);
assert_eq!(copy_count2, expected_copy_count);
assert_eq!(deserialized, value);
}
}