wpa_psk/
lib.rs

1//! Compute the WPA-PSK of a Wi-Fi SSID and passphrase.
2//!
3//! # Example
4//!
5//! Compute and print the WPA-PSK of a valid SSID and passphrase:
6//!
7//! ```
8//! # use wpa_psk::{Ssid, Passphrase, wpa_psk, bytes_to_hex};
9//! # fn main() -> Result<(), Box<dyn std::error::Error>> {
10//! let ssid = Ssid::try_from("home")?;
11//! let passphrase = Passphrase::try_from("0123-4567-89")?;
12//! let psk = wpa_psk(&ssid, &passphrase);
13//! assert_eq!(bytes_to_hex(&psk), "150c047b6fad724512a17fa431687048ee503d14c1ea87681d4f241beb04f5ee");
14//! # Ok(())
15//! # }
16//! ```
17//!
18//! Compute the WPA-PSK of possibly invalid raw bytes:
19//!
20//! ```
21//! # use wpa_psk::{wpa_psk_unchecked, bytes_to_hex};
22//! let ssid = "bar".as_bytes();
23//! let passphrase = "2short".as_bytes();
24//! let psk = wpa_psk_unchecked(&ssid, &passphrase);
25//! assert_eq!(bytes_to_hex(&psk), "cb5de4e4d23b2ab0bf5b9ba0fe8132c1e2af3bb52298ec801af8ad520cea3437");
26//! ```
27
28#![forbid(unsafe_code)]
29#![deny(missing_docs)]
30
31use std::{error::Error, fmt::Display};
32
33use pbkdf2::pbkdf2_hmac;
34use sha1::Sha1;
35
36/// An SSID consisting of 1 up to 32 arbitrary bytes.
37#[derive(Debug)]
38pub struct Ssid<'a>(&'a [u8]);
39
40impl<'a> TryFrom<&'a [u8]> for Ssid<'a> {
41    type Error = ValidateSsidError;
42
43    fn try_from(value: &'a [u8]) -> Result<Self, Self::Error> {
44        if value.is_empty() {
45            Err(ValidateSsidError::TooShort)
46        } else if value.len() > 32 {
47            Err(ValidateSsidError::TooLong)
48        } else {
49            Ok(Ssid(value))
50        }
51    }
52}
53
54impl<'a> TryFrom<&'a str> for Ssid<'a> {
55    type Error = ValidateSsidError;
56
57    fn try_from(value: &'a str) -> Result<Self, Self::Error> {
58        Self::try_from(value.as_bytes())
59    }
60}
61
62/// SSID validation error.
63#[derive(Debug, PartialEq, Eq)]
64pub enum ValidateSsidError {
65    /// SSID is too short.
66    TooShort,
67    /// SSID is too long.
68    TooLong,
69}
70
71impl Error for ValidateSsidError {}
72
73impl Display for ValidateSsidError {
74    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
75        let msg = match self {
76            ValidateSsidError::TooShort => "SSID must have at least one byte",
77            ValidateSsidError::TooLong => "SSID must have at most 32 bytes",
78        };
79        write!(f, "{msg}")
80    }
81}
82
83/// A passphrase consisting of 8 up to 63 printable ASCII characters.
84#[derive(Debug)]
85pub struct Passphrase<'a>(&'a [u8]);
86
87impl<'a> TryFrom<&'a [u8]> for Passphrase<'a> {
88    type Error = ValidatePassphraseError;
89
90    fn try_from(value: &'a [u8]) -> Result<Self, Self::Error> {
91        if value.len() < 8 {
92            Err(ValidatePassphraseError::TooShort)
93        } else if value.len() > 63 {
94            Err(ValidatePassphraseError::TooLong)
95        } else if value.iter().any(|i| !matches!(i, 32u8..=126)) {
96            Err(ValidatePassphraseError::InvalidByte)
97        } else {
98            Ok(Passphrase(value))
99        }
100    }
101}
102
103impl<'a> TryFrom<&'a str> for Passphrase<'a> {
104    type Error = ValidatePassphraseError;
105
106    fn try_from(value: &'a str) -> Result<Self, Self::Error> {
107        Self::try_from(value.as_bytes())
108    }
109}
110
111impl Display for Passphrase<'_> {
112    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
113        write!(f, "{}", std::str::from_utf8(self.0).unwrap())
114    }
115}
116
117/// Passphrase validation error.
118#[derive(Debug, PartialEq, Eq)]
119pub enum ValidatePassphraseError {
120    /// Passphrase is too short.
121    TooShort,
122    /// Passphrase is too long.
123    TooLong,
124    /// Passphrase contains a invalid byte.
125    InvalidByte,
126}
127
128impl Error for ValidatePassphraseError {}
129
130impl Display for ValidatePassphraseError {
131    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
132        let msg = match self {
133            ValidatePassphraseError::TooShort => "passphrase must have at least 8 bytes",
134            ValidatePassphraseError::TooLong => "passphrase must have at most 63 bytes",
135            ValidatePassphraseError::InvalidByte => {
136                "passphrase must consist of printable ASCII characters"
137            }
138        };
139        write!(f, "{msg}")
140    }
141}
142
143/// Returns the WPA-PSK of the given SSID and passphrase.
144pub fn wpa_psk(ssid: &Ssid, passphrase: &Passphrase) -> [u8; 32] {
145    wpa_psk_unchecked(ssid.0, passphrase.0)
146}
147
148/// Unchecked WPA-PSK.
149/// See [`wpa_psk`].
150pub fn wpa_psk_unchecked(ssid: &[u8], passphrase: &[u8]) -> [u8; 32] {
151    let mut buf = [0u8; 32];
152    pbkdf2_hmac::<Sha1>(passphrase, ssid, 4096, &mut buf);
153    buf
154}
155
156/// Returns the hexdecimal representation of the given bytes.
157pub fn bytes_to_hex(bytes: &[u8]) -> String {
158    use std::fmt::Write;
159    bytes.iter().fold(String::new(), |mut acc, b| {
160        let _ = write!(acc, "{b:02x}");
161        acc
162    })
163}
164
165#[cfg(test)]
166mod tests {
167    use super::*;
168
169    #[test]
170    fn special_characters() {
171        let ssid = Ssid::try_from("123abcABC.,-").unwrap();
172        let passphrase = Passphrase::try_from("456defDEF *<:D").unwrap();
173        assert_eq!(
174            bytes_to_hex(&wpa_psk(&ssid, &passphrase)),
175            "8a366e5bc51cd5d8fbbeffacc5f1af23fac30e3ac93cdcc368fafbbf63a1085c"
176        );
177    }
178
179    #[test]
180    fn passphrase_too_short() {
181        assert_eq!(
182            Passphrase::try_from("foobar").unwrap_err(),
183            ValidatePassphraseError::TooShort
184        );
185    }
186
187    #[test]
188    fn display_passphrase() {
189        assert_eq!(
190            format!("{}", Passphrase::try_from("foobarbuzz").unwrap()),
191            "foobarbuzz"
192        );
193    }
194}