use axum::http::header::HeaderMap;
fn sanitize_resource_name(name: &str) -> String {
name.chars()
.filter(|c| c.is_ascii() && !c.is_ascii_control())
.collect()
}
#[must_use]
pub fn calculate_content_range(
offset: u64,
limit: u64,
total_count: u64,
resource_name: &str,
) -> HeaderMap {
let max_offset_limit = if total_count == 0 || offset >= total_count {
offset } else {
offset
.saturating_add(limit)
.saturating_sub(1)
.min(total_count.saturating_sub(1))
};
let safe_name = sanitize_resource_name(resource_name);
let content_range = format!("{safe_name} {offset}-{max_offset_limit}/{total_count}");
let mut headers = HeaderMap::new();
if let Ok(value) = content_range.parse() {
headers.insert("Content-Range", value);
} else {
headers.insert(
"Content-Range",
format!("items {offset}-{max_offset_limit}/{total_count}")
.parse()
.unwrap_or_else(|_| "items 0-0/0".parse().unwrap()),
);
}
headers
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_content_range_normal() {
let headers = calculate_content_range(0, 10, 100, "users");
let value = headers.get("Content-Range").unwrap().to_str().unwrap();
assert_eq!(value, "users 0-9/100");
}
#[test]
fn test_content_range_handles_special_chars_gracefully() {
let headers = calculate_content_range(0, 10, 100, "users\r\nInjected: evil");
let value = headers.get("Content-Range");
assert!(
value.is_some(),
"Should return a valid header even with bad input"
);
if let Some(val) = value {
let val_str = val.to_str().unwrap_or("");
assert!(!val_str.contains('\r'), "Should remove carriage returns");
assert!(!val_str.contains('\n'), "Should remove newlines");
}
}
#[test]
fn test_content_range_unicode() {
let headers = calculate_content_range(0, 10, 100, "用户");
let value = headers.get("Content-Range");
assert!(
value.is_some(),
"Should produce a valid header even with non-ASCII input"
);
}
#[test]
fn test_content_range_offset_exceeds_total() {
let headers = calculate_content_range(100, 10, 50, "items");
let value = headers.get("Content-Range").unwrap().to_str().unwrap();
let parts: Vec<&str> = value.split(' ').collect();
let range_part = parts[1].split('/').next().unwrap();
let range_nums: Vec<u64> = range_part.split('-').map(|s| s.parse().unwrap()).collect();
assert!(
range_nums[0] <= range_nums[1] || range_nums[0] == range_nums[1],
"Range start ({}) should not exceed end ({}) in: {}",
range_nums[0],
range_nums[1],
value
);
}
#[test]
fn test_content_range_zero_items() {
let headers = calculate_content_range(0, 10, 0, "users");
let value = headers.get("Content-Range").unwrap().to_str().unwrap();
assert!(value.contains("users"));
}
#[test]
fn test_content_range_large_numbers() {
let headers = calculate_content_range(u64::MAX - 100, 10, u64::MAX, "users");
let value = headers.get("Content-Range").unwrap().to_str().unwrap();
assert!(value.contains("users"));
}
}