Documentation
// core cpu affinity operations

use crate::error::CpuAffinityError;

#[cfg(target_os = "linux")]
use std::{fs, io};

// maximum cpu id that can be used with CPU_SET
// standard linux value in glibc - fixed at 1024 across major distros
#[cfg(target_os = "linux")]
const CPU_SETSIZE: usize = 1024;

#[cfg(target_os = "linux")]
pub type ThreadId = libc::pid_t;

#[cfg(not(target_os = "linux"))]
pub type ThreadId = i32;

#[cfg(target_os = "linux")]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
struct CpuListInfo {
    max_id: usize,
    count: usize,
}

// get thread id for the calling thread.
#[cfg(target_os = "linux")]
pub fn current_thread_id() -> Result<ThreadId, CpuAffinityError> {
    let thread_id = unsafe { libc::syscall(libc::SYS_gettid) };

    if thread_id <= 0 {
        return Err(CpuAffinityError::Io(io::Error::last_os_error()));
    }

    Ok(thread_id as ThreadId)
}

#[cfg(not(target_os = "linux"))]
pub fn current_thread_id() -> Result<ThreadId, CpuAffinityError> {
    Err(CpuAffinityError::NotSupported)
}

// set cpu affinity for a specific thread id.
#[cfg(target_os = "linux")]
pub fn set_cpu_affinity(
    thread_id: ThreadId,
    cpus: impl IntoIterator<Item = usize>,
) -> Result<(), CpuAffinityError> {
    // safety: cpu_set_t is pod type, zero-initialization standard
    let mut cpu_set: libc::cpu_set_t = unsafe { std::mem::zeroed() };
    let mut max_cpu = None;
    let mut has_cpus = false;

    // validate, deduplicate via CPU_ISSET, and set cpus
    for cpu in cpus {
        let max_cpu = *max_cpu.get_or_insert(max_cpu_id()?);
        if cpu > max_cpu {
            return Err(CpuAffinityError::InvalidCpu { cpu, max: max_cpu });
        }
        if cpu >= CPU_SETSIZE {
            return Err(CpuAffinityError::InvalidCpu {
                cpu,
                max: CPU_SETSIZE - 1,
            });
        }

        // safety: CPU_ISSET safe after validation
        if unsafe { libc::CPU_ISSET(cpu, &cpu_set) } {
            continue;
        }

        // safety: validated cpu within range
        unsafe {
            libc::CPU_SET(cpu, &mut cpu_set);
        }
        has_cpus = true;
    }

    if !has_cpus {
        return Err(CpuAffinityError::EmptyCpuList);
    }

    // safety: sched_setaffinity safe with valid parameters
    let result = unsafe {
        libc::sched_setaffinity(thread_id, std::mem::size_of::<libc::cpu_set_t>(), &cpu_set)
    };

    if result != 0 {
        return Err(CpuAffinityError::Io(io::Error::last_os_error()));
    }

    Ok(())
}

#[cfg(not(target_os = "linux"))]
pub fn set_cpu_affinity(
    _thread_id: ThreadId,
    _cpus: impl IntoIterator<Item = usize>,
) -> Result<(), CpuAffinityError> {
    Err(CpuAffinityError::NotSupported)
}

// set cpu affinity for the calling thread.
#[cfg(target_os = "linux")]
pub fn set_current_thread_affinity(
    cpus: impl IntoIterator<Item = usize>,
) -> Result<(), CpuAffinityError> {
    set_cpu_affinity(0, cpus)
}

#[cfg(not(target_os = "linux"))]
pub fn set_current_thread_affinity(
    _cpus: impl IntoIterator<Item = usize>,
) -> Result<(), CpuAffinityError> {
    Err(CpuAffinityError::NotSupported)
}

// set cpu affinity for a specific thread id.
pub fn set_thread_affinity(
    thread_id: ThreadId,
    cpus: impl IntoIterator<Item = usize>,
) -> Result<(), CpuAffinityError> {
    set_cpu_affinity(thread_id, cpus)
}

#[cfg(target_os = "linux")]
pub fn for_each_online_cpu(mut visit: impl FnMut(usize)) -> Result<(), CpuAffinityError> {
    if let Ok(content) = fs::read_to_string("/sys/devices/system/cpu/online") {
        let content = content.trim();
        if parse_cpu_list(content).is_some() {
            let _ = fold_cpu_list(content, |start, end| {
                for cpu in start..=end {
                    visit(cpu);
                }
            });
            return Ok(());
        }
    }

    let info = sysconf_cpu_info()?;
    for cpu in 0..=info.max_id {
        visit(cpu);
    }
    Ok(())
}

#[cfg(not(target_os = "linux"))]
pub fn for_each_online_cpu(_visit: impl FnMut(usize)) -> Result<(), CpuAffinityError> {
    Err(CpuAffinityError::NotSupported)
}

// get maximum cpu id on system (online cpus only)
#[cfg(target_os = "linux")]
pub fn max_cpu_id() -> Result<usize, CpuAffinityError> {
    Ok(online_cpu_info()?.max_id)
}

