keepass_ng/db/
otp.rs

1use base32;
2use std::time::{Duration, SystemTime, SystemTimeError, UNIX_EPOCH};
3use totp_lite::{Sha1, Sha256, Sha512, totp_custom};
4use url::Url;
5use zeroize::{Zeroize, ZeroizeOnDrop};
6
7const DEFAULT_PERIOD: u64 = 30;
8const DEFAULT_DIGITS: u32 = 8;
9
10/// Choices of hash algorithm for TOTP
11#[derive(Debug, PartialEq, Eq, Zeroize, ZeroizeOnDrop)]
12pub enum TOTPAlgorithm {
13    Sha1,
14    Sha256,
15    Sha512,
16}
17
18impl std::str::FromStr for TOTPAlgorithm {
19    type Err = TOTPError;
20
21    fn from_str(s: &str) -> Result<Self, Self::Err> {
22        match s.to_uppercase().as_str() {
23            "SHA1" => Ok(TOTPAlgorithm::Sha1),
24            "SHA256" => Ok(TOTPAlgorithm::Sha256),
25            "SHA512" => Ok(TOTPAlgorithm::Sha512),
26            _ => Err(TOTPError::BadAlgorithm(s.to_string())),
27        }
28    }
29}
30
31impl std::fmt::Display for TOTPAlgorithm {
32    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
33        match self {
34            TOTPAlgorithm::Sha1 => write!(f, "SHA1"),
35            TOTPAlgorithm::Sha256 => write!(f, "SHA256"),
36            TOTPAlgorithm::Sha512 => write!(f, "SHA512"),
37        }
38    }
39}
40
41/// Time-based one time password settings
42#[derive(Debug, PartialEq, Eq, Zeroize, ZeroizeOnDrop)]
43pub struct TOTP {
44    pub label: String,
45    secret: Vec<u8>,
46    pub issuer: Option<String>,
47    pub period: u64,
48    pub digits: u32,
49    pub algorithm: TOTPAlgorithm,
50}
51
52/// A generated one time password
53pub struct OTPCode {
54    pub code: String,
55    pub valid_for: Duration,
56    pub period: Duration,
57}
58
59impl std::fmt::Display for OTPCode {
60    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
61        write!(
62            f,
63            "Code: {}, valid for: {}/{}s",
64            self.code,
65            self.valid_for.as_secs(),
66            self.period.as_secs(),
67        )
68    }
69}
70
71/// Errors while processing a TOTP specification
72#[derive(Debug, thiserror::Error)]
73pub enum TOTPError {
74    #[error(transparent)]
75    UrlFormat(#[from] url::ParseError),
76
77    #[error(transparent)]
78    IntFormat(#[from] std::num::ParseIntError),
79
80    #[error("Missing TOTP field: {}", _0)]
81    MissingField(&'static str),
82
83    #[error(transparent)]
84    Time(#[from] SystemTimeError),
85
86    #[error("Base32 decoding error")]
87    Base32,
88
89    #[error("No OTP record found")]
90    NoRecord,
91
92    #[error("Bad URL scheme: '{}'", _0)]
93    BadScheme(String),
94
95    #[error("Bad hash algorithm: '{}'", _0)]
96    BadAlgorithm(String),
97}
98
99impl std::str::FromStr for TOTP {
100    type Err = TOTPError;
101
102    fn from_str(s: &str) -> Result<Self, Self::Err> {
103        let parsed = Url::parse(s)?;
104
105        if parsed.scheme() != "otpauth" {
106            return Err(TOTPError::BadScheme(parsed.scheme().to_string()));
107        }
108        let query_pairs = parsed.query_pairs();
109
110        let label: String = parsed.path().trim_start_matches('/').to_string();
111        let mut secret: Option<String> = None;
112        let mut issuer: Option<String> = None;
113        let mut period: u64 = DEFAULT_PERIOD;
114        let mut digits: u32 = DEFAULT_DIGITS;
115        let mut algorithm: TOTPAlgorithm = TOTPAlgorithm::Sha1;
116
117        for pair in query_pairs {
118            let (k, v) = pair;
119            match k.as_ref() {
120                "secret" => secret = Some(v.to_string()),
121                "issuer" => issuer = Some(v.to_string()),
122                "period" => period = v.parse()?,
123                "digits" => digits = v.parse()?,
124                "algorithm" => algorithm = v.parse()?,
125                _ => {}
126            }
127        }
128
129        let secret = secret.ok_or(TOTPError::MissingField("secret"))?;
130
131        let secret = base32::decode(base32::Alphabet::Rfc4648 { padding: true }, &secret).ok_or(TOTPError::Base32)?;
132
133        Ok(TOTP {
134            label,
135            secret,
136            issuer,
137            period,
138            digits,
139            algorithm,
140        })
141    }
142}
143
144impl std::fmt::Display for TOTP {
145    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
146        write!(
147            f,
148            "otpauth://totp/{}?secret={}&period={}&digits={}&issuer={}&algorithm={:?}",
149            self.label,
150            base32::encode(base32::Alphabet::Rfc4648 { padding: true }, &self.secret),
151            self.period,
152            self.digits,
153            self.issuer.as_deref().unwrap_or(""),
154            self.algorithm
155        )
156    }
157}
158
159impl TOTP {
160    /// Get the one-time code for a specific unix timestamp
161    pub fn value_at(&self, time: u64) -> OTPCode {
162        let code = match self.algorithm {
163            TOTPAlgorithm::Sha1 => totp_custom::<Sha1>(self.period, self.digits, &self.secret, time),
164            TOTPAlgorithm::Sha256 => totp_custom::<Sha256>(self.period, self.digits, &self.secret, time),
165            TOTPAlgorithm::Sha512 => totp_custom::<Sha512>(self.period, self.digits, &self.secret, time),
166        };
167
168        let valid_for = Duration::from_secs(self.period - (time % self.period));
169
170        OTPCode {
171            code,
172            valid_for,
173            period: Duration::from_secs(self.period),
174        }
175    }
176
177    /// Get the current one-time code
178    pub fn value_now(&self) -> Result<OTPCode, SystemTimeError> {
179        let time: u64 = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs();
180        Ok(self.value_at(time))
181    }
182
183    pub fn get_secret(&self) -> String {
184        base32::encode(base32::Alphabet::Rfc4648 { padding: true }, &self.secret)
185    }
186}
187
188#[cfg(test)]
189mod kdbx4_otp_tests {
190    use super::{TOTP, TOTPAlgorithm, TOTPError};
191    use crate::{
192        db::{Database, Entry, Group, Node, with_node},
193        key::DatabaseKey,
194    };
195    use std::{fs::File, path::Path};
196
197    #[test]
198    fn kdbx4_entry() -> Result<(), Box<dyn std::error::Error>> {
199        // KDBX4 database format Base64 encodes ExpiryTime (and all other XML timestamps)
200        let path = Path::new("tests/resources/test_db_kdbx4_with_totp_entry.kdbx");
201        let key = DatabaseKey::new().with_password("test");
202        let db = Database::open(&mut File::open(path)?, key)?;
203
204        let otp_str = "otpauth://totp/KeePassXC:none?secret=JBSWY3DPEHPK3PXP&period=30&digits=6&issuer=KeePassXC";
205
206        // get an entry on the root node
207        let entry = Group::get(&db.root, &["this entry has totp"]).unwrap();
208        with_node::<Entry, _, _>(&entry, |e| {
209            assert_eq!(e.get_title(), Some("this entry has totp"));
210            assert_eq!(e.get_raw_otp_value(), Some(otp_str));
211        })
212        .unwrap();
213
214        Ok(())
215    }
216
217    #[test]
218    fn totp_default() -> Result<(), TOTPError> {
219        let otp_str = "otpauth://totp/KeePassXC:none?secret=JBSWY3DPEHPK3PXP&period=30&digits=6&issuer=KeePassXC";
220
221        let expected = TOTP {
222            label: "KeePassXC:none".to_string(),
223            secret: b"Hello!\xDE\xAD\xBE\xEF".to_vec(),
224            issuer: Some("KeePassXC".to_string()),
225            period: 30,
226            digits: 6,
227            algorithm: TOTPAlgorithm::Sha1,
228        };
229
230        assert_eq!(otp_str.parse::<TOTP>()?, expected);
231
232        Ok(())
233    }
234
235    #[test]
236    fn totp_get_secret() -> Result<(), TOTPError> {
237        let otp_str = "otpauth://totp/KeePassXC:none?secret=JBSWY3DPEHPK3PXP&period=30&digits=6&issuer=KeePassXC";
238
239        let otp = otp_str.parse::<TOTP>()?;
240
241        assert_eq!(otp.get_secret(), "JBSWY3DPEHPK3PXP".to_string());
242
243        Ok(())
244    }
245
246    #[test]
247    fn totp_sha512() -> Result<(), TOTPError> {
248        let otp_str = "otpauth://totp/sha512%20totp:none?secret=GEZDGNBVGY%3D%3D%3D%3D%3D%3D&period=30&digits=6&issuer=sha512%20totp&algorithm=SHA512";
249
250        let expected = TOTP {
251            label: "sha512%20totp:none".to_string(),
252            secret: b"123456".to_vec(),
253            issuer: Some("sha512 totp".to_string()),
254            period: 30,
255            digits: 6,
256            algorithm: TOTPAlgorithm::Sha512,
257        };
258
259        assert_eq!(otp_str.parse::<TOTP>()?, expected);
260
261        Ok(())
262    }
263
264    #[test]
265    fn totp_value() {
266        let totp = TOTP {
267            label: "KeePassXC:none".to_string(),
268            secret: b"Hello!\xDE\xAD\xBE\xEF".to_vec(),
269            issuer: Some("KeePassXC".to_string()),
270            period: 30,
271            digits: 6,
272            algorithm: TOTPAlgorithm::Sha1,
273        };
274
275        assert_eq!(totp.value_at(1234).code, "806863")
276    }
277
278    #[test]
279    fn totp_bad() {
280        assert!(matches!("not a totp string".parse::<TOTP>(), Err(TOTPError::UrlFormat(_))));
281
282        assert!(matches!(
283            "http://totp/sha512%20totp:none?secret=GEZDGNBVGY%3D%3D%3D%3D%3D%3D&period=30&digits=6&issuer=sha512%20totp&algorithm=SHA512"
284                .parse::<TOTP>(),
285            Err(TOTPError::BadScheme(_))
286        ));
287
288        assert!(matches!(
289            "otpauth://totp/sha512%20totp:none?secret=GEZDGNBVGY%3D%3D%3D%3D%3D%3D&period=30&digits=6&issuer=sha512%20totp&algorithm=SHA123".parse::<TOTP>(),
290            Err(TOTPError::BadAlgorithm(_))
291        ));
292
293        assert!(matches!(
294            "otpauth://missing_fields".parse::<TOTP>(),
295            Err(TOTPError::MissingField("secret"))
296        ));
297    }
298
299    #[test]
300    fn totp_minimal() -> Result<(), TOTPError> {
301        let otp_str = "otpauth://totp/KeePassXC:none?secret=JBSWY3DPEHPK3PXP&period=30&digits=6";
302
303        let expected = TOTP {
304            label: "KeePassXC:none".to_string(),
305            secret: b"Hello!\xDE\xAD\xBE\xEF".to_vec(),
306            issuer: None,
307            period: 30,
308            digits: 6,
309            algorithm: TOTPAlgorithm::Sha1,
310        };
311
312        assert_eq!(otp_str.parse::<TOTP>()?, expected);
313
314        Ok(())
315    }
316}