Skip to main content

embedded_svc/
wifi.rs

1use core::fmt::Debug;
2use core::mem;
3
4#[cfg(feature = "alloc")]
5extern crate alloc;
6
7use enumset::*;
8
9#[cfg(feature = "use_serde")]
10use serde::{Deserialize, Serialize};
11
12#[cfg(feature = "use_strum")]
13use strum_macros::{Display, EnumIter, EnumMessage, EnumString, FromRepr, VariantNames};
14
15#[cfg(feature = "use_numenum")]
16use num_enum::TryFromPrimitive;
17
18#[derive(EnumSetType, Debug, PartialOrd)]
19#[cfg_attr(feature = "defmt", derive(defmt::Format))]
20#[cfg_attr(feature = "use_serde", derive(Serialize, Deserialize))]
21#[cfg_attr(
22    feature = "use_strum",
23    derive(EnumString, Display, EnumMessage, EnumIter, VariantNames, FromRepr)
24)]
25#[cfg_attr(feature = "use_numenum", derive(TryFromPrimitive))]
26#[cfg_attr(feature = "use_numenum", repr(u8))]
27#[derive(Default)]
28pub enum AuthMethod {
29    #[cfg_attr(feature = "use_strum", strum(serialize = "none", message = "None"))]
30    None,
31    #[cfg_attr(feature = "use_strum", strum(serialize = "wep", message = "WEP"))]
32    WEP,
33    #[cfg_attr(feature = "use_strum", strum(serialize = "wpa", message = "WPA"))]
34    WPA,
35    #[cfg_attr(
36        feature = "use_strum",
37        strum(serialize = "wpa2personal", message = "WPA2 Personal")
38    )]
39    #[default]
40    WPA2Personal,
41    #[cfg_attr(
42        feature = "use_strum",
43        strum(serialize = "wpawpa2personal", message = "WPA & WPA2 Personal")
44    )]
45    WPAWPA2Personal,
46    #[cfg_attr(
47        feature = "use_strum",
48        strum(serialize = "wpa2enterprise", message = "WPA2 Enterprise")
49    )]
50    WPA2Enterprise,
51    #[cfg_attr(
52        feature = "use_strum",
53        strum(serialize = "wpa3personal", message = "WPA3 Personal")
54    )]
55    WPA3Personal,
56    #[cfg_attr(
57        feature = "use_strum",
58        strum(serialize = "wpa2wpa3personal", message = "WPA2 & WPA3 Personal")
59    )]
60    WPA2WPA3Personal,
61    #[cfg_attr(
62        feature = "use_strum",
63        strum(serialize = "wapipersonal", message = "WAPI Personal")
64    )]
65    WAPIPersonal,
66}
67
68#[derive(EnumSetType, Debug, PartialOrd)]
69#[cfg_attr(feature = "defmt", derive(defmt::Format))]
70#[cfg_attr(feature = "use_serde", derive(Serialize, Deserialize))]
71#[cfg_attr(
72    feature = "use_strum",
73    derive(EnumString, Display, EnumMessage, EnumIter, VariantNames, FromRepr)
74)]
75#[cfg_attr(feature = "use_numenum", derive(TryFromPrimitive))]
76#[cfg_attr(feature = "use_numenum", repr(u8))]
77#[derive(Default)]
78pub enum Protocol {
79    #[cfg_attr(
80        feature = "use_strum",
81        strum(serialize = "p802d11b", message = "802.11B")
82    )]
83    P802D11B,
84    #[cfg_attr(
85        feature = "use_strum",
86        strum(serialize = "p802d11bg", message = "802.11BG")
87    )]
88    P802D11BG,
89    #[cfg_attr(
90        feature = "use_strum",
91        strum(serialize = "p802d11bgn", message = "802.11BGN")
92    )]
93    #[default]
94    P802D11BGN,
95    #[cfg_attr(
96        feature = "use_strum",
97        strum(serialize = "p802d11bgnlr", message = "802.11BGNLR")
98    )]
99    P802D11BGNLR,
100    #[cfg_attr(
101        feature = "use_strum",
102        strum(serialize = "p802d11lr", message = "802.11LR")
103    )]
104    P802D11LR,
105}
106
107#[derive(EnumSetType, Debug, PartialOrd)]
108#[cfg_attr(feature = "defmt", derive(defmt::Format))]
109#[cfg_attr(feature = "use_serde", derive(Serialize, Deserialize))]
110#[cfg_attr(
111    feature = "use_strum",
112    derive(EnumString, Display, EnumMessage, EnumIter, VariantNames, FromRepr)
113)]
114#[cfg_attr(feature = "use_numenum", derive(TryFromPrimitive))]
115#[cfg_attr(feature = "use_numenum", repr(u8))]
116#[derive(Default)]
117pub enum SecondaryChannel {
118    // TODO: Need to extend that for 5GHz
119    #[cfg_attr(feature = "use_strum", strum(serialize = "none", message = "None"))]
120    #[default]
121    None,
122    #[cfg_attr(feature = "use_strum", strum(serialize = "above", message = "Above"))]
123    Above,
124    #[cfg_attr(feature = "use_strum", strum(serialize = "below", message = "Below"))]
125    Below,
126}
127
128#[derive(Clone, Debug, Default, PartialEq, Eq)]
129#[cfg_attr(feature = "defmt", derive(defmt::Format))]
130#[cfg_attr(feature = "use_serde", derive(Serialize, Deserialize))]
131pub struct AccessPointInfo {
132    pub ssid: heapless::String<32>,
133    pub bssid: [u8; 6],
134    pub channel: u8,
135    pub secondary_channel: SecondaryChannel,
136    pub signal_strength: i8,
137    #[cfg_attr(feature = "defmt", defmt(Debug2Format))]
138    pub protocols: EnumSet<Protocol>,
139    pub auth_method: Option<AuthMethod>,
140}
141
142#[derive(Clone, Debug, PartialEq, Eq)]
143#[cfg_attr(feature = "defmt", derive(defmt::Format))]
144#[cfg_attr(feature = "use_serde", derive(Serialize, Deserialize))]
145pub struct AccessPointConfiguration {
146    pub ssid: heapless::String<32>,
147    pub ssid_hidden: bool,
148    pub channel: u8,
149    pub secondary_channel: Option<u8>,
150    #[cfg_attr(feature = "defmt", defmt(Debug2Format))]
151    pub protocols: EnumSet<Protocol>,
152    pub auth_method: AuthMethod,
153    pub password: heapless::String<64>,
154    pub max_connections: u16,
155}
156
157impl Default for AccessPointConfiguration {
158    fn default() -> Self {
159        Self {
160            ssid: "iot-device".try_into().unwrap(),
161            ssid_hidden: false,
162            channel: 1,
163            secondary_channel: None,
164            protocols: Protocol::P802D11B | Protocol::P802D11BG | Protocol::P802D11BGN,
165            auth_method: AuthMethod::None,
166            password: heapless::String::new(),
167            max_connections: 255,
168        }
169    }
170}
171
172/// Configuration for wifi in STA mode
173#[derive(Clone, PartialEq, Eq)]
174#[cfg_attr(feature = "defmt", derive(defmt::Format))]
175#[cfg_attr(feature = "use_serde", derive(Serialize, Deserialize))]
176pub struct ClientConfiguration {
177    /// SSID of the target AP
178    pub ssid: heapless::String<32>,
179    /// BSSID of the target AP
180    pub bssid: Option<[u8; 6]>,
181    //pub protocol: Protocol,
182    #[cfg_attr(feature = "use_serde", serde(default))]
183    pub auth_method: AuthMethod,
184    #[cfg_attr(feature = "use_serde", serde(default))]
185    pub password: heapless::String<64>,
186    /// The expected Channel of the target AP
187    ///
188    /// Connecting might be quicker when the client starts its scan
189    /// at the channel the target AP is on.
190    pub channel: Option<u8>,
191    /// The scan method to use when searching for the target AP
192    #[cfg_attr(feature = "use_serde", serde(default))]
193    pub scan_method: ScanMethod,
194    /// Protected Management Frame configuration
195    #[cfg_attr(feature = "use_serde", serde(default))]
196    pub pmf_cfg: PmfConfiguration,
197}
198
199impl Debug for ClientConfiguration {
200    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
201        f.debug_struct("ClientConfiguration")
202            .field("ssid", &self.ssid)
203            .field("bssid", &self.bssid)
204            .field("auth_method", &self.auth_method)
205            .field("channel", &self.channel)
206            .field("scan_method", &self.scan_method)
207            .field("pmf_cfg", &self.pmf_cfg)
208            .finish()
209    }
210}
211
212impl Default for ClientConfiguration {
213    fn default() -> Self {
214        ClientConfiguration {
215            ssid: heapless::String::new(),
216            bssid: None,
217            auth_method: Default::default(),
218            password: heapless::String::new(),
219            channel: None,
220            scan_method: ScanMethod::default(),
221            pmf_cfg: PmfConfiguration::default(),
222        }
223    }
224}
225
226/// Protected Management Frame configuration
227#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
228#[cfg_attr(feature = "defmt", derive(defmt::Format))]
229#[cfg_attr(feature = "use_serde", derive(Serialize, Deserialize))]
230#[cfg_attr(
231    feature = "use_strum",
232    derive(EnumString, Display, EnumMessage, EnumIter, VariantNames)
233)]
234pub enum PmfConfiguration {
235    /// No support for PMF will be advertized (default)
236    #[default]
237    #[cfg_attr(
238        feature = "use_strum",
239        strum(
240            serialize = "not_capable",
241            serialize = "pmf_disabled",
242            to_string = "PMF Disabled",
243            message = "Don't advertise PMF capabilities",
244        )
245    )]
246    NotCapable,
247    /// Advertize PMF support and wether PMF is required or not
248    #[cfg_attr(
249        feature = "use_strum",
250        strum(
251            serialize = "capable",
252            to_string = "PMF enabled",
253            message = "Advertise PMF capabilities",
254        )
255    )]
256    Capable { required: bool },
257}
258impl PmfConfiguration {
259    /// PMF configuration with PMF strictly required
260    pub fn new_required() -> Self {
261        PmfConfiguration::Capable { required: true }
262    }
263    /// PMF configuration with PMF optional but available
264    pub fn new_pmf_optional() -> Self {
265        PmfConfiguration::Capable { required: true }
266    }
267}
268
269/// The scan method to use when connecting to an AP
270#[derive(Debug, Clone, PartialEq, Eq)]
271#[cfg_attr(feature = "defmt", derive(defmt::Format))]
272#[cfg_attr(feature = "use_serde", derive(Serialize, Deserialize))]
273#[cfg_attr(
274    feature = "use_strum",
275    derive(EnumString, Display, EnumMessage, EnumIter, VariantNames)
276)]
277#[non_exhaustive]
278pub enum ScanMethod {
279    /// Scan every channel and connect according to [ScanSortMethod] (default)
280    #[cfg_attr(
281        feature = "use_strum",
282        strum(
283            serialize = "complete_scan",
284            to_string = "Complete Scan",
285            message = "Do a complete scan",
286            detailed_message = "Scan all APs and sort by a criteria"
287        )
288    )]
289    CompleteScan(ScanSortMethod),
290    /// Connect to the first found AP and stop scanning
291    #[cfg_attr(
292        feature = "use_strum",
293        strum(
294            serialize = "fast_scan",
295            to_string = "Fast Scan",
296            message = "Do a fast scan",
297            detailed_message = "Connect to the first matching AP"
298        )
299    )]
300    FastScan,
301}
302impl Default for ScanMethod {
303    fn default() -> Self {
304        Self::CompleteScan(ScanSortMethod::default())
305    }
306}
307
308#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
309#[cfg_attr(feature = "defmt", derive(defmt::Format))]
310#[cfg_attr(feature = "use_serde", derive(Serialize, Deserialize))]
311#[cfg_attr(
312    feature = "use_strum",
313    derive(EnumString, Display, EnumMessage, EnumIter, VariantNames)
314)]
315#[non_exhaustive]
316pub enum ScanSortMethod {
317    /// Sort by signal strength (default)
318    #[default]
319    #[cfg_attr(
320        feature = "use_strum",
321        strum(
322            serialize = "signal_strength",
323            serialize = "signal",
324            to_string = "Signal Strength",
325            message = "Sort by signal strength"
326        )
327    )]
328    Signal,
329    /// Sort by Security
330    #[cfg_attr(
331        feature = "use_strum",
332        strum(
333            serialize = "security",
334            to_string = "Security",
335            message = "Sort by security"
336        )
337    )]
338    Security,
339}
340
341#[derive(EnumSetType, Debug, PartialOrd)]
342#[cfg_attr(feature = "defmt", derive(defmt::Format))]
343#[cfg_attr(feature = "use_serde", derive(Serialize, Deserialize))]
344#[cfg_attr(
345    feature = "use_strum",
346    derive(EnumString, Display, EnumMessage, EnumIter, VariantNames, FromRepr)
347)]
348#[cfg_attr(feature = "use_numenum", derive(TryFromPrimitive))]
349#[cfg_attr(feature = "use_numenum", repr(u8))]
350pub enum Capability {
351    #[cfg_attr(feature = "use_strum", strum(serialize = "client", message = "Client"))]
352    Client,
353    #[cfg_attr(
354        feature = "use_strum",
355        strum(serialize = "ap", message = "Access Point")
356    )]
357    AccessPoint,
358    #[cfg_attr(
359        feature = "use_strum",
360        strum(serialize = "mixed", message = "Client & Access Point")
361    )]
362    Mixed,
363}
364
365#[derive(Clone, Debug, PartialEq, Eq)]
366#[cfg_attr(feature = "defmt", derive(defmt::Format))]
367#[cfg_attr(feature = "use_serde", derive(Serialize, Deserialize))]
368#[derive(Default)]
369pub enum Configuration {
370    #[default]
371    None,
372    Client(ClientConfiguration),
373    AccessPoint(AccessPointConfiguration),
374    Mixed(ClientConfiguration, AccessPointConfiguration),
375}
376
377impl Configuration {
378    pub fn as_client_conf_ref(&self) -> Option<&ClientConfiguration> {
379        match self {
380            Self::Client(client_conf) | Self::Mixed(client_conf, _) => Some(client_conf),
381            _ => None,
382        }
383    }
384
385    pub fn as_ap_conf_ref(&self) -> Option<&AccessPointConfiguration> {
386        match self {
387            Self::AccessPoint(ap_conf) | Self::Mixed(_, ap_conf) => Some(ap_conf),
388            _ => None,
389        }
390    }
391
392    pub fn as_client_conf_mut(&mut self) -> &mut ClientConfiguration {
393        match self {
394            Self::Client(client_conf) => client_conf,
395            Self::Mixed(_, _) => {
396                let prev = mem::replace(self, Self::None);
397                match prev {
398                    Self::Mixed(client_conf, _) => {
399                        *self = Self::Client(client_conf);
400                        self.as_client_conf_mut()
401                    }
402                    _ => unreachable!(),
403                }
404            }
405            _ => {
406                *self = Self::Client(Default::default());
407                self.as_client_conf_mut()
408            }
409        }
410    }
411
412    pub fn as_ap_conf_mut(&mut self) -> &mut AccessPointConfiguration {
413        match self {
414            Self::AccessPoint(ap_conf) => ap_conf,
415            Self::Mixed(_, _) => {
416                let prev = mem::replace(self, Self::None);
417                match prev {
418                    Self::Mixed(_, ap_conf) => {
419                        *self = Self::AccessPoint(ap_conf);
420                        self.as_ap_conf_mut()
421                    }
422                    _ => unreachable!(),
423                }
424            }
425            _ => {
426                *self = Self::AccessPoint(Default::default());
427                self.as_ap_conf_mut()
428            }
429        }
430    }
431
432    pub fn as_mixed_conf_mut(
433        &mut self,
434    ) -> (&mut ClientConfiguration, &mut AccessPointConfiguration) {
435        match self {
436            Self::Mixed(client_conf, ref mut ap_conf) => (client_conf, ap_conf),
437            Self::AccessPoint(_) => {
438                let prev = mem::replace(self, Self::None);
439                match prev {
440                    Self::AccessPoint(ap_conf) => {
441                        *self = Self::Mixed(Default::default(), ap_conf);
442                        self.as_mixed_conf_mut()
443                    }
444                    _ => unreachable!(),
445                }
446            }
447            Self::Client(_) => {
448                let prev = mem::replace(self, Self::None);
449                match prev {
450                    Self::Client(client_conf) => {
451                        *self = Self::Mixed(client_conf, Default::default());
452                        self.as_mixed_conf_mut()
453                    }
454                    _ => unreachable!(),
455                }
456            }
457            _ => {
458                *self = Self::Mixed(Default::default(), Default::default());
459                self.as_mixed_conf_mut()
460            }
461        }
462    }
463}
464
465pub trait Wifi {
466    type Error: Debug;
467
468    fn get_capabilities(&self) -> Result<EnumSet<Capability>, Self::Error>;
469
470    fn get_configuration(&self) -> Result<Configuration, Self::Error>;
471
472    fn set_configuration(&mut self, conf: &Configuration) -> Result<(), Self::Error>;
473
474    fn start(&mut self) -> Result<(), Self::Error>;
475    fn stop(&mut self) -> Result<(), Self::Error>;
476
477    fn connect(&mut self) -> Result<(), Self::Error>;
478    fn disconnect(&mut self) -> Result<(), Self::Error>;
479
480    fn is_started(&self) -> Result<bool, Self::Error>;
481    fn is_connected(&self) -> Result<bool, Self::Error>;
482
483    fn scan_n<const N: usize>(
484        &mut self,
485    ) -> Result<(heapless::Vec<AccessPointInfo, N>, usize), Self::Error>;
486
487    #[cfg(feature = "alloc")]
488    fn scan(&mut self) -> Result<alloc::vec::Vec<AccessPointInfo>, Self::Error>;
489}
490
491impl<W> Wifi for &mut W
492where
493    W: Wifi,
494{
495    type Error = W::Error;
496
497    fn get_capabilities(&self) -> Result<EnumSet<Capability>, Self::Error> {
498        (**self).get_capabilities()
499    }
500
501    fn get_configuration(&self) -> Result<Configuration, Self::Error> {
502        (**self).get_configuration()
503    }
504
505    fn set_configuration(&mut self, conf: &Configuration) -> Result<(), Self::Error> {
506        (*self).set_configuration(conf)
507    }
508
509    fn start(&mut self) -> Result<(), Self::Error> {
510        (*self).start()
511    }
512
513    fn stop(&mut self) -> Result<(), Self::Error> {
514        (*self).stop()
515    }
516
517    fn connect(&mut self) -> Result<(), Self::Error> {
518        (*self).connect()
519    }
520
521    fn disconnect(&mut self) -> Result<(), Self::Error> {
522        (*self).disconnect()
523    }
524
525    fn is_started(&self) -> Result<bool, Self::Error> {
526        (**self).is_started()
527    }
528
529    fn is_connected(&self) -> Result<bool, Self::Error> {
530        (**self).is_connected()
531    }
532
533    fn scan_n<const N: usize>(
534        &mut self,
535    ) -> Result<(heapless::Vec<AccessPointInfo, N>, usize), Self::Error> {
536        (*self).scan_n()
537    }
538
539    #[cfg(feature = "alloc")]
540    fn scan(&mut self) -> Result<alloc::vec::Vec<AccessPointInfo>, Self::Error> {
541        (*self).scan()
542    }
543}
544
545pub mod asynch {
546    use super::*;
547
548    pub trait Wifi {
549        type Error: Debug;
550
551        async fn get_capabilities(&self) -> Result<EnumSet<Capability>, Self::Error>;
552
553        async fn get_configuration(&self) -> Result<Configuration, Self::Error>;
554
555        async fn set_configuration(&mut self, conf: &Configuration) -> Result<(), Self::Error>;
556
557        async fn start(&mut self) -> Result<(), Self::Error>;
558        async fn stop(&mut self) -> Result<(), Self::Error>;
559
560        async fn connect(&mut self) -> Result<(), Self::Error>;
561        async fn disconnect(&mut self) -> Result<(), Self::Error>;
562
563        async fn is_started(&self) -> Result<bool, Self::Error>;
564        async fn is_connected(&self) -> Result<bool, Self::Error>;
565
566        async fn scan_n<const N: usize>(
567            &mut self,
568        ) -> Result<(heapless::Vec<AccessPointInfo, N>, usize), Self::Error>;
569
570        #[cfg(feature = "alloc")]
571        async fn scan(&mut self) -> Result<alloc::vec::Vec<AccessPointInfo>, Self::Error>;
572    }
573
574    impl<W> Wifi for &mut W
575    where
576        W: Wifi,
577    {
578        type Error = W::Error;
579
580        async fn get_capabilities(&self) -> Result<EnumSet<Capability>, Self::Error> {
581            (**self).get_capabilities().await
582        }
583
584        async fn get_configuration(&self) -> Result<Configuration, Self::Error> {
585            (**self).get_configuration().await
586        }
587
588        async fn set_configuration(&mut self, conf: &Configuration) -> Result<(), Self::Error> {
589            (**self).set_configuration(conf).await
590        }
591
592        async fn start(&mut self) -> Result<(), Self::Error> {
593            (**self).start().await
594        }
595
596        async fn stop(&mut self) -> Result<(), Self::Error> {
597            (**self).stop().await
598        }
599
600        async fn connect(&mut self) -> Result<(), Self::Error> {
601            (**self).connect().await
602        }
603
604        async fn disconnect(&mut self) -> Result<(), Self::Error> {
605            (**self).disconnect().await
606        }
607
608        async fn is_started(&self) -> Result<bool, Self::Error> {
609            (**self).is_started().await
610        }
611
612        async fn is_connected(&self) -> Result<bool, Self::Error> {
613            (**self).is_connected().await
614        }
615
616        async fn scan_n<const N: usize>(
617            &mut self,
618        ) -> Result<(heapless::Vec<AccessPointInfo, N>, usize), Self::Error> {
619            (**self).scan_n::<N>().await
620        }
621
622        #[cfg(feature = "alloc")]
623        async fn scan(&mut self) -> Result<alloc::vec::Vec<AccessPointInfo>, Self::Error> {
624            (**self).scan().await
625        }
626    }
627}