procref 0.1.0

Cross-platform process reference counting for shared service lifecycle management
Documentation
//! Linux implementation using System V Semaphores with SEM_UNDO.
//!
//! System V semaphores with SEM_UNDO flag ensure that the kernel automatically
//! reverses semaphore operations when a process exits (normally or crashes).
//!
//! Key points:
//! - `semget()` creates or opens a semaphore set
//! - `semop()` with SEM_UNDO flag ensures kernel tracks and reverses on exit
//! - Semaphore persists until explicitly removed with `semctl(IPC_RMID)`

use crate::{Error, RefCounter, Result};
use std::ffi::CString;
use std::sync::atomic::{AtomicBool, Ordering};

/// Linux reference counter using System V semaphores.
pub struct LinuxRefCounter {
    /// Semaphore set ID
    sem_id: i32,
    /// Whether we hold a reference
    has_reference: AtomicBool,
    /// Whether we hold the startup lock
    has_lock: AtomicBool,
    /// Service name (for key generation)
    #[allow(dead_code)]
    name: String,
}

impl LinuxRefCounter {
    /// Create a new reference counter for the given service name.
    pub fn new(name: &str) -> Result<Self> {
        let key = Self::name_to_key(name)?;

        // Create or get semaphore set with 2 semaphores:
        // [0] = reference count
        // [1] = startup lock (binary semaphore)
        let sem_id = unsafe {
            libc::semget(
                key,
                2,
                libc::IPC_CREAT | libc::IPC_EXCL | 0o666,
            )
        };

        let sem_id = if sem_id < 0 {
            // Semaphore exists, just open it
            let errno = std::io::Error::last_os_error();
            if errno.raw_os_error() == Some(libc::EEXIST) {
                let id = unsafe { libc::semget(key, 2, 0o666) };
                if id < 0 {
                    return Err(Error::RefCounterInit(format!(
                        "Failed to open semaphore: {}",
                        std::io::Error::last_os_error()
                    )));
                }
                id
            } else {
                return Err(Error::RefCounterInit(format!(
                    "Failed to create semaphore: {}",
                    errno
                )));
            }
        } else {
            // New semaphore created, initialize the startup lock to 1 (available)
            let arg = libc::semun { val: 1 };
            unsafe {
                if libc::semctl(sem_id, 1, libc::SETVAL, arg) < 0 {
                    return Err(Error::RefCounterInit(format!(
                        "Failed to initialize lock semaphore: {}",
                        std::io::Error::last_os_error()
                    )));
                }
            }
            sem_id
        };

        Ok(Self {
            sem_id,
            has_reference: AtomicBool::new(false),
            has_lock: AtomicBool::new(false),
            name: name.to_string(),
        })
    }

    /// Convert service name to System V IPC key.
    fn name_to_key(name: &str) -> Result<libc::key_t> {
        // Use ftok with a well-known path and project ID derived from name
        let path = CString::new("/tmp").unwrap();
        let proj_id = name.bytes().fold(1u8, |acc, b| acc.wrapping_add(b)) as i32;

        let key = unsafe { libc::ftok(path.as_ptr(), proj_id) };
        if key < 0 {
            return Err(Error::RefCounterInit(format!(
                "ftok failed: {}",
                std::io::Error::last_os_error()
            )));
        }
        Ok(key)
    }

    /// Perform semaphore operation with SEM_UNDO.
    fn sem_op(&self, sem_num: u16, op: i16, flags: i16) -> Result<()> {
        let mut sop = libc::sembuf {
            sem_num,
            sem_op: op,
            sem_flg: flags,
        };

        let result = unsafe { libc::semop(self.sem_id, &mut sop, 1) };
        if result < 0 {
            return Err(Error::Platform(format!(
                "semop failed: {}",
                std::io::Error::last_os_error()
            )));
        }
        Ok(())
    }

