cstr_argument/
lib.rs

1//! A trait for converting function arguments to null terminated strings.
2#![deny(missing_docs)]
3// #![cfg_attr(any(nightly, feature = "nightly"), feature(specialization))]
4// #[macro_use]
5// extern crate cfg_if;
6extern crate memchr;
7
8use std::borrow::Cow;
9use std::error;
10use std::ffi::{CStr, CString};
11use std::fmt;
12use std::result;
13
14use memchr::memchr;
15
16/// An error returned from [`CStrArgument::try_into_cstr`] to indicate that a null byte
17/// was found before the last byte in the string.
18///
19/// [`CStrArgument::try_into_cstr`]: trait.CStrArgument.html#tymethod.try_into_cstr
20#[derive(Debug, Copy, Clone, Eq, PartialEq)]
21pub struct NulError<T> {
22    inner: T,
23    pos: usize,
24}
25
26impl<T> NulError<T> {
27    /// Returns the position of the null byte in the string.
28    #[inline]
29    pub fn nul_position(&self) -> usize {
30        self.pos
31    }
32
33    /// Returns the original string.
34    #[inline]
35    pub fn into_inner(self) -> T {
36        self.inner
37    }
38}
39
40impl<T> fmt::Display for NulError<T> {
41    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
42        write!(
43            f,
44            "nul byte found before end of provided data at position: {}",
45            self.pos
46        )
47    }
48}
49
50impl<T: fmt::Debug> error::Error for NulError<T> {
51    fn description(&self) -> &str {
52        "nul byte found before end of data"
53    }
54}
55
56type Result<T, S> = result::Result<T, NulError<S>>;
57
58/// A trait for converting function arguments to null terminated strings. It can be used to convert
59/// string arguments that are passed on to C APIs using the minimal amount of allocations.
60///
61/// Strings that are already null terminated are just wrapped in a CStr without any allocations.
62/// Strings that are not already null terminated are converted to a CString possibly requiring one
63/// or more allocations. Trying to convert strings with a null byte in any position other than the
64/// final will result in an error.
65///
66/// # Example
67///
68/// ```no_run
69/// use std::os::raw::c_char;
70/// use cstr_argument::CStrArgument;
71///
72/// extern "C" {
73///     fn foo(s: *const c_char);
74/// }
75///
76/// fn bar<S: CStrArgument>(s: S) {
77///     let s = s.into_cstr();
78///     unsafe {
79///         foo(s.as_ref().as_ptr())
80///     }
81/// }
82///
83/// fn baz() {
84///     bar("hello "); // Argument will be converted to a CString requiring an allocation
85///     bar("world\0"); // Argument will be converted to a CStr without any allocations
86///     bar("!".to_owned()); // Argument will be converted to a CString possibly requiring an
87///                          // allocation
88/// }
89/// ```
90pub trait CStrArgument: fmt::Debug + Sized {
91    /// The type of the string after conversion. The type may or may not own the resulting string.
92    type Output: AsRef<CStr>;
93
94    /// Returns the string with a null terminator or an error.
95    ///
96    /// # Errors
97    ///
98    /// This function will return an error if the string contains a null byte at any position
99    /// other than the final.
100    fn try_into_cstr(self) -> Result<Self::Output, Self>;
101
102    /// Returns the string with a null terminator.
103    ///
104    /// # Panics
105    ///
106    /// This function will panic if the string contains a null byte at any position other
107    /// than the final. See [`try_into_cstr`](#tymethod.try_into_cstr) for a non-panicking version
108    /// of this function.
109    fn into_cstr(self) -> Self::Output {
110        self.try_into_cstr()
111            .expect("string contained an interior null byte")
112    }
113}
114
115// BUG in rustc (#23341)
116// cfg_if! {
117//     if #[cfg(any(nightly, feature = "nightly"))] {
118//         impl<T> CStrArgument for T where Self: AsRef<CStr> {
119//             default type Output = Self;
120//
121//             default fn try_into_cstr(self) -> Result<Self, Self> {
122//                 Ok(self)
123//             }
124//         }
125//
126//         impl<'a, T> CStrArgument for &'a T where Self: AsRef<str> {
127//             default type Output = Cow<'a, CStr>;
128//
129//             default fn try_into_cstr(self) -> Result<Self::Output, Self> {
130//                 self.as_ref().try_into_cstr()
131//             }
132//         }
133//     } else {
134impl<'a> CStrArgument for CString {
135    type Output = Self;
136
137    #[inline]
138    fn try_into_cstr(self) -> Result<Self, Self> {
139        Ok(self)
140    }
141}
142
143impl<'a> CStrArgument for &'a CString {
144    type Output = &'a CStr;
145
146    #[inline]
147    fn try_into_cstr(self) -> Result<Self::Output, Self> {
148        Ok(self)
149    }
150}
151
152impl<'a> CStrArgument for &'a CStr {
153    type Output = Self;
154
155    #[inline]
156    fn try_into_cstr(self) -> Result<Self, Self> {
157        Ok(self)
158    }
159}
160// }
161// }
162
163impl CStrArgument for String {
164    type Output = CString;
165
166    #[inline]
167    fn try_into_cstr(self) -> Result<Self::Output, Self> {
168        self.into_bytes().try_into_cstr().map_err(|e| NulError {
169            inner: unsafe { String::from_utf8_unchecked(e.inner) },
170            pos: e.pos,
171        })
172    }
173}
174
175impl<'a> CStrArgument for &'a String {
176    type Output = Cow<'a, CStr>;
177
178    #[inline]
179    fn try_into_cstr(self) -> Result<Self::Output, Self> {
180        self.as_bytes().try_into_cstr().map_err(|e| NulError {
181            inner: self,
182            pos: e.pos,
183        })
184    }
185}
186
187impl<'a> CStrArgument for &'a str {
188    type Output = Cow<'a, CStr>;
189
190    #[inline]
191    fn try_into_cstr(self) -> Result<Self::Output, Self> {
192        self.as_bytes().try_into_cstr().map_err(|e| NulError {
193            inner: self,
194            pos: e.pos,
195        })
196    }
197}
198
199impl<'a> CStrArgument for Vec<u8> {
200    type Output = CString;
201
202    fn try_into_cstr(mut self) -> Result<Self::Output, Self> {
203        match memchr(0, &self) {
204            Some(n) if n == (self.len() - 1) => {
205                self.pop();
206                Ok(unsafe { CString::from_vec_unchecked(self) })
207            }
208            Some(n) => Err(NulError {
209                inner: self,
210                pos: n,
211            }),
212            None => Ok(unsafe { CString::from_vec_unchecked(self) }),
213        }
214    }
215}
216
217impl<'a> CStrArgument for &'a Vec<u8> {
218    type Output = Cow<'a, CStr>;
219
220    #[inline]
221    fn try_into_cstr(self) -> Result<Self::Output, Self> {
222        self.as_slice().try_into_cstr().map_err(|e| NulError {
223            inner: self,
224            pos: e.pos,
225        })
226    }
227}
228
229impl<'a> CStrArgument for &'a [u8] {
230    type Output = Cow<'a, CStr>;
231
232    fn try_into_cstr(self) -> Result<Self::Output, Self> {
233        match memchr(0, self) {
234            Some(n) if n == (self.len() - 1) => Ok(Cow::Borrowed(unsafe {
235                CStr::from_bytes_with_nul_unchecked(self)
236            })),
237            Some(n) => Err(NulError {
238                inner: self,
239                pos: n,
240            }),
241            None => Ok(Cow::Owned(unsafe {
242                CString::from_vec_unchecked(self.into())
243            })),
244        }
245    }
246}
247
248#[cfg(test)]
249mod tests {
250    use super::{CStrArgument, NulError};
251
252    fn test<T, F, R>(t: T, f: F) -> R
253    where
254        T: CStrArgument,
255        F: FnOnce(Result<T::Output, NulError<T>>) -> R, {
256        f(t.try_into_cstr())
257    }
258
259    #[test]
260    fn test_basic() {
261        let case = "";
262        test(case, |s| {
263            let s = s.unwrap();
264            assert_eq!(s.to_bytes_with_nul().len(), case.len() + 1);
265            assert_ne!(s.as_ptr() as *const u8, case.as_ptr());
266        });
267
268        test(case.to_owned(), |s| {
269            let s = s.unwrap();
270            assert_eq!(s.to_bytes_with_nul().len(), case.len() + 1);
271        });
272
273        test(case.as_bytes(), |s| {
274            let s = s.unwrap();
275            assert_eq!(s.to_bytes_with_nul().len(), case.len() + 1);
276            assert_ne!(s.as_ptr() as *const u8, case.as_ptr());
277        });
278
279        let case = "hello";
280        test(case, |s| {
281            let s = s.unwrap();
282            assert_eq!(s.to_bytes_with_nul().len(), case.len() + 1);
283            assert_ne!(s.as_ptr() as *const u8, case.as_ptr());
284        });
285
286        test(case.to_owned(), |s| {
287            let s = s.unwrap();
288            assert_eq!(s.to_bytes_with_nul().len(), case.len() + 1);
289        });
290
291        test(case.as_bytes(), |s| {
292            let s = s.unwrap();
293            assert_eq!(s.to_bytes_with_nul().len(), case.len() + 1);
294            assert_ne!(s.as_ptr() as *const u8, case.as_ptr());
295        });
296    }
297
298    #[test]
299    fn test_terminating_null() {
300        let case = "\0";
301        test(case, |s| {
302            let s = s.unwrap();
303            assert_eq!(s.to_bytes_with_nul().len(), case.len());
304            assert_eq!(s.as_ptr() as *const u8, case.as_ptr());
305        });
306
307        test(case.to_owned(), |s| {
308            let s = s.unwrap();
309            assert_eq!(s.to_bytes_with_nul().len(), case.len());
310        });
311
312        test(case.as_bytes(), |s| {
313            let s = s.unwrap();
314            assert_eq!(s.to_bytes_with_nul().len(), case.len());
315            assert_eq!(s.as_ptr() as *const u8, case.as_ptr());
316        });
317
318        let case = "hello\0";
319        test(case, |s| {
320            let s = s.unwrap();
321            assert_eq!(s.to_bytes_with_nul().len(), case.len());
322            assert_eq!(s.as_ptr() as *const u8, case.as_ptr());
323        });
324
325        test(case.to_owned(), |s| {
326            let s = s.unwrap();
327            assert_eq!(s.to_bytes_with_nul().len(), case.len());
328        });
329
330        test(case.as_bytes(), |s| {
331            let s = s.unwrap();
332            assert_eq!(s.to_bytes_with_nul().len(), case.len());
333            assert_eq!(s.as_ptr() as *const u8, case.as_ptr());
334        });
335    }
336
337    #[test]
338    fn test_interior_null() {
339        let case = "hello\0world";
340        test(case, |s| s.unwrap_err());
341        test(case.to_owned(), |s| s.unwrap_err());
342        test(case.as_bytes(), |s| s.unwrap_err());
343    }
344
345    #[test]
346    fn test_interior_and_terminating_null() {
347        let case = "\0\0";
348        test(case, |s| s.unwrap_err());
349        test(case.to_owned(), |s| s.unwrap_err());
350        test(case.as_bytes(), |s| s.unwrap_err());
351
352        let case = "hello\0world\0";
353        test(case, |s| s.unwrap_err());
354        test(case.to_owned(), |s| s.unwrap_err());
355        test(case.as_bytes(), |s| s.unwrap_err());
356
357        let case = "hello world\0\0";
358        test(case, |s| s.unwrap_err());
359        test(case.to_owned(), |s| s.unwrap_err());
360        test(case.as_bytes(), |s| s.unwrap_err());
361    }
362
363    #[test]
364    #[should_panic]
365    fn test_interior_null_panic() {
366        "\0\0".into_cstr();
367    }
368}