jetstream_wireformat/
lib.rs

1#![doc(html_logo_url = "https://raw.githubusercontent.com/sevki/jetstream/main/logo/JetStream.png")]
2#![doc(
3    html_favicon_url = "https://raw.githubusercontent.com/sevki/jetstream/main/logo/JetStream.png"
4)]
5#![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))]
6// Copyright (c) 2024, Sevki <s@sevki.io>
7// Copyright 2018 The ChromiumOS Authors
8// Use of this source code is governed by a BSD-style license that can be
9// found in the LICENSE file.
10pub use jetstream_macros::JetStreamWireFormat;
11
12use {
13    bytes::Buf,
14    std::{
15        ffi::{CStr, CString, OsStr},
16        fmt,
17        io::{self, ErrorKind, Read, Write},
18        mem,
19        ops::{Deref, DerefMut},
20        string::String,
21        vec::Vec,
22    },
23    zerocopy::LittleEndian,
24};
25pub mod wire_format_extensions;
26
27/// A type that can be encoded on the wire using the 9P protocol.
28pub trait WireFormat: std::marker::Sized + Send {
29    /// Returns the number of bytes necessary to fully encode `self`.
30    fn byte_size(&self) -> u32;
31
32    /// Encodes `self` into `writer`.
33    fn encode<W: Write>(&self, writer: &mut W) -> io::Result<()>;
34
35    /// Decodes `Self` from `reader`.
36    fn decode<R: Read>(reader: &mut R) -> io::Result<Self>;
37}
38
39/// A 9P protocol string.
40///
41/// The string is always valid UTF-8 and 65535 bytes or less (enforced by `P9String::new()`).
42///
43/// It is represented as a C string with a terminating 0 (NUL) character to allow it to be passed
44/// directly to libc functions.
45#[derive(Clone, Debug, PartialEq, Eq, Hash)]
46pub struct P9String {
47    cstr: CString,
48}
49
50impl P9String {
51    pub fn new(string_bytes: impl Into<Vec<u8>>) -> io::Result<Self> {
52        let string_bytes: Vec<u8> = string_bytes.into();
53
54        if string_bytes.len() > u16::MAX as usize {
55            return Err(io::Error::new(
56                ErrorKind::InvalidInput,
57                "string is too long",
58            ));
59        }
60
61        // 9p strings must be valid UTF-8.
62        let _check_utf8 = std::str::from_utf8(&string_bytes)
63            .map_err(|e| io::Error::new(ErrorKind::InvalidInput, e))?;
64
65        let cstr =
66            CString::new(string_bytes).map_err(|e| io::Error::new(ErrorKind::InvalidInput, e))?;
67
68        Ok(P9String { cstr })
69    }
70
71    pub fn len(&self) -> usize {
72        self.cstr.as_bytes().len()
73    }
74
75    pub fn is_empty(&self) -> bool {
76        self.cstr.as_bytes().is_empty()
77    }
78
79    pub fn as_c_str(&self) -> &CStr {
80        self.cstr.as_c_str()
81    }
82
83    pub fn as_bytes(&self) -> &[u8] {
84        self.cstr.as_bytes()
85    }
86    #[cfg(not(target_arch = "wasm32"))]
87    /// Returns a raw pointer to the string's storage.
88    ///
89    /// The string bytes are always followed by a NUL terminator ('\0'), so the pointer can be
90    /// passed directly to libc functions that expect a C string.
91    pub fn as_ptr(&self) -> *const libc::c_char {
92        self.cstr.as_ptr()
93    }
94}
95
96impl PartialEq<&str> for P9String {
97    fn eq(&self, other: &&str) -> bool {
98        self.cstr.as_bytes() == other.as_bytes()
99    }
100}
101
102impl TryFrom<&OsStr> for P9String {
103    type Error = io::Error;
104
105    fn try_from(value: &OsStr) -> io::Result<Self> {
106        let string_bytes = value.as_encoded_bytes();
107        Self::new(string_bytes)
108    }
109}
110
111// The 9P protocol requires that strings are UTF-8 encoded.  The wire format is a u16
112// count |N|, encoded in little endian, followed by |N| bytes of UTF-8 data.
113impl WireFormat for P9String {
114    fn byte_size(&self) -> u32 {
115        (mem::size_of::<u16>() + self.len()) as u32
116    }
117
118    fn encode<W: Write>(&self, writer: &mut W) -> io::Result<()> {
119        (self.len() as u16).encode(writer)?;
120        writer.write_all(self.cstr.as_bytes())
121    }
122
123    fn decode<R: Read>(reader: &mut R) -> io::Result<Self> {
124        let len: u16 = WireFormat::decode(reader)?;
125        let mut string_bytes = vec![0u8; usize::from(len)];
126        reader.read_exact(&mut string_bytes)?;
127        Self::new(string_bytes)
128    }
129}
130
131// This doesn't really _need_ to be a macro but unfortunately there is no trait bound to
132// express "can be casted to another type", which means we can't write `T as u8` in a trait
133// based implementation.  So instead we have this macro, which is implemented for all the
134// stable unsigned types with the added benefit of not being implemented for the signed
135// types which are not allowed by the protocol.
136macro_rules! uint_wire_format_impl {
137    ($Ty:ty) => {
138        impl WireFormat for $Ty {
139            fn byte_size(&self) -> u32 {
140                mem::size_of::<$Ty>() as u32
141            }
142
143            fn encode<W: Write>(&self, writer: &mut W) -> io::Result<()> {
144                writer.write_all(&self.to_le_bytes())
145            }
146
147            fn decode<R: Read>(reader: &mut R) -> io::Result<Self> {
148                let mut buf = [0; mem::size_of::<$Ty>()];
149                reader.read_exact(&mut buf)?;
150                paste::expr! {
151                    let num: zerocopy::[<$Ty:snake:upper>]<LittleEndian> =  zerocopy::byteorder::[<$Ty:snake:upper>]::from_bytes(buf);
152                    Ok(num.get())
153                }
154            }
155        }
156    };
157}
158// unsigned integers
159uint_wire_format_impl!(u16);
160uint_wire_format_impl!(u32);
161uint_wire_format_impl!(u64);
162uint_wire_format_impl!(u128);
163// signed integers
164uint_wire_format_impl!(i16);
165uint_wire_format_impl!(i32);
166uint_wire_format_impl!(i64);
167uint_wire_format_impl!(i128);
168
169macro_rules! float_wire_format_impl {
170    ($Ty:ty) => {
171        impl WireFormat for $Ty {
172            fn byte_size(&self) -> u32 {
173                mem::size_of::<$Ty>() as u32
174            }
175
176            fn encode<W: Write>(&self, writer: &mut W) -> io::Result<()> {
177                paste::expr! {
178                    writer.write_all(&self.to_le_bytes())
179                }
180            }
181
182            fn decode<R: Read>(reader: &mut R) -> io::Result<Self> {
183                let mut buf = [0; mem::size_of::<$Ty>()];
184                reader.read_exact(&mut buf)?;
185                paste::expr! {
186                    let num: zerocopy::[<$Ty:snake:upper>]<LittleEndian> =  zerocopy::byteorder::[<$Ty:snake:upper>]::from_bytes(buf);
187                    Ok(num.get())
188                }
189            }
190        }
191    };
192}
193
194float_wire_format_impl!(f32);
195float_wire_format_impl!(f64);
196
197impl WireFormat for u8 {
198    fn byte_size(&self) -> u32 {
199        1
200    }
201
202    fn encode<W: Write>(&self, writer: &mut W) -> io::Result<()> {
203        writer.write_all(&[*self])
204    }
205
206    fn decode<R: Read>(reader: &mut R) -> io::Result<Self> {
207        let mut byte = [0u8; 1];
208        reader.read_exact(&mut byte)?;
209        Ok(byte[0])
210    }
211}
212
213impl WireFormat for usize {
214    fn byte_size(&self) -> u32 {
215        mem::size_of::<usize>() as u32
216    }
217
218    fn encode<W: Write>(&self, writer: &mut W) -> io::Result<()> {
219        writer.write_all(&self.to_le_bytes())
220    }
221
222    fn decode<R: Read>(reader: &mut R) -> io::Result<Self> {
223        let mut buf = [0; mem::size_of::<usize>()];
224        reader.read_exact(&mut buf)?;
225        Ok(usize::from_le_bytes(buf))
226    }
227}
228
229impl WireFormat for isize {
230    fn byte_size(&self) -> u32 {
231        mem::size_of::<isize>() as u32
232    }
233
234    fn encode<W: Write>(&self, writer: &mut W) -> io::Result<()> {
235        writer.write_all(&self.to_le_bytes())
236    }
237
238    fn decode<R: Read>(reader: &mut R) -> io::Result<Self> {
239        let mut buf = [0; mem::size_of::<isize>()];
240        reader.read_exact(&mut buf)?;
241        Ok(isize::from_le_bytes(buf))
242    }
243}
244
245// The 9P protocol requires that strings are UTF-8 encoded.  The wire format is a u16
246// count |N|, encoded in little endian, followed by |N| bytes of UTF-8 data.
247impl WireFormat for String {
248    fn byte_size(&self) -> u32 {
249        (mem::size_of::<u16>() + self.len()) as u32
250    }
251
252    fn encode<W: Write>(&self, writer: &mut W) -> io::Result<()> {
253        if self.len() > u16::MAX as usize {
254            return Err(io::Error::new(
255                ErrorKind::InvalidInput,
256                "string is too long",
257            ));
258        }
259
260        (self.len() as u16).encode(writer)?;
261        writer.write_all(self.as_bytes())
262    }
263
264    fn decode<R: Read>(reader: &mut R) -> io::Result<Self> {
265        let len: u16 = WireFormat::decode(reader)?;
266        let mut result = String::with_capacity(len as usize);
267        reader.take(len as u64).read_to_string(&mut result)?;
268        Ok(result)
269    }
270}
271
272// The wire format for repeated types is similar to that of strings: a little endian
273// encoded u16 |N|, followed by |N| instances of the given type.
274impl<T: WireFormat> WireFormat for Vec<T> {
275    fn byte_size(&self) -> u32 {
276        mem::size_of::<u16>() as u32 + self.iter().map(|elem| elem.byte_size()).sum::<u32>()
277    }
278
279    fn encode<W: Write>(&self, writer: &mut W) -> io::Result<()> {
280        if self.len() > u16::MAX as usize {
281            return Err(std::io::Error::new(
282                std::io::ErrorKind::InvalidInput,
283                "too many elements in vector",
284            ));
285        }
286
287        (self.len() as u16).encode(writer)?;
288        for elem in self {
289            elem.encode(writer)?;
290        }
291
292        Ok(())
293    }
294
295    fn decode<R: Read>(reader: &mut R) -> io::Result<Self> {
296        let len: u16 = WireFormat::decode(reader)?;
297        let mut result = Vec::with_capacity(len as usize);
298
299        for _ in 0..len {
300            result.push(WireFormat::decode(reader)?);
301        }
302
303        Ok(result)
304    }
305}
306
307/// A type that encodes an arbitrary number of bytes of data.  Typically used for Rread
308/// Twrite messages.  This differs from a `Vec<u8>` in that it encodes the number of bytes
309/// using a `u32` instead of a `u16`.
310#[derive(PartialEq, Eq, Clone)]
311#[repr(transparent)]
312#[cfg_attr(feature = "testing", derive(serde::Serialize, serde::Deserialize))]
313pub struct Data(pub Vec<u8>);
314
315// The maximum length of a data buffer that we support.  In practice the server's max message
316// size should prevent us from reading too much data so this check is mainly to ensure a
317// malicious client cannot trick us into allocating massive amounts of memory.
318const MAX_DATA_LENGTH: u32 = 32 * 1024 * 1024;
319
320impl fmt::Debug for Data {
321    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
322        // There may be a lot of data and we don't want to spew it all out in a trace.  Instead
323        // just print out the number of bytes in the buffer.
324        write!(f, "Data({} bytes)", self.len())
325    }
326}
327
328// Implement Deref and DerefMut so that we don't have to use self.0 everywhere.
329impl Deref for Data {
330    type Target = Vec<u8>;
331    fn deref(&self) -> &Self::Target {
332        &self.0
333    }
334}
335impl DerefMut for Data {
336    fn deref_mut(&mut self) -> &mut Self::Target {
337        &mut self.0
338    }
339}
340
341// Same as Vec<u8> except that it encodes the length as a u32 instead of a u16.
342impl WireFormat for Data {
343    fn byte_size(&self) -> u32 {
344        mem::size_of::<u32>() as u32 + self.iter().map(|elem| elem.byte_size()).sum::<u32>()
345    }
346
347    fn encode<W: Write>(&self, writer: &mut W) -> io::Result<()> {
348        if self.len() > u32::MAX as usize {
349            return Err(std::io::Error::new(
350                std::io::ErrorKind::InvalidInput,
351                "data is too large",
352            ));
353        }
354        (self.len() as u32).encode(writer)?;
355        writer.write_all(self)
356    }
357
358    fn decode<R: Read>(reader: &mut R) -> io::Result<Self> {
359        let len: u32 = WireFormat::decode(reader)?;
360        if len > MAX_DATA_LENGTH {
361            return Err(std::io::Error::new(
362                std::io::ErrorKind::InvalidData,
363                format!("data length ({} bytes) is too large", len),
364            ));
365        }
366
367        let mut buf = Vec::with_capacity(len as usize);
368        reader.take(len as u64).read_to_end(&mut buf)?;
369
370        if buf.len() == len as usize {
371            Ok(Data(buf))
372        } else {
373            Err(io::Error::new(
374                std::io::ErrorKind::UnexpectedEof,
375                format!(
376                    "unexpected end of data: want: {} bytes, got: {} bytes",
377                    len,
378                    buf.len()
379                ),
380            ))
381        }
382    }
383}
384
385impl<T> WireFormat for Option<T>
386where
387    T: WireFormat,
388{
389    fn byte_size(&self) -> u32 {
390        1 + match self {
391            None => 0,
392            Some(value) => value.byte_size(),
393        }
394    }
395
396    fn encode<W: Write>(&self, writer: &mut W) -> io::Result<()> {
397        match self {
398            None => WireFormat::encode(&0u8, writer),
399            Some(value) => {
400                WireFormat::encode(&1u8, writer)?;
401                WireFormat::encode(value, writer)
402            }
403        }
404    }
405
406    fn decode<R: Read>(reader: &mut R) -> io::Result<Self> {
407        let tag: u8 = WireFormat::decode(reader)?;
408        match tag {
409            0 => Ok(None),
410            1 => Ok(Some(WireFormat::decode(reader)?)),
411            _ => {
412                Err(io::Error::new(
413                    io::ErrorKind::InvalidData,
414                    format!("Invalid Option tag: {}", tag),
415                ))
416            }
417        }
418    }
419}
420
421impl WireFormat for () {
422    fn byte_size(&self) -> u32 {
423        0
424    }
425
426    fn encode<W: Write>(&self, _writer: &mut W) -> io::Result<()> {
427        Ok(())
428    }
429
430    fn decode<R: Read>(_reader: &mut R) -> io::Result<Self> {
431        Ok(())
432    }
433}
434
435impl WireFormat for bool {
436    fn byte_size(&self) -> u32 {
437        1
438    }
439
440    fn encode<W: Write>(&self, writer: &mut W) -> io::Result<()> {
441        writer.write_all(&[*self as u8])
442    }
443
444    fn decode<R: Read>(reader: &mut R) -> io::Result<Self> {
445        let mut byte = [0u8; 1];
446        reader.read_exact(&mut byte)?;
447        match byte[0] {
448            0 => Ok(false),
449            1 => Ok(true),
450            _ => {
451                Err(io::Error::new(
452                    io::ErrorKind::InvalidData,
453                    "invalid byte for bool",
454                ))
455            }
456        }
457    }
458}
459
460impl io::Read for Data {
461    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
462        self.0.reader().read(buf)
463    }
464}