1use rand::{thread_rng, Rng};
7use secrecy::zeroize::Zeroize;
8use std::env;
9use std::fmt;
10use std::io::{self, BufRead, BufReader, Read, Write};
11use std::iter;
12use std::path::Path;
13use std::process::{ChildStdin, ChildStdout, Command, Stdio};
14
15use crate::{
16 format::{grease_the_joint, read, write, Stanza},
17 io::{DebugReader, DebugWriter},
18};
19
20pub const IDENTITY_V1: &str = "identity-v1";
21pub const RECIPIENT_V1: &str = "recipient-v1";
22
23const COMMAND_DONE: &str = "done";
24const RESPONSE_OK: &str = "ok";
25const RESPONSE_FAIL: &str = "fail";
26const RESPONSE_UNSUPPORTED: &str = "unsupported";
27
28#[derive(Debug)]
30pub enum Error {
31 Fail,
32 Unsupported,
33}
34
35impl fmt::Display for Error {
36 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
37 match self {
38 Self::Fail => write!(f, "General plugin protocol error"),
39 Self::Unsupported => write!(f, "Unsupported command"),
40 }
41 }
42}
43
44impl std::error::Error for Error {}
45
46pub type Result<T> = io::Result<std::result::Result<T, Error>>;
53
54type UnidirResult<A, B, C, D, E> = io::Result<(
55 std::result::Result<Vec<A>, Vec<E>>,
56 std::result::Result<Vec<B>, Vec<E>>,
57 Option<std::result::Result<Vec<C>, Vec<E>>>,
58 Option<std::result::Result<Vec<D>, Vec<E>>>,
59)>;
60
61pub struct Connection<R: Read, W: Write> {
63 input: BufReader<R>,
64 output: W,
65 buffer: String,
66 _working_dir: Option<tempfile::TempDir>,
67}
68
69impl Connection<DebugReader<ChildStdout>, DebugWriter<ChildStdin>> {
70 pub fn open(binary: &Path, state_machine: &str) -> io::Result<Self> {
76 let working_dir = tempfile::tempdir()?;
77 let debug_enabled = env::var("AGEDEBUG").map(|s| s == "plugin").unwrap_or(false);
78 let process = Command::new(binary.canonicalize()?)
79 .arg(format!("--age-plugin={}", state_machine))
80 .current_dir(working_dir.path())
81 .stdin(Stdio::piped())
82 .stdout(Stdio::piped())
83 .stderr(if debug_enabled {
84 Stdio::inherit()
85 } else {
86 Stdio::null()
87 })
88 .spawn()?;
89 let input = BufReader::new(DebugReader::new(
90 process.stdout.expect("could open stdout"),
91 debug_enabled,
92 ));
93 let output = DebugWriter::new(process.stdin.expect("could open stdin"), debug_enabled);
94 Ok(Connection {
95 input,
96 output,
97 buffer: String::new(),
98 _working_dir: Some(working_dir),
99 })
100 }
101}
102
103impl Connection<io::Stdin, io::Stdout> {
104 pub fn accept() -> Self {
106 Connection {
107 input: BufReader::new(io::stdin()),
108 output: io::stdout(),
109 buffer: String::new(),
110 _working_dir: None,
111 }
112 }
113}
114
115impl<R: Read, W: Write> Connection<R, W> {
116 fn send<S: AsRef<str>>(
117 &mut self,
118 command: &str,
119 metadata: &[S],
120 data: &[u8],
121 ) -> io::Result<()> {
122 use cookie_factory::GenError;
123
124 cookie_factory::gen_simple(write::age_stanza(command, metadata, data), &mut self.output)
125 .map_err(|e| match e {
126 GenError::IoError(e) => e,
127 e => io::Error::new(io::ErrorKind::Other, format!("{}", e)),
128 })
129 .and_then(|w| w.flush())
130 }
131
132 fn send_stanza<S: AsRef<str>>(
133 &mut self,
134 command: &str,
135 metadata: &[S],
136 stanza: &Stanza,
137 ) -> io::Result<()> {
138 let metadata: Vec<_> = metadata
139 .iter()
140 .map(|s| s.as_ref())
141 .chain(iter::once(stanza.tag.as_str()))
142 .chain(stanza.args.iter().map(|s| s.as_str()))
143 .collect();
144
145 self.send(command, &metadata, &stanza.body)
146 }
147
148 fn receive(&mut self) -> io::Result<Stanza> {
149 let (stanza, consumed) = loop {
150 match read::age_stanza(self.buffer.as_bytes()) {
151 Ok((remainder, r)) => break (r.into(), self.buffer.len() - remainder.len()),
152 Err(nom::Err::Incomplete(_)) => {
153 if self.input.read_line(&mut self.buffer)? == 0 {
154 return Err(io::Error::new(
155 io::ErrorKind::UnexpectedEof,
156 "incomplete response",
157 ));
158 };
159 }
160 Err(_) => {
161 return Err(io::Error::new(
162 io::ErrorKind::InvalidData,
163 "invalid response",
164 ));
165 }
166 }
167 };
168
169 let remainder = self.buffer.split_off(consumed);
171 self.buffer.zeroize();
172 self.buffer = remainder;
173
174 Ok(stanza)
175 }
176
177 fn grease_gun(&mut self) -> impl Iterator<Item = Stanza> {
178 let mut rng = thread_rng();
180 (0..2).filter_map(move |_| {
181 if rng.gen_range(0..100) < 5 {
182 Some(grease_the_joint())
183 } else {
184 None
185 }
186 })
187 }
188
189 fn done(&mut self) -> io::Result<()> {
190 self.send::<&str>(COMMAND_DONE, &[], &[])
191 }
192
193 pub fn unidir_send<P: FnOnce(UnidirSend<R, W>) -> io::Result<()>>(
195 &mut self,
196 phase_steps: P,
197 ) -> io::Result<()> {
198 phase_steps(UnidirSend(self))?;
199 for grease in self.grease_gun() {
200 self.send(&grease.tag, &grease.args, &grease.body)?;
201 }
202 self.done()
203 }
204
205 pub fn unidir_receive<A, B, C, D, E, F, G, H, I>(
213 &mut self,
214 command_a: (&str, F),
215 command_b: (&str, G),
216 command_c: (Option<&str>, H),
217 command_d: (Option<&str>, I),
218 ) -> UnidirResult<A, B, C, D, E>
219 where
220 F: Fn(Stanza) -> std::result::Result<A, E>,
221 G: Fn(Stanza) -> std::result::Result<B, E>,
222 H: Fn(Stanza) -> std::result::Result<C, E>,
223 I: Fn(Stanza) -> std::result::Result<D, E>,
224 {
225 let mut res_a = Ok(vec![]);
226 let mut res_b = Ok(vec![]);
227 let mut res_c = Ok(vec![]);
228 let mut res_d = Ok(vec![]);
229
230 for stanza in iter::repeat_with(|| self.receive()).take_while(|res| match res {
231 Ok(stanza) => stanza.tag != COMMAND_DONE,
232 _ => true,
233 }) {
234 let stanza = stanza?;
235
236 fn validate<T, E>(
237 val: std::result::Result<T, E>,
238 res: &mut std::result::Result<Vec<T>, Vec<E>>,
239 ) {
240 match val {
242 Ok(a) => {
243 if let Ok(stanzas) = res {
244 stanzas.push(a)
245 }
246 }
247 Err(e) => match res {
248 Ok(_) => *res = Err(vec![e]),
249 Err(errors) => errors.push(e),
250 },
251 }
252 }
253
254 if stanza.tag.as_str() == command_a.0 {
255 validate(command_a.1(stanza), &mut res_a)
256 } else if stanza.tag.as_str() == command_b.0 {
257 validate(command_b.1(stanza), &mut res_b)
258 } else {
259 if let Some(tag) = command_c.0 {
260 if stanza.tag.as_str() == tag {
261 validate(command_c.1(stanza), &mut res_c);
262 continue;
263 }
264 }
265 if let Some(tag) = command_d.0 {
266 if stanza.tag.as_str() == tag {
267 validate(command_d.1(stanza), &mut res_d);
268 continue;
269 }
270 }
271 }
272 }
273
274 Ok((
275 res_a,
276 res_b,
277 command_c.0.map(|_| res_c),
278 command_d.0.map(|_| res_d),
279 ))
280 }
281
282 pub fn bidir_send<P: FnOnce(BidirSend<R, W>) -> io::Result<()>>(
284 &mut self,
285 phase_steps: P,
286 ) -> io::Result<()> {
287 phase_steps(BidirSend(self))?;
288 for grease in self.grease_gun() {
289 self.send(&grease.tag, &grease.args, &grease.body)?;
290 self.receive()?;
291 }
292 self.done()
293 }
294
295 pub fn bidir_receive<H>(&mut self, commands: &[&str], mut handler: H) -> io::Result<()>
297 where
298 H: FnMut(Stanza, Reply<R, W>) -> Response,
299 {
300 loop {
301 let stanza = self.receive()?;
302 match stanza.tag.as_str() {
303 COMMAND_DONE => break Ok(()),
304 t if commands.contains(&t) => handler(stanza, Reply(self)).0?,
305 _ => self.send::<&str>(RESPONSE_UNSUPPORTED, &[], &[])?,
306 }
307 }
308 }
309}
310
311pub struct UnidirSend<'a, R: Read, W: Write>(&'a mut Connection<R, W>);
315
316impl<'a, R: Read, W: Write> UnidirSend<'a, R, W> {
317 pub fn send(&mut self, command: &str, metadata: &[&str], data: &[u8]) -> io::Result<()> {
319 for grease in self.0.grease_gun() {
320 self.0.send(&grease.tag, &grease.args, &grease.body)?;
321 }
322 self.0.send(command, metadata, data)
323 }
324
325 pub fn send_stanza(
327 &mut self,
328 command: &str,
329 metadata: &[&str],
330 stanza: &Stanza,
331 ) -> io::Result<()> {
332 for grease in self.0.grease_gun() {
333 self.0.send(&grease.tag, &grease.args, &grease.body)?;
334 }
335 self.0.send_stanza(command, metadata, stanza)
336 }
337}
338
339pub struct BidirSend<'a, R: Read, W: Write>(&'a mut Connection<R, W>);
343
344impl<'a, R: Read, W: Write> BidirSend<'a, R, W> {
345 pub fn send(&mut self, command: &str, metadata: &[&str], data: &[u8]) -> Result<Stanza> {
347 for grease in self.0.grease_gun() {
348 self.0.send(&grease.tag, &grease.args, &grease.body)?;
349 self.0.receive()?;
350 }
351 self.0.send(command, metadata, data)?;
352 let s = self.0.receive()?;
353 match s.tag.as_ref() {
354 RESPONSE_OK => Ok(Ok(s)),
355 RESPONSE_FAIL => Ok(Err(Error::Fail)),
356 RESPONSE_UNSUPPORTED => Ok(Err(Error::Unsupported)),
357 tag => Err(io::Error::new(
358 io::ErrorKind::InvalidData,
359 format!("unexpected response: {}", tag),
360 )),
361 }
362 }
363
364 pub fn send_stanza(
366 &mut self,
367 command: &str,
368 metadata: &[&str],
369 stanza: &Stanza,
370 ) -> Result<Stanza> {
371 for grease in self.0.grease_gun() {
372 self.0.send(&grease.tag, &grease.args, &grease.body)?;
373 self.0.receive()?;
374 }
375 self.0.send_stanza(command, metadata, stanza)?;
376 let s = self.0.receive()?;
377 match s.tag.as_ref() {
378 RESPONSE_OK => Ok(Ok(s)),
379 RESPONSE_FAIL => Ok(Err(Error::Fail)),
380 RESPONSE_UNSUPPORTED => Ok(Err(Error::Unsupported)),
381 tag => Err(io::Error::new(
382 io::ErrorKind::InvalidData,
383 format!("unexpected response: {}", tag),
384 )),
385 }
386 }
387}
388
389pub struct Reply<'a, R: Read, W: Write>(&'a mut Connection<R, W>);
391
392impl<'a, R: Read, W: Write> Reply<'a, R, W> {
393 pub fn ok(self, data: Option<&[u8]>) -> Response {
395 Response(
396 self.0
397 .send::<&str>(RESPONSE_OK, &[], data.unwrap_or_default()),
398 )
399 }
400
401 pub fn ok_with_metadata<S: AsRef<str>>(self, metadata: &[S], data: Option<&[u8]>) -> Response {
403 Response(self.0.send(RESPONSE_OK, metadata, data.unwrap_or_default()))
404 }
405
406 pub fn fail(self) -> Response {
408 Response(self.0.send::<&str>(RESPONSE_FAIL, &[], &[]))
409 }
410}
411
412pub struct Response(io::Result<()>);
414
415#[cfg(test)]
416mod tests {
417 use std::sync::{Arc, Mutex};
418
419 use super::*;
420
421 pub struct Pipe(Vec<u8>);
422
423 impl Pipe {
424 pub fn new() -> Arc<Mutex<Self>> {
425 Arc::new(Mutex::new(Pipe(Vec::new())))
426 }
427 }
428
429 pub struct PipeReader {
430 pipe: Arc<Mutex<Pipe>>,
431 }
432
433 impl PipeReader {
434 pub fn new(pipe: Arc<Mutex<Pipe>>) -> Self {
435 PipeReader { pipe }
436 }
437 }
438
439 impl Read for PipeReader {
440 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
441 let mut pipe = self.pipe.lock().unwrap();
442 let n_in = pipe.0.len();
443 let n_out = buf.len();
444 if n_in == 0 {
445 Err(io::Error::new(io::ErrorKind::WouldBlock, ""))
446 } else if n_out < n_in {
447 buf.copy_from_slice(&pipe.0[..n_out]);
448 pipe.0 = pipe.0.split_off(n_out);
449 Ok(n_out)
450 } else {
451 buf[..n_in].copy_from_slice(&pipe.0);
452 pipe.0.clear();
453 Ok(n_in)
454 }
455 }
456 }
457
458 pub struct PipeWriter {
459 pipe: Arc<Mutex<Pipe>>,
460 }
461
462 impl PipeWriter {
463 pub fn new(pipe: Arc<Mutex<Pipe>>) -> Self {
464 PipeWriter { pipe }
465 }
466 }
467
468 impl Write for PipeWriter {
469 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
470 let mut pipe = self.pipe.lock().unwrap();
471 pipe.0.extend_from_slice(buf);
472 Ok(buf.len())
473 }
474
475 fn flush(&mut self) -> io::Result<()> {
476 Ok(())
477 }
478 }
479
480 #[test]
481 fn mock_plugin() {
482 let client_to_plugin = Pipe::new();
483 let plugin_to_client = Pipe::new();
484
485 let mut client_conn = Connection {
486 input: BufReader::new(PipeReader::new(plugin_to_client.clone())),
487 output: PipeWriter::new(client_to_plugin.clone()),
488 buffer: String::new(),
489 _working_dir: None,
490 };
491 let mut plugin_conn = Connection {
492 input: BufReader::new(PipeReader::new(client_to_plugin)),
493 output: PipeWriter::new(plugin_to_client),
494 buffer: String::new(),
495 _working_dir: None,
496 };
497
498 client_conn
499 .unidir_send(|mut phase| phase.send("test", &["foo"], b"bar"))
500 .unwrap();
501 let stanza = plugin_conn
502 .unidir_receive::<_, (), (), (), _, _, _, _, _>(
503 ("test", Ok),
504 ("other", |_| Err(())),
505 (None, |_| Ok(())),
506 (None, |_| Ok(())),
507 )
508 .unwrap();
509 assert_eq!(
510 stanza,
511 (
512 Ok(vec![Stanza {
513 tag: "test".to_owned(),
514 args: vec!["foo".to_owned()],
515 body: b"bar"[..].to_owned()
516 }]),
517 Ok(vec![]),
518 None,
519 None,
520 )
521 );
522 }
523}