1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
//! Thread-specific objects.
//!
//! This is an abstraction over usual thread-local storage, adding a special type which has a value
//! for every thread.
//!
//! This means that you can dynamically create TLS variables, as opposed to the classical fixed
//! static variable. This means that you can store the object reference in a struct, and have many
//! in the same thread.
//!
//! It works by holding a TLS variable with a binary tree map associating unique object IDs with
//! pointers to the object.
//!
//! Performance wise, this is suboptimal, but it is portable contrary to most other approaches.

#![feature(const_fn)]

use std::any::Any;
use std::cell::RefCell;
use std::collections::BTreeMap;
use std::mem;
use std::sync::atomic;

/// The ID counter.
///
/// This is incremented when a new object is created, associating an unique value with the object.
static ID_COUNTER: atomic::AtomicUsize = atomic::AtomicUsize::new(0);

thread_local! {
    /// This thread's thread object maps.
    ///
    /// This maps IDs to pointers to the associated object.
    static THREAD_OBJECTS: RefCell<BTreeMap<usize, Box<Any>>> = RefCell::new(BTreeMap::new());
}

/// A multi-faced object.
///
/// An initial value is chosen upon creation. This value will be copied once the thread reads it
/// for the first time. The value can be read and written, but will only be presented for the
/// current thread. As such, it is "many-faced" meaning that different threads view different
/// values.
#[derive(Copy, Clone)]
pub struct Object<T> {
    /// The initial value cloned when read by a new thread.
    initial: T,
    /// The ID of the object.
    id: usize,
}

impl<T> Object<T> {
    /// Create a new thread object with some initial value.
    ///
    /// The specified value `initial` will be the value assigned when new threads read the object.
    pub fn new(initial: T) -> Object<T> {
        Object {
            initial: initial,
            // Increment the ID counter and use the previous value. Relaxed ordering is fine as it
            // guarantees uniqueness, which is the only constraint we need.
            id: ID_COUNTER.fetch_add(1, atomic::Ordering::Relaxed),
        }
    }
}

impl<T: Clone + Any> Object<T> {
    /// Read and/or modify the value associated with this thread.
    ///
    /// This reads the object's value associated with the current thread, and initializes it if
    /// necessary. The mutable reference to the object is passed through the closure `f` and the
    /// return value of said closure is then returned.
    ///
    /// The reason we use a closure is to prevent the programmer leaking the pointer to another
    /// thread, causing memory safety issues as the pointer is only valid in the current thread.
    pub fn with<F, R>(&self, f: F) -> R where F: FnOnce(&mut T) -> R {
        // We'll fetch it from the thread object map.
        THREAD_OBJECTS.with(|map| {
            // TODO: Eliminate this `RefCell`.
            let mut guard = map.borrow_mut();
            // Fetch the pointer to the object, and initialize if it doesn't exist.
            let ptr = guard.entry(self.id).or_insert_with(|| Box::new(self.initial.clone()));
            // Run it through the provided closure.
            f(ptr.downcast_mut().unwrap())
        })
    }

    /// Replace the inner value.
    ///
    /// This replaces the inner value with `new` and returns the old value.
    pub fn replace(&self, new: T) -> T {
        self.with(|x| mem::replace(x, new))
    }

    /// Copy the inner value.
    pub fn get(&self) -> T where T: Copy {
        self.with(|x| *x)
    }
}

impl<T: Default> Default for Object<T> {
    fn default() -> Object<T> {
        Object::new(T::default())
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    use std::thread;
    use std::sync::{Mutex, Arc};

    #[test]
    fn initial_value() {
        let obj = Object::new(23);
        obj.with(|&mut x| assert_eq!(x, 23));
        assert_eq!(obj.with(|&mut x| x), 23);
    }

    #[test]
    fn string() {
        let obj = Object::new(String::new());

        obj.with(|x| {
            assert!(x.is_empty());

            x.push('b');
        });

        obj.with(|x| {
            assert_eq!(x, "b");

            x.push('a');
        });

        obj.with(|x| {
            assert_eq!(x, "ba");
        });
    }

    #[test]
    fn multiple_objects() {
        let obj1 = Object::new(0);
        let obj2 = Object::new(0);

        obj2.with(|x| *x = 1);

        obj1.with(|&mut x| assert_eq!(x, 0));
        obj2.with(|&mut x| assert_eq!(x, 1));
    }

    #[test]
    fn multi_thread() {
        let obj = Object::new(0);
        thread::spawn(move || {
            obj.with(|x| *x = 1);
        }).join().unwrap();

        obj.with(|&mut x| assert_eq!(x, 0));

        thread::spawn(move || {
            obj.with(|&mut x| assert_eq!(x, 0));
            obj.with(|x| *x = 2);
        }).join().unwrap();

        obj.with(|&mut x| assert_eq!(x, 0));
    }

    #[test]
    fn replace() {
        let obj = Object::new(420); // blaze it
        assert_eq!(obj.replace(42), 420);
        assert_eq!(obj.replace(32), 42);
        assert_eq!(obj.replace(0), 32);
    }

    #[test]
    fn default() {
        assert_eq!(Object::<usize>::default().get(), 0);
    }

    #[derive(Clone)]
    struct Dropper {
        is_dropped: Arc<Mutex<bool>>,
    }

    impl Drop for Dropper {
        fn drop(&mut self) {
            *self.is_dropped.lock().unwrap() = true;
        }
    }

    #[test]
    fn drop() {
        let is_dropped = Arc::new(Mutex::new(false));
        let arc = is_dropped.clone();
        thread::spawn(move || {
            let obj = Object::new(Dropper {
                is_dropped: arc,
            });

            obj.with(|_| {});

            mem::forget(obj);
        }).join().unwrap();

        assert!(*is_dropped.lock().unwrap());
    }
}