Skip to main content

mitm2openapi/
params.rs

1use indexmap::IndexMap;
2use openapiv3::{
3    Parameter, ParameterData, ParameterSchemaOrContent, ReferenceOr, Schema, SchemaData,
4    SchemaKind, StringType, Type,
5};
6
7/// Create a `ParameterData` with a string schema and the given name/required flag.
8fn string_param_data(name: &str, required: bool) -> ParameterData {
9    ParameterData {
10        name: name.to_string(),
11        description: None,
12        required,
13        deprecated: None,
14        format: ParameterSchemaOrContent::Schema(ReferenceOr::Item(Schema {
15            schema_data: SchemaData::default(),
16            schema_kind: SchemaKind::Type(Type::String(StringType::default())),
17        })),
18        example: None,
19        examples: IndexMap::new(),
20        explode: None,
21        extensions: IndexMap::new(),
22    }
23}
24
25/// Extract query parameters from a URL string, returning `openapiv3::Parameter` objects.
26///
27/// Parses the query string (`?key=value&key2=value2`) and creates query parameters.
28/// Each parameter is optional (`required: false`) with a string schema.
29pub fn extract_query_params(url: &str) -> Vec<Parameter> {
30    let query_str = match url.split_once('?') {
31        Some((_, q)) => q,
32        None => return Vec::new(),
33    };
34
35    let query_str = query_str.split('#').next().unwrap_or(query_str);
36
37    let mut seen = std::collections::HashSet::new();
38    let mut params = Vec::new();
39
40    for pair in query_str.split('&') {
41        let key = match pair.split_once('=') {
42            Some((k, _)) => k,
43            None => pair,
44        };
45        let key = urlencoding_decode(key);
46        if key.is_empty() || !seen.insert(key.clone()) {
47            continue;
48        }
49        params.push(Parameter::Query {
50            parameter_data: string_param_data(&key, false),
51            allow_reserved: false,
52            style: Default::default(),
53            allow_empty_value: None,
54        });
55    }
56
57    params
58}
59
60/// Decode percent-encoded strings (minimal implementation, no extra deps).
61fn urlencoding_decode(input: &str) -> String {
62    let mut bytes = Vec::with_capacity(input.len());
63    let mut iter = input.bytes();
64    while let Some(b) = iter.next() {
65        if b == b'+' {
66            bytes.push(b' ');
67        } else if b == b'%' {
68            let hi = iter.next().and_then(hex_val);
69            let lo = iter.next().and_then(hex_val);
70            if let (Some(h), Some(l)) = (hi, lo) {
71                bytes.push(h << 4 | l);
72            } else {
73                bytes.push(b'%');
74            }
75        } else {
76            bytes.push(b);
77        }
78    }
79    String::from_utf8_lossy(&bytes).into_owned()
80}
81
82fn hex_val(b: u8) -> Option<u8> {
83    match b {
84        b'0'..=b'9' => Some(b - b'0'),
85        b'a'..=b'f' => Some(b - b'a' + 10),
86        b'A'..=b'F' => Some(b - b'A' + 10),
87        _ => None,
88    }
89}
90
91/// Extract path parameters from a template string like `/users/{id}/posts/{post_id}`.
92///
93/// Returns `Parameter` objects with `in: path, required: true`.
94pub fn extract_path_params(template: &str) -> Vec<Parameter> {
95    let mut params = Vec::new();
96    let mut rest = template;
97
98    while let Some(start) = rest.find('{') {
99        if let Some(end) = rest[start..].find('}') {
100            let name = &rest[start + 1..start + end];
101            if !name.is_empty() {
102                params.push(Parameter::Path {
103                    parameter_data: string_param_data(name, true),
104                    style: Default::default(),
105                });
106            }
107            rest = &rest[start + end + 1..];
108        } else {
109            break;
110        }
111    }
112
113    params
114}
115
116/// Headers to exclude by default (case-insensitive).
117const DEFAULT_EXCLUDE_HEADERS: &[&str] = &[
118    "host",
119    "content-length",
120    "content-type",
121    "accept",
122    "accept-encoding",
123    "accept-language",
124    "connection",
125    "user-agent",
126    "cookie",
127    "authorization",
128    "cache-control",
129    "pragma",
130    "te",
131    "transfer-encoding",
132    "upgrade",
133];
134
135/// Extract headers from request, optionally filtering by exclude list.
136///
137/// Returns `Parameter` objects with `in: header`.
138/// By default, common non-informative headers (Host, Content-Length, etc.) are excluded.
139/// The `exclude` list provides *additional* headers to exclude (case-insensitive).
140pub fn extract_header_params(headers: &[(String, String)], exclude: &[String]) -> Vec<Parameter> {
141    let exclude_lower: Vec<String> = exclude.iter().map(|h| h.to_lowercase()).collect();
142    let mut seen = std::collections::HashSet::new();
143    let mut params = Vec::new();
144
145    for (name, _value) in headers {
146        let lower = name.to_lowercase();
147        if DEFAULT_EXCLUDE_HEADERS.contains(&lower.as_str()) {
148            continue;
149        }
150        if exclude_lower.contains(&lower) {
151            continue;
152        }
153        if !seen.insert(lower) {
154            continue;
155        }
156        params.push(Parameter::Header {
157            parameter_data: string_param_data(name, false),
158            style: Default::default(),
159        });
160    }
161
162    params
163}
164
165/// Generate an endpoint name from method + path template.
166///
167/// E.g., `"GET"`, `"/api/v1/users/{id}"` → `"getApiV1UsersId"`.
168/// Strips parameter braces and converts to camelCase.
169pub fn endpoint_name(method: &str, path: &str) -> String {
170    let mut parts: Vec<String> = Vec::new();
171    parts.push(method.to_lowercase());
172
173    for segment in path.split('/') {
174        if segment.is_empty() {
175            continue;
176        }
177        let seg = segment.trim_start_matches('{').trim_end_matches('}');
178        if seg.is_empty() {
179            continue;
180        }
181        let mut chars = seg.chars();
182        if let Some(first) = chars.next() {
183            let capitalized: String = first.to_uppercase().chain(chars).collect();
184            parts.push(capitalized);
185        }
186    }
187
188    parts.concat()
189}
190
191#[cfg(test)]
192#[allow(clippy::indexing_slicing)]
193mod tests {
194    use super::*;
195
196    #[test]
197    fn query_params_basic() {
198        let params = extract_query_params("https://example.com/api?page=1&limit=10");
199        assert_eq!(params.len(), 2);
200
201        let names: Vec<&str> = params
202            .iter()
203            .map(|p| p.parameter_data_ref().name.as_str())
204            .collect();
205        assert_eq!(names, vec!["page", "limit"]);
206
207        for p in &params {
208            assert!(!p.parameter_data_ref().required);
209        }
210    }
211
212    #[test]
213    fn query_params_empty() {
214        let params = extract_query_params("https://example.com/api");
215        assert!(params.is_empty());
216    }
217
218    #[test]
219    fn query_params_no_value() {
220        let params = extract_query_params("https://example.com/api?debug");
221        assert_eq!(params.len(), 1);
222        assert_eq!(params[0].parameter_data_ref().name, "debug");
223    }
224
225    #[test]
226    fn query_params_dedup() {
227        let params = extract_query_params("https://example.com/api?a=1&a=2&b=3");
228        let names: Vec<&str> = params
229            .iter()
230            .map(|p| p.parameter_data_ref().name.as_str())
231            .collect();
232        assert_eq!(names, vec!["a", "b"]);
233    }
234
235    #[test]
236    fn query_params_with_fragment() {
237        let params = extract_query_params("https://example.com/api?x=1#section");
238        assert_eq!(params.len(), 1);
239        assert_eq!(params[0].parameter_data_ref().name, "x");
240    }
241
242    #[test]
243    fn query_params_encoded() {
244        let params = extract_query_params("https://example.com/api?user%20name=foo");
245        assert_eq!(params.len(), 1);
246        assert_eq!(params[0].parameter_data_ref().name, "user name");
247    }
248
249    #[test]
250    fn path_params_single() {
251        let params = extract_path_params("/users/{id}");
252        assert_eq!(params.len(), 1);
253        assert_eq!(params[0].parameter_data_ref().name, "id");
254        assert!(params[0].parameter_data_ref().required);
255    }
256
257    #[test]
258    fn path_params_multiple() {
259        let params = extract_path_params("/users/{user_id}/posts/{post_id}");
260        assert_eq!(params.len(), 2);
261        let names: Vec<&str> = params
262            .iter()
263            .map(|p| p.parameter_data_ref().name.as_str())
264            .collect();
265        assert_eq!(names, vec!["user_id", "post_id"]);
266        for p in &params {
267            assert!(p.parameter_data_ref().required);
268        }
269    }
270
271    #[test]
272    fn path_params_none() {
273        let params = extract_path_params("/users");
274        assert!(params.is_empty());
275    }
276
277    #[test]
278    fn path_params_empty_braces() {
279        let params = extract_path_params("/users/{}");
280        assert!(params.is_empty());
281    }
282
283    #[test]
284    fn header_params_basic() {
285        let headers = vec![
286            ("X-Request-Id".to_string(), "abc123".to_string()),
287            ("X-Custom".to_string(), "val".to_string()),
288        ];
289        let params = extract_header_params(&headers, &[]);
290        assert_eq!(params.len(), 2);
291        let names: Vec<&str> = params
292            .iter()
293            .map(|p| p.parameter_data_ref().name.as_str())
294            .collect();
295        assert_eq!(names, vec!["X-Request-Id", "X-Custom"]);
296    }
297
298    #[test]
299    fn header_params_excludes_default() {
300        let headers = vec![
301            ("Host".to_string(), "example.com".to_string()),
302            ("Content-Length".to_string(), "42".to_string()),
303            ("X-Custom".to_string(), "val".to_string()),
304        ];
305        let params = extract_header_params(&headers, &[]);
306        assert_eq!(params.len(), 1);
307        assert_eq!(params[0].parameter_data_ref().name, "X-Custom");
308    }
309
310    #[test]
311    fn header_params_custom_exclude() {
312        let headers = vec![
313            ("X-Request-Id".to_string(), "abc".to_string()),
314            ("X-Internal".to_string(), "secret".to_string()),
315        ];
316        let exclude = vec!["X-Internal".to_string()];
317        let params = extract_header_params(&headers, &exclude);
318        assert_eq!(params.len(), 1);
319        assert_eq!(params[0].parameter_data_ref().name, "X-Request-Id");
320    }
321
322    #[test]
323    fn header_params_case_insensitive_exclude() {
324        let headers = vec![("host".to_string(), "example.com".to_string())];
325        let params = extract_header_params(&headers, &[]);
326        assert!(params.is_empty());
327    }
328
329    #[test]
330    fn header_params_dedup() {
331        let headers = vec![
332            ("X-Dup".to_string(), "val1".to_string()),
333            ("x-dup".to_string(), "val2".to_string()),
334        ];
335        let params = extract_header_params(&headers, &[]);
336        assert_eq!(params.len(), 1);
337    }
338
339    #[test]
340    fn endpoint_name_basic() {
341        assert_eq!(
342            endpoint_name("GET", "/api/v1/users/{id}"),
343            "getApiV1UsersId"
344        );
345    }
346
347    #[test]
348    fn endpoint_name_post() {
349        assert_eq!(endpoint_name("POST", "/api/users"), "postApiUsers");
350    }
351
352    #[test]
353    fn endpoint_name_root() {
354        assert_eq!(endpoint_name("GET", "/"), "get");
355    }
356
357    #[test]
358    fn endpoint_name_nested_params() {
359        assert_eq!(
360            endpoint_name("DELETE", "/orgs/{org}/repos/{repo}"),
361            "deleteOrgsOrgReposRepo"
362        );
363    }
364
365    #[test]
366    fn urlencoding_utf8_roundtrip() {
367        assert_eq!(urlencoding_decode("%C3%A9"), "é");
368        assert_eq!(urlencoding_decode("%E4%B8%AD"), "中");
369        assert_eq!(urlencoding_decode("%F0%9F%A6%80"), "🦀");
370    }
371
372    #[test]
373    fn urlencoding_rejects_overlong() {
374        let decoded = urlencoding_decode("%C0%80");
375        assert_ne!(decoded, "\0");
376        assert!(decoded.is_char_boundary(0));
377    }
378
379    #[test]
380    fn urlencoding_preserves_ascii() {
381        assert_eq!(urlencoding_decode("hello+world%21"), "hello world!");
382    }
383
384    #[test]
385    fn urlencoding_malformed_percent() {
386        let decoded = urlencoding_decode("%ZZ");
387        assert_eq!(decoded, "%");
388        let decoded2 = urlencoding_decode("%C");
389        assert_eq!(decoded2, "%");
390        assert_eq!(urlencoding_decode("100%"), "100%");
391    }
392}