netlink_packet_generic/ctrl/nlas/
mod.rs

1// SPDX-License-Identifier: MIT
2
3use crate::constants::*;
4use netlink_packet_core::{
5    emit_u16, emit_u32, parse_string, parse_u16, parse_u32, DecodeError,
6    Emitable, ErrorContext, Nla, NlaBuffer, NlasIterator, Parseable,
7};
8use std::mem::size_of_val;
9
10mod mcast;
11mod oppolicy;
12mod ops;
13mod policy;
14
15pub use mcast::*;
16pub use oppolicy::*;
17pub use ops::*;
18pub use policy::*;
19
20#[derive(Clone, Debug, PartialEq, Eq)]
21pub enum GenlCtrlAttrs {
22    FamilyId(u16),
23    FamilyName(String),
24    Version(u32),
25    HdrSize(u32),
26    MaxAttr(u32),
27    Ops(Vec<Vec<OpAttrs>>),
28    McastGroups(Vec<Vec<McastGrpAttrs>>),
29    Policy(PolicyAttr),
30    OpPolicy(OppolicyAttr),
31    Op(u32),
32}
33
34impl Nla for GenlCtrlAttrs {
35    fn value_len(&self) -> usize {
36        use GenlCtrlAttrs::*;
37        match self {
38            FamilyId(v) => size_of_val(v),
39            FamilyName(s) => s.len() + 1,
40            Version(v) => size_of_val(v),
41            HdrSize(v) => size_of_val(v),
42            MaxAttr(v) => size_of_val(v),
43            Ops(nlas) => OpList::from(nlas).as_slice().buffer_len(),
44            McastGroups(nlas) => {
45                McastGroupList::from(nlas).as_slice().buffer_len()
46            }
47            Policy(nla) => nla.buffer_len(),
48            OpPolicy(nla) => nla.buffer_len(),
49            Op(v) => size_of_val(v),
50        }
51    }
52
53    fn kind(&self) -> u16 {
54        use GenlCtrlAttrs::*;
55        match self {
56            FamilyId(_) => CTRL_ATTR_FAMILY_ID,
57            FamilyName(_) => CTRL_ATTR_FAMILY_NAME,
58            Version(_) => CTRL_ATTR_VERSION,
59            HdrSize(_) => CTRL_ATTR_HDRSIZE,
60            MaxAttr(_) => CTRL_ATTR_MAXATTR,
61            Ops(_) => CTRL_ATTR_OPS,
62            McastGroups(_) => CTRL_ATTR_MCAST_GROUPS,
63            Policy(_) => CTRL_ATTR_POLICY,
64            OpPolicy(_) => CTRL_ATTR_OP_POLICY,
65            Op(_) => CTRL_ATTR_OP,
66        }
67    }
68
69    fn emit_value(&self, buffer: &mut [u8]) {
70        use GenlCtrlAttrs::*;
71        match self {
72            FamilyId(v) => emit_u16(buffer, *v).unwrap(),
73            FamilyName(s) => {
74                buffer[..s.len()].copy_from_slice(s.as_bytes());
75                buffer[s.len()] = 0;
76            }
77            Version(v) => emit_u32(buffer, *v).unwrap(),
78            HdrSize(v) => emit_u32(buffer, *v).unwrap(),
79            MaxAttr(v) => emit_u32(buffer, *v).unwrap(),
80            Ops(nlas) => {
81                OpList::from(nlas).as_slice().emit(buffer);
82            }
83            McastGroups(nlas) => {
84                McastGroupList::from(nlas).as_slice().emit(buffer);
85            }
86            Policy(nla) => nla.emit_value(buffer),
87            OpPolicy(nla) => nla.emit_value(buffer),
88            Op(v) => emit_u32(buffer, *v).unwrap(),
89        }
90    }
91}
92
93impl<'a, T: AsRef<[u8]> + ?Sized> Parseable<NlaBuffer<&'a T>>
94    for GenlCtrlAttrs
95{
96    fn parse(buf: &NlaBuffer<&'a T>) -> Result<Self, DecodeError> {
97        let payload = buf.value();
98        Ok(match buf.kind() {
99            CTRL_ATTR_FAMILY_ID => Self::FamilyId(
100                parse_u16(payload)
101                    .context("invalid CTRL_ATTR_FAMILY_ID value")?,
102            ),
103            CTRL_ATTR_FAMILY_NAME => Self::FamilyName(
104                parse_string(payload)
105                    .context("invalid CTRL_ATTR_FAMILY_NAME value")?,
106            ),
107            CTRL_ATTR_VERSION => Self::Version(
108                parse_u32(payload)
109                    .context("invalid CTRL_ATTR_VERSION value")?,
110            ),
111            CTRL_ATTR_HDRSIZE => Self::HdrSize(
112                parse_u32(payload)
113                    .context("invalid CTRL_ATTR_HDRSIZE value")?,
114            ),
115            CTRL_ATTR_MAXATTR => Self::MaxAttr(
116                parse_u32(payload)
117                    .context("invalid CTRL_ATTR_MAXATTR value")?,
118            ),
119            CTRL_ATTR_OPS => {
120                let ops = NlasIterator::new(payload)
121                    .map(|nlas| {
122                        nlas.and_then(|nlas| {
123                            NlasIterator::new(nlas.value())
124                                .map(|nla| {
125                                    nla.and_then(|nla| OpAttrs::parse(&nla))
126                                })
127                                .collect::<Result<Vec<_>, _>>()
128                        })
129                    })
130                    .collect::<Result<Vec<Vec<_>>, _>>()
131                    .context("failed to parse CTRL_ATTR_OPS")?;
132                Self::Ops(ops)
133            }
134            CTRL_ATTR_MCAST_GROUPS => {
135                let groups = NlasIterator::new(payload)
136                    .map(|nlas| {
137                        nlas.and_then(|nlas| {
138                            NlasIterator::new(nlas.value())
139                                .map(|nla| {
140                                    nla.and_then(|nla| {
141                                        McastGrpAttrs::parse(&nla)
142                                    })
143                                })
144                                .collect::<Result<Vec<_>, _>>()
145                        })
146                    })
147                    .collect::<Result<Vec<Vec<_>>, _>>()
148                    .context("failed to parse CTRL_ATTR_MCAST_GROUPS")?;
149                Self::McastGroups(groups)
150            }
151            CTRL_ATTR_POLICY => Self::Policy(
152                PolicyAttr::parse(&NlaBuffer::new(payload))
153                    .context("failed to parse CTRL_ATTR_POLICY")?,
154            ),
155            CTRL_ATTR_OP_POLICY => Self::OpPolicy(
156                OppolicyAttr::parse(&NlaBuffer::new(payload))
157                    .context("failed to parse CTRL_ATTR_OP_POLICY")?,
158            ),
159            CTRL_ATTR_OP => Self::Op(parse_u32(payload)?),
160            kind => {
161                return Err(DecodeError::from(format!(
162                    "Unknown NLA type: {kind}"
163                )))
164            }
165        })
166    }
167}
168
169#[cfg(test)]
170mod tests {
171    use super::*;
172
173    #[test]
174    fn mcast_groups_parse() {
175        let mcast_bytes: [u8; 24] = [
176            24, 0, // Netlink header length
177            7, 0, // Netlink header kind (Mcast groups)
178            20, 0, // Mcast group nested NLA length
179            1, 0, // Mcast group kind
180            8, 0, // Id length
181            2, 0, // Id kind
182            1, 0, 0, 0, // Id
183            8, 0, // Name length
184            1, 0, // Name kind
185            b't', b'e', b's', b't', // Name
186        ];
187        let nla_buffer = NlaBuffer::new_checked(&mcast_bytes[..])
188            .expect("Failed to create NlaBuffer");
189        let result_attr = GenlCtrlAttrs::parse(&nla_buffer)
190            .expect("Failed to parse encoded McastGroups");
191        let expected_attr = GenlCtrlAttrs::McastGroups(vec![vec![
192            McastGrpAttrs::Id(1),
193            McastGrpAttrs::Name("test".to_string()),
194        ]]);
195        assert_eq!(expected_attr, result_attr);
196    }
197
198    #[test]
199    fn mcast_groups_emit() {
200        let mcast_attr = GenlCtrlAttrs::McastGroups(vec![
201            vec![
202                McastGrpAttrs::Id(7),
203                McastGrpAttrs::Name("group1".to_string()),
204            ],
205            vec![
206                McastGrpAttrs::Id(8),
207                McastGrpAttrs::Name("group2".to_string()),
208            ],
209        ]);
210        let expected_bytes: [u8; 52] = [
211            52, 0, // Netlink header length
212            7, 0, // Netlink header kind (Mcast groups)
213            24, 0, // Mcast group nested NLA length
214            1, 0, // Mcast group kind (index 1)
215            8, 0, // Id length
216            2, 0, // Id kind
217            7, 0, 0, 0, // Id
218            11, 0, // Name length
219            1, 0, // Name kind
220            b'g', b'r', b'o', b'u', b'p', b'1', 0, // Name
221            0, // mcast group padding
222            24, 0, // Mcast group nested NLA length
223            2, 0, // Mcast group kind (index 2)
224            8, 0, // Id length
225            2, 0, // Id kind
226            8, 0, 0, 0, // Id
227            11, 0, // Name length
228            1, 0, // Name kind
229            b'g', b'r', b'o', b'u', b'p', b'2', 0, // Name
230            0, // padding
231        ];
232        let mut buf = vec![0u8; 100];
233        mcast_attr.emit(&mut buf);
234
235        assert_eq!(&expected_bytes[..], &buf[..expected_bytes.len()]);
236    }
237
238    #[test]
239    fn ops_parse() {
240        let ops_bytes: [u8; 24] = [
241            24, 0, // Netlink header length
242            6, 0, // Netlink header kind (Ops)
243            20, 0, // Op nested NLA length
244            0, 0, // Op kind
245            8, 0, // Id length
246            1, 0, // Id kind
247            1, 0, 0, 0, // Id
248            8, 0, // Flags length
249            2, 0, // Flags kind
250            123, 0, 0, 0, // Flags
251        ];
252        let nla_buffer = NlaBuffer::new_checked(&ops_bytes[..])
253            .expect("Failed to create NlaBuffer");
254        let result_attr = GenlCtrlAttrs::parse(&nla_buffer)
255            .expect("Failed to parse encoded McastGroups");
256        let expected_attr =
257            GenlCtrlAttrs::Ops(vec![vec![OpAttrs::Id(1), OpAttrs::Flags(123)]]);
258        assert_eq!(expected_attr, result_attr);
259    }
260
261    #[test]
262    fn ops_emit() {
263        let ops = GenlCtrlAttrs::Ops(vec![
264            vec![OpAttrs::Id(1), OpAttrs::Flags(11)],
265            vec![OpAttrs::Id(3), OpAttrs::Flags(33)],
266        ]);
267        let expected_bytes: [u8; 44] = [
268            44, 0, // Netlink header length
269            6, 0, // Netlink header kind (Ops)
270            20, 0, // Op nested NLA length
271            1, 0, // Op kind
272            8, 0, // Id length
273            1, 0, // Id kind
274            1, 0, 0, 0, // Id
275            8, 0, // Flags length
276            2, 0, // Flags kind
277            11, 0, 0, 0, // Flags
278            20, 0, // Op nested NLA length
279            2, 0, // Op kind
280            8, 0, // Id length
281            1, 0, // Id kind
282            3, 0, 0, 0, // Id
283            8, 0, // Flags length
284            2, 0, // Flags kind
285            33, 0, 0, 0, // Flags
286        ];
287        let mut buf = vec![0u8; 100];
288        ops.emit(&mut buf);
289
290        assert_eq!(&expected_bytes[..], &buf[..expected_bytes.len()]);
291    }
292}