#[cfg(target_os = "linux")]
fn online_cpu_info() -> Result<CpuListInfo, CpuAffinityError> {
    if let Ok(content) = fs::read_to_string("/sys/devices/system/cpu/online")
        && let Some(info) = parse_cpu_list(content.trim())
    {
        return Ok(info);
    }

    sysconf_cpu_info()
}

#[cfg(target_os = "linux")]
fn sysconf_cpu_info() -> Result<CpuListInfo, CpuAffinityError> {
    let count = unsafe { libc::sysconf(libc::_SC_NPROCESSORS_ONLN) };

    if count <= 0 {
        return Err(CpuAffinityError::Io(io::Error::last_os_error()));
    }

    let count = count as usize;
    Ok(CpuListInfo {
        max_id: count - 1,
        count,
    })
}

#[cfg(target_os = "linux")]
fn parse_cpu_list(input: &str) -> Option<CpuListInfo> {
    fold_cpu_list(input, |_, _| {})
}

#[cfg(target_os = "linux")]
fn fold_cpu_list(mut input: &str, mut range: impl FnMut(usize, usize)) -> Option<CpuListInfo> {
    input = input.trim();
    if input.is_empty() {
        return None;
    }

    let mut count = 0usize;
    let mut max_id = 0usize;

    for part in input.split(',') {
        let (start, end) = parse_cpu_range(part.trim())?;
        let span = end.checked_sub(start)?.checked_add(1)?;
        count = count.checked_add(span)?;
        max_id = max_id.max(end);
        range(start, end);
    }

    Some(CpuListInfo { max_id, count })
}

#[cfg(target_os = "linux")]
fn parse_cpu_range(input: &str) -> Option<(usize, usize)> {
    if input.is_empty() {
        return None;
    }

    let (start, end) = match input.split_once('-') {
        Some((start, end)) => (
            start.trim().parse::<usize>().ok()?,
            end.trim().parse::<usize>().ok()?,
        ),
        None => {
            let cpu = input.parse::<usize>().ok()?;
            (cpu, cpu)
        }
    };

    (start <= end).then_some((start, end))
}

#[cfg(not(target_os = "linux"))]
pub fn max_cpu_id() -> Result<usize, CpuAffinityError> {
    Err(CpuAffinityError::NotSupported)
}

// get total number of online cpus on system
// returns count of online logical cpus (includes hyperthreads)
#[cfg(target_os = "linux")]
pub fn cpu_count() -> Result<usize, CpuAffinityError> {
    Ok(online_cpu_info()?.count)
}

#[cfg(not(target_os = "linux"))]
pub fn cpu_count() -> Result<usize, CpuAffinityError> {
    Err(CpuAffinityError::NotSupported)
}

#[cfg(test)]
#[cfg(target_os = "linux")]
mod tests {
    use super::{CpuListInfo, ThreadId, current_thread_id, parse_cpu_list, set_cpu_affinity};
    use crate::CpuAffinityError;
    use core::mem::size_of;
    use std::{sync::mpsc, thread};

    #[test]
    fn parses_single_cpu() {
        assert_eq!(
            parse_cpu_list("0"),
            Some(CpuListInfo {
                max_id: 0,
                count: 1
            })
        );
    }

    #[test]
    fn parses_dense_range() {
        assert_eq!(
            parse_cpu_list("0-3"),
            Some(CpuListInfo {
                max_id: 3,
                count: 4
            })
        );
    }

    #[test]
    fn parses_sparse_ranges() {
        assert_eq!(
            parse_cpu_list("0-3,8-11"),
            Some(CpuListInfo {
                max_id: 11,
                count: 8
            })
        );
    }

    #[test]
    fn rejects_invalid_ranges() {
        assert_eq!(parse_cpu_list("3-1"), None);
        assert_eq!(parse_cpu_list(""), None);
        assert_eq!(parse_cpu_list("0-3,"), None);
    }

    #[test]
    fn set_cpu_affinity_rejects_empty_cpu_list_before_syscall() {
        assert!(matches!(
            set_cpu_affinity(current_thread_id().unwrap(), []),
            Err(CpuAffinityError::EmptyCpuList)
        ));
    }

    #[test]
    fn thread_id_matches_linux_pid_type() {
        assert_eq!(size_of::<ThreadId>(), size_of::<libc::pid_t>());
    }

    #[test]
    fn current_thread_id_is_positive() {
        assert!(current_thread_id().unwrap() > 0);
    }

    #[test]
    fn spawned_thread_has_distinct_thread_id() {
        let main_thread_id = current_thread_id().unwrap();
        let (sender, receiver) = mpsc::channel();

        let worker = thread::spawn(move || {
            sender.send(current_thread_id().unwrap()).unwrap();
        });

        let worker_thread_id = receiver.recv().unwrap();
        worker.join().unwrap();

        assert_ne!(main_thread_id, worker_thread_id);
    }
}