use indexmap::IndexMap;
use openapiv3::{
Parameter, ParameterData, ParameterSchemaOrContent, ReferenceOr, Schema, SchemaData,
SchemaKind, StringType, Type,
};
fn string_param_data(name: &str, required: bool) -> ParameterData {
ParameterData {
name: name.to_string(),
description: None,
required,
deprecated: None,
format: ParameterSchemaOrContent::Schema(ReferenceOr::Item(Schema {
schema_data: SchemaData::default(),
schema_kind: SchemaKind::Type(Type::String(StringType::default())),
})),
example: None,
examples: IndexMap::new(),
explode: None,
extensions: IndexMap::new(),
}
}
pub fn extract_query_params(url: &str) -> Vec<Parameter> {
let query_str = match url.split_once('?') {
Some((_, q)) => q,
None => return Vec::new(),
};
let query_str = query_str.split('#').next().unwrap_or(query_str);
let mut seen = std::collections::HashSet::new();
let mut params = Vec::new();
for pair in query_str.split('&') {
let key = match pair.split_once('=') {
Some((k, _)) => k,
None => pair,
};
let key = urlencoding_decode(key);
if key.is_empty() || !seen.insert(key.clone()) {
continue;
}
params.push(Parameter::Query {
parameter_data: string_param_data(&key, false),
allow_reserved: false,
style: Default::default(),
allow_empty_value: None,
});
}
params
}
fn urlencoding_decode(input: &str) -> String {
let mut bytes = Vec::with_capacity(input.len());
let mut iter = input.bytes();
while let Some(b) = iter.next() {
if b == b'+' {
bytes.push(b' ');
} else if b == b'%' {
let hi = iter.next().and_then(hex_val);
let lo = iter.next().and_then(hex_val);
if let (Some(h), Some(l)) = (hi, lo) {
bytes.push(h << 4 | l);
} else {
bytes.push(b'%');
}
} else {
bytes.push(b);
}
}
String::from_utf8_lossy(&bytes).into_owned()
}
fn hex_val(b: u8) -> Option<u8> {
match b {
b'0'..=b'9' => Some(b - b'0'),
b'a'..=b'f' => Some(b - b'a' + 10),
b'A'..=b'F' => Some(b - b'A' + 10),
_ => None,
}
}
pub fn extract_path_params(template: &str) -> Vec<Parameter> {
let mut params = Vec::new();
let mut rest = template;
while let Some(start) = rest.find('{') {
if let Some(end) = rest[start..].find('}') {
let name = &rest[start + 1..start + end];
if !name.is_empty() {
params.push(Parameter::Path {
parameter_data: string_param_data(name, true),
style: Default::default(),
});
}
rest = &rest[start + end + 1..];
} else {
break;
}
}
params
}
const DEFAULT_EXCLUDE_HEADERS: &[&str] = &[
"host",
"content-length",
"content-type",
"accept",
"accept-encoding",
"accept-language",
"connection",
"user-agent",
"cookie",
"authorization",
"cache-control",
"pragma",
"te",
"transfer-encoding",
"upgrade",
];
pub fn extract_header_params(headers: &[(String, String)], exclude: &[String]) -> Vec<Parameter> {
let exclude_lower: Vec<String> = exclude.iter().map(|h| h.to_lowercase()).collect();
let mut seen = std::collections::HashSet::new();
let mut params = Vec::new();
for (name, _value) in headers {
let lower = name.to_lowercase();
if DEFAULT_EXCLUDE_HEADERS.contains(&lower.as_str()) {
continue;
}
if exclude_lower.contains(&lower) {
continue;
}
if !seen.insert(lower) {
continue;
}
params.push(Parameter::Header {
parameter_data: string_param_data(name, false),
style: Default::default(),
});
}
params
}
pub fn endpoint_name(method: &str, path: &str) -> String {
let mut parts: Vec<String> = Vec::new();
parts.push(method.to_lowercase());
for segment in path.split('/') {
if segment.is_empty() {
continue;
}
let seg = segment.trim_start_matches('{').trim_end_matches('}');
if seg.is_empty() {
continue;
}
let mut chars = seg.chars();
if let Some(first) = chars.next() {
let capitalized: String = first.to_uppercase().chain(chars).collect();
parts.push(capitalized);
}
}
parts.concat()
}
#[cfg(test)]
#[allow(clippy::indexing_slicing)]
mod tests {
use super::*;
#[test]
fn query_params_basic() {
let params = extract_query_params("https://example.com/api?page=1&limit=10");
assert_eq!(params.len(), 2);
let names: Vec<&str> = params
.iter()
.map(|p| p.parameter_data_ref().name.as_str())
.collect();
assert_eq!(names, vec!["page", "limit"]);
for p in ¶ms {
assert!(!p.parameter_data_ref().required);
}
}
#[test]
fn query_params_empty() {
let params = extract_query_params("https://example.com/api");
assert!(params.is_empty());
}
#[test]
fn query_params_no_value() {
let params = extract_query_params("https://example.com/api?debug");
assert_eq!(params.len(), 1);
assert_eq!(params[0].parameter_data_ref().name, "debug");
}
#[test]
fn query_params_dedup() {
let params = extract_query_params("https://example.com/api?a=1&a=2&b=3");
let names: Vec<&str> = params
.iter()
.map(|p| p.parameter_data_ref().name.as_str())
.collect();
assert_eq!(names, vec!["a", "b"]);
}
#[test]
fn query_params_with_fragment() {
let params = extract_query_params("https://example.com/api?x=1#section");
assert_eq!(params.len(), 1);
assert_eq!(params[0].parameter_data_ref().name, "x");
}
#[test]
fn query_params_encoded() {
let params = extract_query_params("https://example.com/api?user%20name=foo");
assert_eq!(params.len(), 1);
assert_eq!(params[0].parameter_data_ref().name, "user name");
}
#[test]
fn path_params_single() {
let params = extract_path_params("/users/{id}");
assert_eq!(params.len(), 1);
assert_eq!(params[0].parameter_data_ref().name, "id");
assert!(params[0].parameter_data_ref().required);
}
#[test]
fn path_params_multiple() {
let params = extract_path_params("/users/{user_id}/posts/{post_id}");
assert_eq!(params.len(), 2);
let names: Vec<&str> = params
.iter()
.map(|p| p.parameter_data_ref().name.as_str())
.collect();
assert_eq!(names, vec!["user_id", "post_id"]);
for p in ¶ms {
assert!(p.parameter_data_ref().required);
}
}
#[test]
fn path_params_none() {
let params = extract_path_params("/users");
assert!(params.is_empty());
}
#[test]
fn path_params_empty_braces() {
let params = extract_path_params("/users/{}");
assert!(params.is_empty());
}
#[test]
fn header_params_basic() {
let headers = vec![
("X-Request-Id".to_string(), "abc123".to_string()),
("X-Custom".to_string(), "val".to_string()),
];
let params = extract_header_params(&headers, &[]);
assert_eq!(params.len(), 2);
let names: Vec<&str> = params
.iter()
.map(|p| p.parameter_data_ref().name.as_str())
.collect();
assert_eq!(names, vec!["X-Request-Id", "X-Custom"]);
}
#[test]
fn header_params_excludes_default() {
let headers = vec![
("Host".to_string(), "example.com".to_string()),
("Content-Length".to_string(), "42".to_string()),
("X-Custom".to_string(), "val".to_string()),
];
let params = extract_header_params(&headers, &[]);
assert_eq!(params.len(), 1);
assert_eq!(params[0].parameter_data_ref().name, "X-Custom");
}
#[test]
fn header_params_custom_exclude() {
let headers = vec![
("X-Request-Id".to_string(), "abc".to_string()),
("X-Internal".to_string(), "secret".to_string()),
];
let exclude = vec!["X-Internal".to_string()];
let params = extract_header_params(&headers, &exclude);
assert_eq!(params.len(), 1);
assert_eq!(params[0].parameter_data_ref().name, "X-Request-Id");
}
#[test]
fn header_params_case_insensitive_exclude() {
let headers = vec![("host".to_string(), "example.com".to_string())];
let params = extract_header_params(&headers, &[]);
assert!(params.is_empty());
}
#[test]
fn header_params_dedup() {
let headers = vec![
("X-Dup".to_string(), "val1".to_string()),
("x-dup".to_string(), "val2".to_string()),
];
let params = extract_header_params(&headers, &[]);
assert_eq!(params.len(), 1);
}
#[test]
fn endpoint_name_basic() {
assert_eq!(
endpoint_name("GET", "/api/v1/users/{id}"),
"getApiV1UsersId"
);
}
#[test]
fn endpoint_name_post() {
assert_eq!(endpoint_name("POST", "/api/users"), "postApiUsers");
}
#[test]
fn endpoint_name_root() {
assert_eq!(endpoint_name("GET", "/"), "get");
}
#[test]
fn endpoint_name_nested_params() {
assert_eq!(
endpoint_name("DELETE", "/orgs/{org}/repos/{repo}"),
"deleteOrgsOrgReposRepo"
);
}
#[test]
fn urlencoding_utf8_roundtrip() {
assert_eq!(urlencoding_decode("%C3%A9"), "é");
assert_eq!(urlencoding_decode("%E4%B8%AD"), "中");
assert_eq!(urlencoding_decode("%F0%9F%A6%80"), "🦀");
}
#[test]
fn urlencoding_rejects_overlong() {
let decoded = urlencoding_decode("%C0%80");
assert_ne!(decoded, "\0");
assert!(decoded.is_char_boundary(0));
}
#[test]
fn urlencoding_preserves_ascii() {
assert_eq!(urlencoding_decode("hello+world%21"), "hello world!");
}
#[test]
fn urlencoding_malformed_percent() {
let decoded = urlencoding_decode("%ZZ");
assert_eq!(decoded, "%");
let decoded2 = urlencoding_decode("%C");
assert_eq!(decoded2, "%");
assert_eq!(urlencoding_decode("100%"), "100%");
}
}