fd_passing/
functions.rs

1use std::mem::{size_of, size_of_val, zeroed};
2use std::os::unix::io::AsRawFd;
3
4use std::os::unix::net::UnixStream;
5
6use libc::{self, c_void, c_int, cmsghdr, iovec, msghdr, SOL_SOCKET};
7
8#[cfg(not(target_os = "macos"))]
9use libc::__errno_location;
10#[cfg(target_os = "macos")]
11use libc::__error as __errno_location;
12
13use FdImplementor;
14use utils::{compute_bufspace, compute_msglen};
15#[cfg(any(debug, test))]
16use utils::dump_msg;
17
18extern "C" {
19    #[doc(hidden)]
20    fn get_SCM_RIGHTS() -> c_int;
21}
22
23macro_rules! auto_cast {
24    ($right:expr) => {{
25        #[cfg(not(target_os = "macos"))]
26        {
27            $right
28        }
29        #[cfg(target_os = "macos")]
30        {
31            $right as libc::c_uint
32        }
33    }};
34    ($right:expr, $cast:ty) => {{
35        #[cfg(not(target_os = "macos"))]
36        {
37            $right
38        }
39        #[cfg(target_os = "macos")]
40        {
41            $right as $cast
42        }
43    }}
44}
45
46// TODO: impl this as a method to FdImplementor or UnixStream.
47pub fn send(channel: &UnixStream, wrapped: FdImplementor) -> Result<(), String> {
48    let (rawfd, mut fdtype) = wrapped.to();
49
50    let mut controlbuf = vec![0u8; compute_bufspace(size_of::<c_int>())];
51    let mut iov : iovec = unsafe { zeroed() };
52    let mut message : msghdr = unsafe { zeroed() };
53
54    iov.iov_base = &mut fdtype as *mut i32 as *mut c_void;
55    iov.iov_len = size_of_val(&fdtype);
56
57    message.msg_control = controlbuf.as_mut_ptr() as *mut c_void;
58    message.msg_controllen = auto_cast!(controlbuf.len());
59    message.msg_iov = &mut iov;
60    message.msg_iovlen = 1;
61
62    unsafe {
63        let controlp : *mut cmsghdr = message.msg_control as *mut cmsghdr;
64        (*controlp).cmsg_level = SOL_SOCKET;
65        (*controlp).cmsg_type = get_SCM_RIGHTS();
66        (*controlp).cmsg_len = auto_cast!(compute_msglen(size_of::<c_int>()));
67
68        let datap : *mut c_int = controlp.offset(1) as *mut c_int; 
69        *datap = rawfd;
70
71        #[cfg(any(debug, test))]
72        dump_msg(&message);
73
74        match libc::sendmsg(channel.as_raw_fd(), &mut message, 0) {
75            x if x == size_of::<c_int>() as isize => Ok(()),
76            -1 => {
77                let s = libc::strerror(*__errno_location());
78                let slen = libc::strlen(s);
79                let serr = String::from_raw_parts(s as *mut u8, slen, slen);
80                let rerr = serr.clone();
81                ::std::mem::forget(serr);
82                Err(rerr)
83            },
84            _  => Err("Incomplete message sent".to_owned()),
85        }
86    }
87}
88
89// TODO: impl this as a method to FdImplementor or UnixStream.
90pub fn receive(channel: &UnixStream) -> Result<FdImplementor, String> {
91    let mut fdtype : c_int = -1;
92    let mut controlbuf = vec![0u8; compute_bufspace(size_of::<c_int>())];
93    let mut iov : iovec = unsafe { zeroed() };
94    let mut message : msghdr = unsafe { zeroed() };
95
96    iov.iov_base = &mut fdtype as *mut i32 as *mut c_void;
97    iov.iov_len = size_of::<c_int>();
98
99    message.msg_control = controlbuf.as_mut_ptr() as *mut c_void;
100    message.msg_controllen = auto_cast!(size_of_val(&controlbuf));
101    message.msg_iov = &mut iov;
102    message.msg_iovlen = 1;
103
104    unsafe {
105        let read = libc::recvmsg(channel.as_raw_fd(), &mut message, 0);
106        match read {
107            x if x == size_of::<c_int>() as isize => {
108                let controlp : *mut cmsghdr =
109                    if message.msg_controllen >= auto_cast!(size_of::<cmsghdr>()) {
110                        message.msg_control as *mut cmsghdr
111                    } else {
112                        ::std::ptr::null_mut()
113                    };
114                // The cmsghdr struct is made so that multiple ones can be chained.
115                // Here we ensure that we only received one, otherwise we fail
116                // explicitly to ensure consistency with the provided server
117                // API (which only sends one int packed with only one cmsghdr struct).
118                if (*controlp).cmsg_level != libc::SOL_SOCKET
119                   || (*controlp).cmsg_type != get_SCM_RIGHTS() {
120                    return Err("Message was not the expected command: format mismatch".to_owned());
121                }
122                if message.msg_controllen > auto_cast!(compute_bufspace(size_of::<c_int>())) {
123                    return Err("Message read was longer than expected: format mismatch".to_owned());
124                }
125                if message.msg_controllen < auto_cast!(compute_bufspace(size_of::<c_int>())) {
126                    return Err("Message read was shorter than expected: format mismatch".to_owned());
127                }
128                let rawfd = *((message.msg_control as *mut cmsghdr).offset(1) as *mut c_int);
129                FdImplementor::from(fdtype, rawfd).ok_or("Unexpected file descriptor type".to_owned())
130            },
131            -1 => {
132                let s = libc::strerror(*__errno_location());
133                let slen = libc::strlen(s);
134                let serr = String::from_raw_parts(s as *mut u8, slen, slen);
135                let rerr = serr.clone();
136                ::std::mem::forget(serr);
137                Err(rerr)
138            },
139            _ => Err("Message data was not of the expected size".to_owned()),
140        }
141    }
142}
143
144#[cfg(test)]
145mod tests {
146    extern crate tempdir;
147
148    use std::io::{Write, Read};
149    use std::os::unix::net::{UnixStream, UnixListener};
150    use std::fs::{self, remove_file};
151    use std::{thread, time};
152    use ::{FdImplementor, receive, send};
153    use std::path::Path;
154    use std::fmt::Debug;
155    use self::tempdir::TempDir;
156
157    #[test]
158    fn run() {
159        let tmp_dir = TempDir::new("tmp").expect("create temp dir");
160        let sockpath = tmp_dir.path().join("rust-fd-passing.test.sock");
161        let fpath = tmp_dir.path().join("rust-fd-passing.test.txt");
162        let lineone = String::from("Imperio is such an ass!\n");
163        let linetwo = String::from("So true...\n");
164        let mut text = String::new();
165        text.push_str(&lineone);
166        text.push_str(&linetwo);
167        cleanup();
168        unsafe {
169            let pid = ::libc::fork();
170            match pid {
171                -1 => assert!(false, "fork failed"),
172                0 => {
173                    // If we don't "forget" tmp_dir in here, the folder is removed before the
174                    // parent has ended.
175                    ::std::mem::forget(tmp_dir);
176                    run_child(&sockpath, &linetwo);
177                }
178                _ => {
179                    run_parent(&sockpath, &fpath, &lineone, &text);
180                    // Seems everything ran as expected, we can remove the logs.
181                    cleanup();
182                }
183            };
184        }
185    }
186
187    fn run_child<S: AsRef<Path>>(sockpath: &S, linetwo: &str) {
188        let res = run_fd_receiver(&sockpath, linetwo);
189        if !res.is_ok() {
190            panic!(res.err().unwrap());
191        }
192    }
193
194    fn run_parent<S: AsRef<Path> + Debug>(sockpath: &S, fpath: &S, lineone: &str, text: &str) {
195        thread::sleep(time::Duration::new(1, 0));
196        let res = run_fd_sender(sockpath, fpath, lineone);
197        if !res.is_ok() {
198            panic!(res.err().unwrap());
199        }
200        thread::sleep(time::Duration::new(1, 0));
201        let mut f = fs::File::open(fpath).expect(&format!("cannot open {:?}", fpath));
202        let mut readstr = String::new();
203        let bytes = f.read_to_string(&mut readstr).unwrap();
204        assert!(bytes == text.len(), "Resulting data was not of the expected size.");
205        assert!(readstr == text, "Resulting data differs from expectations.");
206    }
207
208    #[allow(unused_must_use)]
209    fn cleanup() {
210        remove_file("/tmp/rust-fd-passing-child-log.txt");
211    }
212
213    fn printfile(text: &str) {
214        let fpath = String::from("/tmp/rust-fd-passing-child-log.txt");
215        let mut f = fs::OpenOptions::new().append(true).create(true).open(&fpath)
216                                                                    .expect("printfile failed");
217        let written = f.write_all(text.as_bytes());
218        assert!(written.is_ok());
219    }
220
221    fn run_fd_receiver<S: AsRef<Path>>(sockpath: &S, text: &str) -> Result<bool, String> {
222        let listener = UnixListener::bind(sockpath).unwrap();
223        printfile("Started server\n");
224        // accept one connection and process it, receiving the fd and reading it
225        let stream = listener.incoming().next().unwrap();
226        match stream {
227            Ok(stream) => {
228                printfile("Accepted client\n");
229                /* connection succeeded */
230                match receive(&stream) {
231                    Ok(FdImplementor::File(mut res)) => {
232                        printfile("Writing into file\n");
233                        res.write_all(text.as_bytes())
234                           .map_err(|_| "Could not write second data line.")?;
235                        Ok(true)
236                    },
237                    Err(e) => Err(e),
238                    _ => Err("Did not get the expected FdImplementor type.".to_owned()),
239                }
240            }
241            Err(e) => Err(format!("IO Error: {}", e))
242        }
243    }
244
245    fn run_fd_sender<S: AsRef<Path> + Debug>(sockpath: &S, fpath: &S, text: &str) -> Result<bool, String> {
246        let mut f = fs::File::create(fpath)
247                             .expect(&format!("Could not create data file {:?}", fpath));
248        f.write_all(text.as_bytes()).expect("Could not write first data line.");
249        let stream = UnixStream::connect(sockpath)
250                                .expect(&format!("cannot connect to unix socket {:?}", sockpath));
251        send(&stream, FdImplementor::File(f))?;
252        Ok(true)
253    }
254}