Skip to main content

hiver_runtime/scheduler/
queue.rs

1//! Local task queue for thread-per-core scheduler
2//! thread-per-core调度器的本地任务队列
3
4use std::cell::UnsafeCell;
5use std::mem::MaybeUninit;
6use std::sync::atomic::{AtomicUsize, Ordering};
7
8use super::RawTask;
9
10/// Local queue for thread-per-core scheduler
11/// thread-per-core调度器的本地队列
12///
13/// Uses a bounded ring buffer optimized for single consumer (the scheduler thread)
14/// with support for external producers (work stealing injectors).
15/// Uses interior mutability for thread-safe operations.
16///
17/// 使用为单个消费者(调度器线程)优化的有界环形缓冲区,
18/// 支持外部生产者(工作窃取注入器)。
19/// 使用内部可变性实现线程安全操作。
20pub struct LocalQueue {
21    /// Ring buffer for task pointers / 任务指针的环形缓冲区
22    buffer: Box<[UnsafeCell<MaybeUninit<RawTask>>]>,
23    /// Queue capacity (must be power of 2) / 队列容量(必须是2的幂)
24    capacity: usize,
25    /// Capacity mask for fast modulo / 快速取模的容量掩码
26    mask: usize,
27    /// Head index (consumer) / 头索引(消费者)
28    head: AtomicUsize,
29    /// Tail index (producer) / 尾索引(生产者)
30    tail: AtomicUsize,
31}
32
33// Safety: The queue uses atomic operations for thread safety
34// and UnsafeCell for interior mutability
35// 队列使用原子操作和UnsafeCell实现线程安全
36unsafe impl Send for LocalQueue {}
37unsafe impl Sync for LocalQueue {}
38
39impl LocalQueue {
40    /// Create a new local queue with the specified capacity
41    /// 创建具有指定容量的新本地队列
42    ///
43    /// The capacity will be rounded up to the next power of 2.
44    /// 容量将向上舍入到下一个2的幂。
45    #[must_use]
46    pub fn new(capacity: usize) -> Self {
47        let capacity = capacity.next_power_of_two().max(2);
48        let mask = capacity - 1;
49
50        // Initialize buffer with MaybeUninit (more efficient than Vec<Option>)
51        // 使用MaybeUninit初始化缓冲区(比Vec<Option>更高效)
52        let buffer = (0..capacity)
53            .map(|_| UnsafeCell::new(MaybeUninit::uninit()))
54            .collect();
55
56        Self {
57            buffer,
58            capacity,
59            mask,
60            head: AtomicUsize::new(0),
61            tail: AtomicUsize::new(0),
62        }
63    }
64
65    /// Push a task to the back of the queue
66    /// 将任务推入队列后部
67    ///
68    /// Returns `true` if successful, `false` if the queue is full.
69    /// 成功返回 `true`,队列已满返回 `false`。
70    #[inline]
71    pub fn push(&self, task: RawTask) -> bool {
72        loop {
73            let tail = self.tail.load(Ordering::Relaxed);
74            let head = self.head.load(Ordering::Acquire);
75
76            // Check if queue is full
77            // 检查队列是否已满
78            if tail - head >= self.capacity {
79                return false;
80            }
81
82            let pos = tail & self.mask;
83            // SAFETY: pos is within bounds and we have exclusive access to this slot
84            // 通过环形缓冲区规则对此位置拥有独占访问权
85            unsafe {
86                self.buffer[pos].get().write(MaybeUninit::new(task));
87            }
88
89            if self
90                .tail
91                .compare_exchange(tail, tail + 1, Ordering::AcqRel, Ordering::Relaxed)
92                .is_ok()
93            {
94                return true;
95            }
96        }
97    }
98
99    /// Pop a task from the front of the queue
100    /// 从队列前部弹出一个任务
101    ///
102    /// Returns `Some(task)` if available, `None` if the queue is empty.
103    /// 有可用任务时返回 `Some(task)`,队列为空时返回 `None`。
104    #[inline]
105    pub fn pop(&self) -> Option<RawTask> {
106        loop {
107            let head = self.head.load(Ordering::Relaxed);
108            let tail = self.tail.load(Ordering::Acquire);
109
110            if head == tail {
111                return None;
112            }
113
114            let pos = head & self.mask;
115            // SAFETY: pos is within bounds and we have exclusive access to this slot
116            // The value was initialized by push, so assume_init is safe
117            // 通过环形缓冲区规则对此位置拥有独占访问权
118            // 该值由push初始化,因此assume_init是安全的
119            let task = unsafe { self.buffer[pos].get().read().assume_init() };
120
121            if let Ok(_) =
122                self.head
123                    .compare_exchange(head, head + 1, Ordering::AcqRel, Ordering::Relaxed)
124            {
125                return Some(task);
126            }
127            // Put the task back — another thread claimed this slot
128            unsafe {
129                self.buffer[pos].get().write(MaybeUninit::new(task));
130            }
131        }
132    }
133
134    /// Get the current length of the queue
135    /// 获取队列当前长度
136    #[inline]
137    #[must_use]
138    pub fn len(&self) -> usize {
139        let tail = self.tail.load(Ordering::Relaxed);
140        let head = self.head.load(Ordering::Relaxed);
141        tail.saturating_sub(head)
142    }
143
144    /// Check if the queue is empty
145    /// 检查队列是否为空
146    #[inline]
147    #[must_use]
148    pub fn is_empty(&self) -> bool {
149        self.len() == 0
150    }
151
152    /// Get the queue capacity
153    /// 获取队列容量
154    #[inline]
155    #[must_use]
156    pub const fn capacity(&self) -> usize {
157        self.capacity
158    }
159
160    /// Steal half of the tasks from this queue
161    /// 从此队列窃取一半任务
162    ///
163    /// Used for work stealing between scheduler threads.
164    /// 用于调度器线程间的工作窃取。
165    ///
166    /// Returns the number of tasks stolen.
167    /// 返回被窃取的任务数量。
168    pub fn steal(&self, dest: &LocalQueue) -> usize {
169        let head = self.head.load(Ordering::Relaxed);
170        let tail = self.tail.load(Ordering::Acquire);
171
172        let len = tail.saturating_sub(head);
173        if len == 0 {
174            return 0;
175        }
176
177        // Steal half (rounding down)
178        // 窃取一半(向下取整)
179        let steal_count = len / 2;
180        if steal_count == 0 {
181            return 0;
182        }
183
184        let mut stolen = 0;
185        for i in 0..steal_count {
186            let pos = (head + i) & self.mask;
187            // SAFETY: pos is within bounds and value was initialized by push
188            // SAFETY: pos在范围内且值由push初始化
189            let task = unsafe { self.buffer[pos].get().read().assume_init() };
190
191            if dest.push(task) {
192                stolen += 1;
193                // Update head to reflect stolen tasks
194                // 更新head以反映被窃取的任务
195                self.head.store(head + i + 1, Ordering::Release);
196            } else {
197                // Destination full, put back remaining
198                // 目标已满,放回剩余任务
199                break;
200            }
201        }
202
203        stolen
204    }
205}
206
207#[cfg(test)]
208mod tests {
209    use super::*;
210
211    #[test]
212    fn test_queue_push_pop() {
213        let queue = LocalQueue::new(16);
214
215        let task1 = 0x1000 as RawTask;
216        let task2 = 0x2000 as RawTask;
217
218        assert!(queue.push(task1));
219        assert!(queue.push(task2));
220
221        assert_eq!(queue.pop(), Some(task1));
222        assert_eq!(queue.pop(), Some(task2));
223        assert_eq!(queue.pop(), None);
224    }
225
226    #[test]
227    fn test_queue_empty_full() {
228        let queue = LocalQueue::new(4);
229
230        assert!(queue.is_empty());
231        assert_eq!(queue.len(), 0);
232
233        // Fill the queue
234        // 填满队列
235        for i in 0..4 {
236            assert!(queue.push(i as RawTask));
237        }
238
239        // Queue should be full now
240        // 队列现在应该满了
241        assert!(!queue.push(99 as RawTask));
242        assert_eq!(queue.len(), 4);
243
244        // Empty the queue
245        // 清空队列
246        for i in 0..4 {
247            assert_eq!(queue.pop(), Some(i as RawTask));
248        }
249
250        assert!(queue.is_empty());
251    }
252
253    #[test]
254    fn test_queue_wrap_around() {
255        let queue = LocalQueue::new(4);
256
257        // Fill and drain multiple times to test wrap-around
258        // 多次填充和排空以测试包装
259        for round in 0..3 {
260            for i in 0..4 {
261                assert!(queue.push((round * 4 + i) as RawTask));
262            }
263
264            for i in 0..4 {
265                assert_eq!(queue.pop(), Some((round * 4 + i) as RawTask));
266            }
267        }
268    }
269
270    #[test]
271    fn test_queue_capacity_power_of_two() {
272        // Capacity should be rounded to next power of 2
273        // 容量应向上舍入到下一个2的幂
274        let q = LocalQueue::new(5);
275        assert_eq!(q.capacity(), 8);
276
277        let q = LocalQueue::new(100);
278        assert_eq!(q.capacity(), 128);
279    }
280}