cmd_derive/
lib.rs

1#![forbid(unsafe_code)]
2
3use proc_macro2::{TokenStream, TokenTree, Punct, Group, Ident, Span};
4use proc_macro::{TokenStream as TokenStream1};
5use boolinator::Boolinator;
6use std::iter::Peekable;
7use quote::{quote, quote_spanned, ToTokens};
8use syn::Error;
9
10#[proc_macro]
11pub fn cmd(input: TokenStream1) -> TokenStream1 {
12    parse_cmd(TokenStream::from(input))
13        .map_or_else(
14            |e| e.to_compile_error(),
15            |c| c.into_token_stream(),
16        ).into()
17}
18
19#[derive(Debug, PartialEq)]
20pub(crate) struct Cmd {
21    commands: Vec<Command>,
22}
23
24#[derive(Debug)]
25struct Command {
26    span: Span,
27    terms: Vec<Expr>
28}
29
30impl PartialEq for Command {
31    fn eq(&self, other: &Self) -> bool {
32        self.terms == other.terms
33    }
34}
35
36pub(crate) fn parse_cmd(input: TokenStream) -> syn::Result<Cmd> {
37    let mut tokens = input.into_iter().peekable();
38
39    let mut commands: Vec<Command> = vec![];
40    loop {
41        let command = parse_command(&mut tokens)?;
42        commands.push(command);
43
44        match next_punct_is(&mut tokens, '|') {
45            Ok(_) => { tokens.next(); },
46            Err(NextError::EOF) => break,
47            _ => return Err(
48                Error::new(span_remaining(&mut tokens),
49                    "expected EOF or | after end of command")
50            ),
51        }
52    }
53
54    Ok(Cmd{
55        commands,
56    })
57}
58
59fn parse_command<I>(input: &mut Peekable<I>) -> syn::Result<Command>
60where I: Iterator<Item=TokenTree>
61{
62    let first_term = parse_term(input)?;
63    let mut span = first_term.span();
64    let mut terms = vec![first_term];
65
66    loop {
67        if next_punct_is(input, '|').is_ok() || is_eof(input)  {
68            break;
69        }
70
71        let term = parse_term(input)?;
72        let term_span = term.span();
73        terms.push(term);
74
75        span = span.join(term_span).ok_or(
76            Error::new(span, "internal error: could not join spans")
77        )?;
78    }
79
80    Ok(Command{span, terms})
81}
82
83fn is_eof<I>(input: &mut Peekable<I>) -> bool
84where I: Iterator<Item=TokenTree>
85{
86    input.peek().is_none()
87}
88
89enum Expr {
90    Literal(String, Span),
91    Expr(syn::Expr, Span),
92}
93
94impl std::fmt::Debug for Expr {
95    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
96        match self {
97            Expr::Literal(s, _) => write!(f, "Literal({})", s),
98            Expr::Expr(e, _) => write!(f, "Expr({:?})", e.to_token_stream()),
99        }
100    }
101}
102
103impl PartialEq for Expr {
104    fn eq(&self, other: &Self) -> bool {
105        format!("{:?}", self) == format!("{:?}", other)
106    }
107}
108
109impl Expr {
110    pub fn span(&self) -> Span {
111        match self {
112            Expr::Literal(_, span) => *span,
113            Expr::Expr(_, span) => *span,
114        }
115    }
116}
117
118impl ToTokens for Expr {
119    fn to_tokens(&self, t: &mut TokenStream) {
120        t.extend(match self {
121            Expr::Literal(str, span) => quote_spanned!{*span=> ::cmd_macro::env::into_arg(#str.to_string())?.expand()? },
122            Expr::Expr(e, span) => quote_spanned! {*span=> format!("{}", #e) },
123        })
124    }
125}
126
127const ALONE_HASH_ERROR: &str = "'#' must not be alone. use '##' if you meant to pass a literal '#' symbol.
128To pass two or more, only one extra '#' must be supplied
129Example: `cmd!(echo ####a###)` is equiavalent to `$ echo ###a###`";
130
131const UNEXPECTED_HASH_ERROR: &str = "'#' should be followed by an ident, an expression wrapped in any or (), {}, [] or another #
132Example:
133cmd!(echo ##) // echo #
134
135let hello = \"hello world\";
136cmd!(echo #hello) // echo \"hello world\"
137
138let hello = \"hello\"
139let world = \"world\";
140cmd!(echo #{[hello, world].join(\" \")}) // echo \"hello world\"
141";
142
143fn parse_term<I>(input: &mut Peekable<I>) -> syn::Result<Expr>
144where I: Iterator<Item=TokenTree>
145{
146    if next_punct_is(input, '#').is_ok() {
147        let p = punct(input)
148            .expect("internal error: peeked '#' but could not parse punct. please file a bug report");
149        
150        match input.peek() {
151            None => return Err(Error::new(p.span(), ALONE_HASH_ERROR)),
152            Some(tt) => {
153                let span = to_span(tt);
154
155                if !joined(p.span(), span) {
156                    return Err(Error::new(p.span(), ALONE_HASH_ERROR));
157                }
158
159                match tt {
160                    TokenTree::Punct(p) => 
161                        match p.as_char() {
162                            '#' => {}, // Do nothing. Valid term
163                            _ => return Err(Error::new(p.span(), UNEXPECTED_HASH_ERROR)),
164                        },
165                    TokenTree::Ident(_) => {
166                        return parse_ident(ident(input)
167                            .expect("internal error: peeked ident but could not parse ident. please file a bug report"))
168                    }
169                    TokenTree::Group(_) => {
170                        return parse_expr(group(input)
171                            .expect("internal error: peeked group but could not parse group. please file a bug report"))
172                    }
173                    _ => return Err(Error::new(p.span(), UNEXPECTED_HASH_ERROR)),
174                }
175                
176            }
177        }
178    }
179
180    parse_literal(input).map(|(lit, span)| Expr::Literal(lit, span))  
181}
182
183fn parse_literal<I>(input: &mut Peekable<I>) -> syn::Result<(String, Span)>
184where I: Iterator<Item=TokenTree> {
185    // We check for EOF in parse_command so this should not return None
186    // Unless we're parsing a Group like `{}`, which will check the literal before
187    // joining the span
188    let first = match input.next() {
189        None => return Ok((String::new(), Span::call_site())),
190        Some(term) => term,
191    };
192
193    let (mut lit, mut span) = to_string_span(first)?;
194
195    loop {
196        let peek = match input.peek() {
197            Some(tt) => tt,
198            None => break,
199        };
200
201        if !joined(span, peek.span()) {
202            break;
203        }
204
205        let (new_lit, new_span) = to_string_span(input.next()
206            .expect("internal error: peeked token but could not parse token. please file a bug report"))?;
207
208        lit = [lit, new_lit].join("");
209        span = span.join(new_span)
210            .ok_or(Error::new(span, "internal error: could not join spans"))?;
211    }
212
213    Ok((lit, span))
214}
215
216fn to_string_span(tt: TokenTree) -> syn::Result<(String, Span)> {
217    match tt {
218        TokenTree::Group(g) => {
219            let mut tokens = g.stream().into_iter().peekable();
220            // TODO: refactor. This currently gets confused.
221            // cmd!(echo {a | foo }) will run `echo {a}`
222            let (lit, _) = parse_literal(&mut tokens)?;
223            let lit = match g.delimiter() {
224                proc_macro2::Delimiter::None => lit,
225                proc_macro2::Delimiter::Parenthesis => format!("({})", lit),
226                proc_macro2::Delimiter::Brace => format!("{{{}}}", lit),
227                proc_macro2::Delimiter::Bracket => format!("<{}>", lit),
228            };
229            Ok((lit, g.span()))
230        },
231        TokenTree::Literal(l) => Ok((l.to_string(), l.span())),
232        TokenTree::Ident(i) => Ok((i.to_string(), i.span())),
233        TokenTree::Punct(p) => Ok((p.to_string(), p.span())),
234    }
235}
236
237fn parse_ident(i: Ident) -> syn::Result<Expr> {
238    use syn::{PathSegment, PathArguments, ExprPath, Path, token::Colon2, punctuated::Punctuated};
239    let mut p: Punctuated::<PathSegment, Colon2> = Punctuated::new();
240
241    let span = i.span();
242
243    p.push_value(PathSegment{
244        ident: i,
245        arguments: PathArguments::None,
246    });
247
248    Ok(Expr::Expr(
249        syn::Expr::Path(
250            ExprPath {
251                attrs: vec![],
252                qself: None,
253                path: Path {
254                    leading_colon: None,
255                    segments: p,
256                }
257            }
258        ),
259        span,
260    ))
261}
262
263fn parse_expr(g: Group) -> syn::Result<Expr> {
264    let expr = syn::parse2::<syn::Expr>(g.stream())
265        .map_err(|err| Error::new(err.span(), err.to_string()))?;
266    Ok(Expr::Expr(expr, g.span()))
267}
268
269#[derive(Debug)]
270enum NextError {
271    EOF,
272    NotFound,
273}
274
275fn to_span(tt: &TokenTree) -> Span {
276    match tt {
277        TokenTree::Literal(l) => l.span(),
278        TokenTree::Group(g) => g.span(),
279        TokenTree::Punct(p) => p.span(),
280        TokenTree::Ident(i) => i.span(),
281    }
282}
283
284fn punct<I>(tokens: &mut I) -> Result<Punct, NextError>
285where I: Iterator<Item=TokenTree>
286{
287    match tokens.next().ok_or(NextError::EOF)? {
288        TokenTree::Punct(p) => Ok(p),
289        _ => Err(NextError::NotFound)
290    }
291}
292
293fn next_punct_is<I>(tokens: &mut Peekable<I>, is: char) -> Result<Span, NextError>
294where I: Iterator<Item=TokenTree>
295{
296    match tokens.peek().ok_or(NextError::EOF)? {
297        TokenTree::Punct(p) => (p.as_char() == is).as_result(p.span(), NextError::NotFound),
298        _ => Err(NextError::NotFound),
299    }
300}
301
302fn ident<I>(tokens: &mut I) -> Result<Ident, NextError>
303where I: Iterator<Item=TokenTree>
304{
305    match tokens.next().ok_or(NextError::EOF)? {
306        TokenTree::Ident(i) => Ok(i),
307        _ => Err(NextError::NotFound),
308    }
309}
310
311fn group<I>(tokens: &mut I) -> Result<Group, NextError>
312where I: Iterator<Item=TokenTree>
313{
314    match tokens.next().ok_or(NextError::EOF)? {
315        TokenTree::Group(g) => Ok(g),
316        _ => Err(NextError::NotFound),
317    }
318}
319
320fn span_remaining<I>(tokens: &mut I) -> Span 
321where I: Iterator<Item=TokenTree>
322{
323    let mut spans = tokens
324        .map(|x| x.span());
325
326	let mut result = match spans.next() {
327        Some(span) => span,
328        None => Span::call_site(),
329    };
330    
331	for span in spans {
332		result = match result.join(span) {
333			None => return Span::from(result),
334			Some(span) => span,
335		}
336    }
337    result
338}
339    
340fn joined(lhs: Span, rhs: Span) -> bool {
341    let a = lhs.end();
342    let b = rhs.start();
343    a == b
344}
345
346impl ToTokens for Cmd {
347    fn to_tokens(&self, tokens: &mut TokenStream) {
348        if self.commands.len() == 0 {
349            tokens.extend(quote!{
350                Err(
351                    std::io::Error::new(
352                        std::io::ErrorKind::Other,
353                        "no command input",
354                    )
355                )
356            });
357            return
358        }
359
360        let mut prev: Option<Ident> = None;
361        let mut stdin = None;
362        let cmds: Vec<TokenStream> = self.commands.iter().zip(0..).map(|(cmd, i)| {
363            
364            let ident = Ident::new(&format!("x{}", i), cmd.span);
365
366            let command = cmd.to_tokens(ident.clone(), stdin.clone(), i == self.commands.len() - 1);
367
368            prev = Some(ident);
369            stdin = prev.clone().map(|p| quote!{
370                #p.stdout.unwrap()
371            });
372
373            command
374        }).collect();
375
376        let prev = prev.unwrap();
377
378        tokens.extend(quote!{
379            || -> ::cmd_macro::Result<std::process::Output> {
380                #(
381                    #cmds;
382                )*
383                Ok(#prev)
384            }()
385        })
386    }
387}
388
389impl Command {
390    fn to_tokens(&self, ident: Ident, stdin: Option<TokenStream>, last: bool) -> TokenStream {
391        let mut terms = (*self.terms).into_iter();
392        let first = match terms.next() {
393            None => return quote!{},
394            Some(term) => term,
395        };
396
397        let command = terms.fold(
398            quote!{ std::process::Command::new(#first) }, // Create the command
399            |command, arg| quote!{ #command.arg(#arg) }, // and add all the args
400        );
401
402        // If a previous command exists, add it's stdin to the stdin
403        // TODO: expand on this in future to support stderr => stdin
404        let command = match stdin {
405            Some(stdin) => quote!{ #command.stdin(#stdin) },
406            None => command,
407        };
408
409        // If this is the last command, get the output
410        // Otherwise, create a pipe for the stdout and spawn the command
411        // TODO: maybe in future always use pipes and return pipes for the user to deal with
412        if last {
413            quote!{ let #ident = #command.output()? }
414        } else {
415            quote!{ let #ident = #command.stdout(std::process::Stdio::piped()).spawn()? }
416        }
417    }
418}
419
420#[cfg(test)]
421mod tests {
422    use crate::{Cmd, Command, parse_cmd, Expr};
423    use quote::quote;
424    use proc_macro2::Span;
425
426    fn literal(s: &str) -> Expr {
427        Expr::Literal(s.to_string(), Span::call_site())
428    }
429
430    #[test]
431    fn test_cmd() {
432        let stream = quote!{ls | grep Cargo};
433        println!("{:?}", stream);
434
435        let got = parse_cmd(stream).unwrap();
436        let expected = Cmd {
437            commands: vec![
438                Command {
439                    span: Span::call_site(),
440                    terms: vec![
441                        literal("ls")
442                    ]
443                },
444                Command {
445                    span: Span::call_site(),
446                    terms: vec![
447                        literal("grep"),
448                        literal("Cargo")
449                    ]
450                }
451            ],
452        };
453        
454        assert_eq!(got, expected);
455    }
456}