use reqwest::blocking::Response;
use std::path::Path;
pub trait ResponseExt {
fn resolve_filename(&self, fallback_name: &str) -> String;
}
impl ResponseExt for Response {
fn resolve_filename(&self, fallback_name: &str) -> String {
if let Some(filename) = parse_content_disposition(self) {
let sanitized = sanitize_filename(&filename);
if !sanitized.is_empty() {
return sanitized;
}
}
if let Some(filename) = parse_url_filename(self) {
let sanitized = sanitize_filename(&filename);
if !sanitized.is_empty() {
return sanitized;
}
}
let ext = guess_extension(self).unwrap_or_default();
return format!("{fallback_name}{ext}");
}
}
fn parse_content_disposition(response: &Response) -> Option<String> {
let header_value = response
.headers()
.get("Content-Disposition")?
.to_str()
.ok()?;
if let Some(filename) = extract_filename_star(header_value) {
return Some(filename);
}
extract_filename(header_value)
}
fn extract_filename_star(header: &str) -> Option<String> {
let lower = header.to_lowercase();
let key = "filename*=";
let pos = lower.find(key)?;
let value = header[pos + key.len()..].trim();
let value = value.split(';').next()?.trim();
let parts: Vec<&str> = value.splitn(3, '\'').collect();
if parts.len() != 3 {
return None;
}
let encoded = parts[2];
percent_decode(encoded).ok()
}
fn extract_filename(header: &str) -> Option<String> {
let lower = header.to_lowercase();
let key = "filename=";
let pos = lower.find(key)?;
let value = header[pos + key.len()..].trim();
let value = if value.starts_with('"') {
value
.trim_start_matches('"')
.split('"')
.next()
.unwrap_or("")
} else {
value.split(';').next().unwrap_or("").trim()
};
if value.is_empty() {
None
} else {
Some(value.to_string())
}
}
fn parse_url_filename(response: &Response) -> Option<String> {
let url = response.url();
let path = url.path();
let last_segment = path.split('/').filter(|s| !s.is_empty()).last()?;
let decoded = percent_decode(last_segment).ok()?;
if decoded.contains('.') {
Some(decoded)
} else {
None
}
}
fn guess_extension(response: &Response) -> Option<String> {
let content_type = response
.headers()
.get("Content-Type")?
.to_str()
.ok()?
.split(';')
.next()?
.trim();
let ext = match content_type {
"application/pdf" => ".pdf",
"application/zip" => ".zip",
"application/json" => ".json",
"application/octet-stream" => ".bin",
"text/plain" => ".txt",
"text/html" => ".html",
"text/csv" => ".csv",
"image/jpeg" => ".jpg",
"image/png" => ".png",
"image/gif" => ".gif",
"image/webp" => ".webp",
"video/mp4" => ".mp4",
"audio/mpeg" => ".mp3",
_ => return None,
};
Some(ext.to_string())
}
fn sanitize_filename(name: &str) -> String {
let name = Path::new(name)
.file_name()
.and_then(|n| n.to_str())
.unwrap_or(name);
let sanitized: String = name
.chars()
.map(|c| match c {
'/' | '\\' | ':' | '*' | '?' | '"' | '<' | '>' | '|' => '_',
c if c.is_control() => '_',
c => c,
})
.collect();
sanitized.trim().trim_matches('.').to_string()
}
fn percent_decode(input: &str) -> Result<String, std::num::ParseIntError> {
let mut output = String::new();
let mut bytes = input.bytes();
while let Some(b) = bytes.next() {
if b == b'%' {
let hi = bytes.next().unwrap_or(0) as char;
let lo = bytes.next().unwrap_or(0) as char;
let hex = format!("{hi}{lo}");
let byte = u8::from_str_radix(&hex, 16)?;
output.push(byte as char);
} else if b == b'+' {
output.push(' ');
} else {
output.push(b as char);
}
}
Ok(output)
}