use crate::fd_readable_set::{FdReadableSet, Timeout};
use crate::fds::{self, make_fd_nonblocking, AutoClosePipes};
use crate::flog::{flog, FloggableDebug};
use fish_util::perror;
use fish_widestring::WString;
use nix::errno::Errno;
use nix::unistd;
use std::cell::Cell;
use std::os::fd::AsRawFd as _;
use std::sync::atomic::{AtomicU8, Ordering};
use std::sync::{Condvar, Mutex, MutexGuard};
#[cfg(target_os = "linux")]
use std::{cell::UnsafeCell, pin::Pin};
#[repr(u8)]
#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
pub enum Topic {
SigHupIntTerm = 0, SigChld = 1, InternalExit = 2, }
#[derive(Clone, Debug, Default, PartialEq, PartialOrd, Eq, Ord)]
pub struct GenerationsList {
pub sighupintterm: Cell<u64>,
pub sigchld: Cell<u64>,
pub internal_exit: Cell<u64>,
}
impl GenerationsList {
pub fn update(&self, other: &Self) {
self.sighupintterm.set(other.sighupintterm.get());
self.sigchld.set(other.sigchld.get());
self.internal_exit.set(other.internal_exit.get());
}
}
pub type Generation = u64;
impl FloggableDebug for Topic {}
pub const INVALID_GENERATION: Generation = u64::MAX;
pub fn all_topics() -> [Topic; 3] {
[Topic::SigHupIntTerm, Topic::SigChld, Topic::InternalExit]
}
impl GenerationsList {
pub fn new() -> Self {
Self::default()
}
pub fn invalid() -> GenerationsList {
GenerationsList {
sighupintterm: INVALID_GENERATION.into(),
sigchld: INVALID_GENERATION.into(),
internal_exit: INVALID_GENERATION.into(),
}
}
#[allow(dead_code)]
fn describe(&self) -> WString {
let mut result = WString::new();
for r#gen in self.as_array() {
if !result.is_empty() {
result.push(',');
}
if r#gen == INVALID_GENERATION {
result.push_str("-1");
} else {
result.push_str(&r#gen.to_string());
}
}
result
}
pub fn set(&self, topic: Topic, value: Generation) {
match topic {
Topic::SigHupIntTerm => self.sighupintterm.set(value),
Topic::SigChld => self.sigchld.set(value),
Topic::InternalExit => self.internal_exit.set(value),
}
}
pub fn get(&self, topic: Topic) -> Generation {
match topic {
Topic::SigHupIntTerm => self.sighupintterm.get(),
Topic::SigChld => self.sigchld.get(),
Topic::InternalExit => self.internal_exit.get(),
}
}
pub fn as_array(&self) -> [Generation; 3] {
[
self.sighupintterm.get(),
self.sigchld.get(),
self.internal_exit.get(),
]
}
pub fn set_min_from(&mut self, topic: Topic, other: &Self) {
if self.get(topic) > other.get(topic) {
self.set(topic, other.get(topic));
}
}
pub fn is_valid(&self, topic: Topic) -> bool {
self.get(topic) != INVALID_GENERATION
}
pub fn any_valid(&self) -> bool {
let mut valid = false;
for r#gen in self.as_array() {
if r#gen != INVALID_GENERATION {
valid = true;
}
}
valid
}
}
pub enum BinarySemaphore {
#[cfg(target_os = "linux")]
Semaphore(Pin<Box<UnsafeCell<libc::sem_t>>>),
Pipes(AutoClosePipes),
}
impl BinarySemaphore {
pub fn new() -> BinarySemaphore {
#[cfg(target_os = "linux")]
{
let sem = Box::pin(UnsafeCell::new(unsafe { std::mem::zeroed() }));
let res = unsafe { libc::sem_init(sem.get(), 0, 0) };
if res == 0 {
return Self::Semaphore(sem);
}
}
let pipes = fds::make_autoclose_pipes().expect("Failed to make pubsub pipes");
if cfg!(feature = "tsan") {
let _ = make_fd_nonblocking(pipes.read.as_raw_fd());
}
Self::Pipes(pipes)
}
pub fn post(&self) {
match self {
#[cfg(target_os = "linux")]
Self::Semaphore(sem) => {
let res = unsafe { libc::sem_post(sem.get()) };
if res < 0 {
self.die("sem_post");
}
}
Self::Pipes(pipes) => {
loop {
match unistd::write(&pipes.write, &[0]) {
Err(Errno::EINTR) => continue,
Err(_) => self.die("write"),
Ok(_) => break,
}
}
}
}
}
pub fn wait(&self) {
match self {
#[cfg(target_os = "linux")]
Self::Semaphore(sem) => {
loop {
match unsafe { libc::sem_wait(sem.get()) } {
0.. => break,
_ if Errno::last() == Errno::EINTR => continue,
_ => self.die("sem_wait"),
}
}
}
Self::Pipes(pipes) => {
let fd = pipes.read.as_raw_fd();
loop {
if cfg!(feature = "tsan") {
let _ = FdReadableSet::is_fd_readable(fd, Timeout::Forever);
}
let mut ignored: u8 = 0;
match unistd::read(&pipes.read, std::slice::from_mut(&mut ignored)) {
Ok(1) => break,
Ok(_) => continue,
Err(Errno::EINTR) | Err(Errno::EAGAIN) => continue,
Err(_) => self.die("read"),
}
}
}
}
}
pub fn die(&self, msg: &str) {
perror(msg);
panic!("die");
}
}
#[cfg(target_os = "linux")]
impl Drop for BinarySemaphore {
fn drop(&mut self) {
if let Self::Semaphore(sem) = self {
_ = unsafe { libc::sem_destroy(sem.get()) };
}
}
}
impl Default for BinarySemaphore {
fn default() -> Self {
Self::new()
}
}
type TopicBitmask = u8;
fn topic_to_bit(t: Topic) -> TopicBitmask {
1 << (t as u8)
}
#[derive(Default)]
struct data_t {
current: GenerationsList,
has_reader: bool,
}
const STATUS_NEEDS_WAKEUP: u8 = 128;
type StatusBits = u8;
#[derive(Default)]
pub struct TopicMonitor {
data_: Mutex<data_t>,
data_notifier_: Condvar,
status_: AtomicU8,
sema_: BinarySemaphore,
}
#[cfg(test)]
unsafe impl Sync for TopicMonitor {}
static mut PRINCIPAL: *const TopicMonitor = std::ptr::null();
impl TopicMonitor {
pub fn initialize() -> &'static Self {
unsafe {
if PRINCIPAL.is_null() {
PRINCIPAL = Box::into_raw(Box::default());
}
&*PRINCIPAL
}
}
pub fn post(&self, topic: Topic) {
let topicbit = topic_to_bit(topic);
let relaxed = Ordering::Relaxed;
let mut oldstatus: StatusBits = 0;
let mut cas_success = false;
while !cas_success {
oldstatus = self.status_.load(relaxed);
let mut newstatus = oldstatus;
newstatus &= !STATUS_NEEDS_WAKEUP; newstatus |= topicbit;
cas_success = self
.status_
.compare_exchange_weak(oldstatus, newstatus, relaxed, relaxed)
.is_ok();
}
assert_eq!(
(oldstatus == STATUS_NEEDS_WAKEUP),
((oldstatus & STATUS_NEEDS_WAKEUP) != 0),
"If STATUS_NEEDS_WAKEUP is set no other bits should be set"
);
if (oldstatus & topicbit) != 0 {
return;
}
if (oldstatus & STATUS_NEEDS_WAKEUP) != 0 {
std::sync::atomic::fence(Ordering::Release);
self.sema_.post();
}
}
fn updated_gens_in_data(&self, data: &mut MutexGuard<data_t>) -> GenerationsList {
let relaxed = Ordering::Relaxed;
let mut changed_topic_bits: TopicBitmask = 0;
let mut cas_success = false;
while !cas_success {
changed_topic_bits = self.status_.load(relaxed);
if changed_topic_bits == 0 || changed_topic_bits == STATUS_NEEDS_WAKEUP {
return data.current.clone();
}
cas_success = self
.status_
.compare_exchange_weak(changed_topic_bits, 0, relaxed, relaxed)
.is_ok();
}
assert_eq!(
changed_topic_bits & STATUS_NEEDS_WAKEUP,
0,
"Thread waiting bit should not be set"
);
for topic in all_topics() {
if changed_topic_bits & topic_to_bit(topic) != 0 {
data.current.set(topic, data.current.get(topic) + 1);
flog!(
topic_monitor,
"Updating topic",
topic,
"to",
data.current.get(topic)
);
}
}
self.data_notifier_.notify_all();
data.current.clone()
}
fn updated_gens(&self) -> GenerationsList {
let mut data = self.data_.lock().unwrap();
self.updated_gens_in_data(&mut data)
}
pub fn current_generations(self: &TopicMonitor) -> GenerationsList {
self.updated_gens()
}
pub fn generation_for_topic(self: &TopicMonitor, topic: Topic) -> Generation {
self.current_generations().get(topic)
}
fn try_update_gens_maybe_becoming_reader(&self, gens: &mut GenerationsList) -> bool {
let mut become_reader = false;
let mut data = self.data_.lock().unwrap();
loop {
let current = self.updated_gens_in_data(&mut data);
if *gens != current {
*gens = current;
break;
}
if data.has_reader {
data = self.data_notifier_.wait(data).unwrap();
continue;
} else {
assert_eq!(
self.status_.load(Ordering::Relaxed) & STATUS_NEEDS_WAKEUP,
0,
"No thread should be waiting"
);
let expected_old: StatusBits = 0;
if self
.status_
.compare_exchange(
expected_old,
STATUS_NEEDS_WAKEUP,
Ordering::SeqCst,
Ordering::SeqCst,
)
.is_err()
{
continue;
}
become_reader = true;
data.has_reader = true;
break;
}
}
become_reader
}
fn await_gens(&self, input_gens: &GenerationsList) -> GenerationsList {
let mut gens = input_gens.clone();
while &gens == input_gens {
let become_reader = self.try_update_gens_maybe_becoming_reader(&mut gens);
if become_reader {
assert_eq!(
gens, *input_gens,
"Generations should not have changed if we are the reader."
);
self.sema_.wait();
let mut data = self.data_.lock().unwrap();
gens = data.current.clone();
assert!(data.has_reader, "We should be the reader");
data.has_reader = false;
self.data_notifier_.notify_all();
}
}
gens
}
pub fn check(&self, gens: &GenerationsList, wait: bool) -> bool {
if !gens.any_valid() {
return false;
}
let mut current: GenerationsList = self.updated_gens();
let mut changed = false;
loop {
for topic in all_topics() {
if gens.is_valid(topic) {
assert!(
gens.get(topic) <= current.get(topic),
"Incoming gen count exceeded published count"
);
if gens.get(topic) < current.get(topic) {
gens.set(topic, current.get(topic));
changed = true;
}
}
}
if !wait || changed {
break;
}
current = self.await_gens(¤t);
}
changed
}
}
pub fn topic_monitor_init() {
TopicMonitor::initialize();
}
pub fn topic_monitor_principal() -> &'static TopicMonitor {
unsafe {
assert!(
!PRINCIPAL.is_null(),
"Principal topic monitor not initialized"
);
&*PRINCIPAL
}
}
#[cfg(test)]
mod tests {
use super::{GenerationsList, Topic, TopicMonitor};
use crate::portable_atomic::AtomicU64;
use crate::tests::prelude::*;
use std::sync::{
atomic::{AtomicU32, Ordering},
Arc,
};
#[test]
#[serial]
fn test_topic_monitor() {
let _cleanup = test_init();
let monitor = TopicMonitor::default();
let gens = GenerationsList::new();
let t = Topic::SigChld;
gens.sigchld.set(0);
assert_eq!(monitor.generation_for_topic(t), 0);
let changed = monitor.check(&gens, false );
assert!(!changed);
assert_eq!(gens.sigchld.get(), 0);
monitor.post(t);
let changed = monitor.check(&gens, true );
assert!(changed);
assert_eq!(gens.get(t), 1);
assert_eq!(monitor.generation_for_topic(t), 1);
monitor.post(t);
assert_eq!(monitor.generation_for_topic(t), 2);
let changed = monitor.check(&gens, true );
assert!(changed);
assert_eq!(gens.sigchld.get(), 2);
}
#[test]
#[serial]
fn test_topic_monitor_torture() {
let _cleanup = test_init();
let monitor = Arc::new(TopicMonitor::default());
const THREAD_COUNT: usize = 64;
let t1 = Topic::SigChld;
let t2 = Topic::SigHupIntTerm;
let mut gens_list = vec![GenerationsList::invalid(); THREAD_COUNT];
let post_count = Arc::new(AtomicU64::new(0));
for r#gen in &mut gens_list {
*r#gen = monitor.current_generations();
post_count.fetch_add(1, Ordering::Relaxed);
monitor.post(t1);
}
let completed = Arc::new(AtomicU32::new(0));
let mut threads = vec![];
for gens in gens_list {
let monitor = Arc::downgrade(&monitor);
let post_count = Arc::downgrade(&post_count);
let completed = Arc::downgrade(&completed);
threads.push(std::thread::spawn(move || {
for _ in 0..1 << 11 {
let before = gens.clone();
let _changed = monitor.upgrade().unwrap().check(&gens, true );
assert!(before.get(t1) < gens.get(t1));
assert!(gens.get(t1) <= post_count.upgrade().unwrap().load(Ordering::Relaxed));
assert_eq!(gens.get(t2), 0);
}
let _amt = completed.upgrade().unwrap().fetch_add(1, Ordering::Relaxed);
}));
}
while completed.load(Ordering::Relaxed) < THREAD_COUNT.try_into().unwrap() {
post_count.fetch_add(1, Ordering::Relaxed);
monitor.post(t1);
std::thread::yield_now();
}
for t in threads {
t.join().unwrap();
}
}
}