1use std::{cell::Cell, io};
4
5use crate::{fl, scrypt, Callbacks, DecryptError, Decryptor, EncryptError, IdentityFile};
6
7enum IdentityState<R: io::Read, C: Callbacks> {
9 Encrypted {
10 decryptor: Decryptor<R>,
11 max_work_factor: Option<u8>,
12 callbacks: C,
13 },
14 Decrypted(IdentityFile<C>),
15
16 Poisoned(Option<DecryptError>),
21}
22
23impl<R: io::Read, C: Callbacks> Default for IdentityState<R, C> {
24 fn default() -> Self {
25 Self::Poisoned(None)
26 }
27}
28
29impl<R: io::Read, C: Callbacks> IdentityState<R, C> {
30 fn decrypt(self, filename: Option<&str>) -> Result<(IdentityFile<C>, bool), DecryptError> {
35 match self {
36 Self::Encrypted {
37 decryptor,
38 max_work_factor,
39 callbacks,
40 } => {
41 let passphrase = match callbacks.request_passphrase(&fl!(
42 "encrypted-passphrase-prompt",
43 filename = filename.unwrap_or_default()
44 )) {
45 Some(passphrase) => passphrase,
46 None => Err(DecryptError::KeyDecryptionFailed)?,
47 };
48
49 let mut identity = scrypt::Identity::new(passphrase);
50 if let Some(max_work_factor) = max_work_factor {
51 identity.set_max_work_factor(max_work_factor);
52 }
53
54 decryptor
55 .decrypt(Some(&identity as _).into_iter())
56 .map_err(|e| {
57 if matches!(e, DecryptError::DecryptionFailed) {
58 DecryptError::KeyDecryptionFailed
59 } else {
60 e
61 }
62 })
63 .and_then(|stream| {
64 let file = IdentityFile::from_buffer(io::BufReader::new(stream))?
65 .with_callbacks(callbacks);
66 Ok((file, true))
67 })
68 }
69 Self::Decrypted(identity_file) => Ok((identity_file, false)),
70 Self::Poisoned(e) => Err(e.unwrap()),
72 }
73 }
74}
75
76pub struct Identity<R: io::Read, C: Callbacks> {
78 state: Cell<IdentityState<R, C>>,
79 filename: Option<String>,
80}
81
82impl<R: io::Read, C: Callbacks> Identity<R, C> {
83 pub fn from_buffer(
90 data: R,
91 filename: Option<String>,
92 callbacks: C,
93 max_work_factor: Option<u8>,
94 ) -> Result<Option<Self>, DecryptError> {
95 let decryptor = Decryptor::new(data)?;
96 Ok(decryptor.is_scrypt().then_some(Identity {
97 state: Cell::new(IdentityState::Encrypted {
98 decryptor,
99 max_work_factor,
100 callbacks,
101 }),
102 filename,
103 }))
104 }
105
106 pub fn recipients(&self) -> Result<Vec<Box<dyn crate::Recipient + Send>>, EncryptError> {
111 match self.state.take().decrypt(self.filename.as_deref()) {
112 Ok((identity_file, _)) => {
113 let recipients = identity_file.to_recipients();
114 self.state.set(IdentityState::Decrypted(identity_file));
115 recipients
116 }
117 Err(e) => {
118 self.state.set(IdentityState::Poisoned(Some(e.clone())));
119 Err(EncryptError::EncryptedIdentities(e))
120 }
121 }
122 }
123
124 fn unwrap_stanzas_base<F>(
136 &self,
137 filter: F,
138 ) -> Option<Result<age_core::format::FileKey, DecryptError>>
139 where
140 F: Fn(
141 Result<Box<dyn crate::Identity>, DecryptError>,
142 ) -> Option<Result<age_core::format::FileKey, DecryptError>>,
143 {
144 match self.state.take().decrypt(self.filename.as_deref()) {
145 Ok((identity_file, requested_passphrase)) => {
146 let result = identity_file.to_identities().find_map(filter);
147
148 if requested_passphrase && result.is_none() {
151 identity_file.callbacks.display_message(&fl!(
152 "encrypted-warn-no-match",
153 filename = self.filename.as_deref().unwrap_or_default()
154 ));
155 }
156
157 self.state.set(IdentityState::Decrypted(identity_file));
158 result
159 }
160 Err(e) => {
161 self.state.set(IdentityState::Poisoned(Some(e.clone())));
162 Some(Err(e))
163 }
164 }
165 }
166}
167
168impl<R: io::Read, C: Callbacks> crate::Identity for Identity<R, C> {
169 fn unwrap_stanza(
170 &self,
171 stanza: &age_core::format::Stanza,
172 ) -> Option<Result<age_core::format::FileKey, DecryptError>> {
173 self.unwrap_stanzas_base(|identity| match identity {
174 Ok(i) => i.unwrap_stanza(stanza),
175 Err(e) => Some(Err(e)),
176 })
177 }
178
179 fn unwrap_stanzas(
180 &self,
181 stanzas: &[age_core::format::Stanza],
182 ) -> Option<Result<age_core::format::FileKey, DecryptError>> {
183 self.unwrap_stanzas_base(|identity| match identity {
184 Ok(i) => i.unwrap_stanzas(stanzas),
185 Err(e) => Some(Err(e)),
186 })
187 }
188}
189
190#[cfg(test)]
191mod tests {
192 use std::sync::{Arc, Mutex};
193
194 use age_core::secrecy::{ExposeSecret, SecretString};
195
196 use super::Identity;
197 use crate::{x25519, Callbacks, DecryptError, Identity as _, Recipient as _};
198
199 #[cfg(feature = "armor")]
200 use crate::armor::ArmoredReader;
201
202 const TEST_ENCRYPTED_IDENTITY_PASSPHRASE: &str = "foobar";
203
204 const TEST_ENCRYPTED_IDENTITY: &str = "-----BEGIN AGE ENCRYPTED FILE-----
205YWdlLWVuY3J5cHRpb24ub3JnL3YxCi0+IHNjcnlwdCBza2I4R0t6L2NLT2s4cGlI
206TTRGRjFRIDEwCnVodTdORmZjcCtjRmdnYU54bm8rZEJ5NWlrVHZLY1hyRzZEN2JE
207ZVpwWnMKLS0tIEZTcDlSL3oyRC9NQ3JZa3ZvUzNaNlk4bnhBSUdJRTFrMmE4QzMr
208UVNETlkK34fdtpwZz+qQaGuirGHEdodVe4JvnSG3ANQpWhkDcsRzoe/+OuHXNdnv
209zhBhaKdthstzGXbd2yJbLrTH1A3YbWO+/3zTIZENzKU9XbibLLQ4M/TXwKMzoObY
210oiMf5/+8GiQVREtHmm24wsc/479cVwnGVTdH7DL+wANmyf6S9Vc2FYQmXjLDxsJ0
211LMF6Cpgcg09C2gg4pcb4TFUWmDuxnZrfggrptOtyzC8O8aRuKPZqCGnzoWNOWl86
212fOrxrKTj7xCdNS3+OrCdnBC8Z9cKDxjCGWW3fkjLsYha0Jo=
213-----END AGE ENCRYPTED FILE-----
214";
215
216 const TEST_RECIPIENT: &str = "age1ysxuaeqlk7xd8uqsh8lsnfwt9jzzjlqf49ruhpjrrj5yatlcuf7qke4pqe";
217
218 #[derive(Clone)]
219 struct MockCallbacks(Arc<Mutex<Option<Option<&'static str>>>>);
220
221 impl MockCallbacks {
222 fn new(passphrase: Option<&'static str>) -> Self {
223 MockCallbacks(Arc::new(Mutex::new(Some(passphrase))))
224 }
225 }
226
227 impl Callbacks for MockCallbacks {
228 fn display_message(&self, _: &str) {
229 unimplemented!()
230 }
231
232 fn confirm(&self, _: &str, _: &str, _: Option<&str>) -> Option<bool> {
233 unimplemented!()
234 }
235
236 fn request_public_string(&self, _: &str) -> Option<String> {
237 unimplemented!()
238 }
239
240 fn request_passphrase(&self, _: &str) -> Option<SecretString> {
242 self.0
243 .lock()
244 .unwrap()
245 .take()
246 .expect("passphrase is only input once")
247 .to_owned()
248 .map(SecretString::from)
249 }
250 }
251
252 #[test]
253 #[cfg(feature = "armor")]
254 fn round_trip() {
255 use age_core::format::FileKey;
256
257 let pk: x25519::Recipient = TEST_RECIPIENT.parse().unwrap();
258 let file_key = FileKey::new(Box::new([12; 16]));
259 let (wrapped, labels) = pk.wrap_file_key(&file_key).unwrap();
260 assert!(labels.is_empty());
261
262 {
264 let buf = ArmoredReader::new(TEST_ENCRYPTED_IDENTITY.as_bytes());
265 let identity = Identity::from_buffer(
266 buf,
267 None,
268 MockCallbacks::new(Some("wrong passphrase")),
269 None,
270 )
271 .unwrap()
272 .unwrap();
273
274 if let Err(e) = identity.unwrap_stanzas(&wrapped).unwrap() {
275 assert!(matches!(e, DecryptError::KeyDecryptionFailed));
276 } else {
277 panic!("Should have failed");
278 }
279 }
280
281 {
283 let buf = ArmoredReader::new(TEST_ENCRYPTED_IDENTITY.as_bytes());
284 let identity = Identity::from_buffer(buf, None, MockCallbacks::new(None), None)
285 .unwrap()
286 .unwrap();
287
288 if let Err(e) = identity.unwrap_stanzas(&wrapped).unwrap() {
289 assert!(matches!(e, DecryptError::KeyDecryptionFailed));
290 } else {
291 panic!("Should have failed");
292 }
293 }
294
295 let buf = ArmoredReader::new(TEST_ENCRYPTED_IDENTITY.as_bytes());
296 let identity = Identity::from_buffer(
297 buf,
298 None,
299 MockCallbacks::new(Some(TEST_ENCRYPTED_IDENTITY_PASSPHRASE)),
300 None,
301 )
302 .unwrap()
303 .unwrap();
304 let unwrapped = identity.unwrap_stanzas(&wrapped);
305 assert_eq!(
306 unwrapped.unwrap().unwrap().expose_secret(),
307 file_key.expose_secret()
308 );
309
310 identity.unwrap_stanzas(&wrapped);
312 }
313}