Skip to main content

kmod_loader/arch/
mod.rs

1#![allow(unused)]
2
3cfg_if::cfg_if! {
4    if #[cfg(target_arch = "aarch64")] {
5        mod aarch64;
6        pub use aarch64::*;
7    } else if #[cfg(target_arch = "loongarch64")] {
8        mod loongarch64;
9        pub use loongarch64::*;
10    } else if #[cfg(target_arch = "riscv64")] {
11        mod riscv64;
12        pub use riscv64::*;
13    } else if #[cfg(target_arch = "x86_64")] {
14        mod x86_64;
15        pub use x86_64::*;
16    } else {
17        compile_error!("Unsupported architecture");
18    }
19}
20
21const SZ_128M: u64 = 0x08000000;
22const SZ_512K: u64 = 0x00080000;
23const SZ_128K: u64 = 0x00020000;
24const SZ_2K: u64 = 0x00000800;
25
26/**
27 * sign_extend64 - sign extend a 64-bit value using specified bit as sign-bit
28 * @value: value to sign extend
29 * @index: 0 based bit index (0<=index<64) to sign bit
30 */
31pub const fn sign_extend64(value: u64, index: u32) -> i64 {
32    let shift = 63 - index;
33    ((value << shift) as i64) >> shift
34}
35
36/// Extracts the relocation type from the r_info field of an Elf64_Rela
37const fn get_rela_type(r_info: u64) -> u32 {
38    (r_info & 0xffffffff) as u32
39}
40
41/// Extracts the symbol index from the r_info field of an Elf64_Rela
42const fn get_rela_sym_idx(r_info: u64) -> usize {
43    (r_info >> 32) as usize
44}
45
46#[derive(Debug, Clone, Copy)]
47struct Ptr(u64);
48impl Ptr {
49    fn as_ptr<T>(&self) -> *mut T {
50        self.0 as *mut T
51    }
52
53    /// Writes a value of type T to the pointer location
54    pub fn write<T>(&self, value: T) {
55        unsafe {
56            let ptr = self.as_ptr::<T>();
57            ptr.write(value);
58        }
59    }
60
61    pub fn read<T>(&self) -> T {
62        unsafe {
63            let ptr = self.as_ptr::<T>();
64            ptr.read()
65        }
66    }
67
68    pub fn add(&self, offset: usize) -> Ptr {
69        Ptr(self.0 + offset as u64)
70    }
71
72    pub fn as_slice<T>(&self, len: usize) -> &[T] {
73        unsafe {
74            let ptr = self.as_ptr::<T>();
75            core::slice::from_raw_parts(ptr, len)
76        }
77    }
78}
79
80#[macro_export]
81macro_rules! BIT {
82    ($nr:expr) => {
83        (1u32 << $nr)
84    };
85}
86
87#[macro_export]
88macro_rules! BIT_U64 {
89    ($nr:expr) => {
90        (1u64 << $nr)
91    };
92}
93
94#[cfg(any(target_arch = "loongarch64", target_arch = "riscv64"))]
95pub use common::*;
96
97#[cfg(any(target_arch = "loongarch64", target_arch = "riscv64"))]
98mod common {
99    use goblin::elf::{Elf, Reloc, RelocSection, SectionHeaders};
100
101    use crate::{KernelModuleHelper, ModuleErr, ModuleOwner, Result, arch::PltEntry};
102    #[derive(Debug, Clone, Copy, Default)]
103    #[repr(C)]
104    pub struct ModuleArchSpecific {
105        got: ModSection,
106        plt: ModSection,
107        plt_idx: ModSection,
108    }
109
110    #[derive(Debug, Clone, Copy, Default)]
111    #[repr(C)]
112    pub struct ModSection {
113        shndx: usize,
114        num_entries: usize,
115        max_entries: usize,
116    }
117
118    #[derive(Debug, Clone, Copy)]
119    #[repr(C)]
120    pub struct GotEntry {
121        symbol_addr: u64,
122    }
123
124    #[derive(Debug, Clone, Copy)]
125    #[repr(C)]
126    pub struct PltIdxEntry {
127        symbol_addr: u64,
128    }
129
130    pub fn duplicate_rela(rela_sec: &RelocSection, idx: usize) -> bool {
131        let rela_now = rela_sec.get(idx).expect("Invalid relocation index");
132        for i in 0..idx {
133            let rela_prev = rela_sec.get(i).expect("Invalid relocation index");
134            if is_rela_equal(&rela_now, &rela_prev) {
135                return true;
136            }
137        }
138        false
139    }
140
141    fn is_rela_equal(rela1: &Reloc, rela2: &Reloc) -> bool {
142        rela1.r_addend == rela2.r_addend
143            && rela1.r_type == rela2.r_type
144            && rela1.r_sym == rela2.r_sym
145    }
146
147    fn get_got_entry(
148        address: u64,
149        sechdrs: &SectionHeaders,
150        sec: &ModSection,
151    ) -> Option<&'static mut GotEntry> {
152        let got_entries_addr = sechdrs[sec.shndx].sh_addr;
153        let got_entries = unsafe {
154            core::slice::from_raw_parts_mut(
155                got_entries_addr as *mut GotEntry,
156                sec.max_entries as usize,
157            )
158        };
159
160        got_entries[0..sec.num_entries as usize]
161            .iter_mut()
162            .find(|entry| entry.symbol_addr == address)
163    }
164
165    fn get_plt_idx(address: u64, sechdrs: &SectionHeaders, sec: &ModSection) -> Option<usize> {
166        let plt_idx_addr = sechdrs[sec.shndx].sh_addr;
167        let plt_idx_entries = unsafe {
168            core::slice::from_raw_parts_mut(
169                plt_idx_addr as *mut PltIdxEntry,
170                sec.max_entries as usize,
171            )
172        };
173        plt_idx_entries[0..sec.num_entries as usize]
174            .iter()
175            .position(|entry| entry.symbol_addr == address)
176    }
177
178    fn get_plt_entry(
179        address: u64,
180        sechdrs: &SectionHeaders,
181        plt_sec: &ModSection,
182        plt_idx_sec: &ModSection,
183    ) -> Option<&'static mut PltEntry> {
184        let plt_idx = get_plt_idx(address, sechdrs, plt_idx_sec);
185        if plt_idx.is_none() {
186            return None;
187        }
188        let plt_idx = plt_idx.unwrap();
189
190        let plt_entries_addr = sechdrs[plt_sec.shndx].sh_addr;
191        let plt_entries = unsafe {
192            core::slice::from_raw_parts_mut(
193                plt_entries_addr as *mut PltEntry,
194                plt_sec.max_entries as usize,
195            )
196        };
197        Some(&mut plt_entries[plt_idx])
198    }
199
200    fn emit_got_entry(address: u64) -> GotEntry {
201        GotEntry {
202            symbol_addr: address,
203        }
204    }
205
206    fn emit_plt_idx_entry(address: u64) -> PltIdxEntry {
207        PltIdxEntry {
208            symbol_addr: address,
209        }
210    }
211
212    pub fn common_module_emit_got_entry(
213        module: &mut ModuleOwner<impl KernelModuleHelper>,
214        sechdrs: &SectionHeaders,
215        address: u64,
216    ) -> Option<&'static mut GotEntry> {
217        let got_sec = &mut module.arch.got;
218        let idx = got_sec.num_entries;
219        let got = get_got_entry(address, sechdrs, got_sec);
220        if got.is_some() {
221            return got;
222        }
223        // There is no GOT entry for val yet, create a new one.
224        let got_entries_addr = sechdrs[got_sec.shndx].sh_addr;
225        let got_entries = unsafe {
226            core::slice::from_raw_parts_mut(
227                got_entries_addr as *mut GotEntry,
228                got_sec.max_entries as usize,
229            )
230        };
231        got_entries[idx as usize] = emit_got_entry(address);
232        got_sec.num_entries += 1;
233        if got_sec.num_entries > got_sec.max_entries {
234            panic!("{}: GOT entries exceed the maximum limit", module.name());
235        }
236        return Some(&mut got_entries[idx as usize]);
237    }
238
239    type ArchEmitPltEntryFunc =
240        fn(address: u64, plt_entry_addr: u64, plt_idx_entry_addr: u64) -> PltEntry;
241
242    pub fn common_module_emit_plt_entry(
243        module: &mut ModuleOwner<impl KernelModuleHelper>,
244        sechdrs: &SectionHeaders,
245        address: u64,
246        arch_emit_plt_entry_func: ArchEmitPltEntryFunc,
247    ) -> Option<&'static mut PltEntry> {
248        let plt_sec = &mut module.arch.plt;
249        let plt_idx_sec = &mut module.arch.plt_idx;
250        let plt = get_plt_entry(address, sechdrs, plt_sec, plt_idx_sec);
251        if plt.is_some() {
252            return plt;
253        }
254        let nr = plt_sec.num_entries;
255        // There is no duplicate entry, create a new one
256        let plt_idx_addr = sechdrs[plt_idx_sec.shndx].sh_addr;
257        let plt_idx_entries = unsafe {
258            core::slice::from_raw_parts_mut(
259                plt_idx_addr as *mut PltIdxEntry,
260                plt_idx_sec.max_entries as usize,
261            )
262        };
263        // write the PLT.IDX(loongarch64)/GOT.PLT(riscv64) entry
264        plt_idx_entries[nr] = emit_plt_idx_entry(address);
265
266        let plt_entries_addr = sechdrs[plt_sec.shndx].sh_addr;
267        let plt_entries = unsafe {
268            core::slice::from_raw_parts_mut(
269                plt_entries_addr as *mut PltEntry,
270                plt_sec.max_entries as usize,
271            )
272        };
273        let plt_entry_addr = &plt_entries[nr] as *const PltEntry as u64;
274        let plt_idx_entry_addr = &plt_idx_entries[nr] as *const PltIdxEntry as u64;
275
276        // write the PLT entry
277        plt_entries[nr] = arch_emit_plt_entry_func(address, plt_entry_addr, plt_idx_entry_addr);
278
279        plt_sec.num_entries += 1;
280        plt_idx_sec.num_entries += 1;
281
282        if plt_sec.num_entries > plt_sec.max_entries {
283            panic!("{}: too many PLT entries", module.name());
284        }
285
286        return Some(&mut plt_entries[nr]);
287    }
288
289    pub type ArchGotPltCounterFunc = fn(rela_sec: &RelocSection) -> (usize, usize);
290
291    fn check_got_plt<H: KernelModuleHelper>(
292        elf: &mut Elf,
293        owner: &mut ModuleOwner<H>,
294        plt_idx_name: &str,
295    ) -> Result<()> {
296        let mut got_section_idx = None;
297        let mut plt_section_idx = None;
298        let mut plt_idx_section_idx = None;
299        // Find the empty .plt sections.
300        for (idx, shdr) in elf.section_headers.iter_mut().enumerate() {
301            let sec_name = elf.shdr_strtab.get_at(shdr.sh_name).unwrap_or("<unknown>");
302            if sec_name == ".got" {
303                got_section_idx = Some(idx);
304            } else if sec_name == ".plt" {
305                plt_section_idx = Some(idx);
306            } else if sec_name == plt_idx_name {
307                plt_idx_section_idx = Some(idx);
308            }
309        }
310        if got_section_idx.is_none() {
311            log::error!("{:?}: module .GOT section(s) missing", owner.name());
312            return Err(ModuleErr::ENOEXEC);
313        }
314        if plt_section_idx.is_none() {
315            log::error!("{:?}: module .PLT section(s) missing", owner.name());
316            return Err(ModuleErr::ENOEXEC);
317        }
318        if plt_idx_section_idx.is_none() {
319            log::error!(
320                "{:?}: module {} section(s) missing",
321                owner.name(),
322                plt_idx_name.to_uppercase()
323            );
324            return Err(ModuleErr::ENOEXEC);
325        }
326
327        owner.arch.got.shndx = got_section_idx.unwrap();
328        owner.arch.plt.shndx = plt_section_idx.unwrap();
329        owner.arch.plt_idx.shndx = plt_idx_section_idx.unwrap();
330
331        Ok(())
332    }
333
334    pub fn common_module_frob_arch_sections<H: KernelModuleHelper>(
335        elf: &mut Elf,
336        owner: &mut ModuleOwner<H>,
337        got_plt_counter_func: ArchGotPltCounterFunc,
338        plt_idx_name: &str,
339    ) -> Result<()> {
340        let mut num_plts = 0;
341        let mut num_gots = 0;
342        // Calculate the maxinum number of entries
343        for (idx, rela_sec) in elf.shdr_relocs.iter() {
344            let shdr = &elf.section_headers[*idx];
345            if shdr.sh_type != goblin::elf::section_header::SHT_RELA {
346                continue;
347            }
348            let infosec = shdr.sh_info;
349            let to_section = &elf.section_headers[infosec as usize];
350            // ignore relocations that operate on non-exec sections
351            if to_section.sh_flags & goblin::elf::section_header::SHF_EXECINSTR as u64 == 0 {
352                continue;
353            }
354            let (plt_entries, got_entries) = got_plt_counter_func(rela_sec);
355            num_plts += plt_entries;
356            num_gots += got_entries;
357        }
358
359        log::info!(
360            "[{:?}]: Need {} PLT entries and {} GOT entries",
361            owner.name(),
362            num_plts,
363            num_gots
364        );
365        check_got_plt(elf, owner, plt_idx_name)?;
366
367        let got_section_idx = owner.arch.got.shndx;
368        let plt_section_idx = owner.arch.plt.shndx;
369        let plt_idx_section_idx = owner.arch.plt_idx.shndx;
370
371        {
372            let got_sec = &mut elf.section_headers[got_section_idx];
373            got_sec.sh_type = goblin::elf::section_header::SHT_NOBITS;
374            got_sec.sh_flags = goblin::elf::section_header::SHF_ALLOC as u64;
375            got_sec.sh_addralign = 64; // TODO: L1_CACHE_BYTES
376            got_sec.sh_size = (num_gots as u64 + 1) * size_of::<GotEntry>() as u64;
377            owner.arch.got.num_entries = 0;
378            owner.arch.got.max_entries = num_gots;
379        }
380
381        {
382            let plt_sec = &mut elf.section_headers[plt_section_idx];
383            plt_sec.sh_type = goblin::elf::section_header::SHT_PROGBITS;
384            plt_sec.sh_flags = (goblin::elf::section_header::SHF_ALLOC
385                | goblin::elf::section_header::SHF_EXECINSTR) as u64;
386            plt_sec.sh_addralign = 64;
387            plt_sec.sh_size = (num_plts as u64 + 1) * size_of::<PltEntry>() as u64;
388            owner.arch.plt.num_entries = 0;
389            owner.arch.plt.max_entries = num_plts;
390        }
391
392        {
393            let plt_idx_sec = &mut elf.section_headers[plt_idx_section_idx];
394            plt_idx_sec.sh_type = goblin::elf::section_header::SHT_PROGBITS;
395            plt_idx_sec.sh_flags = goblin::elf::section_header::SHF_ALLOC as u64;
396            plt_idx_sec.sh_addralign = 64;
397            plt_idx_sec.sh_size = (num_plts as u64 + 1) * size_of::<PltIdxEntry>() as u64;
398            owner.arch.plt_idx.num_entries = 0;
399            owner.arch.plt_idx.max_entries = num_plts;
400        }
401        Ok(())
402    }
403}