sandbox_rs/
utils.rs

1//! Utility functions for sandbox operations
2
3use crate::errors::{Result, SandboxError};
4#[cfg(test)]
5use std::cell::Cell;
6use std::path::Path;
7
8#[cfg(test)]
9thread_local! {
10    static ROOT_OVERRIDE: Cell<Option<bool>> = const { Cell::new(None) };
11}
12
13/// Check if running as root
14pub fn is_root() -> bool {
15    #[cfg(test)]
16    {
17        if let Some(value) = ROOT_OVERRIDE.with(|cell| cell.get()) {
18            return value;
19        }
20    }
21
22    unsafe { libc::geteuid() == 0 }
23}
24
25/// Get current UID
26pub fn get_uid() -> u32 {
27    unsafe { libc::geteuid() }
28}
29
30/// Get current GID
31pub fn get_gid() -> u32 {
32    unsafe { libc::getegid() }
33}
34
35/// Ensure we have root privileges
36pub fn require_root() -> Result<()> {
37    if !is_root() {
38        Err(SandboxError::PermissionDenied(
39            "This operation requires root privileges".to_string(),
40        ))
41    } else {
42        Ok(())
43    }
44}
45
46/// Check if cgroup v2 is available
47pub fn has_cgroup_v2() -> bool {
48    Path::new("/sys/fs/cgroup/cgroup.controllers").exists()
49}
50
51/// Check if a cgroup path exists
52pub fn cgroup_exists(path: &Path) -> bool {
53    path.exists()
54}
55
56/// Parse memory size string (e.g., "100M", "1G")
57pub fn parse_memory_size(s: &str) -> Result<u64> {
58    let s = s.trim().to_uppercase();
59
60    let (num_str, multiplier) = if s.ends_with("G") {
61        (&s[..s.len() - 1], 1024u64 * 1024 * 1024)
62    } else if s.ends_with("M") {
63        (&s[..s.len() - 1], 1024u64 * 1024)
64    } else if s.ends_with("K") {
65        (&s[..s.len() - 1], 1024u64)
66    } else if s.ends_with("B") {
67        (&s[..s.len() - 1], 1u64)
68    } else {
69        (s.as_str(), 1u64)
70    };
71
72    let num: u64 = num_str
73        .parse()
74        .map_err(|_| SandboxError::InvalidConfig(format!("Invalid memory size: {}", s)))?;
75
76    num.checked_mul(multiplier)
77        .ok_or_else(|| SandboxError::InvalidConfig(format!("Memory size overflow: {}", s)))
78}
79
80#[cfg(test)]
81pub fn set_root_override(value: Option<bool>) {
82    ROOT_OVERRIDE.with(|cell| cell.set(value));
83}
84
85#[cfg(test)]
86mod tests {
87    use super::*;
88
89    #[test]
90    fn test_parse_memory_size_bytes() {
91        assert_eq!(parse_memory_size("100").unwrap(), 100);
92        assert_eq!(parse_memory_size("100B").unwrap(), 100);
93    }
94
95    #[test]
96    fn test_parse_memory_size_kilobytes() {
97        assert_eq!(parse_memory_size("1K").unwrap(), 1024);
98        assert_eq!(parse_memory_size("10K").unwrap(), 10 * 1024);
99    }
100
101    #[test]
102    fn test_parse_memory_size_megabytes() {
103        assert_eq!(parse_memory_size("1M").unwrap(), 1024 * 1024);
104        assert_eq!(parse_memory_size("100M").unwrap(), 100 * 1024 * 1024);
105    }
106
107    #[test]
108    fn test_parse_memory_size_gigabytes() {
109        assert_eq!(parse_memory_size("1G").unwrap(), 1024 * 1024 * 1024);
110        assert_eq!(parse_memory_size("2G").unwrap(), 2 * 1024 * 1024 * 1024);
111    }
112
113    #[test]
114    fn test_parse_memory_size_case_insensitive() {
115        assert_eq!(parse_memory_size("1m").unwrap(), 1024 * 1024);
116        assert_eq!(parse_memory_size("1g").unwrap(), 1024 * 1024 * 1024);
117    }
118
119    #[test]
120    fn test_parse_memory_size_whitespace() {
121        assert_eq!(parse_memory_size("  100M  ").unwrap(), 100 * 1024 * 1024);
122    }
123
124    #[test]
125    fn test_parse_memory_size_invalid() {
126        assert!(parse_memory_size("not_a_number").is_err());
127        assert!(parse_memory_size("10X").is_err());
128    }
129
130    #[test]
131    fn test_get_uid_gid() {
132        let uid = get_uid();
133        let gid = get_gid();
134        assert!(uid < u32::MAX);
135        assert!(gid < u32::MAX);
136    }
137
138    #[test]
139    fn test_is_root() {
140        let is_root = is_root();
141        assert_eq!(is_root, get_uid() == 0);
142    }
143
144    #[test]
145    fn test_root_override() {
146        set_root_override(Some(true));
147        assert!(is_root());
148        set_root_override(Some(false));
149        assert!(!is_root());
150        set_root_override(None);
151    }
152
153    #[test]
154    fn test_has_cgroup_v2() {
155        let result = has_cgroup_v2();
156        let _valid = match result {
157            true | false => true,
158        };
159    }
160
161    #[test]
162    fn test_cgroup_exists() {
163        use std::path::Path;
164        assert!(cgroup_exists(Path::new("/")));
165        assert!(!cgroup_exists(Path::new(
166            "/nonexistent/path/that/should/not/exist"
167        )));
168    }
169}