1use crate::{ParsedRequest, error::*};
2use base64::{Engine, engine::general_purpose::STANDARD};
3use http::{
4 HeaderValue, Method,
5 header::{ACCEPT, AUTHORIZATION, CONTENT_TYPE, HeaderName},
6};
7use minijinja::Environment;
8use pest::Parser as _;
9use pest_derive::Parser;
10use serde::Serialize;
11use snafu::ResultExt;
12use std::str::FromStr;
13use std::sync::OnceLock;
14
15#[derive(Debug, Parser)]
16#[grammar = "src/curl.pest"]
17pub struct CurlParser;
18
19fn parse_input(input: &str) -> Result<ParsedRequest> {
20 let pairs = CurlParser::parse(Rule::input, input).context(ParseRuleSnafu)?;
21 let mut parsed = ParsedRequest::default();
22 for pair in pairs {
23 match pair.as_rule() {
24 Rule::method => {
25 let method = pair.as_str().parse().context(ParseMethodSnafu)?;
26 parsed.method = method;
27 }
28 Rule::url => {
29 let url = pair.into_inner().as_str();
30
31 #[cfg(feature = "uri")]
33 let url = if url.contains("://") {
34 url.parse().context(ParseUrlSnafu)?
35 } else {
36 let mut full_url = String::with_capacity(7 + url.len()); full_url.push_str("http://");
39 full_url.push_str(url);
40 full_url.parse().context(ParseUrlSnafu)?
41 };
42 #[cfg(not(feature = "uri"))]
43 let url = if url.contains("://") {
44 url.to_string()
45 } else {
46 let mut full_url = String::with_capacity(8 + url.len()); full_url.push_str("http://");
49 full_url.push_str(url);
50 full_url.push('/');
51 full_url
52 };
53
54 parsed.url = url;
55 }
56 Rule::location => {
57 let s = pair
58 .into_inner()
59 .next()
60 .expect("location string must be present")
61 .as_str();
62 #[cfg(feature = "uri")]
63 let location = s.parse().context(ParseUrlSnafu)?;
64 #[cfg(not(feature = "uri"))]
65 let location = s.to_string();
66 parsed.url = location;
67 }
68 Rule::header => {
69 let s = pair
70 .into_inner()
71 .next()
72 .expect("header string must be present")
73 .as_str();
74
75 if let Some((name, value)) = s.split_once(':') {
77 let header_value = unescape_string(value.trim());
78 parsed.headers.insert(
79 HeaderName::from_str(name.trim()).context(ParseHeaderNameSnafu)?,
80 HeaderValue::from_str(&header_value).context(ParseHeaderValueSnafu)?,
81 );
82 } else {
83 let mut kv = s.splitn(2, ':');
85 let name = kv.next().expect("key must present").trim();
86 let value = kv.next().expect("value must present").trim();
87 let header_value = unescape_string(value);
88 parsed.headers.insert(
89 HeaderName::from_str(name).context(ParseHeaderNameSnafu)?,
90 HeaderValue::from_str(&header_value).context(ParseHeaderValueSnafu)?,
91 );
92 }
93 }
94 Rule::auth => {
95 let s = pair
96 .into_inner()
97 .next()
98 .expect("header string must be present")
99 .as_str();
100 let encoded = STANDARD.encode(s.as_bytes());
101 let mut basic_auth = String::with_capacity(6 + encoded.len()); basic_auth.push_str("Basic ");
104 basic_auth.push_str(&encoded);
105 parsed.headers.insert(
106 AUTHORIZATION,
107 basic_auth.parse().context(ParseHeaderValueSnafu)?,
108 );
109 }
110 Rule::body => {
111 let s = pair.as_str().trim();
112 let s = remove_quote(s);
113 parsed.body.push(s.into());
114 }
115 Rule::ssl_verify_option => {
116 parsed.insecure = true;
117 }
118 Rule::EOI => break,
119 _ => unreachable!("Unexpected rule: {:?}", pair.as_rule()),
120 }
121 }
122
123 if parsed.headers.get(CONTENT_TYPE).is_none() && !parsed.body.is_empty() {
124 parsed.headers.insert(
125 CONTENT_TYPE,
126 HeaderValue::from_static("application/x-www-form-urlencoded"),
127 );
128 }
129 if parsed.headers.get(ACCEPT).is_none() {
130 parsed
131 .headers
132 .insert(ACCEPT, HeaderValue::from_static("*/*"));
133 }
134 if !parsed.body.is_empty() && parsed.method == Method::GET {
135 parsed.method = Method::POST
136 }
137 Ok(parsed)
138}
139
140static TEMPLATE_ENV: OnceLock<Environment<'static>> = OnceLock::new();
142
143fn get_template_env() -> &'static Environment<'static> {
144 TEMPLATE_ENV.get_or_init(Environment::new)
145}
146
147impl ParsedRequest {
148 pub fn load(input: &str, context: impl Serialize) -> Result<Self> {
149 let env = get_template_env();
150 let input = env.render_str(input, context).context(RenderSnafu)?;
151 parse_input(&input)
152 }
153
154 pub fn body(&self) -> Option<String> {
155 if self.body.is_empty() {
156 return None;
157 }
158
159 match self.headers.get(CONTENT_TYPE) {
160 Some(content_type) if content_type == "application/x-www-form-urlencoded" => {
161 Some(self.form_urlencoded())
162 }
163 Some(content_type) if content_type == "application/json" => self.body.last().cloned(),
164 v => unimplemented!("Unsupported content type: {:?}", v),
165 }
166 }
167
168 fn form_urlencoded(&self) -> String {
169 let mut encoded = form_urlencoded::Serializer::new(String::new());
170 for item in &self.body {
171 if let Some((key, value)) = item.split_once('=') {
173 encoded.append_pair(remove_quote(key), remove_quote(value));
174 } else {
175 let mut kv = item.splitn(2, '=');
177 let key = kv.next().expect("key must present");
178 let value = kv.next().expect("value must present");
179 encoded.append_pair(remove_quote(key), remove_quote(value));
180 }
181 }
182 encoded.finish()
183 }
184}
185
186impl FromStr for ParsedRequest {
187 type Err = Error;
188
189 fn from_str(s: &str) -> Result<Self> {
190 parse_input(s)
191 }
192}
193
194#[cfg(feature = "reqwest")]
195impl TryFrom<&ParsedRequest> for reqwest::RequestBuilder {
196 type Error = reqwest::Error;
197
198 fn try_from(req: &ParsedRequest) -> Result<Self, Self::Error> {
199 let body = req.body();
200 let client = reqwest::Client::builder()
201 .danger_accept_invalid_certs(req.insecure)
202 .build()?;
203
204 let req_builder = client
205 .request(req.method.clone(), req.url.to_string())
206 .headers(req.headers.clone());
207
208 let req = if let Some(body) = body {
209 req_builder.body(body)
210 } else {
211 req_builder
212 };
213
214 Ok(req)
215 }
216}
217
218fn remove_quote(s: &str) -> &str {
219 let bytes = s.as_bytes();
220
221 if bytes.len() >= 2 {
223 match (bytes[0], bytes[bytes.len() - 1]) {
224 (b'\'', b'\'') => &s[1..s.len() - 1],
225 (b'"', b'"') => &s[1..s.len() - 1],
226 _ => s,
227 }
228 } else {
229 s
230 }
231}
232
233fn unescape_string(s: &str) -> String {
234 let mut result = String::with_capacity(s.len());
235 let mut chars = s.chars();
236
237 while let Some(ch) = chars.next() {
238 if ch == '\\' {
239 if let Some(next_ch) = chars.next() {
240 match next_ch {
241 '"' | '\\' | '/' => result.push(next_ch),
242 'n' => result.push('\n'),
243 'r' => result.push('\r'),
244 't' => result.push('\t'),
245 _ => {
246 result.push(ch);
248 result.push(next_ch);
249 }
250 }
251 } else {
252 result.push(ch);
253 }
254 } else {
255 result.push(ch);
256 }
257 }
258
259 result
260}
261
262#[cfg(test)]
263mod tests {
264 use super::*;
265 use anyhow::Result;
266 use http::{Method, header::ACCEPT};
267 use serde_json::json;
268
269 #[test]
270 fn parse_curl_1_should_work() -> Result<()> {
271 let input = r#"curl \
272 -X PATCH \
273 -d '{"visibility":"private"}' \
274 -H "Accept: application/vnd.github+json" \
275 -H "Authorization: Bearer {{ token }}"\
276 -H "X-GitHub-Api-Version: 2022-11-28" \
277 https://api.github.com/user/email/visibility "#;
278 let parsed = ParsedRequest::load(input, json!({ "token": "abcd1234" }))?;
279 assert_eq!(parsed.method, Method::PATCH);
280 assert_eq!(
281 parsed.url.to_string(),
282 "https://api.github.com/user/email/visibility"
283 );
284 assert_eq!(
285 parsed.headers.get(ACCEPT),
286 Some(&HeaderValue::from_static("application/vnd.github+json"))
287 );
288 assert_eq!(parsed.body, vec!["{\"visibility\":\"private\"}"]);
289
290 Ok(())
291 }
292
293 #[test]
294 fn parse_curl_2_should_work() -> Result<()> {
295 let input = r#"curl \
296 -X POST \
297 -H "Accept: application/vnd.github+json" \
298 -H "Authorization: Bearer {{ token }}"\
299 -H "X-GitHub-Api-Version: 2022-11-28" \
300 -L "https://api.github.com/user/emails" \
301 -d '{"emails":["octocat@github.com","mona@github.com","octocat@octocat.org"]}'"#;
302 let parsed = ParsedRequest::load(input, json!({ "token": "abcd1234" }))?;
303 assert_eq!(parsed.method, Method::POST);
304 assert_eq!(parsed.url.to_string(), "https://api.github.com/user/emails");
305 assert_eq!(
306 parsed.headers.get(AUTHORIZATION),
307 Some(&HeaderValue::from_static("Bearer abcd1234"))
308 );
309 assert_eq!(
310 parsed.body,
311 vec![r#"{"emails":["octocat@github.com","mona@github.com","octocat@octocat.org"]}"#]
312 );
313 Ok(())
314 }
315
316 #[tokio::test]
317 async fn parse_curl_3_should_work() -> Result<()> {
318 let input = r#"curl https://httpbin.org/basic-auth/testuser/testpass \
319 -u {{ username }}:{{ password }} \
320 -H "Accept: application/json""#;
321
322 let parsed = ParsedRequest::load(
323 input,
324 json!({ "username": "testuser", "password": "testpass" }),
325 )?;
326 assert_eq!(parsed.method, Method::GET);
327 assert_eq!(
328 parsed.url.to_string(),
329 "https://httpbin.org/basic-auth/testuser/testpass"
330 );
331 assert_eq!(
332 parsed.headers.get(AUTHORIZATION),
333 Some(&HeaderValue::from_str("Basic dGVzdHVzZXI6dGVzdHBhc3M=")?)
334 );
335
336 #[cfg(feature = "reqwest")]
337 {
338 let req = reqwest::RequestBuilder::try_from(&parsed)?;
339 let res = req.send().await?;
340 assert_eq!(res.status(), 200);
341 let _body = res.text().await?;
342 }
343 Ok(())
344 }
345
346 #[tokio::test]
347 async fn parse_curl_4_should_work() -> Result<()> {
348 let input = r#"curl "https://ifconfig.me/""#;
349
350 let parsed = ParsedRequest::from_str(input)?;
351 assert_eq!(parsed.method, Method::GET);
352 assert_eq!(parsed.url.to_string(), "https://ifconfig.me/");
353
354 #[cfg(feature = "reqwest")]
355 {
356 let req = reqwest::RequestBuilder::try_from(&parsed)?;
357 let res = req.send().await?;
358 assert_eq!(res.status(), 200);
359 let _body = res.text().await?;
360 }
361 Ok(())
362 }
363
364 #[tokio::test]
365 async fn parse_curl_5_should_work() -> Result<()> {
366 let input = r#"curl 'ifconfig.me'"#;
367
368 let parsed = ParsedRequest::from_str(input)?;
369 assert_eq!(parsed.method, Method::GET);
370 assert_eq!(parsed.url.to_string(), "http://ifconfig.me/");
371
372 #[cfg(feature = "reqwest")]
373 {
374 let req = reqwest::RequestBuilder::try_from(&parsed)?;
375 let res = req.send().await?;
376 assert_eq!(res.status(), 200);
377 let _body = res.text().await?;
378 }
379 Ok(())
380 }
381
382 #[tokio::test]
383 async fn parse_curl_with_insecure_should_work() -> Result<(), Box<dyn std::error::Error>> {
384 let input = r#"#this is good
385 curl -k 'https://example.com/'"#;
386
387 let parsed: ParsedRequest = input.parse()?;
388 assert_eq!(parsed.method, Method::GET);
389 assert_eq!(parsed.url.to_string(), "https://example.com/");
390 assert!(parsed.insecure);
391 Ok(())
392 }
393
394 #[tokio::test]
395 async fn parse_curl_with_body_should_work() -> Result<()> {
396 let input = r#"curl --location https://example.com --header 'Content-Type: application/json' -d '{"-name":"--John"," --age":30}'"#;
397 let parsed = ParsedRequest::from_str(input)?;
398 assert_eq!(parsed.method, Method::POST);
399 assert_eq!(parsed.body, vec!["{\"-name\":\"--John\",\" --age\":30}"]);
400 Ok(())
401 }
402
403 #[test]
404 fn parse_curl_with_escaped_json_in_header() -> Result<()> {
405 let input = r#"curl https://api.github.com/repos/owner/repo \
406 -H "X-GitHub-Api-Version: 2022-11-28" \
407 -H "X-Custom-Metadata: {\"version\":\"1.0.0\",\"client\":\"rust\"}" \
408 -H "Accept: application/json""#;
409 let parsed = ParsedRequest::from_str(input)?;
410 assert_eq!(parsed.method, Method::GET);
411 assert_eq!(
412 parsed.url.to_string(),
413 "https://api.github.com/repos/owner/repo"
414 );
415
416 let header_value = parsed.headers.get("X-Custom-Metadata");
418 assert!(header_value.is_some());
419 let header_str = header_value.unwrap().to_str().unwrap();
420 assert_eq!(header_str, r#"{"version":"1.0.0","client":"rust"}"#);
421
422 Ok(())
423 }
424}