1use std::convert::Into;
12use std::fmt;
13use std::io;
14
15use encoding::all::ISO_8859_1;
16use encoding::{DecoderTrap, Encoding};
17
18use crate::information_element_ids::InformationElementId;
19use crate::unpack::{unpack_vec, LittleUnpack};
20use netlink_rust::{ConvertFrom, Error};
21
22pub struct RawInformationElement<'a> {
35 pub identifier: u8,
36 pub data: &'a [u8],
37}
38
39impl<'a> RawInformationElement<'a> {
40 pub fn parse(data: &'a [u8]) -> Result<RawInformationElement<'a>, Error> {
42 if data.len() < 2 {
43 return Err(io::Error::new(io::ErrorKind::InvalidData, "").into());
44 }
45 let identifier = u8::unpack_unchecked(data);
46 let length = u8::unpack_unchecked(&data[1..]);
47 let length = length as usize;
48 if data.len() < length {
49 return Err(io::Error::new(io::ErrorKind::InvalidData, "").into());
50 }
51 Ok(RawInformationElement {
52 identifier,
53 data: &data[2..(length + 2)],
54 })
55 }
56 pub fn ie_id(&self) -> Option<InformationElementId> {
58 InformationElementId::convert_from(self.identifier)
59 }
60}
61
62pub struct InformationElements<'a> {
64 pub elements: Vec<RawInformationElement<'a>>,
65}
66
67impl<'a> InformationElements<'a> {
68 pub fn parse(data: &'a [u8]) -> InformationElements<'a> {
69 let mut elements = vec![];
70 let mut slice = data;
71 while let Ok(ie) = RawInformationElement::parse(slice) {
72 slice = &slice[(ie.data.len() + 2)..];
73 elements.push(ie);
74 }
75 InformationElements { elements }
76 }
77}
78
79pub struct Ssid {
83 pub ssid: String,
84}
85
86impl Ssid {
87 pub fn parse(data: &[u8]) -> Result<Ssid, Error> {
93 let ssid = String::from_utf8(data.to_vec()).unwrap_or_else(|_|
95 ISO_8859_1.decode(data, DecoderTrap::Strict)
97 .unwrap_or_default());
98 let ssid = ssid.trim_end_matches('\0').to_string();
99 Ok(Ssid { ssid })
100 }
101}
102
103impl fmt::Display for Ssid {
104 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
105 write!(f, "{}", self.ssid)
106 }
107}
108
109#[derive(Debug, PartialEq, Clone)]
111pub enum CipherSuite {
112 UseGroupCipherSuite,
114 WiredEquivalentPrivacy40,
116 TemporalKeyIntegrityProtocol,
118 CounterModeCbcMacProtocol,
120 WiredEquivalentPrivacy104,
122 BroadcastIntegrityProtocol,
124 GroupAddressedTrafficNotAllowed,
126 Reserved(u8),
128 Vendor(u32),
130}
131
132impl From<u32> for CipherSuite {
133 fn from(v: u32) -> Self {
135 use self::CipherSuite::*;
136 if v & 0x00ff_ffff == 0x00ac_0f00 {
137 let c = (v >> 24) as u8;
138 match c {
139 0 => UseGroupCipherSuite,
140 1 => WiredEquivalentPrivacy40,
141 2 => TemporalKeyIntegrityProtocol,
142 4 => CounterModeCbcMacProtocol,
143 5 => WiredEquivalentPrivacy104,
144 6 => BroadcastIntegrityProtocol,
145 7 => GroupAddressedTrafficNotAllowed,
146 _ => Reserved(c),
147 }
148 } else {
149 Vendor(v)
150 }
151 }
152}
153
154impl From<CipherSuite> for u32 {
155 fn from(v: CipherSuite) -> Self {
157 use self::CipherSuite::*;
158 match v {
159 UseGroupCipherSuite => 0x00ac_0f00,
160 WiredEquivalentPrivacy40 => 0x01ac_0f00,
161 TemporalKeyIntegrityProtocol => 0x02ac_0f00,
162 CounterModeCbcMacProtocol => 0x04ac_0f00,
163 WiredEquivalentPrivacy104 => 0x05ac_0f00,
164 BroadcastIntegrityProtocol => 0x06ac_0f00,
165 GroupAddressedTrafficNotAllowed => 0x07ac_0f00,
166 Reserved(v) => 0x00ac_0f00 | u32::from(v) << 24,
167 Vendor(v) => v,
168 }
169 }
170}
171
172impl fmt::Display for CipherSuite {
173 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
174 use self::CipherSuite::*;
175 match *self {
176 UseGroupCipherSuite => write!(f, "GroupCipher"),
177 WiredEquivalentPrivacy40 => write!(f, "WEP40"),
178 TemporalKeyIntegrityProtocol => write!(f, "TKIP"),
179 CounterModeCbcMacProtocol => write!(f, "CCMP"),
180 WiredEquivalentPrivacy104 => write!(f, "WEP104"),
181 BroadcastIntegrityProtocol => write!(f, "BIP"),
182 GroupAddressedTrafficNotAllowed => write!(f, "GroupAddressedTrafficNotAllowed"),
183 Reserved(v) => write!(f, "Reserved {:02x}", v),
184 Vendor(v) => write!(f, "Vendor {:08x}", v),
185 }
186 }
187}
188
189#[derive(Debug, PartialEq, Clone)]
191pub enum AuthenticationKeyManagement {
192 PairwiseMasterKeySecurityAssociation,
194 PreSharedKey,
196 FastTransitionPMKSA,
198 FastTransitionPreSharedKey,
200 PMKSASha256,
202 PreSharedKeySha256,
204 TunneledDirectLinkSetup,
206 SimultaneousAuthenticationOfEquals,
208 FastTransitionSAE,
210 Reserved(u8),
212 Vendor(u32),
214}
215
216impl From<u32> for AuthenticationKeyManagement {
217 fn from(v: u32) -> Self {
218 if v & 0x00ff_ffff == 0x00ac_0f00 {
219 let c = (v >> 24) as u8;
220 use self::AuthenticationKeyManagement::*;
221 match c {
222 1 => PairwiseMasterKeySecurityAssociation,
223 2 => PreSharedKey,
224 3 => FastTransitionPMKSA,
225 4 => FastTransitionPreSharedKey,
226 5 => PMKSASha256,
227 6 => PreSharedKeySha256,
228 7 => TunneledDirectLinkSetup,
229 8 => SimultaneousAuthenticationOfEquals,
230 9 => FastTransitionSAE,
231 _ => Reserved(c),
232 }
233 } else {
234 AuthenticationKeyManagement::Vendor(v)
235 }
236 }
237}
238
239impl From<AuthenticationKeyManagement> for u32 {
240 fn from(v: AuthenticationKeyManagement) -> Self {
241 use self::AuthenticationKeyManagement::*;
242 match v {
243 PairwiseMasterKeySecurityAssociation => 0x01ac_0f00,
244 PreSharedKey => 0x02ac_0f00,
245 FastTransitionPMKSA => 0x03ac_0f00,
246 FastTransitionPreSharedKey => 0x04ac_0f00,
247 PMKSASha256 => 0x05ac_0f00,
248 PreSharedKeySha256 => 0x06ac_0f00,
249 TunneledDirectLinkSetup => 0x07ac_0f00,
250 SimultaneousAuthenticationOfEquals => 0x08ac_0f00,
251 FastTransitionSAE => 0x09ac_0f00,
252 Reserved(v) => 0x00ac_0f00 | u32::from(v) << 24,
253 Vendor(v) => v,
254 }
255 }
256}
257
258impl fmt::Display for AuthenticationKeyManagement {
259 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
260 use self::AuthenticationKeyManagement::*;
261 match *self {
262 PairwiseMasterKeySecurityAssociation => write!(f, "PMKSA"),
263 PreSharedKey => write!(f, "PSK"),
264 FastTransitionPMKSA => write!(f, "FTPMKSA"),
265 FastTransitionPreSharedKey => write!(f, "FTPSK"),
266 PMKSASha256 => write!(f, "PMKSA_SHA256"),
267 PreSharedKeySha256 => write!(f, "PSK_SHA256"),
268 TunneledDirectLinkSetup => write!(f, "TDLS"),
269 SimultaneousAuthenticationOfEquals => write!(f, "SAE"),
270 FastTransitionSAE => write!(f, "FTSAE"),
271 Reserved(v) => write!(f, "Reserved {:x}", v),
272 Vendor(v) => write!(f, "Vendor {:x}", v),
273 }
274 }
275}
276
277bitflags! {
278 #[derive(Clone, Copy, Debug, PartialEq, PartialOrd)]
280 pub struct RsnCapabilities: u16 {
281 const PREAUTHENTICATION = 0x0001;
283 const NO_PAIRWISE = 0x0002;
285 const PMF_REQUIRED = 0x0040;
287 const PMF_CAPABLE = 0x0080;
289 const PEER_KEY_ENABLED = 0x0200;
291 const SPP_AMSDU_CAPABLE = 0x0400;
293 const SPP_AMSDU_REQUIRED = 0x0800;
295 const PBAC = 0x1000;
297 const EXTENDED_KEY_ID = 0x2000;
300 }
301}
302
303#[derive(Debug, PartialEq)]
305pub enum ProtectedManagementFramesMode {
306 Disabled,
308 Capable,
310 Required,
312}
313
314impl fmt::Display for ProtectedManagementFramesMode {
315 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
316 match *self {
317 ProtectedManagementFramesMode::Disabled => write!(f, "Disabled"),
318 ProtectedManagementFramesMode::Capable => write!(f, "Capable"),
319 ProtectedManagementFramesMode::Required => write!(f, "Required"),
320 }
321 }
322}
323
324#[derive(Debug)]
326pub struct RobustSecurityNetwork {
327 version: u16,
329 cipher_suite: CipherSuite,
331 pub ciphers: Vec<CipherSuite>,
333 pub akms: Vec<AuthenticationKeyManagement>,
335 capabilities: RsnCapabilities,
337 ptksa_counters: u8,
339 gtksa_counters: u8,
341}
342
343impl RobustSecurityNetwork {
344 pub fn parse(data: &[u8]) -> Result<RobustSecurityNetwork, Error> {
346 if data.len() > 8 {
347 let version = u16::unpack_unchecked(data);
348 let value = u32::unpack_unchecked(&data[2..]);
349 let suite = CipherSuite::from(value);
350 let count = u16::unpack_unchecked(&data[6..]);
351 let (used, values) = unpack_vec::<u32>(&data[8..], count as usize)?;
352 let mut offset = 8 + used;
353 let ciphers = values.into_iter().map(CipherSuite::from).collect();
354 let (used, count) = u16::unpack_with_size(&data[offset..])?;
355 offset += used;
356 let (used, values) = unpack_vec::<u32>(&data[offset..], count as usize)?;
357 offset += used;
358 let akms = values
359 .into_iter()
360 .map(AuthenticationKeyManagement::from)
361 .collect();
362 let (_used, count) = u16::unpack_with_size(&data[offset..])?;
363 let ptksa_counters = match count & 0x000c {
364 0x0004 => 2,
365 0x0008 => 4,
366 0x000c => 16,
367 _ => 1,
368 };
369 let gtksa_counters = match count & 0x0030 {
370 0x0010 => 2,
371 0x0020 => 4,
372 0x0030 => 16,
373 _ => 1,
374 };
375 return Ok(RobustSecurityNetwork {
376 version,
377 cipher_suite: suite,
378 ciphers,
379 akms,
380 capabilities: RsnCapabilities::from_bits_truncate(count),
381 ptksa_counters,
382 gtksa_counters,
383 });
384 }
385 Err(io::Error::new(io::ErrorKind::InvalidData, "Invalid RSN element").into())
386 }
387 pub fn pmf_mode(&self) -> ProtectedManagementFramesMode {
389 if self.capabilities.intersects(RsnCapabilities::PMF_REQUIRED) {
390 return ProtectedManagementFramesMode::Required;
391 } else if self.capabilities.intersects(RsnCapabilities::PMF_CAPABLE) {
392 return ProtectedManagementFramesMode::Capable;
393 }
394 ProtectedManagementFramesMode::Disabled
395 }
396
397 pub fn version(&self) -> u16 {
399 self.version
400 }
401
402 pub fn pairwise_transient_key_security_association_replay_counters(&self) -> u8 {
404 self.ptksa_counters
405 }
406
407 pub fn group_temporal_key_security_association_replay_counters(&self) -> u8 {
409 self.gtksa_counters
410 }
411}
412
413impl fmt::Display for RobustSecurityNetwork {
414 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
415 write!(
416 f,
417 "Cipher Suite {} Protected Management Frames {}",
418 self.cipher_suite,
419 self.pmf_mode()
420 )
421 }
422}
423
424pub struct HighThroughputOperation {
426 pub width: u32,
428 pub primary_channel: u8,
430 pub secondary_channel: u8,
432}
433
434impl HighThroughputOperation {
435 pub fn parse(data: &[u8]) -> Result<HighThroughputOperation, Error> {
437 if data.len() == 22 {
438 let secondary_channel = match data[1] & 0x03 {
439 1 => data[0] + 1,
440 3 => data[0] - 1,
441 _ => 0,
442 };
443 let width = if data[1] & 0x04 == 0 { 20 } else { 40 };
444 return Ok(HighThroughputOperation {
446 primary_channel: data[0],
447 secondary_channel,
448 width,
449 });
450 }
451 Err(io::Error::new(io::ErrorKind::InvalidData, "Invalid VHT element").into())
452 }
453}
454
455impl fmt::Display for HighThroughputOperation {
456 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
457 write!(
458 f,
459 "Primary Channel {} Secondary Channel {} Bandwidth {}",
460 self.primary_channel, self.secondary_channel, self.width
461 )
462 }
463}
464
465#[derive(Debug, Clone, Copy, PartialEq)]
467pub enum MaxVhtMcs {
468 VhtMcs0to7 = 0,
470 VhtMcs0to8 = 1,
472 VhtMcs0to9 = 2,
474 NotSupported = 3,
476}
477
478impl From<u8> for MaxVhtMcs {
479 fn from(v: u8) -> Self {
480 match v {
481 0 => MaxVhtMcs::VhtMcs0to7,
482 1 => MaxVhtMcs::VhtMcs0to8,
483 2 => MaxVhtMcs::VhtMcs0to9,
484 _ => MaxVhtMcs::NotSupported,
485 }
486 }
487}
488
489pub struct VeryHighThroughputOperation {
491 pub width: u32,
493 pub channel: u8,
495 pub secondary_channel: u8,
497 pub max_vht_mcs_ss: [MaxVhtMcs; 8],
499}
500
501impl VeryHighThroughputOperation {
502 pub fn parse(data: &[u8]) -> Result<VeryHighThroughputOperation, Error> {
504 if data.len() == 5 {
505 let width = match data[0] & 0x03 {
506 1 => 80,
507 2 => 160,
508 3 => 80,
509 _ => 40,
510 };
511 let mut max_vht_mcs_ss = [MaxVhtMcs::NotSupported; 8];
512 max_vht_mcs_ss[0] = MaxVhtMcs::from((data[3] & 0b0000_0011) >> 0);
513 max_vht_mcs_ss[1] = MaxVhtMcs::from((data[3] & 0b0000_1100) >> 2);
514 max_vht_mcs_ss[2] = MaxVhtMcs::from((data[3] & 0b0011_0000) >> 4);
515 max_vht_mcs_ss[3] = MaxVhtMcs::from((data[3] & 0b1100_0000) >> 6);
516 max_vht_mcs_ss[4] = MaxVhtMcs::from((data[4] & 0b0000_0011) >> 0);
517 max_vht_mcs_ss[5] = MaxVhtMcs::from((data[4] & 0b0000_1100) >> 2);
518 max_vht_mcs_ss[6] = MaxVhtMcs::from((data[4] & 0b0011_0000) >> 4);
519 max_vht_mcs_ss[7] = MaxVhtMcs::from((data[4] & 0b1100_0000) >> 6);
520 return Ok(VeryHighThroughputOperation {
521 width,
522 channel: data[1],
523 secondary_channel: data[2],
524 max_vht_mcs_ss,
525 });
526 }
527 Err(io::Error::new(io::ErrorKind::InvalidData, "Invalid VHT element").into())
528 }
529}
530
531impl fmt::Display for VeryHighThroughputOperation {
532 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
533 write!(
534 f,
535 "Primary Channel {} Secondary Channel {} Bandwidth {}",
536 self.channel, self.secondary_channel, self.width
537 )
538 }
539}
540
541pub enum ChannelSwitchMode {
543 NoRestriction = 0,
545 NoTransmission = 1,
547}
548
549impl From<u8> for ChannelSwitchMode {
550 fn from(v: u8) -> Self {
551 match v {
552 1 => ChannelSwitchMode::NoTransmission,
553 _ => ChannelSwitchMode::NoRestriction,
554 }
555 }
556}
557
558pub struct ChannelSwitchAnnouncement {
560 pub switch_mode: ChannelSwitchMode,
562 pub new_channel: u8,
564 pub switch_count: u8,
566}
567
568impl ChannelSwitchAnnouncement {
569 pub fn parse(data: &[u8]) -> Result<Self, Error> {
571 if data.len() == 4 {
572 return Ok(ChannelSwitchAnnouncement {
573 switch_mode: ChannelSwitchMode::from(data[0]),
574 new_channel: data[1],
575 switch_count: data[2],
576 });
577 }
578 Err(io::Error::new(io::ErrorKind::InvalidData, "Invalid CSA element").into())
579 }
580}
581
582pub struct ExtendedChannelSwitchAnnouncement {
584 pub switch_mode: ChannelSwitchMode,
586 pub new_operating_class: u8,
588 pub new_channel: u8,
590 pub switch_count: u8,
592}
593
594impl ExtendedChannelSwitchAnnouncement {
595 pub fn parse(data: &[u8]) -> Result<ExtendedChannelSwitchAnnouncement, Error> {
597 if data.len() == 4 {
598 return Ok(ExtendedChannelSwitchAnnouncement {
599 switch_mode: ChannelSwitchMode::from(data[0]),
600 new_operating_class: data[1],
601 new_channel: data[2],
602 switch_count: data[3],
603 });
604 }
605 Err(io::Error::new(io::ErrorKind::InvalidData, "Invalid ECSA element").into())
606 }
607}
608
609pub struct Country {
611 pub alpha2: String,
613}
614
615impl Country {
616 pub fn parse(data: &[u8]) -> Result<Country, Error> {
618 if data.len() >= 6 {
619 let alpha2 = String::from_utf8(data[..2].to_vec()).unwrap();
620 return Ok(Country { alpha2 });
621 }
622 println!("Bad country element {}", data.len());
623 Err(io::Error::new(io::ErrorKind::InvalidData, "Invalid Country element").into())
624 }
625}
626
627pub enum InformationElement<'a> {
629 Ssid(Ssid),
631 Country(Country),
633 ChannelSwitchAnnouncement(ChannelSwitchAnnouncement),
635 RobustSecurityNetwork(RobustSecurityNetwork),
637 ExtendedChannelSwitchAnnouncement(ExtendedChannelSwitchAnnouncement),
639 HighThroughputOperation(HighThroughputOperation),
641 VeryHighThroughputOperation(VeryHighThroughputOperation),
643 Other(RawInformationElement<'a>),
645}
646
647impl<'a> InformationElement<'a> {
648 pub fn parse(data: &'a [u8]) -> Result<InformationElement<'a>, Error> {
650 let raw = RawInformationElement::parse(data)?;
651 if let Some(id) = raw.ie_id() {
652 return Self::from(id, raw.data);
653 } else {
654 return Ok(InformationElement::Other(raw));
655 }
656 }
657
658 pub fn from(id: InformationElementId, data: &'a [u8]) -> Result<InformationElement<'a>, Error> {
660 let ie = match id {
661 InformationElementId::Ssid => {
662 let ie = Ssid::parse(data)?;
663 InformationElement::Ssid(ie)
664 }
665 InformationElementId::Country => {
666 let ie = Country::parse(data)?;
667 InformationElement::Country(ie)
668 }
669 InformationElementId::ChannelSwitchAnnouncement => {
670 let ie = ChannelSwitchAnnouncement::parse(data)?;
671 InformationElement::ChannelSwitchAnnouncement(ie)
672 }
673 InformationElementId::RobustSecurityNetwork => {
674 let ie = RobustSecurityNetwork::parse(data)?;
675 InformationElement::RobustSecurityNetwork(ie)
676 }
677 InformationElementId::ExtendedChannelSwitchAnnouncement => {
678 let ie = ExtendedChannelSwitchAnnouncement::parse(data)?;
679 InformationElement::ExtendedChannelSwitchAnnouncement(ie)
680 }
681 InformationElementId::HighThroughputOperation => {
682 let ie = HighThroughputOperation::parse(data)?;
683 InformationElement::HighThroughputOperation(ie)
684 }
685 InformationElementId::VeryHighThroughputOperation => {
686 let ie = VeryHighThroughputOperation::parse(data)?;
687 InformationElement::VeryHighThroughputOperation(ie)
688 }
689 _ => InformationElement::Other(RawInformationElement {
690 identifier: id.into(),
691 data,
692 }),
693 };
694 Ok(ie)
695 }
696 pub fn identifier(&self) -> Option<InformationElementId> {
698 let id = match *self {
699 InformationElement::Ssid(_) => InformationElementId::Ssid,
700 InformationElement::Country(_) => InformationElementId::Country,
701 InformationElement::ChannelSwitchAnnouncement(_) => {
702 InformationElementId::ChannelSwitchAnnouncement
703 }
704 InformationElement::RobustSecurityNetwork(_) => {
705 InformationElementId::RobustSecurityNetwork
706 }
707 InformationElement::ExtendedChannelSwitchAnnouncement(_) => {
708 InformationElementId::ExtendedChannelSwitchAnnouncement
709 }
710 InformationElement::HighThroughputOperation(_) => {
711 InformationElementId::HighThroughputOperation
712 }
713 InformationElement::VeryHighThroughputOperation(_) => {
714 InformationElementId::VeryHighThroughputOperation
715 }
716 InformationElement::Other(ref ie) => InformationElementId::from(ie.identifier),
717 };
718 Some(id)
719 }
720 pub fn parse_all(data: &'a [u8]) -> Result<Vec<InformationElement<'a>>, Error> {
722 let mut ies = vec![];
723 let mut slice = data;
724 while let Ok(raw) = RawInformationElement::parse(slice) {
725 slice = &slice[raw.data.len() + 2..];
726 let id = InformationElementId::convert_from(raw.identifier);
727 let ie = if let Some(id) = id {
728 Self::from(id, raw.data)?
729 } else {
730 InformationElement::Other(raw)
731 };
732 ies.push(ie);
733 }
734 Ok(ies)
735 }
736}
737
738#[cfg(test)]
739mod tests {
740 use super::*;
741
742 #[test]
743 fn test_parse_ie() {
744 let bytes = [48, 6, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06];
745 let ie = RawInformationElement::parse(&bytes).unwrap();
746 assert_eq!(ie.identifier, 48u8);
747 assert_eq!(ie.data.len(), 6);
748 assert_eq!(ie.data, &bytes[2..]);
749 }
750
751 #[test]
752 fn test_parse_ies() {
753 let bytes = [
754 48, 6, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 4, 0, 1, 2, 0x55, 0xaa,
755 ];
756 let ies = InformationElements::parse(&bytes);
757 assert_eq!(ies.elements.len(), 3);
758 }
759}