use std::cell::Cell;
use std::cmp::Reverse;
use std::collections::BinaryHeap;
use std::sync::{Mutex, OnceLock};
#[derive(Default)]
struct ThreadIdManager {
free_from: usize,
free_list: BinaryHeap<Reverse<usize>>,
}
impl ThreadIdManager {
fn alloc(&mut self) -> usize {
if let Some(id) = self.free_list.pop() {
id.0
} else {
let id = self.free_from;
self.free_from = self
.free_from
.checked_add(1)
.expect("Ran out of thread IDs");
id
}
}
fn free(&mut self, id: usize) {
self.free_list.push(Reverse(id));
}
}
fn thread_id_manager() -> &'static Mutex<ThreadIdManager> {
static THREAD_ID_MANAGER: OnceLock<Mutex<ThreadIdManager>> = OnceLock::new();
THREAD_ID_MANAGER.get_or_init(Default::default)
}
#[derive(Clone, Copy)]
pub struct Thread {
pub id: usize,
pub entry: usize,
pub bucket: usize,
}
const ZERO_ENTRY: usize = 31;
const ZERO_BUCKET: usize = (usize::BITS - ZERO_ENTRY.leading_zeros()) as usize;
pub const BUCKETS: usize = (usize::BITS as usize) - ZERO_BUCKET;
const MAX_INDEX: usize = usize::MAX - ZERO_ENTRY - 1;
impl Thread {
#[inline]
pub fn new(id: usize) -> Thread {
if id > MAX_INDEX {
panic!("exceeded maximum thread count")
}
let index = id + ZERO_ENTRY;
let bucket = BUCKETS - ((index + 1).leading_zeros() as usize) - 1;
let entry = index - (Thread::bucket_capacity(bucket) - 1);
Thread { id, bucket, entry }
}
#[inline]
pub fn bucket_capacity(bucket: usize) -> usize {
1 << (bucket + ZERO_BUCKET)
}
#[inline]
pub fn current() -> Thread {
THREAD.with(|thread| {
if let Some(thread) = thread.get() {
thread
} else {
Thread::init_slow(thread)
}
})
}
#[cold]
#[inline(never)]
fn init_slow(thread: &Cell<Option<Thread>>) -> Thread {
let new = Thread::create();
thread.set(Some(new));
THREAD_GUARD.with(|guard| guard.id.set(new.id));
new
}
pub fn create() -> Thread {
Thread::new(thread_id_manager().lock().unwrap().alloc())
}
pub unsafe fn free(id: usize) {
thread_id_manager().lock().unwrap().free(id);
}
}
thread_local! { static THREAD: Cell<Option<Thread>> = const { Cell::new(None) }; }
thread_local! { static THREAD_GUARD: ThreadGuard = const { ThreadGuard { id: Cell::new(0) } }; }
struct ThreadGuard {
id: Cell<usize>,
}
impl Drop for ThreadGuard {
fn drop(&mut self) {
let _ = THREAD.try_with(|thread| thread.set(None));
unsafe { Thread::free(self.id.get()) };
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn thread() {
assert_eq!(Thread::bucket_capacity(0), 32);
for i in 0..32 {
let thread = Thread::new(i);
assert_eq!(thread.id, i);
assert_eq!(thread.bucket, 0);
assert_eq!(thread.entry, i);
}
assert_eq!(Thread::bucket_capacity(1), 64);
for i in 33..96 {
let thread = Thread::new(i);
assert_eq!(thread.id, i);
assert_eq!(thread.bucket, 1);
assert_eq!(thread.entry, i - 32);
}
assert_eq!(Thread::bucket_capacity(2), 128);
for i in 96..224 {
let thread = Thread::new(i);
assert_eq!(thread.id, i);
assert_eq!(thread.bucket, 2);
assert_eq!(thread.entry, i - 96);
}
}
#[test]
fn max_entries() {
let mut entries = 0;
for i in 0..BUCKETS {
entries += Thread::bucket_capacity(i);
}
assert_eq!(entries, MAX_INDEX + 1);
let max = Thread::new(MAX_INDEX);
assert_eq!(max.id, MAX_INDEX);
assert_eq!(max.bucket, BUCKETS - 1);
assert_eq!(Thread::bucket_capacity(BUCKETS - 1), 1 << (usize::BITS - 1));
assert_eq!(max.entry, (1 << (usize::BITS - 1)) - 1);
}
}