apictl/
request.rs

1use std::collections::HashMap;
2
3use crate::{Applicator, List, Response, ResponseError};
4
5use serde::{Deserialize, Serialize};
6use thiserror::Error;
7
8/// Implement List for Requests.
9impl List for HashMap<String, Request> {
10    fn headers(&self) -> Vec<String> {
11        vec![
12            "Name".into(),
13            "Method".into(),
14            "URL".into(),
15            "Description".into(),
16        ]
17    }
18
19    fn values(&self) -> Vec<Vec<String>> {
20        self.iter()
21            .map(|(n, r)| {
22                vec![
23                    n.clone(),
24                    r.method.clone(),
25                    r.url.clone(),
26                    r.description.clone(),
27                ]
28            })
29            .collect()
30    }
31}
32
33/// RequestError is the error type for requests.
34#[derive(Error, Debug)]
35pub enum RequestError {
36    #[error("http error: {0}")]
37    Http(reqwest::Error),
38
39    #[error("io error: {0}")]
40    Io(std::io::Error),
41
42    #[error("response parse error: {0}")]
43    Parse(ResponseError),
44
45    #[error("unsupported method: {0}")]
46    UnsupportedMethod(String),
47}
48
49/// Result is the result type for requests.
50type Result<T> = std::result::Result<T, RequestError>;
51
52/// Requests from the configuration.
53#[derive(Clone, Debug, Serialize, Deserialize)]
54pub struct Request {
55    pub description: String,
56    pub tags: Vec<String>,
57    pub url: String,
58    #[serde(default = "default_method")]
59    pub method: String,
60    #[serde(default)]
61    pub headers: HashMap<String, String>,
62    #[serde(default)]
63    pub query_parameters: HashMap<String, String>,
64    #[serde(default)]
65    pub body: Body,
66}
67
68fn default_method() -> String {
69    "GET".to_string()
70}
71
72impl Request {
73    /// Apply the configuration and context to the request. All parts
74    /// of the request are replaced with the response values and
75    /// contexts.
76    pub fn apply(&mut self, app: &Applicator) {
77        self.url = app.apply(&self.url);
78        self.method = app.apply(&self.method);
79        for value in self.headers.values_mut() {
80            *value = app.apply(value);
81        }
82        for value in self.query_parameters.values_mut() {
83            *value = app.apply(value);
84        }
85        match &mut self.body {
86            Body::None => {}
87            Body::Form { data } => {
88                for value in data.values_mut() {
89                    *value = app.apply(value);
90                }
91            }
92            Body::Raw { from } => match from {
93                RawBody::File { path } => {
94                    *path = app.apply(path);
95                }
96                RawBody::Text { data } => {
97                    *data = app.apply(data);
98                }
99            },
100            Body::MultiPart { data } => {
101                for value in data.values_mut() {
102                    match value {
103                        MultiPartField::Text { data } => {
104                            *data = app.apply(data);
105                        }
106                        MultiPartField::File { path } => {
107                            *path = app.apply(path);
108                        }
109                    }
110                }
111            }
112        }
113    }
114
115    /// Perform the request and return it's response.
116    pub async fn request(&self) -> Result<Response> {
117        use reqwest::Client;
118
119        let mut builder = match self.method.as_str() {
120            "GET" => Client::new().get(&self.url),
121            "POST" => Client::new().post(&self.url),
122            "PUT" => Client::new().put(&self.url),
123            "DELETE" => Client::new().delete(&self.url),
124            _ => return Err(RequestError::UnsupportedMethod(self.method.clone())),
125        };
126
127        for (key, value) in self.headers.iter() {
128            builder = builder.header(key, value);
129        }
130
131        builder = builder.query(&self.query_parameters);
132
133        match &self.body {
134            Body::None => {}
135            Body::Form { data } => {
136                builder = builder.form(data);
137            }
138            Body::Raw { from } => match from {
139                RawBody::File { path } => {
140                    builder =
141                        builder.body(std::fs::read_to_string(path).map_err(RequestError::Io)?);
142                }
143                RawBody::Text { data } => {
144                    builder = builder.body(data.clone());
145                }
146            },
147            Body::MultiPart { data } => {
148                let mut form = reqwest::multipart::Form::new();
149                for (key, value) in data.iter() {
150                    match value {
151                        MultiPartField::Text { data } => {
152                            form = form.text(key.clone(), data.clone());
153                        }
154                        MultiPartField::File { path } => {
155                            let mut part = reqwest::multipart::Part::stream(
156                                tokio::fs::File::open(path)
157                                    .await
158                                    .map_err(RequestError::Io)?,
159                            );
160                            part = part.file_name(path.clone());
161                            form = form.part(key.clone(), part);
162                        }
163                    }
164                }
165                builder = builder.multipart(form);
166            }
167        }
168
169        Response::from(builder.send().await.map_err(RequestError::Http)?)
170            .await
171            .map_err(RequestError::Parse)
172    }
173}
174
175#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
176#[serde(tag = "type", rename_all = "lowercase")]
177pub enum Body {
178    #[default]
179    None,
180    Form {
181        data: HashMap<String, String>,
182    },
183    Raw {
184        from: RawBody,
185    },
186    MultiPart {
187        data: HashMap<String, MultiPartField>,
188    },
189}
190
191#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
192#[serde(tag = "type", rename_all = "lowercase")]
193pub enum RawBody {
194    File { path: String },
195    Text { data: String },
196}
197
198#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
199#[serde(tag = "type", rename_all = "lowercase")]
200pub enum MultiPartField {
201    File { path: String },
202    Text { data: String },
203}
204
205#[cfg(test)]
206mod tests {
207    use super::*;
208
209    #[test]
210    fn deserialize() {
211        let request = r#"
212tags: [post, form]
213description: post using key/value pairs
214url: https://api.example.com/endpoint1
215method: POST
216headers:
217  Authorization: Bearer your-token
218body:
219  type: form
220  data:
221    key1: value1
222    key2: value2
223"#;
224
225        let request: Request = serde_yaml::from_str(request).unwrap();
226
227        assert_eq!(request.description, "post using key/value pairs");
228        assert_eq!(request.tags, vec!["post", "form"]);
229        assert_eq!(request.url, "https://api.example.com/endpoint1");
230        assert_eq!(request.method, "POST");
231        assert_eq!(request.headers.len(), 1);
232        assert_eq!(
233            request.body,
234            Body::Form {
235                data: vec![
236                    ("key1".to_string(), "value1".to_string()),
237                    ("key2".to_string(), "value2".to_string()),
238                ]
239                .into_iter()
240                .collect()
241            }
242        );
243    }
244
245    #[test]
246    fn apply() {
247        let request = r#"
248tags: [post, form]
249description: post using key/value pairs
250url: "${base_url}/endpoint1"
251method: POST
252headers:
253  Authorization: "Bearer ${token}"
254body:
255  type: form
256  data:
257    key1: "${value1}"
258    key2: value2
259"#;
260
261        let mut request: Request = serde_yaml::from_str(request).unwrap();
262        let mut context = HashMap::new();
263        context.extend(vec![
264            (
265                "base_url".to_string(),
266                "https://api.example.com".to_string(),
267            ),
268            ("token".to_string(), "your-token".to_string()),
269            ("value1".to_string(), "value1".to_string()),
270        ]);
271
272        let app = Applicator::new(context, HashMap::new());
273        request.apply(&app);
274
275        assert_eq!(request.description, "post using key/value pairs");
276        assert_eq!(request.tags, vec!["post", "form"]);
277        assert_eq!(request.url, "https://api.example.com/endpoint1");
278        assert_eq!(request.method, "POST");
279        assert_eq!(request.headers.len(), 1);
280        assert_eq!(
281            request.body,
282            Body::Form {
283                data: vec![
284                    ("key1".to_string(), "value1".to_string()),
285                    ("key2".to_string(), "value2".to_string()),
286                ]
287                .into_iter()
288                .collect()
289            }
290        );
291    }
292}