rl2tp/message/
control_message.rs

1use crate::common::{DecodeError, Reader, Writer};
2use crate::message::flags::{Flags, MessageFlagType};
3use crate::message::*;
4use avp::AVP;
5use core::borrow::Borrow;
6
7/// # Summary
8/// A `ControlMessage` is a representation of an L2TP control message which is the primary link control mechanism of the protocol.
9///
10/// # Data members
11/// * `length` - The payload length field.
12/// * `tunnel_id` - The tunnel identifier field.
13/// * `session_id` - The session identifier field.
14/// * `ns` - The NS field.
15/// * `nr` - The NR field.
16/// * `avps` - A collection of Attribute Value Pairs constituting the payload of this message.
17#[derive(Clone, Debug, Eq, PartialEq)]
18pub struct ControlMessage {
19    pub length: u16,
20    pub tunnel_id: u16,
21    pub session_id: u16,
22    pub ns: u16,
23    pub nr: u16,
24    pub avps: Vec<AVP>,
25}
26
27impl ControlMessage {
28    #[inline]
29    pub(crate) fn try_read<T: Borrow<[u8]>>(
30        flags: Flags,
31        validation_options: ValidationOptions,
32        reader: &mut impl Reader<T>,
33    ) -> Result<Self, Vec<DecodeError>> {
34        if let ValidateUnused::Yes = validation_options.unused {
35            if flags.is_prioritized() {
36                return Err(vec![DecodeError::ForbiddenControlMessagePriority]);
37            }
38
39            if flags.has_offset() {
40                return Err(vec![DecodeError::ForbiddenControlMessageOffset]);
41            }
42        }
43
44        if !flags.has_length() {
45            return Err(vec![DecodeError::ControlMessageWithoutLength]);
46        }
47        if !flags.has_ns_nr() {
48            return Err(vec![DecodeError::ControlMessageWithoutNsNr]);
49        }
50
51        const FIXED_LENGTH_MINUS_FLAGS: usize = 10;
52        if reader.len() < FIXED_LENGTH_MINUS_FLAGS {
53            return Err(vec![DecodeError::IncompleteControlMessageHeader]);
54        }
55
56        let length = unsafe { reader.read_u16_be_unchecked() };
57        let tunnel_id = unsafe { reader.read_u16_be_unchecked() };
58        let session_id = unsafe { reader.read_u16_be_unchecked() };
59        let ns = unsafe { reader.read_u16_be_unchecked() };
60        let nr = unsafe { reader.read_u16_be_unchecked() };
61
62        const FIXED_LENGTH: usize = 12;
63        if length as usize > reader.len() + FIXED_LENGTH {
64            return Err(vec![DecodeError::IncompleteControlMessagePayload]);
65        }
66
67        let mut avp_reader = reader.subreader(length as usize - FIXED_LENGTH);
68        let avp_and_err = AVP::try_read_greedy(&mut avp_reader);
69
70        if let Some(first) = avp_and_err.first() {
71            match first {
72                Ok(AVP::MessageType(_)) => (),
73                _ => return Err(vec![DecodeError::ControlMessageTypeNotFirst]),
74            }
75        }
76
77        if avp_and_err.iter().any(|x| {
78            println!("{x:?}");
79            x.is_err()
80        }) {
81            return Err(avp_and_err.into_iter().filter_map(|x| x.err()).collect());
82        }
83
84        let avps = avp_and_err.into_iter().filter_map(|x| x.ok()).collect();
85
86        Ok(ControlMessage {
87            length,
88            tunnel_id,
89            session_id,
90            ns,
91            nr,
92            avps,
93        })
94    }
95
96    #[inline]
97    pub(crate) fn write(&self, protocol_version: u8, writer: &mut impl Writer) {
98        let start_position = writer.len();
99        let flags = Flags::new(
100            MessageFlagType::Control,
101            true,
102            true,
103            false,
104            false,
105            protocol_version,
106        );
107        flags.write(writer);
108
109        // Save length field position
110        let length_position = writer.len();
111
112        // Dummy octets to be overwritten
113        writer.write_bytes(&[0, 0]);
114
115        // Write rest of header
116        writer.write_u16_be(self.tunnel_id);
117        writer.write_u16_be(self.session_id);
118        writer.write_u16_be(self.ns);
119        writer.write_u16_be(self.nr);
120
121        // Write payload
122        for avp in self.avps.iter() {
123            avp.write(writer);
124        }
125
126        // Get total length
127        let end_position = writer.len();
128        let length = end_position - start_position;
129
130        // Overwrite dummy octets
131        assert!(length <= u16::MAX as usize);
132        writer.write_bytes_at(&(length as u16).to_be_bytes(), length_position);
133    }
134}