preemptive_threads/
scheduler.rs1use crate::thread::{Thread, ThreadId, ThreadState};
2use crate::error::{ThreadError, ThreadResult};
3use core::cell::UnsafeCell;
4
5const MAX_THREADS: usize = 32;
6
7pub struct Scheduler {
8 threads: [Option<Thread>; MAX_THREADS],
9 current_thread: Option<ThreadId>,
10 next_thread_id: ThreadId,
11 run_queue: [Option<ThreadId>; MAX_THREADS],
12 run_queue_head: usize,
13 run_queue_tail: usize,
14}
15
16pub struct SchedulerCell(UnsafeCell<Scheduler>);
17
18unsafe impl Sync for SchedulerCell {}
19
20impl Default for SchedulerCell {
21 fn default() -> Self {
22 Self::new()
23 }
24}
25
26impl SchedulerCell {
27 pub const fn new() -> Self {
28 SchedulerCell(UnsafeCell::new(Scheduler::new()))
29 }
30
31 #[allow(clippy::mut_from_ref)]
34 pub unsafe fn get(&self) -> &mut Scheduler {
35 &mut *self.0.get()
36 }
37}
38
39pub static SCHEDULER: SchedulerCell = SchedulerCell::new();
40
41impl Default for Scheduler {
42 fn default() -> Self {
43 Self::new()
44 }
45}
46
47impl Scheduler {
48 pub const fn new() -> Self {
49 Scheduler {
50 threads: [const { None }; MAX_THREADS],
51 current_thread: None,
52 next_thread_id: 0,
53 run_queue: [None; MAX_THREADS],
54 run_queue_head: 0,
55 run_queue_tail: 0,
56 }
57 }
58
59 pub fn spawn_thread(
60 &mut self,
61 stack: &'static mut [u8],
62 entry_point: fn(),
63 priority: u8,
64 ) -> ThreadResult<ThreadId> {
65 let thread_id = self.next_thread_id;
66
67 if thread_id >= MAX_THREADS {
68 return Err(ThreadError::MaxThreadsReached);
69 }
70
71 let thread = Thread::new(thread_id, stack, entry_point, priority);
72 self.threads[thread_id] = Some(thread);
73 self.next_thread_id += 1;
74
75 self.enqueue_thread(thread_id)?;
76 Ok(thread_id)
77 }
78
79 fn enqueue_thread(&mut self, thread_id: ThreadId) -> ThreadResult<()> {
80 let next_tail = (self.run_queue_tail + 1) % MAX_THREADS;
81 if next_tail == self.run_queue_head {
82 return Err(ThreadError::SchedulerFull);
83 }
84
85 self.run_queue[self.run_queue_tail] = Some(thread_id);
86 self.run_queue_tail = next_tail;
87 Ok(())
88 }
89
90
91 pub fn schedule(&mut self) -> Option<ThreadId> {
92 if let Some(current) = self.current_thread {
93 if let Some(thread) = &self.threads[current] {
94 if thread.state == ThreadState::Running {
95 let _ = self.enqueue_thread(current);
96 }
97 }
98 }
99
100 self.schedule_with_priority()
101 }
102
103 fn schedule_with_priority(&mut self) -> Option<ThreadId> {
104 let mut best_thread = None;
105 let mut highest_priority = 0u8;
106 let mut best_index = None;
107
108 for i in 0..MAX_THREADS {
109 let queue_index = (self.run_queue_head + i) % MAX_THREADS;
110 if queue_index == self.run_queue_tail {
111 break;
112 }
113
114 if let Some(thread_id) = self.run_queue[queue_index] {
115 if let Some(thread) = &self.threads[thread_id] {
116 if thread.is_runnable() && thread.priority >= highest_priority {
117 highest_priority = thread.priority;
118 best_thread = Some(thread_id);
119 best_index = Some(queue_index);
120 }
121 }
122 }
123 }
124
125 if let (Some(thread_id), Some(index)) = (best_thread, best_index) {
126 self.run_queue[index] = None;
127
128 let mut read_pos = (index + 1) % MAX_THREADS;
129 let mut write_pos = index;
130
131 while read_pos != self.run_queue_tail {
132 self.run_queue[write_pos] = self.run_queue[read_pos];
133 write_pos = (write_pos + 1) % MAX_THREADS;
134 read_pos = (read_pos + 1) % MAX_THREADS;
135 }
136
137 self.run_queue_tail = write_pos;
138
139 return Some(thread_id);
140 }
141
142 None
143 }
144
145 pub fn get_current_thread(&self) -> Option<ThreadId> {
146 self.current_thread
147 }
148
149 pub fn set_current_thread(&mut self, thread_id: Option<ThreadId>) {
150 if let Some(old_id) = self.current_thread {
151 if let Some(thread) = &mut self.threads[old_id] {
152 if thread.state == ThreadState::Running {
153 thread.state = ThreadState::Ready;
154 }
155 }
156 }
157
158 self.current_thread = thread_id;
159
160 if let Some(new_id) = thread_id {
161 if let Some(thread) = &mut self.threads[new_id] {
162 thread.state = ThreadState::Running;
163 }
164 }
165 }
166
167 pub fn exit_current_thread(&mut self) {
168 if let Some(current) = self.current_thread {
169 let mut waiters_to_wake = [None; 4];
170
171 if let Some(thread) = &mut self.threads[current] {
172 thread.state = ThreadState::Finished;
173 waiters_to_wake = thread.join_waiters;
174 }
175
176 for waiter in waiters_to_wake.iter().flatten() {
178 if let Some(waiter_thread) = &mut self.threads[*waiter] {
179 if waiter_thread.state == ThreadState::Blocked {
180 waiter_thread.state = ThreadState::Ready;
181 let _ = self.enqueue_thread(*waiter);
182 }
183 }
184 }
185 }
186 }
187
188 pub fn join_thread(&mut self, target_id: ThreadId, current_id: ThreadId) -> ThreadResult<()> {
189 if target_id >= MAX_THREADS {
190 return Err(ThreadError::InvalidThreadId);
191 }
192
193 if let Some(target_thread) = &mut self.threads[target_id] {
194 if target_thread.state == ThreadState::Finished {
195 return Ok(()); }
197
198 for slot in &mut target_thread.join_waiters {
200 if slot.is_none() {
201 *slot = Some(current_id);
202
203 if let Some(current_thread) = &mut self.threads[current_id] {
205 current_thread.state = ThreadState::Blocked;
206 }
207
208 return Ok(());
209 }
210 }
211
212 Err(ThreadError::SchedulerFull)
213 } else {
214 Err(ThreadError::InvalidThreadId)
215 }
216 }
217
218 pub fn get_thread(&self, thread_id: ThreadId) -> Option<&Thread> {
219 self.threads[thread_id].as_ref()
220 }
221
222 pub fn get_thread_mut(&mut self, thread_id: ThreadId) -> Option<&mut Thread> {
223 self.threads[thread_id].as_mut()
224 }
225
226 pub fn switch_context(&mut self, from_id: ThreadId, to_id: ThreadId) -> ThreadResult<()> {
227 if let Some(from_thread) = self.get_thread(from_id) {
228 if from_thread.check_stack_overflow() {
229 return Err(ThreadError::StackOverflow);
230 }
231 }
232
233 let from_thread = self.get_thread_mut(from_id);
234 let from_context = if let Some(thread) = from_thread {
235 &mut thread.context as *mut _
236 } else {
237 return Err(ThreadError::InvalidThreadId);
238 };
239
240 let to_thread = self.get_thread_mut(to_id);
241 let to_context = if let Some(thread) = to_thread {
242 &thread.context as *const _
243 } else {
244 return Err(ThreadError::InvalidThreadId);
245 };
246
247 unsafe {
248 crate::context::switch_context(from_context, to_context);
249 }
250
251 Ok(())
252 }
253}