use std::{
fmt,
os::fd::{AsRawFd, RawFd},
sync::{Arc, OnceLock, RwLock},
};
use nix::{errno::Errno, unistd::Pid};
use serde::{Serialize, Serializer};
use crate::hash::SydHashMap;
pub(crate) mod abi;
pub(crate) mod api;
#[derive(Clone, Copy, Debug)]
pub(crate) struct TlsSink {
pub(crate) id: KcovId,
}
thread_local! {
static TLS_SINK: RwLock<Option<TlsSink>> = const { RwLock::new(None) };
static RECURSION_GUARD: RwLock<bool> = const { RwLock::new(false) };
}
pub(crate) fn get_tls_sink() -> Option<KcovId> {
let guard = match RECURSION_GUARD.try_with(|g| *g.read().unwrap_or_else(|e| e.into_inner())) {
Ok(g) => g,
Err(_) => return None,
};
if guard {
return None;
}
if let Some(id) = TLS_SINK
.try_with(|s| {
s.read()
.unwrap_or_else(|e| e.into_inner())
.map(|sink| sink.id)
})
.ok()
.flatten()
{
return Some(id);
}
None
}
pub(crate) fn set_tls_sink(id: KcovId) {
let _ =
TLS_SINK.try_with(|s| *s.write().unwrap_or_else(|e| e.into_inner()) = Some(TlsSink { id }));
}
pub(crate) fn clear_tls_sink() {
let _ = TLS_SINK.try_with(|s| *s.write().unwrap_or_else(|e| e.into_inner()) = None);
}
#[expect(clippy::type_complexity)]
static KCOV_TID_MAP: OnceLock<RwLock<SydHashMap<Pid, (KcovId, bool)>>> = OnceLock::new();
#[expect(clippy::type_complexity)]
fn kcov_tid_map() -> &'static RwLock<SydHashMap<Pid, (KcovId, bool)>> {
KCOV_TID_MAP.get_or_init(|| RwLock::new(SydHashMap::default()))
}
pub(crate) fn set_kcov_tid(tid: Pid, id: KcovId, is_remote: bool) {
let mut map = kcov_tid_map().write().unwrap_or_else(|e| e.into_inner());
if is_remote {
if let Some(&(_existing_id, false)) = map.get(&tid) {
return;
}
}
map.insert(tid, (id, is_remote));
}
pub(crate) fn get_kcov_tid(tid: Pid) -> Option<KcovId> {
let map = kcov_tid_map().read().unwrap_or_else(|e| e.into_inner());
let entry = map.get(&tid).copied();
match entry {
Some((id, false)) => Some(id),
Some((_id, true)) => None,
None => None,
}
}
pub(crate) fn remove_kcov_tid(tid: Pid) {
let mut map = kcov_tid_map().write().unwrap_or_else(|e| e.into_inner());
map.remove(&tid);
}
pub(crate) fn inherit_kcov_tid(parent_tid: Pid, child_tid: Pid) {
let entry = {
let map = kcov_tid_map().read().unwrap_or_else(|e| e.into_inner());
map.get(&parent_tid).copied()
};
match entry {
Some((_id, true)) => {}
Some((id, false)) => {
set_kcov_tid(child_tid, id, false);
}
None => {}
}
}
#[derive(Copy, Clone, Eq, PartialEq, Debug)]
pub(crate) enum TraceMode {
Pc,
Cmp,
}
impl fmt::Display for TraceMode {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Pc => write!(f, "pc"),
Self::Cmp => write!(f, "cmp"),
}
}
}
impl Serialize for TraceMode {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_str(&self.to_string())
}
}
#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug)]
pub(crate) struct KcovId(u64);
impl KcovId {
pub(crate) const fn new(id: u64) -> Self {
Self(id)
}
}
impl AsRawFd for KcovId {
#[allow(clippy::disallowed_methods)]
fn as_raw_fd(&self) -> RawFd {
let map = crate::kcov::abi::kcov_reg()
.read()
.unwrap_or_else(|e| e.into_inner());
map.get(self)
.map(|ctx| ctx.syd_fd.as_raw_fd())
.expect("BUG: missing ID in KCOV registry, report a bug!")
}
}
pub(crate) struct Kcov {
map: RwLock<SydHashMap<KcovId, Arc<State>>>,
}
impl Kcov {
pub(crate) fn new() -> Self {
Self {
map: RwLock::new(SydHashMap::default()),
}
}
pub(crate) fn open(&self, kcov_id: u64) -> Result<(), Errno> {
let kcov_id = KcovId(kcov_id);
let state_arc = Arc::new(State::new());
let mut map = self.map.write().unwrap_or_else(|e| e.into_inner());
map.insert(kcov_id, state_arc);
Ok(())
}
pub(crate) fn init_trace(&self, kcov_id: KcovId, words: u64) -> Result<(), Errno> {
self.get(kcov_id)?.init_trace(words)
}
pub(crate) fn enable(&self, id: KcovId, mode: TraceMode) -> Result<(), Errno> {
let st = self.get(id)?;
st.enable(mode)?;
set_tls_sink(id);
Ok(())
}
pub(crate) fn disable(&self, id: KcovId) -> Result<(), Errno> {
let st = self.get(id)?;
st.disable()?;
clear_tls_sink();
Ok(())
}
fn get(&self, kcov_id: KcovId) -> Result<Arc<State>, Errno> {
let read_guard = self.map.read().unwrap_or_else(|e| e.into_inner());
read_guard.get(&kcov_id).cloned().ok_or(Errno::EBADF)
}
}
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
enum Phase {
Disabled,
Init,
Enabled,
}
impl fmt::Display for Phase {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let s = match self {
Self::Disabled => "disabled",
Self::Init => "init",
Self::Enabled => "enabled",
};
f.write_str(s)
}
}
impl Serialize for Phase {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_str(&self.to_string())
}
}
struct State {
core: RwLock<Core>,
}
struct Core {
mode: Option<TraceMode>,
phase: Phase,
}
impl State {
fn new() -> Self {
Self {
core: RwLock::new(Core {
mode: None,
phase: Phase::Disabled,
}),
}
}
fn init_trace(&self, words: u64) -> Result<(), Errno> {
if words < 2 || words > (i32::MAX as u64) / 8 {
return Err(Errno::EINVAL);
}
let mut core = self.core.write().unwrap_or_else(|e| e.into_inner());
if core.phase != Phase::Disabled {
return Err(Errno::EBUSY);
}
core.mode = None;
core.phase = Phase::Init;
Ok(())
}
fn enable(&self, mode: TraceMode) -> Result<(), Errno> {
let mut core = self.core.write().unwrap_or_else(|e| e.into_inner());
match core.phase {
Phase::Init => {
core.mode = Some(mode);
core.phase = Phase::Enabled;
Ok(())
}
Phase::Enabled if core.mode == Some(mode) => Ok(()),
_ => Err(Errno::EBUSY),
}
}
fn disable(&self) -> Result<(), Errno> {
let mut core = self.core.write().unwrap_or_else(|e| e.into_inner());
core.phase = Phase::Init;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_kcov_id_new_1() {
let id = KcovId::new(42);
assert_eq!(id, KcovId(42));
}
#[test]
fn test_kcov_id_eq_1() {
assert_eq!(KcovId::new(1), KcovId::new(1));
}
#[test]
fn test_kcov_id_ne_1() {
assert_ne!(KcovId::new(1), KcovId::new(2));
}
#[test]
fn test_tls_sink_none_by_default_1() {
clear_tls_sink();
assert!(get_tls_sink().is_none());
}
#[test]
fn test_tls_sink_set_get_1() {
let id = KcovId::new(99);
set_tls_sink(id);
assert_eq!(get_tls_sink(), Some(id));
clear_tls_sink();
}
#[test]
fn test_tls_sink_clear_1() {
let id = KcovId::new(77);
set_tls_sink(id);
clear_tls_sink();
assert!(get_tls_sink().is_none());
}
}