1use age_core::{
4 format::{is_arbitrary_string, FileKey, Stanza},
5 plugin::{self, BidirSend, Connection},
6 secrecy::SecretString,
7};
8use base64::{prelude::BASE64_STANDARD_NO_PAD, Engine};
9use bech32::FromBase32;
10
11use std::collections::HashSet;
12use std::convert::Infallible;
13use std::io;
14
15use crate::{Callbacks, PLUGIN_IDENTITY_PREFIX, PLUGIN_RECIPIENT_PREFIX};
16
17const ADD_RECIPIENT: &str = "add-recipient";
18const ADD_IDENTITY: &str = "add-identity";
19const WRAP_FILE_KEY: &str = "wrap-file-key";
20const EXTENSION_LABELS: &str = "extension-labels";
21const RECIPIENT_STANZA: &str = "recipient-stanza";
22const LABELS: &str = "labels";
23
24pub trait RecipientPluginV1 {
36 fn add_recipient(&mut self, index: usize, plugin_name: &str, bytes: &[u8])
42 -> Result<(), Error>;
43
44 fn add_identity(&mut self, index: usize, plugin_name: &str, bytes: &[u8]) -> Result<(), Error>;
50
51 fn labels(&mut self) -> HashSet<String>;
80
81 fn wrap_file_keys(
94 &mut self,
95 file_keys: Vec<FileKey>,
96 callbacks: impl Callbacks<Error>,
97 ) -> io::Result<Result<Vec<Vec<Stanza>>, Vec<Error>>>;
98}
99
100impl RecipientPluginV1 for Infallible {
101 fn add_recipient(&mut self, _: usize, _: &str, _: &[u8]) -> Result<(), Error> {
102 Ok(())
104 }
105
106 fn add_identity(&mut self, _: usize, _: &str, _: &[u8]) -> Result<(), Error> {
107 Ok(())
109 }
110
111 fn labels(&mut self) -> HashSet<String> {
112 HashSet::new()
114 }
115
116 fn wrap_file_keys(
117 &mut self,
118 _: Vec<FileKey>,
119 _: impl Callbacks<Error>,
120 ) -> io::Result<Result<Vec<Vec<Stanza>>, Vec<Error>>> {
121 Ok(Ok(vec![]))
123 }
124}
125
126struct BidirCallbacks<'a, 'b, R: io::Read, W: io::Write>(&'b mut BidirSend<'a, R, W>);
128
129impl<'a, 'b, R: io::Read, W: io::Write> Callbacks<Error> for BidirCallbacks<'a, 'b, R, W> {
130 fn message(&mut self, message: &str) -> plugin::Result<()> {
135 self.0
136 .send("msg", &[], message.as_bytes())
137 .map(|res| res.map(|_| ()))
138 }
139
140 fn confirm(
141 &mut self,
142 message: &str,
143 yes_string: &str,
144 no_string: Option<&str>,
145 ) -> age_core::plugin::Result<bool> {
146 let metadata: Vec<_> = Some(yes_string)
147 .into_iter()
148 .chain(no_string)
149 .map(|s| BASE64_STANDARD_NO_PAD.encode(s))
150 .collect();
151 let metadata: Vec<_> = metadata.iter().map(|s| s.as_str()).collect();
152
153 self.0
154 .send("confirm", &metadata, message.as_bytes())
155 .and_then(|res| match res {
156 Ok(s) => match &s.args[..] {
157 [x] if x == "yes" => Ok(Ok(true)),
158 [x] if x == "no" => Ok(Ok(false)),
159 _ => Err(io::Error::new(
160 io::ErrorKind::InvalidData,
161 "Invalid response to confirm command",
162 )),
163 },
164 Err(e) => Ok(Err(e)),
165 })
166 }
167
168 fn request_public(&mut self, message: &str) -> plugin::Result<String> {
169 self.0
170 .send("request-public", &[], message.as_bytes())
171 .and_then(|res| match res {
172 Ok(s) => String::from_utf8(s.body)
173 .map_err(|_| {
174 io::Error::new(io::ErrorKind::InvalidData, "response is not UTF-8")
175 })
176 .map(Ok),
177 Err(e) => Ok(Err(e)),
178 })
179 }
180
181 fn request_secret(&mut self, message: &str) -> plugin::Result<SecretString> {
185 self.0
186 .send("request-secret", &[], message.as_bytes())
187 .and_then(|res| match res {
188 Ok(s) => String::from_utf8(s.body)
189 .map_err(|_| io::Error::new(io::ErrorKind::InvalidData, "secret is not UTF-8"))
190 .map(|s| Ok(SecretString::from(s))),
191 Err(e) => Ok(Err(e)),
192 })
193 }
194
195 fn error(&mut self, error: Error) -> plugin::Result<()> {
196 error.send(self.0).map(|()| Ok(()))
197 }
198}
199
200pub enum Error {
202 Recipient {
204 index: usize,
206 message: String,
208 },
209 Identity {
211 index: usize,
213 message: String,
215 },
216 Internal {
218 message: String,
220 },
221}
222
223impl Error {
224 fn kind(&self) -> &str {
225 match self {
226 Error::Recipient { .. } => "recipient",
227 Error::Identity { .. } => "identity",
228 Error::Internal { .. } => "internal",
229 }
230 }
231
232 fn message(&self) -> &str {
233 match self {
234 Error::Recipient { message, .. } => message,
235 Error::Identity { message, .. } => message,
236 Error::Internal { message } => message,
237 }
238 }
239
240 fn send<R: io::Read, W: io::Write>(self, phase: &mut BidirSend<R, W>) -> io::Result<()> {
241 let index = match self {
242 Error::Recipient { index, .. } | Error::Identity { index, .. } => {
243 Some(index.to_string())
244 }
245 Error::Internal { .. } => None,
246 };
247
248 let metadata = match &index {
249 Some(index) => vec![self.kind(), index],
250 None => vec![self.kind()],
251 };
252
253 phase
254 .send("error", &metadata, self.message().as_bytes())?
255 .unwrap();
256
257 Ok(())
258 }
259}
260
261pub(crate) fn run_v1<P: RecipientPluginV1>(mut plugin: P) -> io::Result<()> {
263 let mut conn = Connection::accept();
264
265 let ((recipients, identities), file_keys, labels_supported) = {
267 let (recipients, identities, file_keys, labels_supported) = conn.unidir_receive(
268 (ADD_RECIPIENT, |s| match (&s.args[..], &s.body[..]) {
269 ([recipient], []) => Ok(recipient.clone()),
270 _ => Err(Error::Internal {
271 message: format!(
272 "{} command must have exactly one metadata argument and no data",
273 ADD_RECIPIENT
274 ),
275 }),
276 }),
277 (ADD_IDENTITY, |s| match (&s.args[..], &s.body[..]) {
278 ([identity], []) => Ok(identity.clone()),
279 _ => Err(Error::Internal {
280 message: format!(
281 "{} command must have exactly one metadata argument and no data",
282 ADD_IDENTITY
283 ),
284 }),
285 }),
286 (Some(WRAP_FILE_KEY), |s| {
287 FileKey::try_init_with_mut(|file_key| {
289 if s.body.len() == file_key.len() {
290 file_key.copy_from_slice(&s.body);
291 Ok(())
292 } else {
293 Err(Error::Internal {
294 message: "invalid file key length".to_owned(),
295 })
296 }
297 })
298 }),
299 (Some(EXTENSION_LABELS), |_| Ok(())),
300 )?;
301 (
302 match (recipients, identities) {
303 (Ok(r), Ok(i)) if r.is_empty() && i.is_empty() => (
304 Err(vec![Error::Internal {
305 message: format!(
306 "Need at least one {} or {} command",
307 ADD_RECIPIENT, ADD_IDENTITY
308 ),
309 }]),
310 Err(vec![]),
311 ),
312 r => r,
313 },
314 match file_keys.unwrap() {
315 Ok(f) if f.is_empty() => Err(vec![Error::Internal {
316 message: format!("Need at least one {} command", WRAP_FILE_KEY),
317 }]),
318 r => r,
319 },
320 match &labels_supported.unwrap() {
321 Ok(v) if v.is_empty() => Ok(false),
322 Ok(v) if v.len() == 1 => Ok(true),
323 _ => Err(vec![Error::Internal {
324 message: format!("Received more than one {} command", EXTENSION_LABELS),
325 }]),
326 },
327 )
328 };
329
330 fn parse_and_add(
333 items: Result<Vec<String>, Vec<Error>>,
334 plugin_name: impl Fn(&str) -> Option<&str>,
335 error: impl Fn(usize) -> Error,
336 mut adder: impl FnMut(usize, &str, Vec<u8>) -> Result<(), Error>,
337 ) -> Result<usize, Vec<Error>> {
338 items.and_then(|items| {
339 let count = items.len();
340 let errors: Vec<_> = items
341 .into_iter()
342 .enumerate()
343 .map(|(index, item)| {
344 let decoded = bech32::decode(&item).ok();
345 decoded
346 .as_ref()
347 .and_then(|(hrp, data, variant)| match (plugin_name(hrp), variant) {
348 (Some(plugin_name), &bech32::Variant::Bech32) => {
349 Vec::from_base32(data).ok().map(|data| (plugin_name, data))
350 }
351 _ => None,
352 })
353 .ok_or_else(|| error(index))
354 .and_then(|(plugin_name, bytes)| adder(index, plugin_name, bytes))
355 })
356 .filter_map(|res| res.err())
357 .collect();
358
359 if errors.is_empty() {
360 Ok(count)
361 } else {
362 Err(errors)
363 }
364 })
365 }
366 let recipients = parse_and_add(
367 recipients,
368 |hrp| hrp.strip_prefix(PLUGIN_RECIPIENT_PREFIX),
369 |index| Error::Recipient {
370 index,
371 message: "Invalid recipient encoding".to_owned(),
372 },
373 |index, plugin_name, bytes| plugin.add_recipient(index, plugin_name, &bytes),
374 );
375 let identities = parse_and_add(
376 identities,
377 |hrp| {
378 if hrp.starts_with(PLUGIN_IDENTITY_PREFIX) && hrp.ends_with('-') {
379 Some(&hrp[PLUGIN_IDENTITY_PREFIX.len()..hrp.len() - 1])
380 } else {
381 None
382 }
383 },
384 |index| Error::Identity {
385 index,
386 message: "Invalid identity encoding".to_owned(),
387 },
388 |index, plugin_name, bytes| plugin.add_identity(index, plugin_name, &bytes),
389 );
390
391 let required_labels = plugin.labels();
392
393 let labels = match (labels_supported, required_labels.is_empty()) {
394 (Ok(true), _) | (Ok(false), true) => {
395 if required_labels.contains("") {
396 Err(vec![Error::Internal {
397 message: "Plugin tried to use the empty string as a label".into(),
398 }])
399 } else if required_labels.iter().all(is_arbitrary_string) {
400 Ok(required_labels)
401 } else {
402 Err(vec![Error::Internal {
403 message: "Plugin tried to use a label containing an invalid character".into(),
404 }])
405 }
406 }
407 (Ok(false), false) => Err(vec![Error::Internal {
408 message: "Plugin requires labels but client does not support them".into(),
409 }]),
410 (Err(errors), true) => Err(errors),
411 (Err(mut errors), false) => {
412 errors.push(Error::Internal {
413 message: "Plugin requires labels but client does not support them".into(),
414 });
415 Err(errors)
416 }
417 };
418
419 conn.bidir_send(|mut phase| {
421 let (expected_stanzas, file_keys, labels) =
422 match (recipients, identities, file_keys, labels) {
423 (Ok(recipients), Ok(identities), Ok(file_keys), Ok(labels)) => {
424 (recipients + identities, file_keys, labels)
425 }
426 (recipients, identities, file_keys, labels) => {
427 for error in recipients
428 .err()
429 .into_iter()
430 .chain(identities.err())
431 .chain(file_keys.err())
432 .chain(labels.err())
433 .flatten()
434 {
435 error.send(&mut phase)?;
436 }
437 return Ok(());
438 }
439 };
440
441 let labels = labels.iter().map(|s| s.as_str()).collect::<Vec<_>>();
442 let _ = phase.send(LABELS, &labels, &[])?;
446
447 match plugin.wrap_file_keys(file_keys, BidirCallbacks(&mut phase))? {
448 Ok(files) => {
449 for (file_index, stanzas) in files.into_iter().enumerate() {
450 assert!(stanzas.len() >= expected_stanzas);
453
454 for stanza in stanzas {
455 phase
456 .send_stanza(RECIPIENT_STANZA, &[&file_index.to_string()], &stanza)?
457 .unwrap();
458 }
459 }
460 }
461 Err(errors) => {
462 for error in errors {
463 error.send(&mut phase)?;
464 }
465 }
466 }
467
468 Ok(())
469 })
470}