curl_parser/
parser.rs

1use crate::{error::*, ParsedRequest};
2use base64::{engine::general_purpose::STANDARD, Engine};
3use http::{
4    header::{HeaderName, ACCEPT, AUTHORIZATION, CONTENT_TYPE},
5    HeaderValue, Method,
6};
7use minijinja::Environment;
8use pest::Parser as _;
9use pest_derive::Parser;
10use serde::Serialize;
11use snafu::ResultExt;
12use std::str::FromStr;
13
14#[derive(Debug, Parser)]
15#[grammar = "src/curl.pest"]
16pub struct CurlParser;
17
18fn parse_input(input: &str) -> Result<ParsedRequest> {
19    let pairs = CurlParser::parse(Rule::input, input).context(ParseRuleSnafu)?;
20    let mut parsed = ParsedRequest::default();
21    for pair in pairs {
22        match pair.as_rule() {
23            Rule::method => {
24                let method = pair.as_str().parse().context(ParseMethodSnafu)?;
25                parsed.method = method;
26            }
27            Rule::url => {
28                let url = pair.into_inner().as_str();
29
30                // if empty scheme set curl defaults to HTTP
31                #[cfg(feature = "uri")]
32                let url = if url.contains("://") {
33                    url.parse().context(ParseUrlSnafu)?
34                } else {
35                    format!("http://{url}").parse().context(ParseUrlSnafu)?
36                };
37                #[cfg(not(feature = "uri"))]
38                let url = if url.contains("://") {
39                    url.to_string()
40                } else {
41                    format!("http://{url}/")
42                };
43
44                parsed.url = url;
45            }
46            Rule::location => {
47                let s = pair
48                    .into_inner()
49                    .next()
50                    .expect("location string must be present")
51                    .as_str();
52                #[cfg(feature = "uri")]
53                let location = s.parse().context(ParseUrlSnafu)?;
54                #[cfg(not(feature = "uri"))]
55                let location = s.to_string();
56                parsed.url = location;
57            }
58            Rule::header => {
59                let s = pair
60                    .into_inner()
61                    .next()
62                    .expect("header string must be present")
63                    .as_str();
64                let mut kv = s.splitn(2, ':');
65                let name = kv.next().expect("key must present").trim();
66                let value = kv.next().expect("value must present").trim();
67                parsed.headers.insert(
68                    HeaderName::from_str(name).context(ParseHeaderNameSnafu)?,
69                    HeaderValue::from_str(value).context(ParseHeaderValueSnafu)?,
70                );
71            }
72            Rule::auth => {
73                let s = pair
74                    .into_inner()
75                    .next()
76                    .expect("header string must be present")
77                    .as_str();
78                let basic_auth = format!("Basic {}", STANDARD.encode(s.as_bytes()));
79                parsed.headers.insert(
80                    AUTHORIZATION,
81                    basic_auth.parse().context(ParseHeaderValueSnafu)?,
82                );
83            }
84            Rule::body => {
85                let s = pair.as_str().trim();
86                let s = remove_quote(s);
87                parsed.body.push(s.into());
88            }
89            Rule::ssl_verify_option => {
90                parsed.insecure = true;
91            }
92            Rule::EOI => break,
93            _ => unreachable!("Unexpected rule: {:?}", pair.as_rule()),
94        }
95    }
96
97    if parsed.headers.get(CONTENT_TYPE).is_none() && !parsed.body.is_empty() {
98        parsed.headers.insert(
99            CONTENT_TYPE,
100            HeaderValue::from_static("application/x-www-form-urlencoded"),
101        );
102    }
103    if parsed.headers.get(ACCEPT).is_none() {
104        parsed
105            .headers
106            .insert(ACCEPT, HeaderValue::from_static("*/*"));
107    }
108    if !parsed.body.is_empty() && parsed.method == Method::GET {
109        parsed.method = Method::POST
110    }
111    Ok(parsed)
112}
113
114impl ParsedRequest {
115    pub fn load(input: &str, context: impl Serialize) -> Result<Self> {
116        let env = Environment::new();
117        let input = env.render_str(input, context).context(RenderSnafu)?;
118        parse_input(&input)
119    }
120
121    pub fn body(&self) -> Option<String> {
122        if self.body.is_empty() {
123            return None;
124        }
125
126        match self.headers.get(CONTENT_TYPE) {
127            Some(content_type) if content_type == "application/x-www-form-urlencoded" => {
128                Some(self.form_urlencoded())
129            }
130            Some(content_type) if content_type == "application/json" => self.body.last().cloned(),
131            v => unimplemented!("Unsupported content type: {:?}", v),
132        }
133    }
134
135    fn form_urlencoded(&self) -> String {
136        let mut encoded = form_urlencoded::Serializer::new(String::new());
137        for item in &self.body {
138            let mut kv = item.splitn(2, '=');
139            let key = kv.next().expect("key must present");
140            let value = kv.next().expect("value must present");
141            encoded.append_pair(remove_quote(key), remove_quote(value));
142        }
143        encoded.finish()
144    }
145}
146
147impl FromStr for ParsedRequest {
148    type Err = Error;
149
150    fn from_str(s: &str) -> Result<Self> {
151        parse_input(s)
152    }
153}
154
155#[cfg(feature = "reqwest")]
156impl TryFrom<&ParsedRequest> for reqwest::RequestBuilder {
157    type Error = reqwest::Error;
158
159    fn try_from(req: &ParsedRequest) -> Result<Self, Self::Error> {
160        let body = req.body();
161        let client = reqwest::Client::builder()
162            .danger_accept_invalid_certs(req.insecure)
163            .build()?;
164
165        let req_builder = client
166            .request(req.method.clone(), req.url.to_string())
167            .headers(req.headers.clone());
168
169        let req = if let Some(body) = body {
170            req_builder.body(body)
171        } else {
172            req_builder
173        };
174
175        Ok(req)
176    }
177}
178
179fn remove_quote(s: &str) -> &str {
180    match (&s[0..1], &s[s.len() - 1..]) {
181        ("'", "'") => &s[1..s.len() - 1],
182        ("\"", "\"") => &s[1..s.len() - 1],
183        _ => s,
184    }
185}
186
187#[cfg(test)]
188mod tests {
189    use super::*;
190    use anyhow::Result;
191    use http::{header::ACCEPT, Method};
192    use serde_json::json;
193
194    #[test]
195    fn parse_curl_1_should_work() -> Result<()> {
196        let input = r#"curl \
197          -X PATCH \
198          -d '{"visibility":"private"}' \
199          -H "Accept: application/vnd.github+json" \
200          -H "Authorization: Bearer {{ token }}"\
201          -H "X-GitHub-Api-Version: 2022-11-28" \
202          https://api.github.com/user/email/visibility "#;
203        let parsed = ParsedRequest::load(input, json!({ "token": "abcd1234" }))?;
204        assert_eq!(parsed.method, Method::PATCH);
205        assert_eq!(
206            parsed.url.to_string(),
207            "https://api.github.com/user/email/visibility"
208        );
209        assert_eq!(
210            parsed.headers.get(ACCEPT),
211            Some(&HeaderValue::from_static("application/vnd.github+json"))
212        );
213        assert_eq!(parsed.body, vec!["{\"visibility\":\"private\"}"]);
214
215        Ok(())
216    }
217
218    #[test]
219    fn parse_curl_2_should_work() -> Result<()> {
220        let input = r#"curl \
221        -X POST \
222        -H "Accept: application/vnd.github+json" \
223        -H "Authorization: Bearer {{ token }}"\
224        -H "X-GitHub-Api-Version: 2022-11-28" \
225        -L "https://api.github.com/user/emails" \
226        -d '{"emails":["octocat@github.com","mona@github.com","octocat@octocat.org"]}'"#;
227        let parsed = ParsedRequest::load(input, json!({ "token": "abcd1234" }))?;
228        assert_eq!(parsed.method, Method::POST);
229        assert_eq!(parsed.url.to_string(), "https://api.github.com/user/emails");
230        assert_eq!(
231            parsed.headers.get(AUTHORIZATION),
232            Some(&HeaderValue::from_static("Bearer abcd1234"))
233        );
234        assert_eq!(
235            parsed.body,
236            vec![r#"{"emails":["octocat@github.com","mona@github.com","octocat@octocat.org"]}"#]
237        );
238        Ok(())
239    }
240
241    #[tokio::test]
242    async fn parse_curl_3_should_work() -> Result<()> {
243        let input = r#"curl https://api.stripe.com/v1/charges \
244        -u {{ key }}: \
245        -H "Stripe-Version: 2022-11-15""#;
246
247        let parsed =
248            ParsedRequest::load(input, json!({ "key": "sk_test_4eC39HqLyjWDarjtT1zdp7dc" }))?;
249        assert_eq!(parsed.method, Method::GET);
250        assert_eq!(parsed.url.to_string(), "https://api.stripe.com/v1/charges");
251        assert_eq!(
252            parsed.headers.get(AUTHORIZATION),
253            Some(&HeaderValue::from_static(
254                "Basic c2tfdGVzdF80ZUMzOUhxTHlqV0Rhcmp0VDF6ZHA3ZGM6"
255            ))
256        );
257
258        #[cfg(feature = "reqwest")]
259        {
260            let req = reqwest::RequestBuilder::try_from(&parsed)?;
261            let res = req.send().await?;
262            assert_eq!(res.status(), 200);
263            let _body = res.text().await?;
264        }
265        Ok(())
266    }
267
268    #[tokio::test]
269    async fn parse_curl_4_should_work() -> Result<()> {
270        let input = r#"curl "https://ifconfig.me/""#;
271
272        let parsed = ParsedRequest::from_str(input)?;
273        assert_eq!(parsed.method, Method::GET);
274        assert_eq!(parsed.url.to_string(), "https://ifconfig.me/");
275
276        #[cfg(feature = "reqwest")]
277        {
278            let req = reqwest::RequestBuilder::try_from(&parsed)?;
279            let res = req.send().await?;
280            assert_eq!(res.status(), 200);
281            let _body = res.text().await?;
282        }
283        Ok(())
284    }
285
286    #[tokio::test]
287    async fn parse_curl_5_should_work() -> Result<()> {
288        let input = r#"curl 'ifconfig.me'"#;
289
290        let parsed = ParsedRequest::from_str(input)?;
291        assert_eq!(parsed.method, Method::GET);
292        assert_eq!(parsed.url.to_string(), "http://ifconfig.me/");
293
294        #[cfg(feature = "reqwest")]
295        {
296            let req = reqwest::RequestBuilder::try_from(&parsed)?;
297            let res = req.send().await?;
298            assert_eq!(res.status(), 200);
299            let _body = res.text().await?;
300        }
301        Ok(())
302    }
303
304    #[tokio::test]
305    async fn parse_curl_with_insecure_should_work() -> Result<(), Box<dyn std::error::Error>> {
306        let input = r#"#this is good
307        curl -k 'https://example.com/'"#;
308
309        let parsed: ParsedRequest = input.parse()?;
310        assert_eq!(parsed.method, Method::GET);
311        assert_eq!(parsed.url.to_string(), "https://example.com/");
312        assert!(parsed.insecure);
313        Ok(())
314    }
315
316    #[tokio::test]
317    async fn parse_curl_with_body_should_work() -> Result<()> {
318        let input = r#"curl --location https://example.com --header 'Content-Type: application/json' -d '{"-name":"--John"," --age":30}'"#;
319        let parsed = ParsedRequest::from_str(input)?;
320        assert_eq!(parsed.method, Method::POST);
321        assert_eq!(parsed.body, vec!["{\"-name\":\"--John\",\" --age\":30}"]);
322        Ok(())
323    }
324}