tg_kernel_context/
lib.rs

1//! 内核上下文控制。
2
3#![no_std]
4// #![deny(warnings)]
5#![deny(missing_docs)]
6
7/// 不同地址空间的上下文控制。
8#[cfg(feature = "foreign")]
9pub mod foreign;
10
11/// 线程上下文。
12#[derive(Clone)]
13#[repr(C)]
14pub struct LocalContext {
15    sctx: usize,
16    x: [usize; 31],
17    sepc: usize,
18    /// 是否以特权态切换。
19    pub supervisor: bool,
20    /// 线程中断是否开启。
21    pub interrupt: bool,
22}
23
24impl LocalContext {
25    /// 创建空白上下文。
26    #[inline]
27    pub const fn empty() -> Self {
28        Self {
29            sctx: 0,
30            x: [0; 31],
31            supervisor: false,
32            interrupt: false,
33            sepc: 0,
34        }
35    }
36
37    /// 初始化指定入口的用户上下文。
38    ///
39    /// 切换到用户态时会打开内核中断。
40    #[inline]
41    pub const fn user(pc: usize) -> Self {
42        Self {
43            sctx: 0,
44            x: [0; 31],
45            supervisor: false,
46            interrupt: true,
47            sepc: pc,
48        }
49    }
50
51    /// 初始化指定入口的内核上下文。
52    #[inline]
53    pub const fn thread(pc: usize, interrupt: bool) -> Self {
54        Self {
55            sctx: 0,
56            x: [0; 31],
57            supervisor: true,
58            interrupt,
59            sepc: pc,
60        }
61    }
62
63    /// 读取用户通用寄存器。
64    #[inline]
65    pub fn x(&self, n: usize) -> usize {
66        self.x[n - 1]
67    }
68
69    /// 修改用户通用寄存器。
70    #[inline]
71    pub fn x_mut(&mut self, n: usize) -> &mut usize {
72        &mut self.x[n - 1]
73    }
74
75    /// 读取用户参数寄存器。
76    #[inline]
77    pub fn a(&self, n: usize) -> usize {
78        self.x(n + 10)
79    }
80
81    /// 修改用户参数寄存器。
82    #[inline]
83    pub fn a_mut(&mut self, n: usize) -> &mut usize {
84        self.x_mut(n + 10)
85    }
86
87    /// 读取用户栈指针。
88    #[inline]
89    pub fn ra(&self) -> usize {
90        self.x(1)
91    }
92
93    /// 读取用户栈指针。
94    #[inline]
95    pub fn sp(&self) -> usize {
96        self.x(2)
97    }
98
99    /// 修改用户栈指针。
100    #[inline]
101    pub fn sp_mut(&mut self) -> &mut usize {
102        self.x_mut(2)
103    }
104
105    /// 当前上下文的 pc。
106    #[inline]
107    pub fn pc(&self) -> usize {
108        self.sepc
109    }
110
111    /// 修改上下文的 pc。
112    #[inline]
113    pub fn pc_mut(&mut self) -> &mut usize {
114        &mut self.sepc
115    }
116
117    /// 将 pc 移至下一条指令。
118    ///
119    /// # Notice
120    ///
121    /// 假设这一条指令不是压缩版本。
122    #[inline]
123    pub fn move_next(&mut self) {
124        self.sepc = self.sepc.wrapping_add(4);
125    }
126
127    /// 执行此线程,并返回 `sstatus`。
128    ///
129    /// # Safety
130    ///
131    /// 将修改 `sscratch`、`sepc`、`sstatus` 和 `stvec`。
132    /// 调用者需要确保:
133    /// - 当前处于 S 模式
134    /// - `stvec` 可以被安全地修改
135    /// - 上下文中的 `sepc` 指向有效的代码地址
136    #[inline(never)]
137    pub unsafe fn execute(&mut self) -> usize {
138        #[cfg(target_arch = "riscv64")]
139        {
140            let mut sstatus = build_sstatus(self.supervisor, self.interrupt);
141            // 保存 self 指针和 sepc,避免 release 模式下 csrrw 破坏寄存器后的问题
142            let ctx_ptr = self as *mut Self;
143            let mut sepc = self.sepc;
144            let old_sscratch: usize;
145            // SAFETY: 内联汇编执行上下文切换,调用者已确保处于 S 模式且 CSR 可被修改
146            core::arch::asm!(
147                "   csrrw {old_ss}, sscratch, {ctx}
148                    csrw  sepc    , {sepc}
149                    csrw  sstatus , {sstatus}
150                    addi  sp, sp, -8
151                    sd    ra, (sp)
152                    call  {execute_naked}
153                    ld    ra, (sp)
154                    addi  sp, sp,  8
155                    csrw  sscratch, {old_ss}
156                    csrr  {sepc}   , sepc
157                    csrr  {sstatus}, sstatus
158                ",
159                ctx           = in       (reg) ctx_ptr,
160                old_ss        = out      (reg) old_sscratch,
161                sepc          = inlateout(reg) sepc,
162                sstatus       = inlateout(reg) sstatus,
163                execute_naked = sym execute_naked,
164            );
165            let _ = old_sscratch; // suppress unused warning
166            (*ctx_ptr).sepc = sepc;
167            sstatus
168        }
169        #[cfg(not(target_arch = "riscv64"))]
170        unimplemented!("LocalContext::execute() is only supported on riscv64")
171    }
172}
173
174#[cfg(target_arch = "riscv64")]
175#[inline]
176fn build_sstatus(supervisor: bool, interrupt: bool) -> usize {
177    let mut sstatus: usize;
178    // SAFETY: 只是读取 sstatus CSR,不会产生副作用
179    unsafe { core::arch::asm!("csrr {}, sstatus", out(reg) sstatus) };
180    const PREVILEGE_BIT: usize = 1 << 8;
181    const INTERRUPT_BIT: usize = 1 << 5;
182    match supervisor {
183        false => sstatus &= !PREVILEGE_BIT,
184        true => sstatus |= PREVILEGE_BIT,
185    }
186    match interrupt {
187        false => sstatus &= !INTERRUPT_BIT,
188        true => sstatus |= INTERRUPT_BIT,
189    }
190    sstatus
191}
192
193#[cfg(not(target_arch = "riscv64"))]
194#[allow(dead_code)]
195#[inline]
196fn build_sstatus(_supervisor: bool, _interrupt: bool) -> usize {
197    unimplemented!("build_sstatus() is only supported on riscv64")
198}
199
200/// 线程切换核心部分。
201///
202/// 通用寄存器压栈,然后从预存在 `sscratch` 里的上下文指针恢复线程通用寄存器。
203///
204/// # Safety
205///
206/// 这是一个裸函数,只能由 `LocalContext::execute()` 调用。
207/// 调用前必须确保:
208/// - `sscratch` 中存放了有效的 `LocalContext` 指针
209/// - `sepc` 和 `sstatus` 已正确设置
210/// - 栈指针有效且有足够空间保存寄存器
211#[cfg(target_arch = "riscv64")]
212#[unsafe(naked)]
213unsafe extern "C" fn execute_naked() {
214    core::arch::naked_asm!(
215        r"  .altmacro
216            .macro SAVE n
217                sd x\n, \n*8(sp)
218            .endm
219            .macro SAVE_ALL
220                sd x1, 1*8(sp)
221                .set n, 3
222                .rept 29
223                    SAVE %n
224                    .set n, n+1
225                .endr
226            .endm
227
228            .macro LOAD n
229                ld x\n, \n*8(sp)
230            .endm
231            .macro LOAD_ALL
232                ld x1, 1*8(sp)
233                .set n, 3
234                .rept 29
235                    LOAD %n
236                    .set n, n+1
237                .endr
238            .endm
239        ",
240        // 位置无关加载
241        "   .option push
242            .option nopic
243        ",
244        // 保存调度上下文
245        "   addi sp, sp, -32*8
246            SAVE_ALL
247        ",
248        // 设置陷入入口
249        "   la   t0, 1f
250            csrw stvec, t0
251        ",
252        // 保存调度上下文地址并切换上下文
253        "   csrr t0, sscratch
254            sd   sp, (t0)
255            mv   sp, t0
256        ",
257        // 恢复线程上下文
258        "   LOAD_ALL
259            ld   sp, 2*8(sp)
260        ",
261        // 执行线程
262        "   sret",
263        // 陷入
264        "   .align 2",
265        // 切换上下文
266        "1: csrrw sp, sscratch, sp",
267        // 保存线程上下文
268        "   SAVE_ALL
269            csrrw t0, sscratch, sp
270            sd    t0, 2*8(sp)
271        ",
272        // 切换上下文
273        "   ld sp, (sp)",
274        // 恢复调度上下文
275        "   LOAD_ALL
276            addi sp, sp, 32*8
277        ",
278        // 返回调度
279        "   ret",
280        "   .option pop",
281    )
282}
283
284#[cfg(not(target_arch = "riscv64"))]
285#[allow(dead_code)]
286unsafe extern "C" fn execute_naked() {
287    unimplemented!("execute_naked() is only supported on riscv64")
288}