age_core/
plugin.rs

1//! Common structs and constants for the age plugin system.
2//!
3//! These are shared between the client implementation in the `age` crate, and the plugin
4//! implementations built around the `age-plugin` crate.
5
6use rand::{thread_rng, Rng};
7use secrecy::zeroize::Zeroize;
8use std::env;
9use std::fmt;
10use std::io::{self, BufRead, BufReader, Read, Write};
11use std::iter;
12use std::path::Path;
13use std::process::{ChildStdin, ChildStdout, Command, Stdio};
14
15use crate::{
16    format::{grease_the_joint, read, write, Stanza},
17    io::{DebugReader, DebugWriter},
18};
19
20pub const IDENTITY_V1: &str = "identity-v1";
21pub const RECIPIENT_V1: &str = "recipient-v1";
22
23const COMMAND_DONE: &str = "done";
24const RESPONSE_OK: &str = "ok";
25const RESPONSE_FAIL: &str = "fail";
26const RESPONSE_UNSUPPORTED: &str = "unsupported";
27
28/// An error within the plugin protocol.
29#[derive(Debug)]
30pub enum Error {
31    Fail,
32    Unsupported,
33}
34
35impl fmt::Display for Error {
36    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
37        match self {
38            Self::Fail => write!(f, "General plugin protocol error"),
39            Self::Unsupported => write!(f, "Unsupported command"),
40        }
41    }
42}
43
44impl std::error::Error for Error {}
45
46/// Result type for the plugin protocol.
47///
48/// - The outer error indicates a problem with the IPC transport or state machine; these
49///   should result in the state machine being terminated and the connection closed.
50/// - The inner error indicates an error within the plugin protocol, that the recipient
51///   should explicitly handle.
52pub type Result<T> = io::Result<std::result::Result<T, Error>>;
53
54type UnidirResult<A, B, C, D, E> = io::Result<(
55    std::result::Result<Vec<A>, Vec<E>>,
56    std::result::Result<Vec<B>, Vec<E>>,
57    Option<std::result::Result<Vec<C>, Vec<E>>>,
58    Option<std::result::Result<Vec<D>, Vec<E>>>,
59)>;
60
61/// A connection to a plugin binary.
62pub struct Connection<R: Read, W: Write> {
63    input: BufReader<R>,
64    output: W,
65    buffer: String,
66    _working_dir: Option<tempfile::TempDir>,
67}
68
69impl Connection<DebugReader<ChildStdout>, DebugWriter<ChildStdin>> {
70    /// Starts a plugin binary with the given state machine.
71    ///
72    /// If the `AGEDEBUG` environment variable is set to `plugin`, then all messages sent
73    /// to and from the plugin, as well as anything the plugin prints to its `stderr`,
74    /// will be printed to the `stderr` of the parent process.
75    pub fn open(binary: &Path, state_machine: &str) -> io::Result<Self> {
76        let working_dir = tempfile::tempdir()?;
77        let debug_enabled = env::var("AGEDEBUG").map(|s| s == "plugin").unwrap_or(false);
78        let process = Command::new(binary.canonicalize()?)
79            .arg(format!("--age-plugin={}", state_machine))
80            .current_dir(working_dir.path())
81            .stdin(Stdio::piped())
82            .stdout(Stdio::piped())
83            .stderr(if debug_enabled {
84                Stdio::inherit()
85            } else {
86                Stdio::null()
87            })
88            .spawn()?;
89        let input = BufReader::new(DebugReader::new(
90            process.stdout.expect("could open stdout"),
91            debug_enabled,
92        ));
93        let output = DebugWriter::new(process.stdin.expect("could open stdin"), debug_enabled);
94        Ok(Connection {
95            input,
96            output,
97            buffer: String::new(),
98            _working_dir: Some(working_dir),
99        })
100    }
101}
102
103impl Connection<io::Stdin, io::Stdout> {
104    /// Initialise a connection from an age client.
105    pub fn accept() -> Self {
106        Connection {
107            input: BufReader::new(io::stdin()),
108            output: io::stdout(),
109            buffer: String::new(),
110            _working_dir: None,
111        }
112    }
113}
114
115impl<R: Read, W: Write> Connection<R, W> {
116    fn send<S: AsRef<str>>(
117        &mut self,
118        command: &str,
119        metadata: &[S],
120        data: &[u8],
121    ) -> io::Result<()> {
122        use cookie_factory::GenError;
123
124        cookie_factory::gen_simple(write::age_stanza(command, metadata, data), &mut self.output)
125            .map_err(|e| match e {
126                GenError::IoError(e) => e,
127                e => io::Error::new(io::ErrorKind::Other, format!("{}", e)),
128            })
129            .and_then(|w| w.flush())
130    }
131
132    fn send_stanza<S: AsRef<str>>(
133        &mut self,
134        command: &str,
135        metadata: &[S],
136        stanza: &Stanza,
137    ) -> io::Result<()> {
138        let metadata: Vec<_> = metadata
139            .iter()
140            .map(|s| s.as_ref())
141            .chain(iter::once(stanza.tag.as_str()))
142            .chain(stanza.args.iter().map(|s| s.as_str()))
143            .collect();
144
145        self.send(command, &metadata, &stanza.body)
146    }
147
148    fn receive(&mut self) -> io::Result<Stanza> {
149        let (stanza, consumed) = loop {
150            match read::age_stanza(self.buffer.as_bytes()) {
151                Ok((remainder, r)) => break (r.into(), self.buffer.len() - remainder.len()),
152                Err(nom::Err::Incomplete(_)) => {
153                    if self.input.read_line(&mut self.buffer)? == 0 {
154                        return Err(io::Error::new(
155                            io::ErrorKind::UnexpectedEof,
156                            "incomplete response",
157                        ));
158                    };
159                }
160                Err(_) => {
161                    return Err(io::Error::new(
162                        io::ErrorKind::InvalidData,
163                        "invalid response",
164                    ));
165                }
166            }
167        };
168
169        // We are finished with any prior response.
170        let remainder = self.buffer.split_off(consumed);
171        self.buffer.zeroize();
172        self.buffer = remainder;
173
174        Ok(stanza)
175    }
176
177    fn grease_gun(&mut self) -> impl Iterator<Item = Stanza> {
178        // Add 5% grease
179        let mut rng = thread_rng();
180        (0..2).filter_map(move |_| {
181            if rng.gen_range(0..100) < 5 {
182                Some(grease_the_joint())
183            } else {
184                None
185            }
186        })
187    }
188
189    fn done(&mut self) -> io::Result<()> {
190        self.send::<&str>(COMMAND_DONE, &[], &[])
191    }
192
193    /// Runs a unidirectional phase as the controller.
194    pub fn unidir_send<P: FnOnce(UnidirSend<R, W>) -> io::Result<()>>(
195        &mut self,
196        phase_steps: P,
197    ) -> io::Result<()> {
198        phase_steps(UnidirSend(self))?;
199        for grease in self.grease_gun() {
200            self.send(&grease.tag, &grease.args, &grease.body)?;
201        }
202        self.done()
203    }
204
205    /// Runs a unidirectional phase as the recipient.
206    ///
207    /// # Arguments
208    ///
209    /// `command_a`, `command_b`, and (optionally) `command_c` and `command_d` are the
210    /// known commands that are expected to be received. All other received commands
211    /// (including grease) will be ignored.
212    pub fn unidir_receive<A, B, C, D, E, F, G, H, I>(
213        &mut self,
214        command_a: (&str, F),
215        command_b: (&str, G),
216        command_c: (Option<&str>, H),
217        command_d: (Option<&str>, I),
218    ) -> UnidirResult<A, B, C, D, E>
219    where
220        F: Fn(Stanza) -> std::result::Result<A, E>,
221        G: Fn(Stanza) -> std::result::Result<B, E>,
222        H: Fn(Stanza) -> std::result::Result<C, E>,
223        I: Fn(Stanza) -> std::result::Result<D, E>,
224    {
225        let mut res_a = Ok(vec![]);
226        let mut res_b = Ok(vec![]);
227        let mut res_c = Ok(vec![]);
228        let mut res_d = Ok(vec![]);
229
230        for stanza in iter::repeat_with(|| self.receive()).take_while(|res| match res {
231            Ok(stanza) => stanza.tag != COMMAND_DONE,
232            _ => true,
233        }) {
234            let stanza = stanza?;
235
236            fn validate<T, E>(
237                val: std::result::Result<T, E>,
238                res: &mut std::result::Result<Vec<T>, Vec<E>>,
239            ) {
240                // Structurally validate the stanza against this command.
241                match val {
242                    Ok(a) => {
243                        if let Ok(stanzas) = res {
244                            stanzas.push(a)
245                        }
246                    }
247                    Err(e) => match res {
248                        Ok(_) => *res = Err(vec![e]),
249                        Err(errors) => errors.push(e),
250                    },
251                }
252            }
253
254            if stanza.tag.as_str() == command_a.0 {
255                validate(command_a.1(stanza), &mut res_a)
256            } else if stanza.tag.as_str() == command_b.0 {
257                validate(command_b.1(stanza), &mut res_b)
258            } else {
259                if let Some(tag) = command_c.0 {
260                    if stanza.tag.as_str() == tag {
261                        validate(command_c.1(stanza), &mut res_c);
262                        continue;
263                    }
264                }
265                if let Some(tag) = command_d.0 {
266                    if stanza.tag.as_str() == tag {
267                        validate(command_d.1(stanza), &mut res_d);
268                        continue;
269                    }
270                }
271            }
272        }
273
274        Ok((
275            res_a,
276            res_b,
277            command_c.0.map(|_| res_c),
278            command_d.0.map(|_| res_d),
279        ))
280    }
281
282    /// Runs a bidirectional phase as the controller.
283    pub fn bidir_send<P: FnOnce(BidirSend<R, W>) -> io::Result<()>>(
284        &mut self,
285        phase_steps: P,
286    ) -> io::Result<()> {
287        phase_steps(BidirSend(self))?;
288        for grease in self.grease_gun() {
289            self.send(&grease.tag, &grease.args, &grease.body)?;
290            self.receive()?;
291        }
292        self.done()
293    }
294
295    /// Runs a bidirectional phase as the recipient.
296    pub fn bidir_receive<H>(&mut self, commands: &[&str], mut handler: H) -> io::Result<()>
297    where
298        H: FnMut(Stanza, Reply<R, W>) -> Response,
299    {
300        loop {
301            let stanza = self.receive()?;
302            match stanza.tag.as_str() {
303                COMMAND_DONE => break Ok(()),
304                t if commands.contains(&t) => handler(stanza, Reply(self)).0?,
305                _ => self.send::<&str>(RESPONSE_UNSUPPORTED, &[], &[])?,
306            }
307        }
308    }
309}
310
311/// Actions that a controller may take during a unidirectional phase.
312///
313/// Grease is applied automatically.
314pub struct UnidirSend<'a, R: Read, W: Write>(&'a mut Connection<R, W>);
315
316impl<'a, R: Read, W: Write> UnidirSend<'a, R, W> {
317    /// Send a command.
318    pub fn send(&mut self, command: &str, metadata: &[&str], data: &[u8]) -> io::Result<()> {
319        for grease in self.0.grease_gun() {
320            self.0.send(&grease.tag, &grease.args, &grease.body)?;
321        }
322        self.0.send(command, metadata, data)
323    }
324
325    /// Send an entire stanza.
326    pub fn send_stanza(
327        &mut self,
328        command: &str,
329        metadata: &[&str],
330        stanza: &Stanza,
331    ) -> io::Result<()> {
332        for grease in self.0.grease_gun() {
333            self.0.send(&grease.tag, &grease.args, &grease.body)?;
334        }
335        self.0.send_stanza(command, metadata, stanza)
336    }
337}
338
339/// Actions that a controller may take during a bidirectional phase.
340///
341/// Grease is applied automatically.
342pub struct BidirSend<'a, R: Read, W: Write>(&'a mut Connection<R, W>);
343
344impl<'a, R: Read, W: Write> BidirSend<'a, R, W> {
345    /// Send a command and receive a response.
346    pub fn send(&mut self, command: &str, metadata: &[&str], data: &[u8]) -> Result<Stanza> {
347        for grease in self.0.grease_gun() {
348            self.0.send(&grease.tag, &grease.args, &grease.body)?;
349            self.0.receive()?;
350        }
351        self.0.send(command, metadata, data)?;
352        let s = self.0.receive()?;
353        match s.tag.as_ref() {
354            RESPONSE_OK => Ok(Ok(s)),
355            RESPONSE_FAIL => Ok(Err(Error::Fail)),
356            RESPONSE_UNSUPPORTED => Ok(Err(Error::Unsupported)),
357            tag => Err(io::Error::new(
358                io::ErrorKind::InvalidData,
359                format!("unexpected response: {}", tag),
360            )),
361        }
362    }
363
364    /// Send an entire stanza.
365    pub fn send_stanza(
366        &mut self,
367        command: &str,
368        metadata: &[&str],
369        stanza: &Stanza,
370    ) -> Result<Stanza> {
371        for grease in self.0.grease_gun() {
372            self.0.send(&grease.tag, &grease.args, &grease.body)?;
373            self.0.receive()?;
374        }
375        self.0.send_stanza(command, metadata, stanza)?;
376        let s = self.0.receive()?;
377        match s.tag.as_ref() {
378            RESPONSE_OK => Ok(Ok(s)),
379            RESPONSE_FAIL => Ok(Err(Error::Fail)),
380            RESPONSE_UNSUPPORTED => Ok(Err(Error::Unsupported)),
381            tag => Err(io::Error::new(
382                io::ErrorKind::InvalidData,
383                format!("unexpected response: {}", tag),
384            )),
385        }
386    }
387}
388
389/// The possible replies to a bidirectional command.
390pub struct Reply<'a, R: Read, W: Write>(&'a mut Connection<R, W>);
391
392impl<'a, R: Read, W: Write> Reply<'a, R, W> {
393    /// Reply with `ok` and optional data.
394    pub fn ok(self, data: Option<&[u8]>) -> Response {
395        Response(
396            self.0
397                .send::<&str>(RESPONSE_OK, &[], data.unwrap_or_default()),
398        )
399    }
400
401    /// Reply with `ok`, metadata, and optional data.
402    pub fn ok_with_metadata<S: AsRef<str>>(self, metadata: &[S], data: Option<&[u8]>) -> Response {
403        Response(self.0.send(RESPONSE_OK, metadata, data.unwrap_or_default()))
404    }
405
406    /// The command failed (for example, the user failed to respond to an input request).
407    pub fn fail(self) -> Response {
408        Response(self.0.send::<&str>(RESPONSE_FAIL, &[], &[]))
409    }
410}
411
412/// A response to a bidirectional command.
413pub struct Response(io::Result<()>);
414
415#[cfg(test)]
416mod tests {
417    use std::sync::{Arc, Mutex};
418
419    use super::*;
420
421    pub struct Pipe(Vec<u8>);
422
423    impl Pipe {
424        pub fn new() -> Arc<Mutex<Self>> {
425            Arc::new(Mutex::new(Pipe(Vec::new())))
426        }
427    }
428
429    pub struct PipeReader {
430        pipe: Arc<Mutex<Pipe>>,
431    }
432
433    impl PipeReader {
434        pub fn new(pipe: Arc<Mutex<Pipe>>) -> Self {
435            PipeReader { pipe }
436        }
437    }
438
439    impl Read for PipeReader {
440        fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
441            let mut pipe = self.pipe.lock().unwrap();
442            let n_in = pipe.0.len();
443            let n_out = buf.len();
444            if n_in == 0 {
445                Err(io::Error::new(io::ErrorKind::WouldBlock, ""))
446            } else if n_out < n_in {
447                buf.copy_from_slice(&pipe.0[..n_out]);
448                pipe.0 = pipe.0.split_off(n_out);
449                Ok(n_out)
450            } else {
451                buf[..n_in].copy_from_slice(&pipe.0);
452                pipe.0.clear();
453                Ok(n_in)
454            }
455        }
456    }
457
458    pub struct PipeWriter {
459        pipe: Arc<Mutex<Pipe>>,
460    }
461
462    impl PipeWriter {
463        pub fn new(pipe: Arc<Mutex<Pipe>>) -> Self {
464            PipeWriter { pipe }
465        }
466    }
467
468    impl Write for PipeWriter {
469        fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
470            let mut pipe = self.pipe.lock().unwrap();
471            pipe.0.extend_from_slice(buf);
472            Ok(buf.len())
473        }
474
475        fn flush(&mut self) -> io::Result<()> {
476            Ok(())
477        }
478    }
479
480    #[test]
481    fn mock_plugin() {
482        let client_to_plugin = Pipe::new();
483        let plugin_to_client = Pipe::new();
484
485        let mut client_conn = Connection {
486            input: BufReader::new(PipeReader::new(plugin_to_client.clone())),
487            output: PipeWriter::new(client_to_plugin.clone()),
488            buffer: String::new(),
489            _working_dir: None,
490        };
491        let mut plugin_conn = Connection {
492            input: BufReader::new(PipeReader::new(client_to_plugin)),
493            output: PipeWriter::new(plugin_to_client),
494            buffer: String::new(),
495            _working_dir: None,
496        };
497
498        client_conn
499            .unidir_send(|mut phase| phase.send("test", &["foo"], b"bar"))
500            .unwrap();
501        let stanza = plugin_conn
502            .unidir_receive::<_, (), (), (), _, _, _, _, _>(
503                ("test", Ok),
504                ("other", |_| Err(())),
505                (None, |_| Ok(())),
506                (None, |_| Ok(())),
507            )
508            .unwrap();
509        assert_eq!(
510            stanza,
511            (
512                Ok(vec![Stanza {
513                    tag: "test".to_owned(),
514                    args: vec!["foo".to_owned()],
515                    body: b"bar"[..].to_owned()
516                }]),
517                Ok(vec![]),
518                None,
519                None,
520            )
521        );
522    }
523}