preemptive_threads/
atomic_scheduler.rs

1use crate::error::{ThreadError, ThreadResult};
2use crate::thread::{Thread, ThreadId, ThreadState};
3use core::cell::UnsafeCell;
4use core::mem::MaybeUninit;
5use core::sync::atomic::{AtomicBool, AtomicU32, AtomicUsize, Ordering};
6
7const MAX_THREADS: usize = 32;
8const PRIORITY_LEVELS: usize = 8;
9
10/// Lock-free priority queue for thread scheduling
11pub struct PriorityQueue {
12    /// Per-priority circular buffers
13    queues: [CircularBuffer; PRIORITY_LEVELS],
14    /// Bitmap of non-empty priority levels
15    priority_bitmap: AtomicU32,
16}
17
18struct CircularBuffer {
19    buffer: [AtomicUsize; MAX_THREADS],
20    head: AtomicUsize,
21    tail: AtomicUsize,
22}
23
24impl CircularBuffer {
25    #[allow(dead_code)]
26    fn new() -> Self {
27        Self {
28            buffer: core::array::from_fn(|_| AtomicUsize::new(usize::MAX)),
29            head: AtomicUsize::new(0),
30            tail: AtomicUsize::new(0),
31        }
32    }
33
34    fn enqueue(&self, thread_id: ThreadId) -> bool {
35        let mut tail = self.tail.load(Ordering::Acquire);
36
37        loop {
38            let next_tail = (tail + 1) % MAX_THREADS;
39            let head = self.head.load(Ordering::Acquire);
40
41            if next_tail == head {
42                return false; // Queue full
43            }
44
45            match self.tail.compare_exchange_weak(
46                tail,
47                next_tail,
48                Ordering::Release,
49                Ordering::Acquire,
50            ) {
51                Ok(_) => {
52                    self.buffer[tail].store(thread_id, Ordering::Release);
53                    return true;
54                }
55                Err(actual) => tail = actual,
56            }
57        }
58    }
59
60    fn dequeue(&self) -> Option<ThreadId> {
61        let mut head = self.head.load(Ordering::Acquire);
62
63        loop {
64            let tail = self.tail.load(Ordering::Acquire);
65
66            if head == tail {
67                return None; // Queue empty
68            }
69
70            let thread_id = self.buffer[head].load(Ordering::Acquire);
71            if thread_id == usize::MAX {
72                // Spurious empty slot, try next
73                head = (head + 1) % MAX_THREADS;
74                continue;
75            }
76
77            let next_head = (head + 1) % MAX_THREADS;
78
79            match self.head.compare_exchange_weak(
80                head,
81                next_head,
82                Ordering::Release,
83                Ordering::Acquire,
84            ) {
85                Ok(_) => {
86                    self.buffer[head].store(usize::MAX, Ordering::Release);
87                    return Some(thread_id);
88                }
89                Err(actual) => head = actual,
90            }
91        }
92    }
93
94    fn is_empty(&self) -> bool {
95        self.head.load(Ordering::Acquire) == self.tail.load(Ordering::Acquire)
96    }
97}
98
99impl PriorityQueue {
100    #[allow(dead_code)]
101    fn new() -> Self {
102        Self {
103            queues: core::array::from_fn(|_| CircularBuffer::new()),
104            priority_bitmap: AtomicU32::new(0),
105        }
106    }
107
108    fn enqueue(&self, thread_id: ThreadId, priority: u8) -> bool {
109        let priority_level = (priority as usize).min(PRIORITY_LEVELS - 1);
110
111        if self.queues[priority_level].enqueue(thread_id) {
112            // Set bit in bitmap to indicate non-empty queue
113            self.priority_bitmap
114                .fetch_or(1 << priority_level, Ordering::Release);
115            true
116        } else {
117            false
118        }
119    }
120
121    fn dequeue(&self) -> Option<ThreadId> {
122        let mut bitmap = self.priority_bitmap.load(Ordering::Acquire);
123
124        while bitmap != 0 {
125            // Find highest priority non-empty queue (MSB)
126            let priority_level = 31 - bitmap.leading_zeros() as usize;
127
128            if let Some(thread_id) = self.queues[priority_level].dequeue() {
129                // Check if queue is now empty and clear bit
130                if self.queues[priority_level].is_empty() {
131                    self.priority_bitmap
132                        .fetch_and(!(1 << priority_level), Ordering::Release);
133                }
134                return Some(thread_id);
135            }
136
137            // Queue was empty, clear bit and retry
138            bitmap &= !(1 << priority_level);
139        }
140
141        None
142    }
143}
144
145/// Per-CPU scheduler state
146pub struct CpuScheduler {
147    /// Current running thread on this CPU
148    current_thread: AtomicUsize,
149    /// CPU-local run queue for better cache locality
150    local_queue: UnsafeCell<CircularBuffer>,
151    /// Is this CPU idle?
152    idle: AtomicBool,
153}
154
155unsafe impl Sync for CpuScheduler {}
156
157impl CpuScheduler {
158    const fn new() -> Self {
159        Self {
160            current_thread: AtomicUsize::new(usize::MAX),
161            local_queue: UnsafeCell::new(CircularBuffer {
162                buffer: [const { AtomicUsize::new(usize::MAX) }; MAX_THREADS],
163                head: AtomicUsize::new(0),
164                tail: AtomicUsize::new(0),
165            }),
166            idle: AtomicBool::new(true),
167        }
168    }
169}
170
171/// Thread-safe atomic scheduler
172pub struct AtomicScheduler {
173    /// Thread pool
174    threads: [UnsafeCell<MaybeUninit<Thread>>; MAX_THREADS],
175    /// Thread allocation bitmap
176    thread_bitmap: AtomicU32,
177    /// Next thread ID counter
178    #[allow(dead_code)]
179    next_thread_id: AtomicUsize,
180    /// Global priority queue
181    global_queue: UnsafeCell<PriorityQueue>,
182    /// Per-CPU schedulers (we'll use just one for now)
183    cpu_schedulers: [CpuScheduler; 1],
184    /// Scheduler lock for critical sections
185    scheduler_lock: AtomicBool,
186}
187
188unsafe impl Sync for AtomicScheduler {}
189
190impl Default for AtomicScheduler {
191    fn default() -> Self {
192        Self::new()
193    }
194}
195
196impl AtomicScheduler {
197    pub const fn new() -> Self {
198        Self {
199            threads: [const { UnsafeCell::new(MaybeUninit::uninit()) }; MAX_THREADS],
200            thread_bitmap: AtomicU32::new(0),
201            next_thread_id: AtomicUsize::new(0),
202            global_queue: UnsafeCell::new(PriorityQueue {
203                queues: [const {
204                    CircularBuffer {
205                        buffer: [const { AtomicUsize::new(usize::MAX) }; MAX_THREADS],
206                        head: AtomicUsize::new(0),
207                        tail: AtomicUsize::new(0),
208                    }
209                }; PRIORITY_LEVELS],
210                priority_bitmap: AtomicU32::new(0),
211            }),
212            cpu_schedulers: [CpuScheduler::new(); 1],
213            scheduler_lock: AtomicBool::new(false),
214        }
215    }
216
217    /// Acquire scheduler lock with exponential backoff
218    fn acquire_lock(&self) {
219        let mut backoff = 1;
220        while self
221            .scheduler_lock
222            .compare_exchange_weak(false, true, Ordering::Acquire, Ordering::Relaxed)
223            .is_err()
224        {
225            for _ in 0..backoff {
226                core::hint::spin_loop();
227            }
228            backoff = (backoff * 2).min(1024);
229        }
230    }
231
232    /// Release scheduler lock
233    fn release_lock(&self) {
234        self.scheduler_lock.store(false, Ordering::Release);
235    }
236
237    pub fn spawn_thread(
238        &self,
239        stack: &'static mut [u8],
240        entry_point: fn(),
241        priority: u8,
242    ) -> ThreadResult<ThreadId> {
243        // Find free thread slot
244        let mut bitmap = self.thread_bitmap.load(Ordering::Acquire);
245        let mut slot = None;
246
247        for i in 0..MAX_THREADS {
248            if bitmap & (1 << i) == 0 {
249                // Try to claim this slot
250                if self
251                    .thread_bitmap
252                    .compare_exchange_weak(
253                        bitmap,
254                        bitmap | (1 << i),
255                        Ordering::AcqRel,
256                        Ordering::Acquire,
257                    )
258                    .is_ok()
259                {
260                    slot = Some(i);
261                    break;
262                }
263                // Reload bitmap and retry
264                bitmap = self.thread_bitmap.load(Ordering::Acquire);
265            }
266        }
267
268        let slot = slot.ok_or(ThreadError::MaxThreadsReached)?;
269        let thread_id = slot;
270
271        // Initialize thread
272        let thread = Thread::new(thread_id, stack, entry_point, priority);
273
274        unsafe {
275            (*self.threads[slot].get()).write(thread);
276        }
277
278        // Add to run queue
279        let global_queue = unsafe { &*self.global_queue.get() };
280        if !global_queue.enqueue(thread_id, priority) {
281            // Failed to enqueue, free the slot
282            self.thread_bitmap
283                .fetch_and(!(1 << slot), Ordering::Release);
284            return Err(ThreadError::SchedulerFull);
285        }
286
287        Ok(thread_id)
288    }
289
290    pub fn schedule(&self) -> Option<ThreadId> {
291        let cpu = &self.cpu_schedulers[0];
292
293        // Try local queue first for better cache locality
294        let local_queue = unsafe { &*cpu.local_queue.get() };
295
296        if let Some(thread_id) = local_queue.dequeue() {
297            return Some(thread_id);
298        }
299
300        // Fall back to global queue
301        let global_queue = unsafe { &*self.global_queue.get() };
302        global_queue.dequeue()
303    }
304
305    pub fn get_current_thread(&self) -> Option<ThreadId> {
306        let cpu = &self.cpu_schedulers[0];
307        let current = cpu.current_thread.load(Ordering::Acquire);
308
309        if current == usize::MAX {
310            None
311        } else {
312            Some(current)
313        }
314    }
315
316    pub fn set_current_thread(&self, thread_id: Option<ThreadId>) {
317        let cpu = &self.cpu_schedulers[0];
318
319        match thread_id {
320            Some(id) => {
321                cpu.current_thread.store(id, Ordering::Release);
322                cpu.idle.store(false, Ordering::Release);
323            }
324            None => {
325                cpu.current_thread.store(usize::MAX, Ordering::Release);
326                cpu.idle.store(true, Ordering::Release);
327            }
328        }
329    }
330
331    pub fn get_thread(&self, thread_id: ThreadId) -> Option<&Thread> {
332        if thread_id >= MAX_THREADS {
333            return None;
334        }
335
336        let bitmap = self.thread_bitmap.load(Ordering::Acquire);
337        if bitmap & (1 << thread_id) == 0 {
338            return None;
339        }
340
341        unsafe { Some((*self.threads[thread_id].get()).assume_init_ref()) }
342    }
343
344    /// # Safety
345    /// Returns mutable reference to thread. Caller must ensure thread safety.
346    #[allow(clippy::mut_from_ref)]
347    pub unsafe fn get_thread_mut(&self, thread_id: ThreadId) -> Option<&mut Thread> {
348        if thread_id >= MAX_THREADS {
349            return None;
350        }
351
352        let bitmap = self.thread_bitmap.load(Ordering::Acquire);
353        if bitmap & (1 << thread_id) == 0 {
354            return None;
355        }
356
357        Some((*self.threads[thread_id].get()).assume_init_mut())
358    }
359
360    pub fn exit_current_thread(&self) {
361        if let Some(thread_id) = self.get_current_thread() {
362            self.acquire_lock();
363
364            unsafe {
365                if let Some(thread) = self.get_thread_mut(thread_id) {
366                    thread.state = ThreadState::Finished;
367
368                    // Wake up joiners
369                    for waiter_id in thread.join_waiters.iter().flatten() {
370                        if let Some(waiter) = self.get_thread_mut(*waiter_id) {
371                            if waiter.state == ThreadState::Blocked {
372                                waiter.state = ThreadState::Ready;
373                                let global_queue = &*self.global_queue.get();
374                                let _ = global_queue.enqueue(*waiter_id, waiter.priority);
375                            }
376                        }
377                    }
378                }
379            }
380
381            self.release_lock();
382
383            // Clear current thread
384            self.set_current_thread(None);
385        }
386    }
387
388    pub fn switch_context(&self, from_id: ThreadId, to_id: ThreadId) -> ThreadResult<()> {
389        // Validate thread IDs
390        if from_id >= MAX_THREADS || to_id >= MAX_THREADS {
391            return Err(ThreadError::InvalidThreadId);
392        }
393
394        let bitmap = self.thread_bitmap.load(Ordering::Acquire);
395        if bitmap & (1 << from_id) == 0 || bitmap & (1 << to_id) == 0 {
396            return Err(ThreadError::InvalidThreadId);
397        }
398
399        unsafe {
400            // Get thread pointers
401            let from_thread = self
402                .get_thread_mut(from_id)
403                .ok_or(ThreadError::InvalidThreadId)?;
404            let from_context = &mut from_thread.context as *mut _;
405
406            let to_thread = self.get_thread(to_id).ok_or(ThreadError::InvalidThreadId)?;
407            let to_context = &to_thread.context as *const _;
408
409            crate::context::switch_context(from_context, to_context);
410        }
411
412        Ok(())
413    }
414}
415
416pub static ATOMIC_SCHEDULER: AtomicScheduler = AtomicScheduler::new();