preemptive_threads/
scheduler.rs1use crate::error::{ThreadError, ThreadResult};
2use crate::thread::{Thread, ThreadId, ThreadState};
3use core::cell::UnsafeCell;
4
5const MAX_THREADS: usize = 32;
6
7pub struct Scheduler {
8 threads: [Option<Thread>; MAX_THREADS],
9 current_thread: Option<ThreadId>,
10 next_thread_id: ThreadId,
11 run_queue: [Option<ThreadId>; MAX_THREADS],
12 run_queue_head: usize,
13 run_queue_tail: usize,
14 run_queue_count: usize,
15}
16
17pub struct SchedulerCell(UnsafeCell<Scheduler>);
18
19unsafe impl Sync for SchedulerCell {}
20
21impl Default for SchedulerCell {
22 fn default() -> Self {
23 Self::new()
24 }
25}
26
27impl SchedulerCell {
28 pub const fn new() -> Self {
29 SchedulerCell(UnsafeCell::new(Scheduler::new()))
30 }
31
32 #[allow(clippy::mut_from_ref)]
35 pub unsafe fn get(&self) -> &mut Scheduler {
36 &mut *self.0.get()
37 }
38}
39
40pub static SCHEDULER: SchedulerCell = SchedulerCell::new();
41
42impl Default for Scheduler {
43 fn default() -> Self {
44 Self::new()
45 }
46}
47
48impl Scheduler {
49 pub const fn new() -> Self {
50 Scheduler {
51 threads: [const { None }; MAX_THREADS],
52 current_thread: None,
53 next_thread_id: 0,
54 run_queue: [None; MAX_THREADS],
55 run_queue_head: 0,
56 run_queue_tail: 0,
57 run_queue_count: 0,
58 }
59 }
60
61 pub fn spawn_thread(
62 &mut self,
63 stack: &'static mut [u8],
64 entry_point: fn(),
65 priority: u8,
66 ) -> ThreadResult<ThreadId> {
67 let thread_id = self.next_thread_id;
68
69 if thread_id >= MAX_THREADS {
70 return Err(ThreadError::MaxThreadsReached);
71 }
72
73 let thread = Thread::new(thread_id, stack, entry_point, priority);
74 self.threads[thread_id] = Some(thread);
75 self.next_thread_id += 1;
76
77 self.enqueue_thread(thread_id)?;
78 Ok(thread_id)
79 }
80
81 fn enqueue_thread(&mut self, thread_id: ThreadId) -> ThreadResult<()> {
82 if self.run_queue_count >= MAX_THREADS {
83 return Err(ThreadError::SchedulerFull);
84 }
85
86 self.run_queue[self.run_queue_tail] = Some(thread_id);
87 self.run_queue_tail = (self.run_queue_tail + 1) % MAX_THREADS;
88 self.run_queue_count += 1;
89 Ok(())
90 }
91
92 pub fn schedule(&mut self) -> Option<ThreadId> {
93 if let Some(current) = self.current_thread {
94 if let Some(thread) = &mut self.threads[current] {
95 if thread.state == ThreadState::Running {
96 thread.state = ThreadState::Ready;
98 let _ = self.enqueue_thread(current);
99 }
100 }
101 }
102
103 self.schedule_with_priority()
104 }
105
106 fn schedule_with_priority(&mut self) -> Option<ThreadId> {
107 if self.run_queue_count == 0 {
108 return None;
109 }
110
111 let mut best_thread = None;
112 let mut highest_priority = 0u8;
113 let mut best_index = None;
114
115 for i in 0..self.run_queue_count {
117 let queue_index = (self.run_queue_head + i) % MAX_THREADS;
118
119 if let Some(thread_id) = self.run_queue[queue_index] {
120 if let Some(thread) = &self.threads[thread_id] {
121 if thread.is_runnable() && thread.priority > highest_priority {
122 highest_priority = thread.priority;
123 }
124 }
125 }
126 }
127
128 for i in 0..self.run_queue_count {
130 let queue_index = (self.run_queue_head + i) % MAX_THREADS;
131
132 if let Some(thread_id) = self.run_queue[queue_index] {
133 if let Some(thread) = &self.threads[thread_id] {
134 if thread.is_runnable() && thread.priority == highest_priority {
135 best_thread = Some(thread_id);
136 best_index = Some(queue_index);
137 break; }
139 }
140 }
141 }
142
143 if let (Some(thread_id), Some(index)) = (best_thread, best_index) {
144 self.run_queue[index] = None;
146
147 let mut read_pos = (index + 1) % MAX_THREADS;
148 let mut write_pos = index;
149
150 while read_pos != self.run_queue_tail {
151 self.run_queue[write_pos] = self.run_queue[read_pos];
152 self.run_queue[read_pos] = None;
153 write_pos = (write_pos + 1) % MAX_THREADS;
154 read_pos = (read_pos + 1) % MAX_THREADS;
155 }
156
157 self.run_queue_tail = write_pos;
158 self.run_queue_count -= 1;
159
160 return Some(thread_id);
161 }
162
163 None
164 }
165
166 pub fn get_current_thread(&self) -> Option<ThreadId> {
167 self.current_thread
168 }
169
170 pub fn set_current_thread(&mut self, thread_id: Option<ThreadId>) {
171 if let Some(old_id) = self.current_thread {
172 if let Some(thread) = &mut self.threads[old_id] {
173 if thread.state == ThreadState::Running {
174 thread.state = ThreadState::Ready;
175 }
176 }
177 }
178
179 self.current_thread = thread_id;
180
181 if let Some(new_id) = thread_id {
182 if let Some(thread) = &mut self.threads[new_id] {
183 thread.state = ThreadState::Running;
184 }
185 }
186 }
187
188 pub fn exit_current_thread(&mut self) {
189 if let Some(current) = self.current_thread {
190 let mut waiters_to_wake = [None; 4];
191
192 if let Some(thread) = &mut self.threads[current] {
193 thread.state = ThreadState::Finished;
194 waiters_to_wake = thread.join_waiters;
195 }
196
197 for waiter in waiters_to_wake.iter().flatten() {
199 if let Some(waiter_thread) = &mut self.threads[*waiter] {
200 if waiter_thread.state == ThreadState::Blocked {
201 waiter_thread.state = ThreadState::Ready;
202 let _ = self.enqueue_thread(*waiter);
203 }
204 }
205 }
206 }
207 }
208
209 pub fn join_thread(&mut self, target_id: ThreadId, current_id: ThreadId) -> ThreadResult<()> {
210 if target_id >= MAX_THREADS {
211 return Err(ThreadError::InvalidThreadId);
212 }
213
214 if let Some(target_thread) = &mut self.threads[target_id] {
215 if target_thread.state == ThreadState::Finished {
216 return Ok(()); }
218
219 for slot in &mut target_thread.join_waiters {
221 if slot.is_none() {
222 *slot = Some(current_id);
223
224 if let Some(current_thread) = &mut self.threads[current_id] {
226 current_thread.state = ThreadState::Blocked;
227 }
228
229 return Ok(());
230 }
231 }
232
233 Err(ThreadError::SchedulerFull)
234 } else {
235 Err(ThreadError::InvalidThreadId)
236 }
237 }
238
239 pub fn get_thread(&self, thread_id: ThreadId) -> Option<&Thread> {
240 if thread_id >= MAX_THREADS {
241 return None;
242 }
243 self.threads[thread_id].as_ref()
244 }
245
246 pub fn get_thread_mut(&mut self, thread_id: ThreadId) -> Option<&mut Thread> {
247 if thread_id >= MAX_THREADS {
248 return None;
249 }
250 self.threads[thread_id].as_mut()
251 }
252
253 pub fn switch_context(&mut self, from_id: ThreadId, to_id: ThreadId) -> ThreadResult<()> {
254 if let Some(from_thread) = self.get_thread(from_id) {
255 if from_thread.check_stack_overflow() {
256 return Err(ThreadError::StackOverflow);
257 }
258 }
259
260 let from_thread = self.get_thread_mut(from_id);
261 let from_context = if let Some(thread) = from_thread {
262 &mut thread.context as *mut _
263 } else {
264 return Err(ThreadError::InvalidThreadId);
265 };
266
267 let to_thread = self.get_thread_mut(to_id);
268 let to_context = if let Some(thread) = to_thread {
269 &thread.context as *const _
270 } else {
271 return Err(ThreadError::InvalidThreadId);
272 };
273
274 unsafe {
275 crate::context::switch_context(from_context, to_context);
276 }
277
278 Ok(())
279 }
280}