use async_trait::async_trait;
use hyper::header::{ACCEPT_LANGUAGE, COOKIE};
use reinhardt_http::{Handler, Middleware, Request, Response, Result};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
pub const LOCALE_HEADER: &str = "X-Locale";
pub const LOCALE_COOKIE_NAME: &str = "django_language";
#[non_exhaustive]
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LocaleConfig {
pub default_locale: String,
pub supported_locales: Vec<String>,
pub check_url_path: bool,
pub cookie_name: String,
}
impl LocaleConfig {
pub fn new() -> Self {
Self {
default_locale: "en".to_string(),
supported_locales: vec!["en".to_string()],
check_url_path: false,
cookie_name: LOCALE_COOKIE_NAME.to_string(),
}
}
pub fn with_locales(default: String, supported: Vec<String>) -> Self {
Self {
default_locale: default,
supported_locales: supported,
check_url_path: false,
cookie_name: LOCALE_COOKIE_NAME.to_string(),
}
}
}
impl Default for LocaleConfig {
fn default() -> Self {
Self::new()
}
}
pub struct LocaleMiddleware {
config: LocaleConfig,
}
impl LocaleMiddleware {
pub fn new() -> Self {
Self {
config: LocaleConfig::default(),
}
}
pub fn with_config(config: LocaleConfig) -> Self {
Self { config }
}
fn locale_from_path(&self, path: &str) -> Option<String> {
if !self.config.check_url_path {
return None;
}
let parts: Vec<&str> = path.trim_start_matches('/').split('/').collect();
if parts.is_empty() {
return None;
}
let potential_locale = parts[0];
if self
.config
.supported_locales
.contains(&potential_locale.to_string())
{
return Some(potential_locale.to_string());
}
None
}
fn locale_from_cookie(&self, request: &Request) -> Option<String> {
let cookie_header = request.headers.get(COOKIE)?.to_str().ok()?;
for cookie in cookie_header.split(';') {
let cookie = cookie.trim();
if let Some((name, value)) = cookie.split_once('=')
&& name == self.config.cookie_name
{
let locale = value.to_string();
if self.config.supported_locales.contains(&locale) {
return Some(locale);
}
}
}
None
}
fn locale_from_accept_language(&self, request: &Request) -> Option<String> {
let accept_lang = request.headers.get(ACCEPT_LANGUAGE)?.to_str().ok()?;
let mut languages: Vec<(String, f32)> = Vec::new();
for lang_spec in accept_lang.split(',') {
let lang_spec = lang_spec.trim();
let (lang, quality) = if let Some((l, q)) = lang_spec.split_once(";q=") {
(l.trim(), q.parse::<f32>().unwrap_or(1.0))
} else {
(lang_spec, 1.0)
};
let base_lang = lang.split('-').next().unwrap_or(lang).to_string();
languages.push((base_lang, quality));
}
languages.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
languages
.into_iter()
.map(|(lang, _)| lang)
.find(|lang| self.config.supported_locales.contains(lang))
}
fn detect_locale(&self, request: &Request) -> String {
self.locale_from_path(request.uri.path())
.or_else(|| self.locale_from_cookie(request))
.or_else(|| self.locale_from_accept_language(request))
.unwrap_or_else(|| self.config.default_locale.clone())
}
}
impl Default for LocaleMiddleware {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl Middleware for LocaleMiddleware {
async fn process(&self, mut request: Request, handler: Arc<dyn Handler>) -> Result<Response> {
let locale = self.detect_locale(&request);
request.headers.insert(
LOCALE_HEADER,
locale
.parse()
.unwrap_or_else(|_| hyper::header::HeaderValue::from_static("en")),
);
handler.handle(request).await
}
}
#[cfg(test)]
mod tests {
use super::*;
use bytes::Bytes;
use hyper::{HeaderMap, Method, StatusCode, Version};
struct TestHandler;
#[async_trait]
impl Handler for TestHandler {
async fn handle(&self, request: Request) -> Result<Response> {
let locale = request
.headers
.get(LOCALE_HEADER)
.and_then(|h| h.to_str().ok())
.unwrap_or("unknown")
.to_string();
Ok(Response::new(StatusCode::OK).with_body(Bytes::from(locale)))
}
}
#[tokio::test]
async fn test_default_locale() {
let config = LocaleConfig::new();
let middleware = LocaleMiddleware::with_config(config);
let handler = Arc::new(TestHandler);
let request = Request::builder()
.method(Method::GET)
.uri("/page")
.version(Version::HTTP_11)
.headers(HeaderMap::new())
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler).await.unwrap();
let body = String::from_utf8(response.body.to_vec()).unwrap();
assert_eq!(body, "en");
}
#[tokio::test]
async fn test_accept_language_detection() {
let config = LocaleConfig::with_locales(
"en".to_string(),
vec!["en".to_string(), "ja".to_string(), "fr".to_string()],
);
let middleware = LocaleMiddleware::with_config(config);
let handler = Arc::new(TestHandler);
let mut headers = HeaderMap::new();
headers.insert(ACCEPT_LANGUAGE, "ja,en;q=0.9".parse().unwrap());
let request = Request::builder()
.method(Method::GET)
.uri("/page")
.version(Version::HTTP_11)
.headers(headers)
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler).await.unwrap();
let body = String::from_utf8(response.body.to_vec()).unwrap();
assert_eq!(body, "ja");
}
#[tokio::test]
async fn test_accept_language_with_quality() {
let config = LocaleConfig::with_locales(
"en".to_string(),
vec!["en".to_string(), "ja".to_string(), "fr".to_string()],
);
let middleware = LocaleMiddleware::with_config(config);
let handler = Arc::new(TestHandler);
let mut headers = HeaderMap::new();
headers.insert(
ACCEPT_LANGUAGE,
"fr;q=0.7,ja;q=0.9,en;q=0.8".parse().unwrap(),
);
let request = Request::builder()
.method(Method::GET)
.uri("/page")
.version(Version::HTTP_11)
.headers(headers)
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler).await.unwrap();
let body = String::from_utf8(response.body.to_vec()).unwrap();
assert_eq!(body, "ja"); }
#[tokio::test]
async fn test_cookie_detection() {
let config = LocaleConfig::with_locales(
"en".to_string(),
vec!["en".to_string(), "ja".to_string(), "fr".to_string()],
);
let middleware = LocaleMiddleware::with_config(config);
let handler = Arc::new(TestHandler);
let mut headers = HeaderMap::new();
headers.insert(COOKIE, "django_language=fr; other=value".parse().unwrap());
let request = Request::builder()
.method(Method::GET)
.uri("/page")
.version(Version::HTTP_11)
.headers(headers)
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler).await.unwrap();
let body = String::from_utf8(response.body.to_vec()).unwrap();
assert_eq!(body, "fr");
}
#[tokio::test]
async fn test_cookie_overrides_accept_language() {
let config = LocaleConfig::with_locales(
"en".to_string(),
vec!["en".to_string(), "ja".to_string(), "fr".to_string()],
);
let middleware = LocaleMiddleware::with_config(config);
let handler = Arc::new(TestHandler);
let mut headers = HeaderMap::new();
headers.insert(ACCEPT_LANGUAGE, "ja".parse().unwrap());
headers.insert(COOKIE, "django_language=fr".parse().unwrap());
let request = Request::builder()
.method(Method::GET)
.uri("/page")
.version(Version::HTTP_11)
.headers(headers)
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler).await.unwrap();
let body = String::from_utf8(response.body.to_vec()).unwrap();
assert_eq!(body, "fr"); }
#[tokio::test]
async fn test_url_path_detection() {
let mut config = LocaleConfig::with_locales(
"en".to_string(),
vec!["en".to_string(), "ja".to_string(), "fr".to_string()],
);
config.check_url_path = true;
let middleware = LocaleMiddleware::with_config(config);
let handler = Arc::new(TestHandler);
let request = Request::builder()
.method(Method::GET)
.uri("/ja/page/subpage")
.version(Version::HTTP_11)
.headers(HeaderMap::new())
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler).await.unwrap();
let body = String::from_utf8(response.body.to_vec()).unwrap();
assert_eq!(body, "ja");
}
#[tokio::test]
async fn test_url_path_overrides_all() {
let mut config = LocaleConfig::with_locales(
"en".to_string(),
vec!["en".to_string(), "ja".to_string(), "fr".to_string()],
);
config.check_url_path = true;
let middleware = LocaleMiddleware::with_config(config);
let handler = Arc::new(TestHandler);
let mut headers = HeaderMap::new();
headers.insert(ACCEPT_LANGUAGE, "ja".parse().unwrap());
headers.insert(COOKIE, "django_language=fr".parse().unwrap());
let request = Request::builder()
.method(Method::GET)
.uri("/en/page")
.version(Version::HTTP_11)
.headers(headers)
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler).await.unwrap();
let body = String::from_utf8(response.body.to_vec()).unwrap();
assert_eq!(body, "en"); }
#[tokio::test]
async fn test_unsupported_locale_fallback() {
let config =
LocaleConfig::with_locales("en".to_string(), vec!["en".to_string(), "ja".to_string()]);
let middleware = LocaleMiddleware::with_config(config);
let handler = Arc::new(TestHandler);
let mut headers = HeaderMap::new();
headers.insert(ACCEPT_LANGUAGE, "de,fr;q=0.9".parse().unwrap());
let request = Request::builder()
.method(Method::GET)
.uri("/page")
.version(Version::HTTP_11)
.headers(headers)
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler).await.unwrap();
let body = String::from_utf8(response.body.to_vec()).unwrap();
assert_eq!(body, "en"); }
#[tokio::test]
async fn test_accept_language_with_region() {
let config =
LocaleConfig::with_locales("en".to_string(), vec!["en".to_string(), "ja".to_string()]);
let middleware = LocaleMiddleware::with_config(config);
let handler = Arc::new(TestHandler);
let mut headers = HeaderMap::new();
headers.insert(ACCEPT_LANGUAGE, "ja-JP,en-US;q=0.9".parse().unwrap());
let request = Request::builder()
.method(Method::GET)
.uri("/page")
.version(Version::HTTP_11)
.headers(headers)
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler).await.unwrap();
let body = String::from_utf8(response.body.to_vec()).unwrap();
assert_eq!(body, "ja"); }
#[tokio::test]
async fn test_invalid_cookie_value() {
let config =
LocaleConfig::with_locales("en".to_string(), vec!["en".to_string(), "ja".to_string()]);
let middleware = LocaleMiddleware::with_config(config);
let handler = Arc::new(TestHandler);
let mut headers = HeaderMap::new();
headers.insert(COOKIE, "django_language=invalid".parse().unwrap());
headers.insert(ACCEPT_LANGUAGE, "ja".parse().unwrap());
let request = Request::builder()
.method(Method::GET)
.uri("/page")
.version(Version::HTTP_11)
.headers(headers)
.body(Bytes::new())
.build()
.unwrap();
let response = middleware.process(request, handler).await.unwrap();
let body = String::from_utf8(response.body.to_vec()).unwrap();
assert_eq!(body, "ja"); }
}