preemptive_threads/
scheduler.rs

1use crate::thread::{Thread, ThreadId, ThreadState};
2use crate::error::{ThreadError, ThreadResult};
3use core::cell::UnsafeCell;
4
5const MAX_THREADS: usize = 32;
6
7pub struct Scheduler {
8    threads: [Option<Thread>; MAX_THREADS],
9    current_thread: Option<ThreadId>,
10    next_thread_id: ThreadId,
11    run_queue: [Option<ThreadId>; MAX_THREADS],
12    run_queue_head: usize,
13    run_queue_tail: usize,
14}
15
16pub struct SchedulerCell(UnsafeCell<Scheduler>);
17
18unsafe impl Sync for SchedulerCell {}
19
20impl Default for SchedulerCell {
21    fn default() -> Self {
22        Self::new()
23    }
24}
25
26impl SchedulerCell {
27    pub const fn new() -> Self {
28        SchedulerCell(UnsafeCell::new(Scheduler::new()))
29    }
30
31    /// # Safety
32    /// Returns mutable reference to scheduler. Caller must ensure thread safety.
33    #[allow(clippy::mut_from_ref)]
34    pub unsafe fn get(&self) -> &mut Scheduler {
35        &mut *self.0.get()
36    }
37}
38
39pub static SCHEDULER: SchedulerCell = SchedulerCell::new();
40
41impl Default for Scheduler {
42    fn default() -> Self {
43        Self::new()
44    }
45}
46
47impl Scheduler {
48    pub const fn new() -> Self {
49        Scheduler {
50            threads: [const { None }; MAX_THREADS],
51            current_thread: None,
52            next_thread_id: 0,
53            run_queue: [None; MAX_THREADS],
54            run_queue_head: 0,
55            run_queue_tail: 0,
56        }
57    }
58
59    pub fn spawn_thread(
60        &mut self,
61        stack: &'static mut [u8],
62        entry_point: fn(),
63        priority: u8,
64    ) -> ThreadResult<ThreadId> {
65        let thread_id = self.next_thread_id;
66        
67        if thread_id >= MAX_THREADS {
68            return Err(ThreadError::MaxThreadsReached);
69        }
70
71        let thread = Thread::new(thread_id, stack, entry_point, priority);
72        self.threads[thread_id] = Some(thread);
73        self.next_thread_id += 1;
74
75        self.enqueue_thread(thread_id)?;
76        Ok(thread_id)
77    }
78
79    fn enqueue_thread(&mut self, thread_id: ThreadId) -> ThreadResult<()> {
80        let next_tail = (self.run_queue_tail + 1) % MAX_THREADS;
81        if next_tail == self.run_queue_head {
82            return Err(ThreadError::SchedulerFull);
83        }
84
85        self.run_queue[self.run_queue_tail] = Some(thread_id);
86        self.run_queue_tail = next_tail;
87        Ok(())
88    }
89
90
91    pub fn schedule(&mut self) -> Option<ThreadId> {
92        if let Some(current) = self.current_thread {
93            if let Some(thread) = &self.threads[current] {
94                if thread.state == ThreadState::Running {
95                    let _ = self.enqueue_thread(current);
96                }
97            }
98        }
99
100        self.schedule_with_priority()
101    }
102
103    fn schedule_with_priority(&mut self) -> Option<ThreadId> {
104        let mut best_thread = None;
105        let mut highest_priority = 0u8;
106        let mut best_index = None;
107
108        for i in 0..MAX_THREADS {
109            let queue_index = (self.run_queue_head + i) % MAX_THREADS;
110            if queue_index == self.run_queue_tail {
111                break;
112            }
113
114            if let Some(thread_id) = self.run_queue[queue_index] {
115                if let Some(thread) = &self.threads[thread_id] {
116                    if thread.is_runnable() && thread.priority >= highest_priority {
117                        highest_priority = thread.priority;
118                        best_thread = Some(thread_id);
119                        best_index = Some(queue_index);
120                    }
121                }
122            }
123        }
124
125        if let (Some(thread_id), Some(index)) = (best_thread, best_index) {
126            self.run_queue[index] = None;
127            
128            let mut read_pos = (index + 1) % MAX_THREADS;
129            let mut write_pos = index;
130            
131            while read_pos != self.run_queue_tail {
132                self.run_queue[write_pos] = self.run_queue[read_pos];
133                write_pos = (write_pos + 1) % MAX_THREADS;
134                read_pos = (read_pos + 1) % MAX_THREADS;
135            }
136            
137            self.run_queue_tail = write_pos;
138            
139            return Some(thread_id);
140        }
141
142        None
143    }
144
145    pub fn get_current_thread(&self) -> Option<ThreadId> {
146        self.current_thread
147    }
148
149    pub fn set_current_thread(&mut self, thread_id: Option<ThreadId>) {
150        if let Some(old_id) = self.current_thread {
151            if let Some(thread) = &mut self.threads[old_id] {
152                if thread.state == ThreadState::Running {
153                    thread.state = ThreadState::Ready;
154                }
155            }
156        }
157
158        self.current_thread = thread_id;
159
160        if let Some(new_id) = thread_id {
161            if let Some(thread) = &mut self.threads[new_id] {
162                thread.state = ThreadState::Running;
163            }
164        }
165    }
166
167    pub fn exit_current_thread(&mut self) {
168        if let Some(current) = self.current_thread {
169            let mut waiters_to_wake = [None; 4];
170            
171            if let Some(thread) = &mut self.threads[current] {
172                thread.state = ThreadState::Finished;
173                waiters_to_wake = thread.join_waiters;
174            }
175            
176            // Wake up any threads waiting to join this one
177            for waiter in waiters_to_wake.iter().flatten() {
178                if let Some(waiter_thread) = &mut self.threads[*waiter] {
179                    if waiter_thread.state == ThreadState::Blocked {
180                        waiter_thread.state = ThreadState::Ready;
181                        let _ = self.enqueue_thread(*waiter);
182                    }
183                }
184            }
185        }
186    }
187
188    pub fn join_thread(&mut self, target_id: ThreadId, current_id: ThreadId) -> ThreadResult<()> {
189        if target_id >= MAX_THREADS {
190            return Err(ThreadError::InvalidThreadId);
191        }
192
193        if let Some(target_thread) = &mut self.threads[target_id] {
194            if target_thread.state == ThreadState::Finished {
195                return Ok(()); // Already finished
196            }
197
198            // Add current thread to join waiters
199            for slot in &mut target_thread.join_waiters {
200                if slot.is_none() {
201                    *slot = Some(current_id);
202                    
203                    // Block current thread
204                    if let Some(current_thread) = &mut self.threads[current_id] {
205                        current_thread.state = ThreadState::Blocked;
206                    }
207                    
208                    return Ok(());
209                }
210            }
211            
212            Err(ThreadError::SchedulerFull)
213        } else {
214            Err(ThreadError::InvalidThreadId)
215        }
216    }
217
218    pub fn get_thread(&self, thread_id: ThreadId) -> Option<&Thread> {
219        self.threads[thread_id].as_ref()
220    }
221
222    pub fn get_thread_mut(&mut self, thread_id: ThreadId) -> Option<&mut Thread> {
223        self.threads[thread_id].as_mut()
224    }
225
226    pub fn switch_context(&mut self, from_id: ThreadId, to_id: ThreadId) -> ThreadResult<()> {
227        if let Some(from_thread) = self.get_thread(from_id) {
228            if from_thread.check_stack_overflow() {
229                return Err(ThreadError::StackOverflow);
230            }
231        }
232
233        let from_thread = self.get_thread_mut(from_id);
234        let from_context = if let Some(thread) = from_thread {
235            &mut thread.context as *mut _
236        } else {
237            return Err(ThreadError::InvalidThreadId);
238        };
239
240        let to_thread = self.get_thread_mut(to_id);
241        let to_context = if let Some(thread) = to_thread {
242            &thread.context as *const _
243        } else {
244            return Err(ThreadError::InvalidThreadId);
245        };
246
247        unsafe {
248            crate::context::switch_context(from_context, to_context);
249        }
250        
251        Ok(())
252    }
253}