1#![forbid(unsafe_code)]
29#![deny(missing_docs)]
30
31use std::{error::Error, fmt::Display};
32
33use pbkdf2::pbkdf2_hmac;
34use sha1::Sha1;
35
36#[derive(Debug)]
38pub struct Ssid<'a>(&'a [u8]);
39
40impl<'a> TryFrom<&'a [u8]> for Ssid<'a> {
41 type Error = ValidateSsidError;
42
43 fn try_from(value: &'a [u8]) -> Result<Self, Self::Error> {
44 if value.is_empty() {
45 Err(ValidateSsidError::TooShort)
46 } else if value.len() > 32 {
47 Err(ValidateSsidError::TooLong)
48 } else {
49 Ok(Ssid(value))
50 }
51 }
52}
53
54impl<'a> TryFrom<&'a str> for Ssid<'a> {
55 type Error = ValidateSsidError;
56
57 fn try_from(value: &'a str) -> Result<Self, Self::Error> {
58 Self::try_from(value.as_bytes())
59 }
60}
61
62#[derive(Debug, PartialEq, Eq)]
64pub enum ValidateSsidError {
65 TooShort,
67 TooLong,
69}
70
71impl Error for ValidateSsidError {}
72
73impl Display for ValidateSsidError {
74 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
75 let msg = match self {
76 ValidateSsidError::TooShort => "SSID must have at least one byte",
77 ValidateSsidError::TooLong => "SSID must have at most 32 bytes",
78 };
79 write!(f, "{msg}")
80 }
81}
82
83#[derive(Debug)]
85pub struct Passphrase<'a>(&'a [u8]);
86
87impl<'a> TryFrom<&'a [u8]> for Passphrase<'a> {
88 type Error = ValidatePassphraseError;
89
90 fn try_from(value: &'a [u8]) -> Result<Self, Self::Error> {
91 if value.len() < 8 {
92 Err(ValidatePassphraseError::TooShort)
93 } else if value.len() > 63 {
94 Err(ValidatePassphraseError::TooLong)
95 } else if value.iter().any(|i| !matches!(i, 32u8..=126)) {
96 Err(ValidatePassphraseError::InvalidByte)
97 } else {
98 Ok(Passphrase(value))
99 }
100 }
101}
102
103impl<'a> TryFrom<&'a str> for Passphrase<'a> {
104 type Error = ValidatePassphraseError;
105
106 fn try_from(value: &'a str) -> Result<Self, Self::Error> {
107 Self::try_from(value.as_bytes())
108 }
109}
110
111impl Display for Passphrase<'_> {
112 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
113 write!(f, "{}", std::str::from_utf8(self.0).unwrap())
114 }
115}
116
117#[derive(Debug, PartialEq, Eq)]
119pub enum ValidatePassphraseError {
120 TooShort,
122 TooLong,
124 InvalidByte,
126}
127
128impl Error for ValidatePassphraseError {}
129
130impl Display for ValidatePassphraseError {
131 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
132 let msg = match self {
133 ValidatePassphraseError::TooShort => "passphrase must have at least 8 bytes",
134 ValidatePassphraseError::TooLong => "passphrase must have at most 63 bytes",
135 ValidatePassphraseError::InvalidByte => {
136 "passphrase must consist of printable ASCII characters"
137 }
138 };
139 write!(f, "{msg}")
140 }
141}
142
143pub fn wpa_psk(ssid: &Ssid, passphrase: &Passphrase) -> [u8; 32] {
145 wpa_psk_unchecked(ssid.0, passphrase.0)
146}
147
148pub fn wpa_psk_unchecked(ssid: &[u8], passphrase: &[u8]) -> [u8; 32] {
151 let mut buf = [0u8; 32];
152 pbkdf2_hmac::<Sha1>(passphrase, ssid, 4096, &mut buf);
153 buf
154}
155
156pub fn bytes_to_hex(bytes: &[u8]) -> String {
158 use std::fmt::Write;
159 bytes.iter().fold(String::new(), |mut acc, b| {
160 let _ = write!(acc, "{b:02x}");
161 acc
162 })
163}
164
165#[cfg(test)]
166mod tests {
167 use super::*;
168
169 #[test]
170 fn special_characters() {
171 let ssid = Ssid::try_from("123abcABC.,-").unwrap();
172 let passphrase = Passphrase::try_from("456defDEF *<:D").unwrap();
173 assert_eq!(
174 bytes_to_hex(&wpa_psk(&ssid, &passphrase)),
175 "8a366e5bc51cd5d8fbbeffacc5f1af23fac30e3ac93cdcc368fafbbf63a1085c"
176 );
177 }
178
179 #[test]
180 fn passphrase_too_short() {
181 assert_eq!(
182 Passphrase::try_from("foobar").unwrap_err(),
183 ValidatePassphraseError::TooShort
184 );
185 }
186
187 #[test]
188 fn display_passphrase() {
189 assert_eq!(
190 format!("{}", Passphrase::try_from("foobarbuzz").unwrap()),
191 "foobarbuzz"
192 );
193 }
194}