Skip to main content

aranya_capi_core/
cstr.rs

1use core::{
2    cmp,
3    ffi::c_char,
4    fmt::{self, Write},
5    mem::MaybeUninit,
6    ptr,
7};
8
9use buggy::{Bug, BugExt as _};
10
11/// The error returned by [`write_c_str`].
12#[derive(Clone, Debug, Eq, PartialEq, thiserror::Error)]
13pub enum WriteCStrError {
14    /// An internal bug was discovered.
15    #[error(transparent)]
16    Bug(#[from] Bug),
17    /// The provided buffer is too small.
18    #[error("buffer is too small")]
19    BufferTooSmall,
20}
21
22/// Writes `src` as a null-terminated C string to `dst`.
23///
24/// If `dst` is long enough to fit the entirety of `src`, it
25/// updates `n` with the number of bytes written, less the null
26/// terminator and returns `Ok(())`.
27///
28/// Otherwise, if `dst` is not long enough to contain the
29/// entirety of `src`, it updates `n` to the number of bytes
30/// needed to fit the entirety of `src` and returns
31/// [`Err(WriteCStrError::BufferTooSmall)`][WriteCStrError::BufferTooSmall].
32pub fn write_c_str<T: fmt::Display>(
33    dst: &mut [MaybeUninit<c_char>],
34    src: &T,
35    nw: &mut usize,
36) -> Result<(), WriteCStrError> {
37    let mut w = CStrWriter::new(dst, nw);
38    write!(&mut w, "{src:}").assume("`write!` to `Writer` should not fail")?;
39    w.finish().map_err(|()| WriteCStrError::BufferTooSmall)
40}
41
42/// Implements [`Write`] for a fixed-size C string buffer.
43struct CStrWriter<'a> {
44    dst: &'a mut [MaybeUninit<c_char>],
45    // Number of bytes written.
46    nw: &'a mut usize,
47}
48
49impl<'a> CStrWriter<'a> {
50    fn new(dst: &'a mut [MaybeUninit<c_char>], nw: &'a mut usize) -> Self {
51        *nw = 0;
52        Self { dst, nw }
53    }
54
55    fn write(&mut self, s: &str) {
56        // TODO(eric): what if `s` contains a null byte?
57        let src = s.as_bytes();
58        if src.is_empty() {
59            return;
60        }
61
62        let end = self.nw.saturating_add(src.len());
63        let Some(dst) = self
64            .dst
65            .split_last_mut() // chop off the null terminator.
66            .and_then(|(_, dst)| dst.get_mut(*self.nw..end))
67        else {
68            // `dst` isn't large enough, so just record the
69            // updated number of bytes.
70            *self.nw = end;
71            return;
72        };
73
74        // SAFETY: `u8` and `MaybeUninit<u8>` have the same
75        // size in memory.
76        let src = unsafe { &*(ptr::from_ref::<[u8]>(src) as *const [MaybeUninit<c_char>]) };
77        dst.copy_from_slice(src);
78        *self.nw = end;
79    }
80
81    /// Returns `Ok(())` if `dst` is large enough, or `Err(())`
82    /// otherwise.
83    fn finish(self) -> Result<(), ()> {
84        // Write the null terminator after the bytes.
85        let idx = cmp::min(*self.nw, self.dst.len());
86        if let Some(v) = self.dst.get_mut(idx) {
87            v.write(0);
88        }
89        *self.nw = self.nw.saturating_add(1);
90
91        if *self.nw <= self.dst.len() {
92            Ok(())
93        } else {
94            Err(())
95        }
96    }
97}
98
99impl Write for CStrWriter<'_> {
100    fn write_str(&mut self, s: &str) -> fmt::Result {
101        self.write(s);
102        Ok(())
103    }
104}
105
106#[cfg(test)]
107mod tests {
108    use core::ptr;
109
110    use super::*;
111
112    #[test]
113    fn test_write_c_str() {
114        let tests = ["", "hello, world"];
115        for (i, input) in tests.into_iter().enumerate() {
116            let want = input.to_owned() + "\0";
117            let mut dst = vec![0u8; want.len()];
118            let mut n = 0xdeadbeef;
119
120            // Check the empty buffer.
121            let got = write_c_str(&mut [], &input, &mut n);
122            assert_eq!(got, Err(WriteCStrError::BufferTooSmall), "#{i}");
123            assert_eq!(n, want.len(), "#{i}: did not return large enough size");
124
125            println!("=== after empty");
126
127            // Check a short buffer.
128            n = 0xdeadbeef;
129            let got = write_c_str(
130                // SAFETY: `u8` and `MaybeUninit<c_char>` have
131                // the same memory layout.
132                unsafe { &mut *(&raw mut dst[..want.len() - 1] as *mut [MaybeUninit<c_char>]) },
133                &input,
134                &mut n,
135            );
136            assert_eq!(got, Err(WriteCStrError::BufferTooSmall), "#{i}");
137            assert_eq!(n, want.len(), "#{i}: output sizes differ");
138
139            println!("=== after too small");
140
141            // Check the correct length.
142            n = 0xdeadbeef;
143            // Make `dst` a little longer to make sure we place
144            // the null terminator correctly.
145            dst.extend([1, 2, 3, 4, 5]);
146            let got = write_c_str(
147                // SAFETY: `u8` and `MaybeUninit<c_char>` have
148                // the same memory layout.
149                unsafe {
150                    &mut *(ptr::from_mut::<[u8]>(dst.as_mut_slice())
151                        as *mut [MaybeUninit<c_char>])
152                },
153                &input,
154                &mut n,
155            );
156            assert_eq!(got, Ok(()), "#{i}");
157            assert_eq!(n, want.len(), "#{i}: output sizes differ");
158            assert_eq!(&dst[..n], want.as_bytes(), "#{i}");
159
160            println!("=== after too good");
161        }
162    }
163}