Skip to main content

tg_kernel_context/
lib.rs

1//! 内核上下文控制。
2//!
3//! 教程阅读建议:
4//!
5//! 1. 先看 [`LocalContext`] 字段与 `user/thread` 构造;
6//! 2. 再看 `execute()`:理解 Rust 侧如何准备 CSR 和跳入裸汇编;
7//! 3. 最后看 `execute_naked()`:理解“保存调度上下文 <-> 恢复线程上下文”的对称流程。
8
9#![no_std]
10// #![deny(warnings)]
11#![deny(missing_docs)]
12
13/// 不同地址空间的上下文控制。
14#[cfg(feature = "foreign")]
15pub mod foreign;
16
17/// 线程上下文。
18#[derive(Clone)]
19#[repr(C)]
20pub struct LocalContext {
21    /// 调度上下文保存区指针(由裸汇编切换时使用)。
22    sctx: usize,
23    /// 通用寄存器 x1..x31 的镜像(x0 恒为 0,不保存)。
24    x: [usize; 31],
25    /// 返回用户/内核线程时的 PC(对应 sepc)。
26    sepc: usize,
27    /// 是否以特权态切换。
28    pub supervisor: bool,
29    /// 线程中断是否开启。
30    pub interrupt: bool,
31}
32
33impl LocalContext {
34    /// 创建空白上下文。
35    #[inline]
36    pub const fn empty() -> Self {
37        Self {
38            sctx: 0,
39            x: [0; 31],
40            supervisor: false,
41            interrupt: false,
42            sepc: 0,
43        }
44    }
45
46    /// 初始化指定入口的用户上下文。
47    ///
48    /// 切换到用户态时会打开内核中断。
49    #[inline]
50    pub const fn user(pc: usize) -> Self {
51        Self {
52            sctx: 0,
53            x: [0; 31],
54            supervisor: false,
55            interrupt: true,
56            sepc: pc,
57        }
58    }
59
60    /// 初始化指定入口的内核上下文。
61    #[inline]
62    pub const fn thread(pc: usize, interrupt: bool) -> Self {
63        Self {
64            sctx: 0,
65            x: [0; 31],
66            supervisor: true,
67            interrupt,
68            sepc: pc,
69        }
70    }
71
72    /// 读取用户通用寄存器。
73    #[inline]
74    pub fn x(&self, n: usize) -> usize {
75        self.x[n - 1]
76    }
77
78    /// 修改用户通用寄存器。
79    #[inline]
80    pub fn x_mut(&mut self, n: usize) -> &mut usize {
81        &mut self.x[n - 1]
82    }
83
84    /// 读取用户参数寄存器。
85    #[inline]
86    pub fn a(&self, n: usize) -> usize {
87        self.x(n + 10)
88    }
89
90    /// 修改用户参数寄存器。
91    #[inline]
92    pub fn a_mut(&mut self, n: usize) -> &mut usize {
93        self.x_mut(n + 10)
94    }
95
96    /// 读取用户栈指针。
97    #[inline]
98    pub fn ra(&self) -> usize {
99        self.x(1)
100    }
101
102    /// 读取用户栈指针。
103    #[inline]
104    pub fn sp(&self) -> usize {
105        self.x(2)
106    }
107
108    /// 修改用户栈指针。
109    #[inline]
110    pub fn sp_mut(&mut self) -> &mut usize {
111        self.x_mut(2)
112    }
113
114    /// 当前上下文的 pc。
115    #[inline]
116    pub fn pc(&self) -> usize {
117        self.sepc
118    }
119
120    /// 修改上下文的 pc。
121    #[inline]
122    pub fn pc_mut(&mut self) -> &mut usize {
123        &mut self.sepc
124    }
125
126    /// 将 pc 移至下一条指令。
127    ///
128    /// # Notice
129    ///
130    /// 假设这一条指令不是压缩版本。
131    #[inline]
132    pub fn move_next(&mut self) {
133        self.sepc = self.sepc.wrapping_add(4);
134    }
135
136    /// 执行此线程,并返回 `sstatus`。
137    ///
138    /// # Safety
139    ///
140    /// 将修改 `sscratch`、`sepc`、`sstatus` 和 `stvec`。
141    /// 调用者需要确保:
142    /// - 当前处于 S 模式
143    /// - `stvec` 可以被安全地修改
144    /// - 上下文中的 `sepc` 指向有效的代码地址
145    #[inline(never)]
146    pub unsafe fn execute(&mut self) -> usize {
147        #[cfg(target_arch = "riscv64")]
148        {
149            // 第一步:根据目标线程属性构造 sstatus(SPP/SPIE)。
150            let mut sstatus = build_sstatus(self.supervisor, self.interrupt);
151            // 保存 self 指针和 sepc,避免 release 模式下 csrrw 破坏寄存器后的问题
152            let ctx_ptr = self as *mut Self;
153            let mut sepc = self.sepc;
154            let old_sscratch: usize;
155            // 第二步:切换到 execute_naked,执行真正的上下文保存/恢复。
156            // SAFETY: 内联汇编执行上下文切换,调用者已确保处于 S 模式且 CSR 可被修改
157            core::arch::asm!(
158                "   csrrw {old_ss}, sscratch, {ctx}
159                    csrw  sepc    , {sepc}
160                    csrw  sstatus , {sstatus}
161                    addi  sp, sp, -8
162                    sd    ra, (sp)
163                    call  {execute_naked}
164                    ld    ra, (sp)
165                    addi  sp, sp,  8
166                    csrw  sscratch, {old_ss}
167                    csrr  {sepc}   , sepc
168                    csrr  {sstatus}, sstatus
169                ",
170                ctx           = in       (reg) ctx_ptr,
171                old_ss        = out      (reg) old_sscratch,
172                sepc          = inlateout(reg) sepc,
173                sstatus       = inlateout(reg) sstatus,
174                execute_naked = sym execute_naked,
175            );
176            let _ = old_sscratch; // suppress unused warning
177            // 第三步:取回线程返回后的 sepc(比如 trap 后已更新到下一条指令)。
178            (*ctx_ptr).sepc = sepc;
179            sstatus
180        }
181        #[cfg(not(target_arch = "riscv64"))]
182        unimplemented!("LocalContext::execute() is only supported on riscv64")
183    }
184}
185
186#[cfg(target_arch = "riscv64")]
187#[inline]
188fn build_sstatus(supervisor: bool, interrupt: bool) -> usize {
189    let mut sstatus: usize;
190    // SAFETY: 只是读取 sstatus CSR,不会产生副作用
191    unsafe { core::arch::asm!("csrr {}, sstatus", out(reg) sstatus) };
192    const PREVILEGE_BIT: usize = 1 << 8;
193    const INTERRUPT_BIT: usize = 1 << 5;
194    match supervisor {
195        false => sstatus &= !PREVILEGE_BIT,
196        true => sstatus |= PREVILEGE_BIT,
197    }
198    match interrupt {
199        false => sstatus &= !INTERRUPT_BIT,
200        true => sstatus |= INTERRUPT_BIT,
201    }
202    sstatus
203}
204
205#[cfg(not(target_arch = "riscv64"))]
206#[allow(dead_code)]
207#[inline]
208fn build_sstatus(_supervisor: bool, _interrupt: bool) -> usize {
209    unimplemented!("build_sstatus() is only supported on riscv64")
210}
211
212/// 线程切换核心部分。
213///
214/// 通用寄存器压栈,然后从预存在 `sscratch` 里的上下文指针恢复线程通用寄存器。
215///
216/// # Safety
217///
218/// 这是一个裸函数,只能由 `LocalContext::execute()` 调用。
219/// 调用前必须确保:
220/// - `sscratch` 中存放了有效的 `LocalContext` 指针
221/// - `sepc` 和 `sstatus` 已正确设置
222/// - 栈指针有效且有足够空间保存寄存器
223#[cfg(target_arch = "riscv64")]
224#[unsafe(naked)]
225unsafe extern "C" fn execute_naked() {
226    core::arch::naked_asm!(
227        r"  .altmacro
228            .macro SAVE n
229                sd x\n, \n*8(sp)
230            .endm
231            .macro SAVE_ALL
232                sd x1, 1*8(sp)
233                .set n, 3
234                .rept 29
235                    SAVE %n
236                    .set n, n+1
237                .endr
238            .endm
239
240            .macro LOAD n
241                ld x\n, \n*8(sp)
242            .endm
243            .macro LOAD_ALL
244                ld x1, 1*8(sp)
245                .set n, 3
246                .rept 29
247                    LOAD %n
248                    .set n, n+1
249                .endr
250            .endm
251        ",
252        // 位置无关加载
253        "   .option push
254            .option nopic
255        ",
256        // 保存调度上下文
257        "   addi sp, sp, -32*8
258            SAVE_ALL
259        ",
260        // 设置陷入入口
261        "   la   t0, 1f
262            csrw stvec, t0
263        ",
264        // 保存调度上下文地址并切换上下文
265        "   csrr t0, sscratch
266            sd   sp, (t0)
267            mv   sp, t0
268        ",
269        // 恢复线程上下文
270        "   LOAD_ALL
271            ld   sp, 2*8(sp)
272        ",
273        // 执行线程
274        "   sret",
275        // 陷入
276        "   .align 2",
277        // 切换上下文
278        "1: csrrw sp, sscratch, sp",
279        // 保存线程上下文
280        "   SAVE_ALL
281            csrrw t0, sscratch, sp
282            sd    t0, 2*8(sp)
283        ",
284        // 切换上下文
285        "   ld sp, (sp)",
286        // 恢复调度上下文
287        "   LOAD_ALL
288            addi sp, sp, 32*8
289        ",
290        // 返回调度
291        "   ret",
292        "   .option pop",
293    )
294}
295
296#[cfg(not(target_arch = "riscv64"))]
297#[allow(dead_code)]
298unsafe extern "C" fn execute_naked() {
299    unimplemented!("execute_naked() is only supported on riscv64")
300}