1use base32;
2use std::time::{Duration, SystemTime, SystemTimeError, UNIX_EPOCH};
3use totp_lite::{Sha1, Sha256, Sha512, totp_custom};
4use url::Url;
5use zeroize::{Zeroize, ZeroizeOnDrop};
6
7const DEFAULT_PERIOD: u64 = 30;
8const DEFAULT_DIGITS: u32 = 8;
9
10#[derive(Debug, PartialEq, Eq, Zeroize, ZeroizeOnDrop)]
12pub enum TOTPAlgorithm {
13 Sha1,
14 Sha256,
15 Sha512,
16}
17
18impl std::str::FromStr for TOTPAlgorithm {
19 type Err = TOTPError;
20
21 fn from_str(s: &str) -> Result<Self, Self::Err> {
22 match s.to_uppercase().as_str() {
23 "SHA1" => Ok(TOTPAlgorithm::Sha1),
24 "SHA256" => Ok(TOTPAlgorithm::Sha256),
25 "SHA512" => Ok(TOTPAlgorithm::Sha512),
26 _ => Err(TOTPError::BadAlgorithm(s.to_string())),
27 }
28 }
29}
30
31impl std::fmt::Display for TOTPAlgorithm {
32 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
33 match self {
34 TOTPAlgorithm::Sha1 => write!(f, "SHA1"),
35 TOTPAlgorithm::Sha256 => write!(f, "SHA256"),
36 TOTPAlgorithm::Sha512 => write!(f, "SHA512"),
37 }
38 }
39}
40
41#[derive(Debug, PartialEq, Eq, Zeroize, ZeroizeOnDrop)]
43pub struct TOTP {
44 pub label: String,
45 secret: Vec<u8>,
46 pub issuer: Option<String>,
47 pub period: u64,
48 pub digits: u32,
49 pub algorithm: TOTPAlgorithm,
50}
51
52pub struct OTPCode {
54 pub code: String,
55 pub valid_for: Duration,
56 pub period: Duration,
57}
58
59impl std::fmt::Display for OTPCode {
60 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
61 write!(
62 f,
63 "Code: {}, valid for: {}/{}s",
64 self.code,
65 self.valid_for.as_secs(),
66 self.period.as_secs(),
67 )
68 }
69}
70
71#[derive(Debug, thiserror::Error)]
73pub enum TOTPError {
74 #[error(transparent)]
75 UrlFormat(#[from] url::ParseError),
76
77 #[error(transparent)]
78 IntFormat(#[from] std::num::ParseIntError),
79
80 #[error("Missing TOTP field: {}", _0)]
81 MissingField(&'static str),
82
83 #[error(transparent)]
84 Time(#[from] SystemTimeError),
85
86 #[error("Base32 decoding error")]
87 Base32,
88
89 #[error("No OTP record found")]
90 NoRecord,
91
92 #[error("Bad URL scheme: '{}'", _0)]
93 BadScheme(String),
94
95 #[error("Bad hash algorithm: '{}'", _0)]
96 BadAlgorithm(String),
97}
98
99impl std::str::FromStr for TOTP {
100 type Err = TOTPError;
101
102 fn from_str(s: &str) -> Result<Self, Self::Err> {
103 let parsed = Url::parse(s)?;
104
105 if parsed.scheme() != "otpauth" {
106 return Err(TOTPError::BadScheme(parsed.scheme().to_string()));
107 }
108 let query_pairs = parsed.query_pairs();
109
110 let label: String = parsed.path().trim_start_matches('/').to_string();
111 let mut secret: Option<String> = None;
112 let mut issuer: Option<String> = None;
113 let mut period: u64 = DEFAULT_PERIOD;
114 let mut digits: u32 = DEFAULT_DIGITS;
115 let mut algorithm: TOTPAlgorithm = TOTPAlgorithm::Sha1;
116
117 for pair in query_pairs {
118 let (k, v) = pair;
119 match k.as_ref() {
120 "secret" => secret = Some(v.to_string()),
121 "issuer" => issuer = Some(v.to_string()),
122 "period" => period = v.parse()?,
123 "digits" => digits = v.parse()?,
124 "algorithm" => algorithm = v.parse()?,
125 _ => {}
126 }
127 }
128
129 let secret = secret.ok_or(TOTPError::MissingField("secret"))?;
130
131 let secret = base32::decode(base32::Alphabet::Rfc4648 { padding: true }, &secret).ok_or(TOTPError::Base32)?;
132
133 Ok(TOTP {
134 label,
135 secret,
136 issuer,
137 period,
138 digits,
139 algorithm,
140 })
141 }
142}
143
144impl std::fmt::Display for TOTP {
145 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
146 write!(
147 f,
148 "otpauth://totp/{}?secret={}&period={}&digits={}&issuer={}&algorithm={:?}",
149 self.label,
150 base32::encode(base32::Alphabet::Rfc4648 { padding: true }, &self.secret),
151 self.period,
152 self.digits,
153 self.issuer.as_deref().unwrap_or(""),
154 self.algorithm
155 )
156 }
157}
158
159impl TOTP {
160 pub fn value_at(&self, time: u64) -> OTPCode {
162 let code = match self.algorithm {
163 TOTPAlgorithm::Sha1 => totp_custom::<Sha1>(self.period, self.digits, &self.secret, time),
164 TOTPAlgorithm::Sha256 => totp_custom::<Sha256>(self.period, self.digits, &self.secret, time),
165 TOTPAlgorithm::Sha512 => totp_custom::<Sha512>(self.period, self.digits, &self.secret, time),
166 };
167
168 let valid_for = Duration::from_secs(self.period - (time % self.period));
169
170 OTPCode {
171 code,
172 valid_for,
173 period: Duration::from_secs(self.period),
174 }
175 }
176
177 pub fn value_now(&self) -> Result<OTPCode, SystemTimeError> {
179 let time: u64 = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs();
180 Ok(self.value_at(time))
181 }
182
183 pub fn get_secret(&self) -> String {
184 base32::encode(base32::Alphabet::Rfc4648 { padding: true }, &self.secret)
185 }
186}
187
188#[cfg(test)]
189mod kdbx4_otp_tests {
190 use super::{TOTP, TOTPAlgorithm, TOTPError};
191 use crate::{
192 db::{Database, Entry, Group, Node, with_node},
193 key::DatabaseKey,
194 };
195 use std::{fs::File, path::Path};
196
197 #[test]
198 fn kdbx4_entry() -> Result<(), Box<dyn std::error::Error>> {
199 let path = Path::new("tests/resources/test_db_kdbx4_with_totp_entry.kdbx");
201 let key = DatabaseKey::new().with_password("test");
202 let db = Database::open(&mut File::open(path)?, key)?;
203
204 let otp_str = "otpauth://totp/KeePassXC:none?secret=JBSWY3DPEHPK3PXP&period=30&digits=6&issuer=KeePassXC";
205
206 let entry = Group::get(&db.root, &["this entry has totp"]).unwrap();
208 with_node::<Entry, _, _>(&entry, |e| {
209 assert_eq!(e.get_title(), Some("this entry has totp"));
210 assert_eq!(e.get_raw_otp_value(), Some(otp_str));
211 })
212 .unwrap();
213
214 Ok(())
215 }
216
217 #[test]
218 fn totp_default() -> Result<(), TOTPError> {
219 let otp_str = "otpauth://totp/KeePassXC:none?secret=JBSWY3DPEHPK3PXP&period=30&digits=6&issuer=KeePassXC";
220
221 let expected = TOTP {
222 label: "KeePassXC:none".to_string(),
223 secret: b"Hello!\xDE\xAD\xBE\xEF".to_vec(),
224 issuer: Some("KeePassXC".to_string()),
225 period: 30,
226 digits: 6,
227 algorithm: TOTPAlgorithm::Sha1,
228 };
229
230 assert_eq!(otp_str.parse::<TOTP>()?, expected);
231
232 Ok(())
233 }
234
235 #[test]
236 fn totp_get_secret() -> Result<(), TOTPError> {
237 let otp_str = "otpauth://totp/KeePassXC:none?secret=JBSWY3DPEHPK3PXP&period=30&digits=6&issuer=KeePassXC";
238
239 let otp = otp_str.parse::<TOTP>()?;
240
241 assert_eq!(otp.get_secret(), "JBSWY3DPEHPK3PXP".to_string());
242
243 Ok(())
244 }
245
246 #[test]
247 fn totp_sha512() -> Result<(), TOTPError> {
248 let otp_str = "otpauth://totp/sha512%20totp:none?secret=GEZDGNBVGY%3D%3D%3D%3D%3D%3D&period=30&digits=6&issuer=sha512%20totp&algorithm=SHA512";
249
250 let expected = TOTP {
251 label: "sha512%20totp:none".to_string(),
252 secret: b"123456".to_vec(),
253 issuer: Some("sha512 totp".to_string()),
254 period: 30,
255 digits: 6,
256 algorithm: TOTPAlgorithm::Sha512,
257 };
258
259 assert_eq!(otp_str.parse::<TOTP>()?, expected);
260
261 Ok(())
262 }
263
264 #[test]
265 fn totp_value() {
266 let totp = TOTP {
267 label: "KeePassXC:none".to_string(),
268 secret: b"Hello!\xDE\xAD\xBE\xEF".to_vec(),
269 issuer: Some("KeePassXC".to_string()),
270 period: 30,
271 digits: 6,
272 algorithm: TOTPAlgorithm::Sha1,
273 };
274
275 assert_eq!(totp.value_at(1234).code, "806863")
276 }
277
278 #[test]
279 fn totp_bad() {
280 assert!(matches!("not a totp string".parse::<TOTP>(), Err(TOTPError::UrlFormat(_))));
281
282 assert!(matches!(
283 "http://totp/sha512%20totp:none?secret=GEZDGNBVGY%3D%3D%3D%3D%3D%3D&period=30&digits=6&issuer=sha512%20totp&algorithm=SHA512"
284 .parse::<TOTP>(),
285 Err(TOTPError::BadScheme(_))
286 ));
287
288 assert!(matches!(
289 "otpauth://totp/sha512%20totp:none?secret=GEZDGNBVGY%3D%3D%3D%3D%3D%3D&period=30&digits=6&issuer=sha512%20totp&algorithm=SHA123".parse::<TOTP>(),
290 Err(TOTPError::BadAlgorithm(_))
291 ));
292
293 assert!(matches!(
294 "otpauth://missing_fields".parse::<TOTP>(),
295 Err(TOTPError::MissingField("secret"))
296 ));
297 }
298
299 #[test]
300 fn totp_minimal() -> Result<(), TOTPError> {
301 let otp_str = "otpauth://totp/KeePassXC:none?secret=JBSWY3DPEHPK3PXP&period=30&digits=6";
302
303 let expected = TOTP {
304 label: "KeePassXC:none".to_string(),
305 secret: b"Hello!\xDE\xAD\xBE\xEF".to_vec(),
306 issuer: None,
307 period: 30,
308 digits: 6,
309 algorithm: TOTPAlgorithm::Sha1,
310 };
311
312 assert_eq!(otp_str.parse::<TOTP>()?, expected);
313
314 Ok(())
315 }
316}