use crate::{
Result,
sync::{AtomicUsize, Ordering},
tls::{TlsIndex, TlsInfo, TlsResolver},
tls_error,
};
use alloc::{
alloc::{alloc, dealloc, handle_alloc_error},
boxed::Box,
collections::BTreeMap,
vec::Vec,
};
use core::{alloc::Layout, ffi::c_void};
use spin::{Mutex, RwLock};
#[derive(Debug)]
struct ModuleSlot {
generation: usize,
template: Option<ModuleTlsTemplate>,
}
#[derive(Debug, Clone)]
struct ModuleTlsTemplate {
image: &'static [u8],
memsz: usize,
align: usize,
tp_offset: Option<isize>,
}
static MODULE_REGISTRY: RwLock<Vec<ModuleSlot>> = RwLock::new(Vec::new());
static NEXT_MODULE_ID: AtomicUsize = AtomicUsize::new(1);
static GLOBAL_GENERATION: AtomicUsize = AtomicUsize::new(0);
fn register_module(tls_info: &TlsInfo, tp_offset: Option<isize>) -> usize {
let mut registry = MODULE_REGISTRY.write();
let mod_id = registry
.iter()
.enumerate()
.skip(1)
.find(|(_, slot)| slot.template.is_none())
.map(|(id, _)| id)
.unwrap_or_else(|| NEXT_MODULE_ID.fetch_add(1, Ordering::SeqCst));
if mod_id >= registry.len() {
registry.resize_with(mod_id + 1, || ModuleSlot {
generation: 0,
template: None,
});
}
let template = ModuleTlsTemplate {
image: tls_info.image,
memsz: tls_info.memsz,
align: tls_info.align,
tp_offset,
};
let new_gen = GLOBAL_GENERATION.fetch_add(1, Ordering::SeqCst) + 1;
registry[mod_id] = ModuleSlot {
generation: new_gen,
template: Some(template),
};
#[cfg(feature = "log")]
log::debug!(
"Registered TLS module: ID {}, memsz {}, align {}, tp_offset {:?}",
mod_id,
tls_info.memsz,
tls_info.align,
tp_offset
);
mod_id
}
fn unregister_module(mod_id: usize) {
let mut registry = MODULE_REGISTRY.write();
assert!(mod_id < registry.len(), "Invalid module ID");
let new_gen = GLOBAL_GENERATION.fetch_add(1, Ordering::SeqCst) + 1;
registry[mod_id] = ModuleSlot {
generation: new_gen,
template: None, };
#[cfg(feature = "log")]
log::debug!("Unregistered TLS module: ID {}", mod_id);
}
fn get_module_template(mod_id: usize) -> Option<ModuleTlsTemplate> {
let registry = MODULE_REGISTRY.read();
registry.get(mod_id).and_then(|slot| slot.template.clone())
}
#[derive(Debug)]
enum DtvEntry {
Allocated {
ptr: *mut u8,
layout: Layout, },
Static {
ptr: *mut u8,
},
}
unsafe impl Send for DtvEntry {}
unsafe impl Sync for DtvEntry {}
impl DtvEntry {
fn ptr(&self) -> *mut u8 {
match self {
DtvEntry::Allocated { ptr, .. } => *ptr,
DtvEntry::Static { ptr } => *ptr,
}
}
}
impl Drop for DtvEntry {
fn drop(&mut self) {
if let DtvEntry::Allocated { ptr, layout } = self {
unsafe { dealloc(*ptr, *layout) };
}
}
}
struct ThreadDtv {
generation: usize,
dtv: Vec<Option<DtvEntry>>,
}
impl ThreadDtv {
fn new() -> Self {
let registry = MODULE_REGISTRY.read();
let mut dtv = Vec::with_capacity(registry.len());
for slot in registry.iter() {
let entry = slot.template.as_ref().and_then(|t| {
if let Some(offset) = t.tp_offset {
unsafe {
let tp = crate::arch::get_thread_pointer();
Some(DtvEntry::Static {
ptr: tp.offset(offset),
})
}
} else {
None
}
});
dtv.push(entry);
}
Self {
generation: GLOBAL_GENERATION.load(Ordering::Acquire),
dtv,
}
}
fn synchronize(&mut self, global_gen: usize) {
let registry = MODULE_REGISTRY.read();
let check_len = core::cmp::min(self.dtv.len(), registry.len());
for (mod_id, slot_val) in self.dtv.iter_mut().enumerate().take(check_len) {
let registry_slot = ®istry[mod_id];
if registry_slot.generation > self.generation {
*slot_val = None;
}
}
self.generation = global_gen;
}
fn get_or_allocate(&mut self, mod_id: usize) -> Option<*mut u8> {
let global_gen = GLOBAL_GENERATION.load(Ordering::Acquire);
if self.generation < global_gen {
self.synchronize(global_gen);
}
if mod_id >= self.dtv.len() {
self.dtv.resize_with(mod_id + 1, || None);
}
if let Some(entry) = &self.dtv[mod_id] {
return Some(entry.ptr());
}
let template = get_module_template(mod_id)?;
let layout = Layout::from_size_align(template.memsz, template.align).ok()?;
let ptr = unsafe { alloc(layout) };
if ptr.is_null() {
handle_alloc_error(layout);
}
unsafe {
let slice = core::slice::from_raw_parts_mut(ptr, template.memsz);
let image_len = template.image.len();
slice[..image_len].copy_from_slice(template.image);
slice[image_len..].fill(0);
}
self.dtv[mod_id] = Some(DtvEntry::Allocated { ptr, layout });
Some(ptr)
}
fn get(&self, mod_id: usize) -> Option<*mut u8> {
let entry = self.dtv.get(mod_id)?.as_ref()?;
let global_gen = GLOBAL_GENERATION.load(Ordering::Acquire);
if self.generation < global_gen {
let registry = MODULE_REGISTRY.read();
match registry.get(mod_id) {
Some(slot) if slot.generation <= self.generation => {
}
_ => return None,
}
}
Some(entry.ptr())
}
}
type ThreadId = usize;
static THREAD_DTVS: Mutex<BTreeMap<ThreadId, Box<ThreadDtv>>> = Mutex::new(BTreeMap::new());
unsafe extern "C" fn dtv_destructor(_ptr: *mut c_void) {
cleanup_current_thread_tls();
}
fn with_current_dtv<F, R>(f: F) -> R
where
F: FnOnce(&mut ThreadDtv) -> R,
{
unsafe {
let ptr = crate::os::get_thread_local_ptr();
if !ptr.is_null() {
return f(&mut *(ptr as *mut ThreadDtv));
}
}
let tid = crate::os::current_thread_id();
let mut map = THREAD_DTVS.lock();
let dtv = map.entry(tid).or_insert_with(|| Box::new(ThreadDtv::new()));
let dtv_ptr = &mut **dtv as *mut ThreadDtv;
unsafe {
crate::os::register_thread_destructor(dtv_destructor, dtv_ptr as *mut _);
}
f(dtv)
}
#[derive(Debug)]
pub struct DefaultTlsResolver;
impl DefaultTlsResolver {
pub fn new() -> Self {
Self
}
pub fn get_thread_pointer() -> *mut u8 {
unsafe { crate::arch::get_thread_pointer() }
}
pub fn get_ptr(mod_id: usize) -> Option<*mut u8> {
with_current_dtv(|dtv| dtv.get(mod_id))
}
pub fn get_tls_data(mod_id: usize) -> Option<&'static [u8]> {
let memsz = get_module_template(mod_id)?.memsz;
Self::get_ptr(mod_id).map(|ptr| unsafe { core::slice::from_raw_parts(ptr, memsz) })
}
pub fn get_tls_data_mut(mod_id: usize) -> Option<&'static mut [u8]> {
let memsz = get_module_template(mod_id)?.memsz;
Self::get_ptr(mod_id).map(|ptr| unsafe { core::slice::from_raw_parts_mut(ptr, memsz) })
}
}
impl TlsResolver for DefaultTlsResolver {
fn register(tls_info: &TlsInfo) -> Result<usize> {
let id = register_module(tls_info, None);
Ok(id)
}
fn register_static(_tls_info: &TlsInfo) -> Result<(usize, isize)> {
Err(tls_error("unsupport static tls"))
}
fn add_static_tls(tls_info: &TlsInfo, offset: isize) -> Result<usize> {
let id = register_module(tls_info, Some(offset));
Ok(id)
}
fn unregister(mod_id: usize) {
unregister_module(mod_id);
}
extern "C" fn tls_get_addr(ti: *const TlsIndex) -> *mut u8 {
let ti = unsafe { &*ti };
with_current_dtv(|dtv| {
match dtv.get_or_allocate(ti.ti_module) {
Some(base_ptr) => {
unsafe { base_ptr.add(ti.ti_offset) }
}
None => {
panic!(
"__tls_get_addr: Failed to allocate TLS for module {}",
ti.ti_module
);
}
}
})
}
}
pub fn cleanup_current_thread_tls() {
let tid = crate::os::current_thread_id();
let mut map = THREAD_DTVS.lock();
map.remove(&tid);
}