1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
// Copyright 2016 Amanieu d'Antras
//
// Licensed under the Apache License, Version 2.0, <LICENSE-APACHE or
// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
// http://opensource.org/licenses/MIT>, at your option. This file may not be
// copied, modified, or distributed except according to those terms.

use std::sync::atomic::{AtomicUsize, AtomicPtr, Ordering};
use std::time::Instant;
use std::cell::Cell;
use std::ptr;
use std::mem;
use thread_parker::ThreadParker;
use word_lock::WordLock;

static NUM_THREADS: AtomicUsize = AtomicUsize::new(0);
static HASHTABLE: AtomicPtr<HashTable> = AtomicPtr::new(ptr::null_mut());
thread_local!(static THREAD_DATA: ThreadData = ThreadData::new());

// Even with 3x more buckets than threads, the memory overhead per thread is
// still only a few hundred bytes per thread.
const LOAD_FACTOR: usize = 3;

struct HashTable {
    // Hash buckets for the table
    entries: Box<[Bucket]>,

    // Number of bits used for the hash function
    hash_bits: u32,

    // Previous table. This is only kept to keep leak detectors happy.
    _prev: *const HashTable,
}

impl HashTable {
    fn new(num_threads: usize, prev: *const HashTable) -> Box<HashTable> {
        let new_size = (num_threads * LOAD_FACTOR).next_power_of_two();
        let hash_bits = 0usize.leading_zeros() - new_size.leading_zeros() - 1;
        let bucket = Bucket {
            mutex: WordLock::new(),
            queue_head: Cell::new(ptr::null()),
            queue_tail: Cell::new(ptr::null()),
            _padding: unsafe { mem::uninitialized() },
        };
        Box::new(HashTable {
            entries: vec![bucket; new_size].into_boxed_slice(),
            hash_bits: hash_bits,
            _prev: prev,
        })
    }
}

struct Bucket {
    // Lock protecting the queue
    mutex: WordLock,

    // Linked list of threads waiting on this bucket
    queue_head: Cell<*const ThreadData>,
    queue_tail: Cell<*const ThreadData>,

    // Padding to avoid false sharing between buckets. Ideally we would just
    // align the bucket structure to 64 bytes, but Rust doesn't support that yet.
    _padding: [u8; 64],
}

// Implementation of Clone for Bucket, needed to make vec![] work
impl Clone for Bucket {
    fn clone(&self) -> Bucket {
        Bucket {
            mutex: WordLock::new(),
            queue_head: Cell::new(ptr::null()),
            queue_tail: Cell::new(ptr::null()),
            _padding: unsafe { mem::uninitialized() },
        }
    }
}

struct ThreadData {
    parker: ThreadParker,
    key: Cell<usize>,
    next_in_queue: Cell<*const ThreadData>,
}

impl ThreadData {
    fn new() -> ThreadData {
        // Keep track of the total number of live ThreadData objects and resize
        // the hash table accordingly.
        let num_threads = NUM_THREADS.fetch_add(1, Ordering::Relaxed) + 1;
        unsafe {
            grow_hashtable(num_threads);
        }

        ThreadData {
            parker: ThreadParker::new(),
            key: Cell::new(0),
            next_in_queue: Cell::new(ptr::null()),
        }
    }
}

impl Drop for ThreadData {
    fn drop(&mut self) {
        NUM_THREADS.fetch_sub(1, Ordering::Relaxed);
    }
}

// Grow the hash table so that it is big enough for the given number of threads.
// This isn't performance-critical since it is only done when a ThreadData is
// created, which only happens once per thread.
unsafe fn grow_hashtable(num_threads: usize) {
    // If there is no table, create one
    if HASHTABLE.load(Ordering::Relaxed).is_null() {
        let new_table = Box::into_raw(HashTable::new(num_threads, ptr::null()));

        // If this fails then it means some other thread created the hash
        // table first.
        if HASHTABLE.compare_exchange(ptr::null_mut(),
                              new_table,
                              Ordering::Release,
                              Ordering::Relaxed)
            .is_ok() {
            return;
        }

        // Free the table we created
        Box::from_raw(new_table);
    }

    let mut old_table;
    loop {
        old_table = HASHTABLE.load(Ordering::Acquire);

        // Check if we need to resize the existing table
        if (*old_table).entries.len() >= LOAD_FACTOR * num_threads {
            return;
        }

        // Lock all buckets in the old table
        for b in &(*old_table).entries[..] {
            b.mutex.lock();
        }

        // Now check if our table is still the latest one. Another thread could
        // have grown the hash table between us reading HASHTABLE and locking
        // the buckets.
        if HASHTABLE.load(Ordering::Relaxed) == old_table {
            break;
        }

        // Unlock buckets and try again
        for b in &(*old_table).entries[..] {
            b.mutex.unlock();
        }
    }

    // Create the new table
    let new_table = HashTable::new(num_threads, old_table);

    // Move the entries from the old table to the new one
    for b in &(*old_table).entries[..] {
        let mut current = b.queue_head.get();
        while !current.is_null() {
            let next = (*current).next_in_queue.get();
            let hash = hash((*current).key.get(), new_table.hash_bits);
            if new_table.entries[hash].queue_tail.get().is_null() {
                new_table.entries[hash].queue_head.set(current);
            } else {
                (*new_table.entries[hash].queue_tail.get()).next_in_queue.set(current);
            }
            new_table.entries[hash].queue_tail.set(current);
            (*current).next_in_queue.set(ptr::null());
            current = next;
        }
    }

    // Publish the new table. No races are possible at this point because
    // any other thread trying to grow the hash table is blocked on the bucket
    // locks in the old table.
    HASHTABLE.store(Box::into_raw(new_table), Ordering::Release);

    // Unlock all buckets in the old table
    for b in &(*old_table).entries[..] {
        b.mutex.unlock();
    }
}

