1#![deny(missing_docs)]
23#![cfg_attr(feature = "nightly", feature(test))]
24
25use std::collections::hash_map::{HashMap, Entry};
26use std::fmt;
27use std::sync::Mutex;
28use std::sync::atomic::{AtomicUsize, ATOMIC_USIZE_INIT};
29use std::sync::atomic::Ordering::Relaxed;
30
31static COUNTER: AtomicUsize = ATOMIC_USIZE_INIT;
39thread_local!(static THREAD_ID: usize = COUNTER.fetch_add(1, Relaxed) + 1);
40
41pub type CreateFn<T> = Box<Fn() -> T + Send + Sync + 'static>;
43
44pub struct Pool<T: Send> {
46 create: CreateFn<T>,
47 owner: AtomicUsize,
48 owner_val: T,
49 global: Mutex<HashMap<usize, Box<T>>>,
50}
51
52unsafe impl<T: Send> Sync for Pool<T> {}
53
54impl<T: fmt::Debug + Send + 'static> fmt::Debug for Pool<T> {
55 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
56 write!(f, "Pool({:?})", self.owner_val)
57 }
58}
59
60impl<T: Send> Pool<T> {
61 pub fn new(create: CreateFn<T>) -> Pool<T> {
63 let owner_val = (create)();
64 Pool {
65 create: create,
66 owner: AtomicUsize::new(0),
67 owner_val: owner_val,
68 global: Mutex::new(HashMap::new()),
69 }
70 }
71
72 #[inline(always)]
82 pub fn get(&self) -> &T {
83 let id = THREAD_ID.with(|id| *id);
84 let owner = self.owner.load(Relaxed);
85 if owner == id {
88 return &self.owner_val;
89 }
90 self.get_slow(owner, id)
91 }
92
93 #[cold]
94 fn get_slow(&self, owner: usize, thread_id: usize) -> &T {
95 if owner == 0 {
96 if self.owner.compare_and_swap(0, thread_id, Relaxed) == 0 {
97 return &self.owner_val;
98 }
99 }
100 let mut global = self.global.lock().unwrap();
101 match global.entry(thread_id) {
102 Entry::Occupied(ref e) => {
103 let p: *const T = &**e.get();
104 unsafe { &*p }
105 }
106 Entry::Vacant(e) => {
107 let t = Box::new((self.create)());
108 let p: *const T = &*t;
109 e.insert(t);
110 unsafe { &*p }
111 }
112 }
113 }
114}
115
116#[cfg(test)]
117#[cfg(feature = "nightly")]
118mod bench;
119
120#[cfg(test)]
121mod tests {
122 use std::cell::RefCell;
123 use std::sync::Arc;
124 use std::sync::atomic::AtomicUsize;
125 use std::sync::atomic::Ordering::SeqCst;
126 use std::thread;
127
128 use super::{CreateFn, Pool};
129
130 #[derive(Debug, Eq, PartialEq)]
131 struct Dummy(usize);
132
133 fn dummy() -> CreateFn<Dummy> {
134 let count = AtomicUsize::new(0);
135 Box::new(move || {
136 Dummy(count.fetch_add(1, SeqCst))
137 })
138 }
139
140 #[test]
141 fn empty() {
142 let pool = Pool::new(dummy());
143 assert_eq!(&Dummy(0), &*pool.get());
144 }
145
146 #[test]
147 fn reuse() {
148 let pool = Pool::new(dummy());
151 {
152 assert_eq!(&Dummy(0), &*pool.get());
153 }
154 assert_eq!(&Dummy(0), &*pool.get());
155 assert_eq!(&Dummy(0), &*pool.get());
156 }
157
158 #[test]
159 fn no_reuse() {
160 let pool = Arc::new(Pool::new(dummy()));
164 let val = pool.get();
165 assert_eq!(&Dummy(0), &*val);
166
167 let pool2 = pool.clone();
168 thread::spawn(move || {
169 assert_eq!(&Dummy(1), &*pool2.get());
170 }).join().unwrap();
171 }
172
173 #[test]
174 fn is_sync() {
175 fn foo<T: Sync>() {}
176 foo::<Pool<String>>();
177 foo::<Pool<RefCell<String>>>();
178 }
179}