    /// Get current value of a semaphore.
    fn sem_val(&self, sem_num: i32) -> Result<i32> {
        let val = unsafe { libc::semctl(self.sem_id, sem_num, libc::GETVAL) };
        if val < 0 {
            return Err(Error::Count(format!(
                "semctl GETVAL failed: {}",
                std::io::Error::last_os_error()
            )));
        }
        Ok(val)
    }
}

impl RefCounter for LinuxRefCounter {
    fn acquire(&self) -> Result<u32> {
        if self.has_reference.load(Ordering::SeqCst) {
            // Already have a reference, just return current count
            return self.count();
        }

        // Increment semaphore 0 with SEM_UNDO
        // SEM_UNDO ensures kernel will decrement on process exit
        self.sem_op(0, 1, libc::SEM_UNDO as i16)?;
        self.has_reference.store(true, Ordering::SeqCst);

        self.count()
    }

    fn release(&self) -> Result<u32> {
        if !self.has_reference.load(Ordering::SeqCst) {
            // Don't have a reference
            return self.count();
        }

        // Decrement semaphore 0 with SEM_UNDO
        // This cancels out the earlier increment in kernel's undo list
        self.sem_op(0, -1, libc::SEM_UNDO as i16)?;
        self.has_reference.store(false, Ordering::SeqCst);

        self.count()
    }

    fn count(&self) -> Result<u32> {
        let val = self.sem_val(0)?;
        Ok(val as u32)
    }

    fn try_lock(&self) -> Result<bool> {
        if self.has_lock.load(Ordering::SeqCst) {
            return Ok(true);
        }

        // Try to decrement lock semaphore (non-blocking)
        // IPC_NOWAIT makes it return immediately if can't acquire
        let result = self.sem_op(1, -1, (libc::IPC_NOWAIT | libc::SEM_UNDO) as i16);

        match result {
            Ok(()) => {
                self.has_lock.store(true, Ordering::SeqCst);
                Ok(true)
            }
            Err(_) => Ok(false), // Lock held by another process
        }
    }

    fn unlock(&self) -> Result<()> {
        if !self.has_lock.load(Ordering::SeqCst) {
            return Ok(());
        }

        // Release lock semaphore
        self.sem_op(1, 1, libc::SEM_UNDO as i16)?;
        self.has_lock.store(false, Ordering::SeqCst);
        Ok(())
    }
}

impl Drop for LinuxRefCounter {
    fn drop(&mut self) {
        // Release lock if held
        if self.has_lock.load(Ordering::SeqCst) {
            let _ = self.unlock();
        }

        // Note: We don't explicitly release() here because:
        // 1. If drop is called normally, the reference should already be released
        // 2. If process crashes, kernel's SEM_UNDO handles it automatically
        //
        // We also don't remove the semaphore (IPC_RMID) because other processes
        // may still be using it.
    }
}

// Union type needed for semctl
#[repr(C)]
#[allow(dead_code)]
union libc_semun {
    val: i32,
    buf: *mut libc::semid_ds,
    array: *mut u16,
}

// Extend libc with semun (not always defined)
mod libc {
    pub use ::libc::*;

    #[repr(C)]
    pub union semun {
        pub val: i32,
        pub buf: *mut semid_ds,
        pub array: *mut u16,
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn test_linux_ref_counter_basic() {
        let counter = LinuxRefCounter::new("test-procref-basic").unwrap();

        // Initial count should be 0
        assert_eq!(counter.count().unwrap(), 0);

        // Acquire
        let count = counter.acquire().unwrap();
        assert_eq!(count, 1);

        // Acquire again (same process, should be idempotent)
        let count = counter.acquire().unwrap();
        assert_eq!(count, 1);

        // Release
        let count = counter.release().unwrap();
        assert_eq!(count, 0);
    }

    #[test]
    fn test_linux_lock() {
        let counter = LinuxRefCounter::new("test-procref-lock").unwrap();

        // Should be able to acquire lock
        assert!(counter.try_lock().unwrap());

        // Already have lock, should return true
        assert!(counter.try_lock().unwrap());

        // Release lock
        counter.unlock().unwrap();
    }
}