// Hash function for addresses
#[cfg(target_pointer_width = "32")]
fn hash(key: usize, bits: u32) -> usize {
    key.wrapping_mul(0x9E3779B9) >> (32 - bits)
}
#[cfg(target_pointer_width = "64")]
fn hash(key: usize, bits: u32) -> usize {
    key.wrapping_mul(0x9E3779B97F4A7C15) >> (64 - bits)
}

// Lock the bucket for the given key
unsafe fn lock_bucket<'a>(key: usize) -> Option<&'a Bucket> {
    let mut bucket;
    loop {
        let hashtable = HASHTABLE.load(Ordering::Acquire);

        // If there is no hash table then there is no bucket to lock
        if hashtable.is_null() {
            return None;
        }

        let hash = hash(key, (*hashtable).hash_bits);
        bucket = &(*hashtable).entries[hash];

        // Lock the bucket
        bucket.mutex.lock();

        // If no other thread has rehashed the table before we grabbed the lock
        // then we are good to go! The lock we grabbed prevents any rehashes.
        if HASHTABLE.load(Ordering::Relaxed) == hashtable {
            return Some(bucket);
        }

        // Unlock the bucket and try again
        bucket.mutex.unlock();
    }
}

/// Result of an `unpark_one` operation.
#[derive(Copy, Clone, Eq, PartialEq, Debug)]
pub enum UnparkResult {
    /// No parked threads were found for the given key.
    NoParkedThreads,

    /// One thread was unparked and it was one in the queue.
    UnparkedLast,

    /// One thread was unparked but there are more in the queue.
    UnparkedNotLast,
}

/// Parks the current thread in the queue associated with the given key.
///
/// The `validate` function is called while the queue is locked and can abort
/// the operation by returning false. If `validate` returns true then the
/// current thread is appended to the queue and the queue is unlocked.
///
/// The `before_sleep` function is called after the queue is unlocked but before
/// the thread is put to sleep. The thread will then sleep until it is unparked
/// or the given timeout is reached.
///
/// This function returns `true` if the thread was unparked by a call to
/// `unpark_one` or `unpark_all`, and `false` if the validation function failed
/// or the timeout was reached.
///
/// # Safety
///
/// You should only call this function with an address that you control, since
/// you could otherwise interfere with the operation of other synchronization
/// primitives.
///
/// The `validate` function is called while the queue is locked and must not
/// panic or call into any function in `parking_lot`.
///
/// The `before_sleep` function is called outside the queue lock and is allowed
/// to call `unpark_one` or `unpark_all`, but it is not allowed to call `park`
/// or panic.
pub unsafe fn park(key: usize,
                   validate: &mut FnMut() -> bool,
                   before_sleep: &mut FnMut(),
                   timeout: Option<Instant>)
                   -> bool {
    // Grab our thread data, this also ensures that the hash table exists
    THREAD_DATA.with(|thread_data| {
        // Lock the bucket for the given key
        let bucket = lock_bucket(key).unwrap();

        // If the validation function fails, just return
        if !validate() {
            bucket.mutex.unlock();
            return false;
        }

        // Append our thread data to the queue and unlock the bucket
        thread_data.next_in_queue.set(ptr::null());
        thread_data.key.set(key);
        thread_data.parker.prepare_park();
        if !bucket.queue_head.get().is_null() {
            (*bucket.queue_tail.get()).next_in_queue.set(thread_data);
        } else {
            bucket.queue_head.set(thread_data);
        }
        bucket.queue_tail.set(thread_data);
        bucket.mutex.unlock();

        // Invoke the pre-sleep callback
        before_sleep();

        // Park our thread and determine whether we were woken up by an unpark
        // or by our timeout. Note that this isn't precise: we can still be
        // unparked since we are still in the queue.
        let unparked = match timeout {
            Some(timeout) => thread_data.parker.park_until(timeout),
            None => {
                thread_data.parker.park();
                true
            }
        };

        // If we were unparked, return now
        if unparked {
            return true;
        }

        // Lock our bucket again. Note that the hashtable may have been rehashed
        // in the meantime.
        let bucket = lock_bucket(key).unwrap();

        // Now we need to check again if we were unparked or timed out. Unlike
        // the last check this is precise because we hold the bucket lock.
        if !thread_data.parker.timed_out() {
            bucket.mutex.unlock();
            return true;
        }

        // We timed out, so we now need to remove our thread from the queue
        let mut link = &bucket.queue_head;
        let mut current = bucket.queue_head.get();
        let mut previous = ptr::null();
        while !current.is_null() {
            if current == thread_data {
                let next = (*current).next_in_queue.get();
                link.set(next);
                if bucket.queue_tail.get() == current {
                    bucket.queue_tail.set(previous);
                }
                break;
            } else {
                link = &(*current).next_in_queue;
                previous = current;
                current = link.get();
            }
        }

        // Unlock the bucket, we are done
        bucket.mutex.unlock();
        false
    })
}

