coman/cli/
request.rs

1//! CLI commands for making HTTP requests
2//!
3//! This module provides the command-line interface for making HTTP requests,
4//! including progress bars, colored output, and interactive prompts.
5
6use clap::{Args, Subcommand};
7use colored::{ColoredString, Colorize};
8use futures::stream::StreamExt;
9use indicatif::{ProgressBar, ProgressStyle};
10use infer;
11use reqwest::header::HeaderMap;
12use reqwest::multipart::{self, Part};
13use reqwest::{redirect::Policy, ClientBuilder, StatusCode};
14use serde_json::Value;
15use std::fmt;
16use std::io::{self, Write};
17use std::time::Duration;
18
19#[derive(Args, Clone, Debug)]
20pub struct RequestData {
21    pub url: String,
22
23    #[clap(
24        short = 'H',
25        long = "header",
26        value_parser = RequestData::parse_header,
27        value_name = "KEY:VALUE",
28        num_args = 1..,
29        required = false
30    )]
31    pub headers: Vec<(String, String)>,
32
33    #[clap(short, long, default_value = "", required = false)]
34    pub body: String,
35}
36
37impl RequestData {
38    pub fn parse_header(s: &str) -> Result<(String, String), String> {
39        let parts: Vec<&str> = s.splitn(2, ':').collect();
40        if parts.len() != 2 {
41            return Err(format!("Invalid header format: '{}'. Use KEY:VALUE", s));
42        }
43        Ok((parts[0].trim().to_string(), parts[1].trim().to_string()))
44    }
45}
46
47#[derive(Subcommand, Clone, Debug)]
48pub enum RequestCommands {
49    Get {
50        #[clap(flatten)]
51        data: RequestData,
52    },
53    Post {
54        #[clap(flatten)]
55        data: RequestData,
56    },
57    Put {
58        #[clap(flatten)]
59        data: RequestData,
60    },
61    Delete {
62        #[clap(flatten)]
63        data: RequestData,
64    },
65    Patch {
66        #[clap(flatten)]
67        data: RequestData,
68    },
69}
70
71impl fmt::Display for RequestCommands {
72    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
73        match self {
74            Self::Get { .. } => write!(f, "GET"),
75            Self::Post { .. } => write!(f, "POST"),
76            Self::Put { .. } => write!(f, "PUT"),
77            Self::Delete { .. } => write!(f, "DELETE"),
78            Self::Patch { .. } => write!(f, "PATCH"),
79        }
80    }
81}
82
83impl RequestCommands {
84    pub fn get_data(&self) -> &RequestData {
85        // assuming RequestData is the type of 'data'
86        match self {
87            Self::Get { data }
88            | Self::Post { data }
89            | Self::Put { data }
90            | Self::Delete { data }
91            | Self::Patch { data } => data,
92        }
93    }
94
95    pub fn print_request_method(&self, url: &str, status: StatusCode, elapsed: u128) {
96        println!(
97            "\n[{}] {} - {} ({} ms)\n",
98            self.to_string().bold().bright_yellow(),
99            url.to_string().bold().bright_white(),
100            Self::colorize_status(status),
101            elapsed
102        );
103    }
104
105    fn print_request_headers(headers: &[(String, String)]) {
106        println!("{}", "Request Headers:".to_string().bold().bright_blue());
107        for (key, value) in headers.iter() {
108            println!("  {}: {:?}", key.to_string().bright_white(), value);
109        }
110    }
111
112    fn print_request_body(body: &str) {
113        println!("{}", "Request Body:".to_string().bold().bright_blue());
114        println!("{}", body.italic());
115    }
116
117    async fn print_request_response(
118        response: reqwest::Response,
119        verbose: bool,
120        stream: bool,
121    ) -> Result<String, Box<dyn std::error::Error>> {
122        if verbose && !stream {
123            println!("{}", "Response Headers:".to_string().bold().bright_blue());
124            for (key, value) in response.headers().iter() {
125                println!("  {}: {:?}", key.to_string().bright_white(), value);
126            }
127            println!("\n{}", "Response Body:".to_string().bold().bright_blue());
128        }
129
130        if stream {
131            // Get the stream of bytes
132            let mut stream = response.bytes_stream();
133
134            // Process each chunk as it arrives
135            while let Some(chunk) = stream.next().await {
136                let chunk = chunk?;
137                std::io::stdout().write_all(&chunk)?;
138                std::io::stdout().flush()?;
139            }
140        } else {
141            let body = response.text().await?;
142            //Try parsing the body as JSON
143            if let Ok(json) = serde_json::from_str::<Value>(&body) {
144                let pretty = serde_json::to_string_pretty(&json)?;
145                println!("{}", pretty.green());
146            } else {
147                println!("{}", body.italic());
148            }
149        }
150
151        Ok("".to_string())
152    }
153
154    pub fn colorize_status(status: StatusCode) -> ColoredString {
155        match status.as_u16() {
156            200..=299 => status.to_string().bold().bright_green(),
157            300..=499 => status.to_string().bold().bright_yellow(),
158            500..=599 => status.to_string().bold().bright_red(),
159            _ => status.to_string().white(),
160        }
161    }
162
163    fn prompt_missing_header_data(mut headers: Vec<(String, String)>) -> Vec<(String, String)> {
164        for header in headers.iter_mut() {
165            if header.1.contains(":?") {
166                eprint!(
167                    "Header value for key '{}' is missing data. Please provide the correct value: ",
168                    header.0
169                );
170                io::stdout().flush().ok();
171                let mut new_value = String::new();
172                std::io::stdin()
173                    .read_line(&mut new_value)
174                    .expect("Failed to read header value");
175                header.1 = new_value.trim().to_string();
176            }
177        }
178        headers
179    }
180
181    fn prompt_missing_body_data(mut body: String) -> String {
182        while let Some(idx) = body.find(":?") {
183            eprint!(
184                "Missing data at position {} - {}. Please provide the correct value: ",
185                idx, body
186            );
187            io::stdout().flush().ok();
188            let mut replacement = String::new();
189            std::io::stdin()
190                .read_line(&mut replacement)
191                .expect("Failed to read body placeholder");
192            let replacement = replacement.trim();
193            body.replace_range(idx..idx + 2, replacement);
194        }
195        body
196    }
197
198    pub fn build_header_map(headers: &[(String, String)]) -> HeaderMap {
199        let mut header_map = HeaderMap::new();
200        for (key, value) in headers {
201            if let Ok(header_name) = key.parse::<reqwest::header::HeaderName>() {
202                header_map.insert(header_name, value.parse().unwrap());
203            }
204        }
205        header_map
206    }
207
208    /// Checks if the Vec<u8> is valid UTF-8 (likely text) or not (binary).
209    fn is_text_data(data: &[u8]) -> bool {
210        std::str::from_utf8(data).is_ok()
211    }
212
213    pub async fn execute_request(
214        &self,
215        verbose: bool,
216        stdin_input: Vec<u8>,
217        stream: bool,
218    ) -> Result<(reqwest::Response, u128), Box<dyn std::error::Error>> {
219        let data = self.get_data();
220
221        let current_url = if !stream {
222            Self::prompt_missing_body_data(data.url.clone())
223        } else {
224            data.url.clone()
225        };
226
227        let headers = if !stream {
228            Self::prompt_missing_header_data(data.headers.clone())
229        } else {
230            data.headers.clone()
231        };
232
233        let is_text = Self::is_text_data(&stdin_input);
234        let body = if stdin_input.is_empty() {
235            Self::prompt_missing_body_data(data.body.clone())
236        } else if is_text {
237            // Convert to string for text processing
238            let text = String::from_utf8_lossy(&stdin_input).to_string();
239            Self::prompt_missing_body_data(text)
240        } else {
241            // Binary: skip text prompts, use as-is (but reqwest body will handle bytes)
242            String::new() // Placeholder; we'll use bytes directly in the request
243        };
244
245        let part = if !stream && !stdin_input.is_empty() && !is_text {
246            // Binary data from stdin
247            let kind = infer::get(&stdin_input).ok_or_else(|| {
248                Box::new(std::io::Error::new(
249                    std::io::ErrorKind::InvalidData,
250                    "Unknown file type",
251                ))
252            })?;
253            let mime_type = kind.mime_type(); // e.g., "image/jpeg"
254            let extension = kind.extension();
255            let filename = format!("file.{}", extension);
256            Part::bytes(stdin_input.clone())
257                .file_name(filename) // Mandatory for Spring FilePart
258                .mime_str(mime_type)?
259        } else if !stream && !stdin_input.is_empty() && is_text {
260            // Text data from stdin
261            Part::text(String::from_utf8_lossy(&stdin_input).to_string())
262        } else {
263            // Use body string
264            Part::text(body.clone())
265        };
266
267        if verbose {
268            Self::print_request_headers(&headers);
269            Self::print_request_body(body.as_str());
270        }
271
272        let client = ClientBuilder::new()
273            .redirect(Policy::none())
274            .build()
275            .map_err(|e| Box::new(e) as Box<dyn std::error::Error>)?;
276
277        let headers = Self::build_header_map(&headers);
278
279        let method = match self {
280            Self::Get { .. } => reqwest::Method::GET,
281            Self::Post { .. } => reqwest::Method::POST,
282            Self::Put { .. } => reqwest::Method::PUT,
283            Self::Delete { .. } => reqwest::Method::DELETE,
284            Self::Patch { .. } => reqwest::Method::PATCH,
285        };
286
287        let pb = ProgressBar::new_spinner();
288
289        pb.set_style(
290            ProgressStyle::with_template("{spinner:.green} {elapsed} {msg}")
291                .unwrap()
292                .tick_strings(&["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"]),
293        );
294
295        pb.enable_steady_tick(Duration::from_millis(80));
296        pb.set_message("Executing Request...");
297
298        let start = std::time::Instant::now();
299
300        let resp = if method == reqwest::Method::GET {
301            client
302                .get(&current_url)
303                .headers(headers)
304                .send()
305                .await
306                .map_err(|e| Box::new(e) as Box<dyn std::error::Error>)
307        } else if !stdin_input.is_empty() {
308            if stream {
309                // For streaming binary data
310                client
311                    .request(method, &current_url)
312                    .headers(headers)
313                    .body(stdin_input) // Send as bytes
314                    .send()
315                    .await
316                    .map_err(|e| Box::new(e) as Box<dyn std::error::Error>)
317            } else {
318                // For non-streaming binary or text data
319                if is_text {
320                    // Text data
321                    client
322                        .request(method, &current_url)
323                        .headers(headers)
324                        .body(String::from_utf8_lossy(&stdin_input).to_string())
325                        .send()
326                        .await
327                        .map_err(|e| Box::new(e) as Box<dyn std::error::Error>)
328                } else {
329                    let form = multipart::Form::new().part("file", part);
330                    client
331                        .request(method, &current_url)
332                        .headers(headers)
333                        .multipart(form)
334                        .send()
335                        .await
336                        .map_err(|e| Box::new(e) as Box<dyn std::error::Error>)
337                }
338            }
339        } else {
340            client
341                .request(method, &current_url)
342                .headers(headers)
343                .body(body)
344                .send()
345                .await
346                .map_err(|e| Box::new(e) as Box<dyn std::error::Error>)
347        };
348
349        let elapsed = start.elapsed().as_millis();
350
351        match resp {
352            Ok(response) => Ok((response, elapsed)),
353            Err(e) => {
354                pb.finish_with_message("Request failed");
355                Err(e)
356            }
357        }
358    }
359
360    pub async fn run(
361        &self,
362        verbose: bool,
363        stdin_input: Vec<u8>,
364        stream: bool,
365    ) -> Result<String, Box<dyn std::error::Error>> {
366        let response = Self::execute_request(self, verbose, stdin_input, stream).await;
367
368        match response {
369            Ok((resp, elapsed)) => {
370                if verbose && !stream {
371                    println!("{:?}", resp.version());
372                    self.print_request_method(resp.url().as_ref(), resp.status(), elapsed);
373                }
374                Self::print_request_response(resp, verbose, stream).await
375            }
376            Err(err) => {
377                // Provide more detailed error information
378                if let Some(reqwest_err) = err.downcast_ref::<reqwest::Error>() {
379                    if reqwest_err.is_timeout() {
380                        eprintln!("Request timed out");
381                    } else if reqwest_err.is_connect() {
382                        eprintln!("Connection error");
383                    } else if reqwest_err.is_redirect() {
384                        eprintln!("Redirect error");
385                    }
386                } else {
387                    eprintln!("Error: {}", err);
388                }
389                Err(err)
390            }
391        }
392    }
393}