preemptive_threads/
scheduler.rs

1use crate::error::{ThreadError, ThreadResult};
2use crate::thread::{Thread, ThreadId, ThreadState};
3use core::cell::UnsafeCell;
4
5const MAX_THREADS: usize = 32;
6
7pub struct Scheduler {
8    threads: [Option<Thread>; MAX_THREADS],
9    current_thread: Option<ThreadId>,
10    next_thread_id: ThreadId,
11    run_queue: [Option<ThreadId>; MAX_THREADS],
12    run_queue_head: usize,
13    run_queue_tail: usize,
14    run_queue_count: usize,
15}
16
17pub struct SchedulerCell(UnsafeCell<Scheduler>);
18
19unsafe impl Sync for SchedulerCell {}
20
21impl Default for SchedulerCell {
22    fn default() -> Self {
23        Self::new()
24    }
25}
26
27impl SchedulerCell {
28    pub const fn new() -> Self {
29        SchedulerCell(UnsafeCell::new(Scheduler::new()))
30    }
31
32    /// # Safety
33    /// Returns mutable reference to scheduler. Caller must ensure thread safety.
34    #[allow(clippy::mut_from_ref)]
35    pub unsafe fn get(&self) -> &mut Scheduler {
36        &mut *self.0.get()
37    }
38}
39
40pub static SCHEDULER: SchedulerCell = SchedulerCell::new();
41
42impl Default for Scheduler {
43    fn default() -> Self {
44        Self::new()
45    }
46}
47
48impl Scheduler {
49    pub const fn new() -> Self {
50        Scheduler {
51            threads: [const { None }; MAX_THREADS],
52            current_thread: None,
53            next_thread_id: 0,
54            run_queue: [None; MAX_THREADS],
55            run_queue_head: 0,
56            run_queue_tail: 0,
57            run_queue_count: 0,
58        }
59    }
60
61    pub fn spawn_thread(
62        &mut self,
63        stack: &'static mut [u8],
64        entry_point: fn(),
65        priority: u8,
66    ) -> ThreadResult<ThreadId> {
67        let thread_id = self.next_thread_id;
68
69        if thread_id >= MAX_THREADS {
70            return Err(ThreadError::MaxThreadsReached);
71        }
72
73        let thread = Thread::new(thread_id, stack, entry_point, priority);
74        self.threads[thread_id] = Some(thread);
75        self.next_thread_id += 1;
76
77        self.enqueue_thread(thread_id)?;
78        Ok(thread_id)
79    }
80
81    fn enqueue_thread(&mut self, thread_id: ThreadId) -> ThreadResult<()> {
82        if self.run_queue_count >= MAX_THREADS {
83            return Err(ThreadError::SchedulerFull);
84        }
85
86        self.run_queue[self.run_queue_tail] = Some(thread_id);
87        self.run_queue_tail = (self.run_queue_tail + 1) % MAX_THREADS;
88        self.run_queue_count += 1;
89        Ok(())
90    }
91
92    pub fn schedule(&mut self) -> Option<ThreadId> {
93        if let Some(current) = self.current_thread {
94            if let Some(thread) = &mut self.threads[current] {
95                if thread.state == ThreadState::Running {
96                    // Set back to Ready when yielding
97                    thread.state = ThreadState::Ready;
98                    let _ = self.enqueue_thread(current);
99                }
100            }
101        }
102
103        self.schedule_with_priority()
104    }
105
106    fn schedule_with_priority(&mut self) -> Option<ThreadId> {
107        if self.run_queue_count == 0 {
108            return None;
109        }
110
111        let mut best_thread = None;
112        let mut highest_priority = 0u8;
113        let mut best_index = None;
114
115        // First pass: find the highest priority
116        for i in 0..self.run_queue_count {
117            let queue_index = (self.run_queue_head + i) % MAX_THREADS;
118
119            if let Some(thread_id) = self.run_queue[queue_index] {
120                if let Some(thread) = &self.threads[thread_id] {
121                    if thread.is_runnable() && thread.priority > highest_priority {
122                        highest_priority = thread.priority;
123                    }
124                }
125            }
126        }
127
128        // Second pass: find the first thread with highest priority (round-robin for equal priorities)
129        for i in 0..self.run_queue_count {
130            let queue_index = (self.run_queue_head + i) % MAX_THREADS;
131
132            if let Some(thread_id) = self.run_queue[queue_index] {
133                if let Some(thread) = &self.threads[thread_id] {
134                    if thread.is_runnable() && thread.priority == highest_priority {
135                        best_thread = Some(thread_id);
136                        best_index = Some(queue_index);
137                        break; // Take the first one we find for round-robin
138                    }
139                }
140            }
141        }
142
143        if let (Some(thread_id), Some(index)) = (best_thread, best_index) {
144            // Remove from queue and compact
145            self.run_queue[index] = None;
146
147            let mut read_pos = (index + 1) % MAX_THREADS;
148            let mut write_pos = index;
149
150            while read_pos != self.run_queue_tail {
151                self.run_queue[write_pos] = self.run_queue[read_pos];
152                self.run_queue[read_pos] = None;
153                write_pos = (write_pos + 1) % MAX_THREADS;
154                read_pos = (read_pos + 1) % MAX_THREADS;
155            }
156
157            self.run_queue_tail = write_pos;
158            self.run_queue_count -= 1;
159
160            return Some(thread_id);
161        }
162
163        None
164    }
165
166    pub fn get_current_thread(&self) -> Option<ThreadId> {
167        self.current_thread
168    }
169
170    pub fn set_current_thread(&mut self, thread_id: Option<ThreadId>) {
171        if let Some(old_id) = self.current_thread {
172            if let Some(thread) = &mut self.threads[old_id] {
173                if thread.state == ThreadState::Running {
174                    thread.state = ThreadState::Ready;
175                }
176            }
177        }
178
179        self.current_thread = thread_id;
180
181        if let Some(new_id) = thread_id {
182            if let Some(thread) = &mut self.threads[new_id] {
183                thread.state = ThreadState::Running;
184            }
185        }
186    }
187
188    pub fn exit_current_thread(&mut self) {
189        if let Some(current) = self.current_thread {
190            let mut waiters_to_wake = [None; 4];
191
192            if let Some(thread) = &mut self.threads[current] {
193                thread.state = ThreadState::Finished;
194                waiters_to_wake = thread.join_waiters;
195            }
196
197            // Wake up any threads waiting to join this one
198            for waiter in waiters_to_wake.iter().flatten() {
199                if let Some(waiter_thread) = &mut self.threads[*waiter] {
200                    if waiter_thread.state == ThreadState::Blocked {
201                        waiter_thread.state = ThreadState::Ready;
202                        let _ = self.enqueue_thread(*waiter);
203                    }
204                }
205            }
206        }
207    }
208
209    pub fn join_thread(&mut self, target_id: ThreadId, current_id: ThreadId) -> ThreadResult<()> {
210        if target_id >= MAX_THREADS {
211            return Err(ThreadError::InvalidThreadId);
212        }
213
214        if let Some(target_thread) = &mut self.threads[target_id] {
215            if target_thread.state == ThreadState::Finished {
216                return Ok(()); // Already finished
217            }
218
219            // Add current thread to join waiters
220            for slot in &mut target_thread.join_waiters {
221                if slot.is_none() {
222                    *slot = Some(current_id);
223
224                    // Block current thread
225                    if let Some(current_thread) = &mut self.threads[current_id] {
226                        current_thread.state = ThreadState::Blocked;
227                    }
228
229                    return Ok(());
230                }
231            }
232
233            Err(ThreadError::SchedulerFull)
234        } else {
235            Err(ThreadError::InvalidThreadId)
236        }
237    }
238
239    pub fn get_thread(&self, thread_id: ThreadId) -> Option<&Thread> {
240        if thread_id >= MAX_THREADS {
241            return None;
242        }
243        self.threads[thread_id].as_ref()
244    }
245
246    pub fn get_thread_mut(&mut self, thread_id: ThreadId) -> Option<&mut Thread> {
247        if thread_id >= MAX_THREADS {
248            return None;
249        }
250        self.threads[thread_id].as_mut()
251    }
252
253    pub fn switch_context(&mut self, from_id: ThreadId, to_id: ThreadId) -> ThreadResult<()> {
254        if let Some(from_thread) = self.get_thread(from_id) {
255            if from_thread.check_stack_overflow() {
256                return Err(ThreadError::StackOverflow);
257            }
258        }
259
260        let from_thread = self.get_thread_mut(from_id);
261        let from_context = if let Some(thread) = from_thread {
262            &mut thread.context as *mut _
263        } else {
264            return Err(ThreadError::InvalidThreadId);
265        };
266
267        let to_thread = self.get_thread_mut(to_id);
268        let to_context = if let Some(thread) = to_thread {
269            &thread.context as *const _
270        } else {
271            return Err(ThreadError::InvalidThreadId);
272        };
273
274        unsafe {
275            crate::context::switch_context(from_context, to_context);
276        }
277
278        Ok(())
279    }
280}