/// Unparks one thread from the queue associated with the given key.
///
/// The `callback` function is called while the queue is locked and before the
/// target thread is woken up. The `UnparkResult` argument to the function
/// indicates whether a thread was found in the queue and whether this was the
/// last thread in the queue. This value is also returned by `unpark_one`.
///
/// # Safety
///
/// You should only call this function with an address that you control, since
/// you could otherwise interfere with the operation of other synchronization
/// primitives.
///
/// The `callback` function is called while the queue is locked and must not
/// panic or call into any function in `parking_lot`.
pub unsafe fn unpark_one(key: usize, callback: &mut FnMut(UnparkResult)) -> UnparkResult {
    // Lock the bucket for the given key
    let bucket = match lock_bucket(key) {
        Some(x) => x,
        None => {
            // If there is no hash table then there is nothing to unpark
            callback(UnparkResult::NoParkedThreads);
            return UnparkResult::NoParkedThreads;
        }
    };

    // Find a thread with a matching key and remove it from the queue
    let mut link = &bucket.queue_head;
    let mut current = bucket.queue_head.get();
    let mut previous = ptr::null();
    while !current.is_null() {
        if (*current).key.get() == key {
            // Remove the thread from the queue
            let next = (*current).next_in_queue.get();
            link.set(next);
            let mut result = UnparkResult::UnparkedLast;
            if bucket.queue_tail.get() == current {
                bucket.queue_tail.set(previous);
            } else {
                // Scan the rest of the queue to see if there are any other
                // entries with the given key.
                let mut scan = next;
                while !scan.is_null() {
                    if (*scan).key.get() == key {
                        result = UnparkResult::UnparkedNotLast;
                        break;
                    }
                    scan = (*scan).next_in_queue.get();
                }
            }

            // Invoke the callback before waking up the thread
            callback(result);

            // Unpark the thread while holding the bucket lock to avoid race
            // conditions with timeouts. Once unparked, the thread will act as
            // if it was woken up by an unpark even if it reached its timeout.
            (*current).parker.unpark();
            bucket.mutex.unlock();
            return result;
        } else {
            link = &(*current).next_in_queue;
            previous = current;
            current = link.get();
        }
    }

    // No threads with a matching key were found in the bucket
    callback(UnparkResult::NoParkedThreads);
    bucket.mutex.unlock();
    UnparkResult::NoParkedThreads
}

/// Unparks all threads in the queue associated with the given key.
///
/// This function returns the number of threads that were unparked.
///
/// # Safety
///
/// You should only call this function with an address that you control, since
/// you could otherwise interfere with the operation of other synchronization
/// primitives.
pub unsafe fn unpark_all(key: usize) -> usize {
    // Lock the bucket for the given key
    let bucket = match lock_bucket(key) {
        Some(x) => x,
        // If there is no hash table then there is nothing to unpark
        None => return 0,
    };

    // Remove all threads with the given key in the bucket
    let mut link = &bucket.queue_head;
    let mut current = bucket.queue_head.get();
    let mut previous = ptr::null();
    let mut num_threads = 0;
    while !current.is_null() {
        if (*current).key.get() == key {
            // Remove the thread from the queue
            let next = (*current).next_in_queue.get();
            link.set(next);
            if bucket.queue_tail.get() == current {
                bucket.queue_tail.set(previous);
            }

            // Unpark the thread while holding the bucket lock to avoid race
            // conditions with timeouts. Once unparked, the thread will act as
            // if it was woken up by an unpark even if it reached its timeout.
            (*current).parker.unpark();

            num_threads += 1;
            current = next;
        } else {
            link = &(*current).next_in_queue;
            previous = current;
            current = link.get();
        }
    }

    // Unlock the bucket
    bucket.mutex.unlock();

    num_threads
}