use crate::{build_flags, parse_flags, Sv39, Sv39Manager};
use alloc::alloc::alloc_zeroed;
use alloc::collections::BTreeMap;
use alloc::vec;
use alloc::vec::Vec;
use core::alloc::Layout;
use tg_console::log;
use tg_kernel_context::{foreign::ForeignContext, LocalContext};
use tg_kernel_vm::{
page_table::{MmuMeta, VAddr, VmFlags, PPN, VPN},
AddressSpace,
};
use xmas_elf::{
header::{self, HeaderPt2, Machine},
program, ElfFile,
};
const MAX_SYSCALL_NUM: usize = 512;
pub struct Process {
pub context: ForeignContext,
pub address_space: AddressSpace<Sv39, Sv39Manager>,
pub heap_bottom: usize,
pub program_brk: usize,
syscall_times: Vec<usize>,
}
impl Process {
pub fn new(elf: ElfFile) -> Option<Self> {
let entry = match elf.header.pt2 {
HeaderPt2::Header64(pt2)
if pt2.type_.as_type() == header::Type::Executable
&& pt2.machine.as_machine() == Machine::RISC_V =>
{
pt2.entry_point as usize
}
_ => None?,
};
const PAGE_SIZE: usize = 1 << Sv39::PAGE_BITS;
const PAGE_MASK: usize = PAGE_SIZE - 1;
let mut address_space = AddressSpace::new();
let mut max_end_va: usize = 0;
let mut page_flags = BTreeMap::<VPN<Sv39>, VmFlags<Sv39>>::new();
let mut load_segments = Vec::new();
for program in elf.program_iter() {
if !matches!(program.get_type(), Ok(program::Type::Load)) {
continue;
}
let off_file = program.offset() as usize; let len_file = program.file_size() as usize; let off_mem = program.virtual_addr() as usize; let end_mem = off_mem + program.mem_size() as usize; assert_eq!(off_file & PAGE_MASK, off_mem & PAGE_MASK);
if end_mem > max_end_va {
max_end_va = end_mem;
}
let mut flags: [u8; 5] = *b"U___V";
if program.flags().is_execute() {
flags[1] = b'X';
}
if program.flags().is_write() {
flags[2] = b'W';
}
if program.flags().is_read() {
flags[3] = b'R';
}
let flags = parse_flags(unsafe { core::str::from_utf8_unchecked(&flags) }).unwrap();
let vpn_range = VAddr::new(off_mem).floor()..VAddr::new(end_mem).ceil();
let mut vpn = vpn_range.start;
while vpn < vpn_range.end {
page_flags
.entry(vpn)
.and_modify(|page_flag| *page_flag |= flags)
.or_insert(flags);
vpn = vpn + 1;
}
load_segments.push((off_mem, &elf.input[off_file..][..len_file]));
}
for (vpn, flags) in page_flags {
address_space.map(vpn..vpn + 1, &[], 0, flags);
}
for (start, data) in load_segments {
let mut copied = 0usize;
while copied < data.len() {
let addr = VAddr::<Sv39>::new(start + copied);
let copy_len = (PAGE_SIZE - addr.offset()).min(data.len() - copied);
let mut ptr = address_space.translate::<u8>(addr, VmFlags::VALID).unwrap();
unsafe {
core::slice::from_raw_parts_mut(ptr.as_mut(), copy_len)
.copy_from_slice(&data[copied..copied + copy_len]);
}
copied += copy_len;
}
}
let heap_bottom = VAddr::<Sv39>::new(max_end_va).ceil().base().val();
let stack = unsafe {
alloc_zeroed(Layout::from_size_align_unchecked(
2 << Sv39::PAGE_BITS,
1 << Sv39::PAGE_BITS,
))
};
address_space.map_extern(
VPN::new((1 << 26) - 2)..VPN::new(1 << 26),
PPN::new(stack as usize >> Sv39::PAGE_BITS),
build_flags("U_WRV"), );
log::info!(
"process entry = {:#x}, heap_bottom = {:#x}",
entry,
heap_bottom
);
let mut context = LocalContext::user(entry);
let satp = (8 << 60) | address_space.root_ppn().val();
*context.sp_mut() = 1 << 38;
Some(Self {
context: ForeignContext { context, satp },
address_space,
heap_bottom,
program_brk: heap_bottom,
syscall_times: vec![0; MAX_SYSCALL_NUM],
})
}
pub fn change_program_brk(&mut self, size: isize) -> Option<usize> {
let old_brk = self.program_brk;
let new_brk = self.program_brk as isize + size;
if new_brk < self.heap_bottom as isize {
return None;
}
let new_brk = new_brk as usize;
let old_brk_ceil = VAddr::<Sv39>::new(old_brk).ceil();
let new_brk_ceil = VAddr::<Sv39>::new(new_brk).ceil();
if size > 0 {
if new_brk_ceil.val() > old_brk_ceil.val() {
self.address_space
.map(old_brk_ceil..new_brk_ceil, &[], 0, build_flags("U_WRV"));
}
} else if size < 0 {
if old_brk_ceil.val() > new_brk_ceil.val() {
self.address_space.unmap(new_brk_ceil..old_brk_ceil);
}
}
self.program_brk = new_brk;
Some(old_brk)
}
pub fn record_syscall(&mut self, syscall_id: usize) {
if syscall_id < MAX_SYSCALL_NUM {
self.syscall_times[syscall_id] += 1;
}
}
pub fn syscall_times(&self, syscall_id: usize) -> usize {
if syscall_id < MAX_SYSCALL_NUM {
self.syscall_times[syscall_id]
} else {
0
}
}
}