use crate::{
PROCESSOR, Sv39, Sv39Manager, build_flags, fs::Fd, map_portal, parse_flags,
processor::ProcessorInner,
};
use alloc::{alloc::alloc_zeroed, boxed::Box, collections::BTreeMap, sync::Arc, vec, vec::Vec};
use core::alloc::Layout;
use spin::Mutex;
use tg_kernel_context::{LocalContext, foreign::ForeignContext};
use tg_kernel_vm::{
AddressSpace,
page_table::{MmuMeta, PPN, VAddr, VPN},
};
use tg_signal::Signal;
use tg_signal_impl::SignalImpl;
use tg_sync::{Condvar, Mutex as MutexTrait, Semaphore};
use tg_task_manage::{ProcId, ThreadId};
use xmas_elf::{
ElfFile,
header::{self, HeaderPt2, Machine},
program,
};
pub struct Thread {
pub tid: ThreadId,
pub context: ForeignContext,
}
impl Thread {
pub fn new(satp: usize, context: LocalContext) -> Self {
Self {
tid: ThreadId::new(),
context: ForeignContext { context, satp },
}
}
}
pub struct Process {
pub pid: ProcId,
pub address_space: AddressSpace<Sv39, Sv39Manager>,
pub fd_table: Vec<Option<Mutex<Fd>>>,
pub signal: Box<dyn Signal>,
pub heap_bottom: usize,
pub program_brk: usize,
pub semaphore_list: Vec<Option<Arc<Semaphore>>>,
pub mutex_list: Vec<Option<Arc<dyn MutexTrait>>>,
pub condvar_list: Vec<Option<Arc<Condvar>>>,
pub deadlock: DeadlockState,
}
#[derive(Default)]
pub struct DeadlockState {
pub enabled: bool,
semaphore_total: Vec<usize>,
semaphore_alloc: BTreeMap<ThreadId, Vec<usize>>,
semaphore_need: BTreeMap<ThreadId, Vec<usize>>,
mutex_owner: Vec<Option<ThreadId>>,
mutex_wait: BTreeMap<ThreadId, usize>,
}
impl DeadlockState {
fn ensure_sem_row(
rows: &mut BTreeMap<ThreadId, Vec<usize>>,
tid: ThreadId,
width: usize,
) -> &mut Vec<usize> {
let row = rows.entry(tid).or_insert_with(|| vec![0; width]);
if row.len() < width {
row.resize(width, 0);
}
row
}
fn trim_sem_row(rows: &mut BTreeMap<ThreadId, Vec<usize>>, tid: ThreadId) {
if rows
.get(&tid)
.is_some_and(|row| row.iter().all(|&v| v == 0))
{
rows.remove(&tid);
}
}
pub fn register_semaphore(&mut self, total: usize) -> usize {
self.semaphore_total.push(total);
let new_width = self.semaphore_total.len();
self.semaphore_alloc
.values_mut()
.for_each(|row| row.resize(new_width, 0));
self.semaphore_need
.values_mut()
.for_each(|row| row.resize(new_width, 0));
new_width - 1
}
pub fn register_mutex(&mut self) -> usize {
self.mutex_owner.push(None);
self.mutex_owner.len() - 1
}
pub fn semaphore_acquired(&mut self, tid: ThreadId, sem_id: usize) {
let width = self.semaphore_total.len();
let row = Self::ensure_sem_row(&mut self.semaphore_alloc, tid, width);
row[sem_id] += 1;
}
pub fn semaphore_wait(&mut self, tid: ThreadId, sem_id: usize) {
let width = self.semaphore_total.len();
let row = Self::ensure_sem_row(&mut self.semaphore_need, tid, width);
row[sem_id] += 1;
}
pub fn semaphore_cancel_wait(&mut self, tid: ThreadId, sem_id: usize) {
if let Some(row) = self.semaphore_need.get_mut(&tid) {
if sem_id < row.len() && row[sem_id] > 0 {
row[sem_id] -= 1;
}
}
Self::trim_sem_row(&mut self.semaphore_need, tid);
}
pub fn semaphore_release(
&mut self,
tid: ThreadId,
sem_id: usize,
waking_tid: Option<ThreadId>,
) {
if let Some(row) = self.semaphore_alloc.get_mut(&tid) {
if sem_id < row.len() && row[sem_id] > 0 {
row[sem_id] -= 1;
}
}
Self::trim_sem_row(&mut self.semaphore_alloc, tid);
if let Some(waking_tid) = waking_tid {
self.semaphore_cancel_wait(waking_tid, sem_id);
self.semaphore_acquired(waking_tid, sem_id);
}
}
fn semaphore_available(&self) -> Vec<usize> {
let mut available = self.semaphore_total.clone();
for row in self.semaphore_alloc.values() {
for (idx, count) in row.iter().enumerate() {
available[idx] = available[idx].saturating_sub(*count);
}
}
available
}
pub fn semaphore_would_deadlock(
&self,
active_threads: &[ThreadId],
tid: ThreadId,
sem_id: usize,
) -> bool {
let width = self.semaphore_total.len();
if sem_id >= width {
return false;
}
let mut work = self.semaphore_available();
let mut finish = vec![false; active_threads.len()];
loop {
let mut progressed = false;
for (idx, thread_id) in active_threads.iter().copied().enumerate() {
if finish[idx] {
continue;
}
let mut need = self
.semaphore_need
.get(&thread_id)
.cloned()
.unwrap_or_else(|| vec![0; width]);
if thread_id == tid {
need[sem_id] += 1;
}
if need
.iter()
.zip(work.iter())
.all(|(need, work)| need <= work)
{
let alloc = self
.semaphore_alloc
.get(&thread_id)
.cloned()
.unwrap_or_else(|| vec![0; width]);
for (work_item, alloc_item) in work.iter_mut().zip(alloc.iter()) {
*work_item += *alloc_item;
}
finish[idx] = true;
progressed = true;
}
}
if !progressed {
break;
}
}
finish.iter().any(|finished| !finished)
}
pub fn mutex_acquired(&mut self, tid: ThreadId, mutex_id: usize) {
self.mutex_wait.remove(&tid);
self.mutex_owner[mutex_id] = Some(tid);
}
pub fn mutex_wait(&mut self, tid: ThreadId, mutex_id: usize) {
self.mutex_wait.insert(tid, mutex_id);
}
pub fn mutex_release(&mut self, tid: ThreadId, mutex_id: usize, waking_tid: Option<ThreadId>) {
if self.mutex_owner.get(mutex_id).copied().flatten() == Some(tid) {
if let Some(waking_tid) = waking_tid {
self.mutex_wait.remove(&waking_tid);
self.mutex_owner[mutex_id] = Some(waking_tid);
} else {
self.mutex_owner[mutex_id] = None;
}
}
}
pub fn mutex_would_deadlock(&self, tid: ThreadId, mutex_id: usize) -> bool {
let Some(mut holder) = self.mutex_owner.get(mutex_id).copied().flatten() else {
return false;
};
loop {
if holder == tid {
return true;
}
let Some(wait_mutex_id) = self.mutex_wait.get(&holder).copied() else {
return false;
};
let Some(next_holder) = self.mutex_owner.get(wait_mutex_id).copied().flatten() else {
return false;
};
holder = next_holder;
}
}
pub fn on_thread_exit(&mut self, tid: ThreadId) {
self.semaphore_need.remove(&tid);
self.mutex_wait.remove(&tid);
}
}
impl Process {
pub fn exec(&mut self, elf: ElfFile) {
let (proc, thread) = Process::from_elf(elf).unwrap();
self.address_space = proc.address_space;
self.heap_bottom = proc.heap_bottom;
self.program_brk = proc.program_brk;
let processor: *mut ProcessorInner = PROCESSOR.get_mut() as *mut ProcessorInner;
unsafe {
let pthreads = (*processor).get_thread(self.pid).unwrap();
(*processor).get_task(pthreads[0]).unwrap().context = thread.context;
}
}
pub fn fork(&mut self) -> Option<(Self, Thread)> {
let pid = ProcId::new();
let parent_addr_space = &self.address_space;
let mut address_space: AddressSpace<Sv39, Sv39Manager> = AddressSpace::new();
parent_addr_space.cloneself(&mut address_space);
map_portal(&address_space);
let processor: *mut ProcessorInner = PROCESSOR.get_mut() as *mut ProcessorInner;
let pthreads = unsafe { (*processor).get_thread(self.pid).unwrap() };
let context = unsafe {
(*processor)
.get_task(pthreads[0])
.unwrap()
.context
.context
.clone()
};
let satp = (8 << 60) | address_space.root_ppn().val();
let thread = Thread::new(satp, context);
let new_fd_table: Vec<Option<Mutex<Fd>>> = self
.fd_table
.iter()
.map(|fd| fd.as_ref().map(|f| Mutex::new(f.lock().clone())))
.collect();
Some((
Self {
pid,
address_space,
fd_table: new_fd_table,
signal: self.signal.from_fork(),
heap_bottom: self.heap_bottom,
program_brk: self.program_brk,
semaphore_list: Vec::new(),
mutex_list: Vec::new(),
condvar_list: Vec::new(),
deadlock: DeadlockState::default(),
},
thread,
))
}
pub fn from_elf(elf: ElfFile) -> Option<(Self, Thread)> {
let entry = match elf.header.pt2 {
HeaderPt2::Header64(pt2)
if pt2.type_.as_type() == header::Type::Executable
&& pt2.machine.as_machine() == Machine::RISC_V =>
{
pt2.entry_point as usize
}
_ => None?,
};
const PAGE_SIZE: usize = 1 << Sv39::PAGE_BITS;
const PAGE_MASK: usize = PAGE_SIZE - 1;
let mut address_space = AddressSpace::new();
let mut max_end_va: usize = 0;
for program in elf.program_iter() {
if !matches!(program.get_type(), Ok(program::Type::Load)) {
continue;
}
let off_file = program.offset() as usize;
let len_file = program.file_size() as usize;
let off_mem = program.virtual_addr() as usize;
let end_mem = off_mem + program.mem_size() as usize;
assert_eq!(off_file & PAGE_MASK, off_mem & PAGE_MASK);
if end_mem > max_end_va {
max_end_va = end_mem;
}
let mut flags: [u8; 5] = *b"U___V";
if program.flags().is_execute() {
flags[1] = b'X';
}
if program.flags().is_write() {
flags[2] = b'W';
}
if program.flags().is_read() {
flags[3] = b'R';
}
address_space.map(
VAddr::new(off_mem).floor()..VAddr::new(end_mem).ceil(),
&elf.input[off_file..][..len_file],
off_mem & PAGE_MASK,
parse_flags(unsafe { core::str::from_utf8_unchecked(&flags) }).unwrap(),
);
}
let heap_bottom = VAddr::<Sv39>::new(max_end_va).ceil().base().val();
let stack = unsafe {
alloc_zeroed(Layout::from_size_align_unchecked(
2 << Sv39::PAGE_BITS,
1 << Sv39::PAGE_BITS,
))
};
address_space.map_extern(
VPN::new((1 << 26) - 2)..VPN::new(1 << 26),
PPN::new(stack as usize >> Sv39::PAGE_BITS),
build_flags("U_WRV"),
);
map_portal(&address_space);
let satp = (8 << 60) | address_space.root_ppn().val();
let mut context = LocalContext::user(entry);
*context.sp_mut() = 1 << 38;
let thread = Thread::new(satp, context);
Some((
Self {
pid: ProcId::new(),
address_space,
fd_table: vec![
Some(Mutex::new(Fd::Empty {
read: true,
write: false,
})),
Some(Mutex::new(Fd::Empty {
read: false,
write: true,
})),
Some(Mutex::new(Fd::Empty {
read: false,
write: true,
})),
],
signal: Box::new(SignalImpl::new()),
heap_bottom,
program_brk: heap_bottom,
semaphore_list: Vec::new(),
mutex_list: Vec::new(),
condvar_list: Vec::new(),
deadlock: DeadlockState::default(),
},
thread,
))
}
pub fn change_program_brk(&mut self, size: isize) -> Option<usize> {
let old_brk = self.program_brk;
let new_brk = self.program_brk as isize + size;
if new_brk < self.heap_bottom as isize {
return None;
}
let new_brk = new_brk as usize;
let old_brk_ceil = VAddr::<Sv39>::new(old_brk).ceil();
let new_brk_ceil = VAddr::<Sv39>::new(new_brk).ceil();
if size > 0 {
if new_brk_ceil.val() > old_brk_ceil.val() {
self.address_space
.map(old_brk_ceil..new_brk_ceil, &[], 0, build_flags("U_WRV"));
}
} else if size < 0 && old_brk_ceil.val() > new_brk_ceil.val() {
self.address_space.unmap(new_brk_ceil..old_brk_ceil);
}
self.program_brk = new_brk;
Some(old_brk)
}
}