1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
use std::{fmt::Debug, io, marker::PhantomData, mem::MaybeUninit};

use bytes::{BufMut, BytesMut};
use netlink_packet_core::{
    NetlinkBuffer,
    NetlinkDeserializable,
    NetlinkMessage,
    NetlinkSerializable,
};
use tokio_util::codec::{Decoder, Encoder};

pub struct NetlinkCodec<T> {
    phantom: PhantomData<T>,
}

impl<T> Default for NetlinkCodec<T> {
    fn default() -> Self {
        Self::new()
    }
}

impl<T> NetlinkCodec<T> {
    pub fn new() -> Self {
        NetlinkCodec {
            phantom: PhantomData,
        }
    }
}

// FIXME: it seems that for audit, we're receiving malformed packets.
// See https://github.com/mozilla/libaudit-go/issues/24
impl<T> Decoder for NetlinkCodec<NetlinkMessage<T>>
where
    T: NetlinkDeserializable<T> + Debug + Eq + PartialEq + Clone,
{
    type Item = NetlinkMessage<T>;
    type Error = io::Error;

    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
        debug!("NetlinkCodec: decoding next message");

        loop {
            // If there's nothing to read, return Ok(None)
            if src.as_ref().is_empty() {
                trace!("buffer is empty");
                src.clear();
                return Ok(None);
            }

            // This is a bit hacky because we don't want to keep `src`
            // borrowed, since we need to mutate it later.
            let len_res = match NetlinkBuffer::new_checked(src.as_ref()) {
                #[cfg(not(feature = "workaround-audit-bug"))]
                Ok(buf) => Ok(buf.length() as usize),
                #[cfg(feature = "workaround-audit-bug")]
                Ok(buf) => {
                    if (src.as_ref().len() as isize - buf.length() as isize) <= 16 {
                        // The audit messages are sometimes truncated,
                        // because the length specified in the header,
                        // does not take the header itself into
                        // account. To workaround this, we tweak the
                        // length. We've noticed two occurences of
                        // truncated packets:
                        //
                        // - the length of the header is not included (see also:
                        //   https://github.com/mozilla/libaudit-go/issues/24)
                        // - some rule message have some padding for alignment (see
                        //   https://github.com/linux-audit/audit-userspace/issues/78) which is not
                        //   taken into account in the buffer length.
                        warn!("found what looks like a truncated audit packet");
                        Ok(src.as_ref().len())
                    } else {
                        Ok(buf.length() as usize)
                    }
                }
                Err(e) => {
                    // We either received a truncated packet, or the
                    // packet if malformed (invalid length field). In
                    // both case, we can't decode the datagram, and we
                    // cannot find the start of the next one (if
                    // any). The only solution is to clear the buffer
                    // and potentially lose some datagrams.
                    error!("failed to decode datagram: {:?}: {:#x?}.", e, src.as_ref());
                    Err(())
                }
            };

            if len_res.is_err() {
                error!("clearing the whole socket buffer. Datagrams may have been lost");
                src.clear();
                return Ok(None);
            }

            let len = len_res.unwrap();

            #[cfg(feature = "workaround-audit-bug")]
            let bytes = {
                let mut bytes = src.split_to(len);
                {
                    let mut buf = NetlinkBuffer::new(bytes.as_mut());
                    // If the buffer contains more bytes than what the header says the length is, it
                    // means we ran into a malformed packet (see comment above), and we just set the
                    // "right" length ourself, so that parsing does not fail.
                    //
                    // How do we know that's the right length? Due to an implementation detail and to
                    // the fact that netlink is a datagram protocol.
                    //
                    // - our implementation of Stream always calls the codec with at most 1 message in
                    //   the buffer, so we know the extra bytes do not belong to another message.
                    // - because netlink is a datagram protocol, we receive entire messages, so we know
                    //   that if those extra bytes do not belong to another message, they belong to
                    //   this one.
                    if len != buf.length() as usize {
                        warn!(
                            "setting packet length to {} instead of {}",
                            len,
                            buf.length()
                        );
                        buf.set_length(len as u32);
                    }
                }
                bytes
            };
            #[cfg(not(feature = "workaround-audit-bug"))]
            let bytes = src.split_to(len);

            let parsed = NetlinkMessage::<T>::deserialize(&bytes);
            match parsed {
                Ok(packet) => {
                    trace!("<<< {:?}", packet);
                    return Ok(Some(packet));
                }
                Err(e) => {
                    error!("failed to decode packet {:#x?}: {}", &bytes, e);
                    // continue looping, there may be more datagrams in the buffer
                }
            }
        }
    }
}

impl<T> Encoder for NetlinkCodec<NetlinkMessage<T>>
where
    T: Debug + Eq + PartialEq + Clone + NetlinkSerializable<T>,
{
    type Item = NetlinkMessage<T>;
    type Error = io::Error;

    fn encode(&mut self, msg: Self::Item, buf: &mut BytesMut) -> Result<(), Self::Error> {
        let msg_len = msg.buffer_len();
        // FIXME: we should have a max length for the buffer
        while buf.remaining_mut() < msg_len {
            let new_len = buf.len() + 2048;
            buf.resize(new_len, 0);
        }
        let size = msg.buffer_len();
        if buf.remaining_mut() < size {
            return Err(io::Error::new(
                io::ErrorKind::Other,
                format!(
                    "message is {} bytes, but only {} bytes left in the buffer",
                    size,
                    buf.remaining_mut()
                ),
            ));
        }
        unsafe {
            // Safety: we initialize the buffer we're passing to
            // NetlinkMessage::serialize(). In theory, `serialize()`
            // should be safe because it's not supposed to _read_ from
            // the buffer, which is potentially
            // un-initialized. However, since we delegate the actual
            // implementation to users, we cannot guarantee
            // anything. Therefore we have to initialize the buffer
            // here.
            let bytes: &mut [std::mem::MaybeUninit<u8>] = &mut buf.bytes_mut()[..size];
            for b in &mut bytes[..] {
                *b.as_mut_ptr() = 0;
            }
            let initialized_bytes = &mut *(bytes as *mut [MaybeUninit<u8>] as *mut [u8]);

            msg.serialize(initialized_bytes);
            trace!(">>> {:?}", msg);
            buf.advance_mut(size);
        }
        Ok(())
    }
}