daemonize_me/
ffi.rs

1#![deny(warnings)]
2#![allow(unsafe_code)]
3extern crate libc;
4
5use std::ffi::{CStr, CString, OsStr, OsString};
6use std::os::unix::ffi::OsStrExt;
7
8#[cfg(target_os = "linux")]
9use {
10    crate::DaemonError::{GetPasswdRecord, SetProcName},
11    libc::{PR_SET_NAME, prctl},
12};
13
14use crate::{DaemonError, Result};
15#[cfg(not(target_os = "linux"))]
16use crate::DaemonError::{GetPasswdRecord, SetProcName, UnsupportedOnOS};
17use crate::DaemonError::InvalidProcName;
18
19#[repr(C)]
20#[allow(dead_code)]
21struct FFIGroup {
22    gr_name: *const libc::c_char,
23    gr_passwd: *const libc::c_char,
24    gr_gid: libc::gid_t,
25    gr_mem: *const *const libc::c_char,
26}
27
28#[cfg(target_os = "linux")]
29#[repr(C)]
30#[allow(dead_code)]
31struct FFIPasswd {
32    pw_name: *const libc::c_char,
33    pw_passwd: *const libc::c_char,
34    pw_uid: libc::uid_t,
35    pw_gid: libc::gid_t,
36    pw_gecos: *const libc::c_char,
37    pw_dir: *const libc::c_char,
38    pw_shell: *const libc::c_char,
39}
40
41// Used on MacOS and FreeBSD
42#[cfg(any(target_os = "macos", target_os = "freebsd"))]
43#[repr(C)]
44#[allow(dead_code)]
45struct FFIPasswd {
46    pw_name: *const libc::c_char,
47    pw_passwd: *const libc::c_char,
48    pw_uid: libc::uid_t,
49    pw_gid: libc::gid_t,
50    pw_change: libc::time_t,
51    pw_class: *const libc::c_char,
52    pw_gecos: *const libc::c_char,
53    pw_dir: *const libc::c_char,
54    pw_shell: *const libc::c_char,
55    pw_expire: libc::time_t,
56    pw_fields: libc::c_int,
57}
58
59// Used on the other two supported BSDs
60#[cfg(any(target_os = "openbsd", target_os = "netbsd"))]
61#[repr(C)]
62#[allow(dead_code)]
63struct FFIPasswd {
64    pw_name: *const libc::c_char,
65    pw_passwd: *const libc::c_char,
66    pw_uid: libc::uid_t,
67    pw_gid: libc::gid_t,
68    pw_change: libc::time_t,
69    pw_class: *const libc::c_char,
70    pw_gecos: *const libc::c_char,
71    pw_dir: *const libc::c_char,
72    pw_shell: *const libc::c_char,
73    pw_expire: libc::time_t,
74}
75
76#[allow(dead_code)]
77extern "C" {
78    fn getgrnam(name: *const libc::c_char) -> *const FFIGroup;
79    fn getgrgid(name: libc::gid_t) -> *const FFIGroup;
80    fn getpwnam(name: *const libc::c_char) -> *const FFIPasswd;
81    fn getpwuid(name: libc::uid_t) -> *const FFIPasswd;
82}
83
84#[derive(Debug)]
85#[allow(dead_code)]
86pub struct GroupRecord {
87    pub gr_name: String,
88    pub gr_passwd: String,
89    pub gr_gid: u32,
90}
91
92#[derive(Debug)]
93#[allow(dead_code)]
94pub struct PasswdRecord {
95    pub pw_name: String,
96    pub pw_passwd: String,
97    pub pw_uid: u32,
98    pub pw_gid: u32,
99    pub pw_gecos: String,
100    pub pw_dir: String,
101    pub pw_shell: String,
102}
103
104unsafe fn check_group_record(grp: *const FFIGroup) -> Result<GroupRecord> {
105    return if grp.is_null() {
106        Err(DaemonError::GetGrRecord)
107    } else {
108        let gr = &*grp;
109        let sgr = GroupRecord {
110            gr_name: CStr::from_ptr(gr.gr_name).to_string_lossy().to_string(),
111            gr_passwd: CStr::from_ptr(gr.gr_passwd).to_string_lossy().to_string(),
112            gr_gid: gr.gr_gid as u32,
113        };
114        Ok(sgr)
115    };
116}
117
118unsafe fn check_passwd_record(passwd: *const FFIPasswd) -> Result<PasswdRecord> {
119    return if passwd.is_null() {
120        Err(GetPasswdRecord)
121    } else {
122        let pw = &*passwd;
123        let pwr = PasswdRecord {
124            pw_name: CStr::from_ptr(pw.pw_name).to_string_lossy().to_string(),
125            pw_passwd: CStr::from_ptr(pw.pw_passwd).to_string_lossy().to_string(),
126            pw_uid: pw.pw_uid as u32,
127            pw_gid: pw.pw_gid as u32,
128            pw_gecos: CStr::from_ptr(pw.pw_gecos).to_string_lossy().to_string(),
129            pw_dir: CStr::from_ptr(pw.pw_dir).to_string_lossy().to_string(),
130            pw_shell: CStr::from_ptr(pw.pw_shell).to_string_lossy().to_string(),
131        };
132        Ok(pwr)
133    };
134}
135
136#[allow(dead_code)]
137impl GroupRecord {
138    pub fn lookup_record_by_name(name: &str) -> Result<GroupRecord> {
139        let record_name = match CString::new(name) {
140            Ok(s) => s,
141            Err(_) => return Err(DaemonError::InvalidCstr),
142        };
143
144        unsafe {
145            let raw_grp = getgrnam(record_name.as_ptr());
146            return check_group_record(raw_grp);
147        };
148    }
149
150    pub fn lookup_record_by_id(gid: u32) -> Result<GroupRecord> {
151        let record_id = gid as libc::uid_t;
152
153        unsafe {
154            let raw_grp = getgrgid(record_id);
155            return check_group_record(raw_grp);
156        };
157    }
158}
159
160impl PasswdRecord {
161    pub fn lookup_record_by_name(name: &str) -> Result<PasswdRecord> {
162        let record_name = match CString::new(name) {
163            Ok(s) => s,
164            Err(_) => return Err(DaemonError::InvalidCstr),
165        };
166
167        unsafe {
168            let raw_passwd = getpwnam(record_name.as_ptr());
169            return check_passwd_record(raw_passwd);
170        };
171    }
172
173    pub fn lookup_record_by_id(uid: u32) -> Result<PasswdRecord> {
174        let record_id = uid as libc::uid_t;
175
176        unsafe {
177            let raw_passwd = getpwuid(record_id);
178            return check_passwd_record(raw_passwd);
179        };
180    }
181}
182
183#[cfg(target_os = "linux")]
184/// Safe wrapper to the prctl(2) call
185pub fn set_proc_name(name: &OsStr) -> Result<()> {
186    let name_truncated = match CString::new(OsString::from(name).as_bytes()) {
187        Ok(procname) => procname,
188        Err(_) => return Err(InvalidProcName),
189    };
190    unsafe {
191        if prctl(PR_SET_NAME, name_truncated.as_bytes_with_nul()) < 0 {
192            Err(SetProcName)
193        } else {
194            Ok(())
195        }
196    }
197}
198
199// TODO: Implement this for non linux targets
200#[cfg(not(target_os = "linux"))]
201pub fn set_proc_name(name: &OsStr) -> Result<()> {
202    Err(UnsupportedOnOS)
203}
204
205#[cfg(test)]
206mod tests {
207    // TODO: Improve testing because of unsafe code
208    use super::*;
209
210    #[test]
211    /// Asserts if the uid returned for the uname "root" is 0
212    fn test_passwd_by_name() {
213        let root = PasswdRecord::lookup_record_by_name("root").unwrap();
214        assert_eq!(root.pw_uid, 0)
215    }
216
217    #[test]
218    /// Asserts if the uname returned by the uid 0 is "root"
219    fn test_passwd_by_uid() {
220        let root = PasswdRecord::lookup_record_by_id(0).unwrap();
221        assert_eq!(root.pw_name, "root")
222    }
223
224    #[test]
225    /// Asserts if the uid returned for the uname "root" is 0
226    fn test_gr_by_name() {
227        let root = GroupRecord::lookup_record_by_name("root").unwrap();
228        assert_eq!(root.gr_gid, 0)
229    }
230
231    #[test]
232    /// Asserts if the uname returned by the uid 0 is "root"
233    fn test_gr_by_gid() {
234        let root = GroupRecord::lookup_record_by_id(0).unwrap();
235        assert_eq!(root.gr_name, "root")
236    }
237}