sandbox_rs/
utils.rs

1//! Utility functions for sandbox operations
2
3use crate::errors::{Result, SandboxError};
4#[cfg(test)]
5use std::cell::Cell;
6use std::path::Path;
7
8#[cfg(test)]
9thread_local! {
10    static ROOT_OVERRIDE: Cell<Option<bool>> = const { Cell::new(None) };
11}
12
13/// Check if running as root
14pub fn is_root() -> bool {
15    #[cfg(test)]
16    {
17        if let Some(value) = ROOT_OVERRIDE.with(|cell| cell.get()) {
18            return value;
19        }
20    }
21
22    unsafe { libc::geteuid() == 0 }
23}
24
25/// Get current UID
26pub fn get_uid() -> u32 {
27    unsafe { libc::geteuid() }
28}
29
30/// Get current GID
31pub fn get_gid() -> u32 {
32    unsafe { libc::getegid() }
33}
34
35/// Ensure we have root privileges
36pub fn require_root() -> Result<()> {
37    if !is_root() {
38        Err(SandboxError::PermissionDenied(
39            "This operation requires root privileges".to_string(),
40        ))
41    } else {
42        Ok(())
43    }
44}
45
46/// Check if cgroup v2 is available
47pub fn has_cgroup_v2() -> bool {
48    Path::new("/sys/fs/cgroup/cgroup.controllers").exists()
49}
50
51/// Check if a cgroup path exists
52pub fn cgroup_exists(path: &Path) -> bool {
53    path.exists()
54}
55
56/// Parse memory size string (e.g., "100M", "1G")
57pub fn parse_memory_size(s: &str) -> Result<u64> {
58    let s = s.trim().to_uppercase();
59
60    let (num_str, multiplier) = if s.ends_with("G") {
61        (&s[..s.len() - 1], 1024u64 * 1024 * 1024)
62    } else if s.ends_with("M") {
63        (&s[..s.len() - 1], 1024u64 * 1024)
64    } else if s.ends_with("K") {
65        (&s[..s.len() - 1], 1024u64)
66    } else if s.ends_with("B") {
67        (&s[..s.len() - 1], 1u64)
68    } else {
69        (s.as_str(), 1u64)
70    };
71
72    let num: u64 = num_str
73        .parse()
74        .map_err(|_| SandboxError::InvalidConfig(format!("Invalid memory size: {}", s)))?;
75
76    Ok(num * multiplier)
77}
78
79#[cfg(test)]
80pub fn set_root_override(value: Option<bool>) {
81    ROOT_OVERRIDE.with(|cell| cell.set(value));
82}
83
84#[cfg(test)]
85mod tests {
86    use super::*;
87
88    #[test]
89    fn test_parse_memory_size_bytes() {
90        assert_eq!(parse_memory_size("100").unwrap(), 100);
91        assert_eq!(parse_memory_size("100B").unwrap(), 100);
92    }
93
94    #[test]
95    fn test_parse_memory_size_kilobytes() {
96        assert_eq!(parse_memory_size("1K").unwrap(), 1024);
97        assert_eq!(parse_memory_size("10K").unwrap(), 10 * 1024);
98    }
99
100    #[test]
101    fn test_parse_memory_size_megabytes() {
102        assert_eq!(parse_memory_size("1M").unwrap(), 1024 * 1024);
103        assert_eq!(parse_memory_size("100M").unwrap(), 100 * 1024 * 1024);
104    }
105
106    #[test]
107    fn test_parse_memory_size_gigabytes() {
108        assert_eq!(parse_memory_size("1G").unwrap(), 1024 * 1024 * 1024);
109        assert_eq!(parse_memory_size("2G").unwrap(), 2 * 1024 * 1024 * 1024);
110    }
111
112    #[test]
113    fn test_parse_memory_size_case_insensitive() {
114        assert_eq!(parse_memory_size("1m").unwrap(), 1024 * 1024);
115        assert_eq!(parse_memory_size("1g").unwrap(), 1024 * 1024 * 1024);
116    }
117
118    #[test]
119    fn test_parse_memory_size_whitespace() {
120        assert_eq!(parse_memory_size("  100M  ").unwrap(), 100 * 1024 * 1024);
121    }
122
123    #[test]
124    fn test_parse_memory_size_invalid() {
125        assert!(parse_memory_size("not_a_number").is_err());
126        assert!(parse_memory_size("10X").is_err());
127    }
128
129    #[test]
130    fn test_get_uid_gid() {
131        let uid = get_uid();
132        let gid = get_gid();
133        assert!(uid < u32::MAX);
134        assert!(gid < u32::MAX);
135    }
136
137    #[test]
138    fn test_is_root() {
139        let is_root = is_root();
140        assert_eq!(is_root, get_uid() == 0);
141    }
142
143    #[test]
144    fn test_root_override() {
145        set_root_override(Some(true));
146        assert!(is_root());
147        set_root_override(Some(false));
148        assert!(!is_root());
149        set_root_override(None);
150    }
151
152    #[test]
153    fn test_has_cgroup_v2() {
154        let result = has_cgroup_v2();
155        let _valid = match result {
156            true | false => true,
157        };
158    }
159
160    #[test]
161    fn test_cgroup_exists() {
162        use std::path::Path;
163        assert!(cgroup_exists(Path::new("/")));
164        assert!(!cgroup_exists(Path::new(
165            "/nonexistent/path/that/should/not/exist"
166        )));
167    }
168}