curl_parser/
parser.rs

1use crate::{ParsedRequest, error::*};
2use base64::{Engine, engine::general_purpose::STANDARD};
3use http::{
4    HeaderValue, Method,
5    header::{ACCEPT, AUTHORIZATION, CONTENT_TYPE, HeaderName},
6};
7use minijinja::Environment;
8use pest::Parser as _;
9use pest_derive::Parser;
10use serde::Serialize;
11use snafu::ResultExt;
12use std::str::FromStr;
13use std::sync::OnceLock;
14
15#[derive(Debug, Parser)]
16#[grammar = "src/curl.pest"]
17pub struct CurlParser;
18
19fn parse_input(input: &str) -> Result<ParsedRequest> {
20    let pairs = CurlParser::parse(Rule::input, input).context(ParseRuleSnafu)?;
21    let mut parsed = ParsedRequest::default();
22    for pair in pairs {
23        match pair.as_rule() {
24            Rule::method => {
25                let method = pair.as_str().parse().context(ParseMethodSnafu)?;
26                parsed.method = method;
27            }
28            Rule::url => {
29                let url = pair.into_inner().as_str();
30
31                // if empty scheme set curl defaults to HTTP
32                #[cfg(feature = "uri")]
33                let url = if url.contains("://") {
34                    url.parse().context(ParseUrlSnafu)?
35                } else {
36                    // Pre-allocate with known prefix length
37                    let mut full_url = String::with_capacity(7 + url.len()); // "http://" + url
38                    full_url.push_str("http://");
39                    full_url.push_str(url);
40                    full_url.parse().context(ParseUrlSnafu)?
41                };
42                #[cfg(not(feature = "uri"))]
43                let url = if url.contains("://") {
44                    url.to_string()
45                } else {
46                    // Pre-allocate with known prefix length
47                    let mut full_url = String::with_capacity(8 + url.len()); // "http://" + url + "/"
48                    full_url.push_str("http://");
49                    full_url.push_str(url);
50                    full_url.push('/');
51                    full_url
52                };
53
54                parsed.url = url;
55            }
56            Rule::location => {
57                let s = pair
58                    .into_inner()
59                    .next()
60                    .expect("location string must be present")
61                    .as_str();
62                #[cfg(feature = "uri")]
63                let location = s.parse().context(ParseUrlSnafu)?;
64                #[cfg(not(feature = "uri"))]
65                let location = s.to_string();
66                parsed.url = location;
67            }
68            Rule::header => {
69                let s = pair
70                    .into_inner()
71                    .next()
72                    .expect("header string must be present")
73                    .as_str();
74
75                // Use split_once for better performance
76                if let Some((name, value)) = s.split_once(':') {
77                    let header_value = unescape_string(value.trim());
78                    parsed.headers.insert(
79                        HeaderName::from_str(name.trim()).context(ParseHeaderNameSnafu)?,
80                        HeaderValue::from_str(&header_value).context(ParseHeaderValueSnafu)?,
81                    );
82                } else {
83                    // Fallback for malformed headers (should be rare)
84                    let mut kv = s.splitn(2, ':');
85                    let name = kv.next().expect("key must present").trim();
86                    let value = kv.next().expect("value must present").trim();
87                    let header_value = unescape_string(value);
88                    parsed.headers.insert(
89                        HeaderName::from_str(name).context(ParseHeaderNameSnafu)?,
90                        HeaderValue::from_str(&header_value).context(ParseHeaderValueSnafu)?,
91                    );
92                }
93            }
94            Rule::auth => {
95                let s = pair
96                    .into_inner()
97                    .next()
98                    .expect("header string must be present")
99                    .as_str();
100                let encoded = STANDARD.encode(s.as_bytes());
101                // Pre-allocate with known prefix length
102                let mut basic_auth = String::with_capacity(6 + encoded.len()); // "Basic " + encoded
103                basic_auth.push_str("Basic ");
104                basic_auth.push_str(&encoded);
105                parsed.headers.insert(
106                    AUTHORIZATION,
107                    basic_auth.parse().context(ParseHeaderValueSnafu)?,
108                );
109            }
110            Rule::body => {
111                let s = pair.as_str().trim();
112                let s = remove_quote(s);
113                parsed.body.push(s.into());
114            }
115            Rule::ssl_verify_option => {
116                parsed.insecure = true;
117            }
118            Rule::EOI => break,
119            _ => unreachable!("Unexpected rule: {:?}", pair.as_rule()),
120        }
121    }
122
123    if parsed.headers.get(CONTENT_TYPE).is_none() && !parsed.body.is_empty() {
124        parsed.headers.insert(
125            CONTENT_TYPE,
126            HeaderValue::from_static("application/x-www-form-urlencoded"),
127        );
128    }
129    if parsed.headers.get(ACCEPT).is_none() {
130        parsed
131            .headers
132            .insert(ACCEPT, HeaderValue::from_static("*/*"));
133    }
134    if !parsed.body.is_empty() && parsed.method == Method::GET {
135        parsed.method = Method::POST
136    }
137    Ok(parsed)
138}
139
140// Cached minijinja environment for better performance
141static TEMPLATE_ENV: OnceLock<Environment<'static>> = OnceLock::new();
142
143fn get_template_env() -> &'static Environment<'static> {
144    TEMPLATE_ENV.get_or_init(Environment::new)
145}
146
147impl ParsedRequest {
148    pub fn load(input: &str, context: impl Serialize) -> Result<Self> {
149        let env = get_template_env();
150        let input = env.render_str(input, context).context(RenderSnafu)?;
151        parse_input(&input)
152    }
153
154    pub fn body(&self) -> Option<String> {
155        if self.body.is_empty() {
156            return None;
157        }
158
159        match self.headers.get(CONTENT_TYPE) {
160            Some(content_type) if content_type == "application/x-www-form-urlencoded" => {
161                Some(self.form_urlencoded())
162            }
163            Some(content_type) if content_type == "application/json" => self.body.last().cloned(),
164            v => unimplemented!("Unsupported content type: {:?}", v),
165        }
166    }
167
168    fn form_urlencoded(&self) -> String {
169        let mut encoded = form_urlencoded::Serializer::new(String::new());
170        for item in &self.body {
171            // Use split_once for better performance
172            if let Some((key, value)) = item.split_once('=') {
173                encoded.append_pair(remove_quote(key), remove_quote(value));
174            } else {
175                // Fallback for malformed form data (should be rare)
176                let mut kv = item.splitn(2, '=');
177                let key = kv.next().expect("key must present");
178                let value = kv.next().expect("value must present");
179                encoded.append_pair(remove_quote(key), remove_quote(value));
180            }
181        }
182        encoded.finish()
183    }
184}
185
186impl FromStr for ParsedRequest {
187    type Err = Error;
188
189    fn from_str(s: &str) -> Result<Self> {
190        parse_input(s)
191    }
192}
193
194#[cfg(feature = "reqwest")]
195impl TryFrom<&ParsedRequest> for reqwest::RequestBuilder {
196    type Error = reqwest::Error;
197
198    fn try_from(req: &ParsedRequest) -> Result<Self, Self::Error> {
199        let body = req.body();
200        let client = reqwest::Client::builder()
201            .danger_accept_invalid_certs(req.insecure)
202            .build()?;
203
204        let req_builder = client
205            .request(req.method.clone(), req.url.to_string())
206            .headers(req.headers.clone());
207
208        let req = if let Some(body) = body {
209            req_builder.body(body)
210        } else {
211            req_builder
212        };
213
214        Ok(req)
215    }
216}
217
218fn remove_quote(s: &str) -> &str {
219    let bytes = s.as_bytes();
220
221    // Check if string is long enough and has matching quotes
222    if bytes.len() >= 2 {
223        match (bytes[0], bytes[bytes.len() - 1]) {
224            (b'\'', b'\'') => &s[1..s.len() - 1],
225            (b'"', b'"') => &s[1..s.len() - 1],
226            _ => s,
227        }
228    } else {
229        s
230    }
231}
232
233fn unescape_string(s: &str) -> String {
234    let mut result = String::with_capacity(s.len());
235    let mut chars = s.chars();
236
237    while let Some(ch) = chars.next() {
238        if ch == '\\' {
239            if let Some(next_ch) = chars.next() {
240                match next_ch {
241                    '"' | '\\' | '/' => result.push(next_ch),
242                    'n' => result.push('\n'),
243                    'r' => result.push('\r'),
244                    't' => result.push('\t'),
245                    _ => {
246                        // If it's not a recognized escape sequence, keep both characters
247                        result.push(ch);
248                        result.push(next_ch);
249                    }
250                }
251            } else {
252                result.push(ch);
253            }
254        } else {
255            result.push(ch);
256        }
257    }
258
259    result
260}
261
262#[cfg(test)]
263mod tests {
264    use super::*;
265    use anyhow::Result;
266    use http::{Method, header::ACCEPT};
267    use serde_json::json;
268
269    #[test]
270    fn parse_curl_1_should_work() -> Result<()> {
271        let input = r#"curl \
272          -X PATCH \
273          -d '{"visibility":"private"}' \
274          -H "Accept: application/vnd.github+json" \
275          -H "Authorization: Bearer {{ token }}"\
276          -H "X-GitHub-Api-Version: 2022-11-28" \
277          https://api.github.com/user/email/visibility "#;
278        let parsed = ParsedRequest::load(input, json!({ "token": "abcd1234" }))?;
279        assert_eq!(parsed.method, Method::PATCH);
280        assert_eq!(
281            parsed.url.to_string(),
282            "https://api.github.com/user/email/visibility"
283        );
284        assert_eq!(
285            parsed.headers.get(ACCEPT),
286            Some(&HeaderValue::from_static("application/vnd.github+json"))
287        );
288        assert_eq!(parsed.body, vec!["{\"visibility\":\"private\"}"]);
289
290        Ok(())
291    }
292
293    #[test]
294    fn parse_curl_2_should_work() -> Result<()> {
295        let input = r#"curl \
296        -X POST \
297        -H "Accept: application/vnd.github+json" \
298        -H "Authorization: Bearer {{ token }}"\
299        -H "X-GitHub-Api-Version: 2022-11-28" \
300        -L "https://api.github.com/user/emails" \
301        -d '{"emails":["octocat@github.com","mona@github.com","octocat@octocat.org"]}'"#;
302        let parsed = ParsedRequest::load(input, json!({ "token": "abcd1234" }))?;
303        assert_eq!(parsed.method, Method::POST);
304        assert_eq!(parsed.url.to_string(), "https://api.github.com/user/emails");
305        assert_eq!(
306            parsed.headers.get(AUTHORIZATION),
307            Some(&HeaderValue::from_static("Bearer abcd1234"))
308        );
309        assert_eq!(
310            parsed.body,
311            vec![r#"{"emails":["octocat@github.com","mona@github.com","octocat@octocat.org"]}"#]
312        );
313        Ok(())
314    }
315
316    #[tokio::test]
317    async fn parse_curl_3_should_work() -> Result<()> {
318        let input = r#"curl https://httpbin.org/basic-auth/testuser/testpass \
319        -u {{ username }}:{{ password }} \
320        -H "Accept: application/json""#;
321
322        let parsed = ParsedRequest::load(
323            input,
324            json!({ "username": "testuser", "password": "testpass" }),
325        )?;
326        assert_eq!(parsed.method, Method::GET);
327        assert_eq!(
328            parsed.url.to_string(),
329            "https://httpbin.org/basic-auth/testuser/testpass"
330        );
331        assert_eq!(
332            parsed.headers.get(AUTHORIZATION),
333            Some(&HeaderValue::from_str("Basic dGVzdHVzZXI6dGVzdHBhc3M=")?)
334        );
335
336        #[cfg(feature = "reqwest")]
337        {
338            let req = reqwest::RequestBuilder::try_from(&parsed)?;
339            let res = req.send().await?;
340            assert_eq!(res.status(), 200);
341            let _body = res.text().await?;
342        }
343        Ok(())
344    }
345
346    #[tokio::test]
347    async fn parse_curl_4_should_work() -> Result<()> {
348        let input = r#"curl "https://ifconfig.me/""#;
349
350        let parsed = ParsedRequest::from_str(input)?;
351        assert_eq!(parsed.method, Method::GET);
352        assert_eq!(parsed.url.to_string(), "https://ifconfig.me/");
353
354        #[cfg(feature = "reqwest")]
355        {
356            let req = reqwest::RequestBuilder::try_from(&parsed)?;
357            let res = req.send().await?;
358            assert_eq!(res.status(), 200);
359            let _body = res.text().await?;
360        }
361        Ok(())
362    }
363
364    #[tokio::test]
365    async fn parse_curl_5_should_work() -> Result<()> {
366        let input = r#"curl 'ifconfig.me'"#;
367
368        let parsed = ParsedRequest::from_str(input)?;
369        assert_eq!(parsed.method, Method::GET);
370        assert_eq!(parsed.url.to_string(), "http://ifconfig.me/");
371
372        #[cfg(feature = "reqwest")]
373        {
374            let req = reqwest::RequestBuilder::try_from(&parsed)?;
375            let res = req.send().await?;
376            assert_eq!(res.status(), 200);
377            let _body = res.text().await?;
378        }
379        Ok(())
380    }
381
382    #[tokio::test]
383    async fn parse_curl_with_insecure_should_work() -> Result<(), Box<dyn std::error::Error>> {
384        let input = r#"#this is good
385        curl -k 'https://example.com/'"#;
386
387        let parsed: ParsedRequest = input.parse()?;
388        assert_eq!(parsed.method, Method::GET);
389        assert_eq!(parsed.url.to_string(), "https://example.com/");
390        assert!(parsed.insecure);
391        Ok(())
392    }
393
394    #[tokio::test]
395    async fn parse_curl_with_body_should_work() -> Result<()> {
396        let input = r#"curl --location https://example.com --header 'Content-Type: application/json' -d '{"-name":"--John"," --age":30}'"#;
397        let parsed = ParsedRequest::from_str(input)?;
398        assert_eq!(parsed.method, Method::POST);
399        assert_eq!(parsed.body, vec!["{\"-name\":\"--John\",\" --age\":30}"]);
400        Ok(())
401    }
402
403    #[test]
404    fn parse_curl_with_escaped_json_in_header() -> Result<()> {
405        let input = r#"curl https://api.github.com/repos/owner/repo \
406            -H "X-GitHub-Api-Version: 2022-11-28" \
407            -H "X-Custom-Metadata: {\"version\":\"1.0.0\",\"client\":\"rust\"}" \
408            -H "Accept: application/json""#;
409        let parsed = ParsedRequest::from_str(input)?;
410        assert_eq!(parsed.method, Method::GET);
411        assert_eq!(
412            parsed.url.to_string(),
413            "https://api.github.com/repos/owner/repo"
414        );
415
416        // Verify the header with escaped JSON is parsed correctly
417        let header_value = parsed.headers.get("X-Custom-Metadata");
418        assert!(header_value.is_some());
419        let header_str = header_value.unwrap().to_str().unwrap();
420        assert_eq!(header_str, r#"{"version":"1.0.0","client":"rust"}"#);
421
422        Ok(())
423    }
424}