use std::{
cmp::Ordering,
future::Future,
pin::Pin,
task::{Context, Poll},
};
use axum::{
body::Body,
http::{
HeaderMap, Request,
header::{ACCEPT_LANGUAGE, CONTENT_LANGUAGE, HeaderValue, VARY},
},
response::Response,
};
use tower::{Layer, Service};
type BoxResponseFuture<E> = Pin<Box<dyn Future<Output = Result<Response, E>> + Send>>;
#[derive(Clone)]
pub struct LocaleMiddlewareLayer {
pub default_language: String,
}
impl Default for LocaleMiddlewareLayer {
fn default() -> Self {
Self {
default_language: "en".to_string(),
}
}
}
impl<S> Layer<S> for LocaleMiddlewareLayer {
type Service = LocaleMiddleware<S>;
fn layer(&self, inner: S) -> Self::Service {
LocaleMiddleware {
inner,
default_language: self.default_language.clone(),
}
}
}
#[derive(Clone)]
pub struct LocaleMiddleware<S> {
inner: S,
default_language: String,
}
struct LocaleGuard;
impl Drop for LocaleGuard {
fn drop(&mut self) {
crate::i18n::deactivate();
}
}
fn append_vary(headers: &mut HeaderMap, value: &str) {
let mut vary_values = headers
.get(VARY)
.and_then(|header| header.to_str().ok())
.map(|header| {
header
.split(',')
.map(str::trim)
.filter(|entry| !entry.is_empty())
.map(ToOwned::to_owned)
.collect::<Vec<_>>()
})
.unwrap_or_default();
if vary_values
.iter()
.any(|entry| entry.eq_ignore_ascii_case(value))
{
return;
}
vary_values.push(value.to_string());
headers.insert(
VARY,
HeaderValue::from_str(&vary_values.join(", ")).expect("vary header value should be valid"),
);
}
#[must_use]
fn detect_language_from_path(path: &str) -> Option<String> {
let segment = path.trim_start_matches('/').split('/').next()?;
if segment.len() == 2
&& segment
.chars()
.all(|character| character.is_ascii_alphabetic())
{
Some(segment.to_ascii_lowercase())
} else {
None
}
}
#[must_use]
fn parse_quality(parameter: &str) -> Option<f32> {
parameter.trim().strip_prefix("q=")?.parse::<f32>().ok()
}
#[must_use]
fn detect_language_from_header(header: &str) -> Option<String> {
header
.split(',')
.filter_map(|entry| {
let mut pieces = entry.trim().split(';');
let language = pieces.next()?.trim();
if language.is_empty() {
return None;
}
let quality = pieces.find_map(parse_quality).unwrap_or(1.0_f32);
Some((language.to_ascii_lowercase(), quality))
})
.max_by(|left, right| left.1.partial_cmp(&right.1).unwrap_or(Ordering::Equal))
.map(|(language, _)| language)
}
#[must_use]
fn detect_language(request: &Request<Body>, default_language: &str) -> String {
detect_language_from_path(request.uri().path())
.or_else(|| {
request
.headers()
.get(ACCEPT_LANGUAGE)
.and_then(|value| value.to_str().ok())
.and_then(detect_language_from_header)
})
.unwrap_or_else(|| default_language.to_ascii_lowercase())
}
impl<S> Service<Request<Body>> for LocaleMiddleware<S>
where
S: Service<Request<Body>, Response = Response> + Send + 'static,
S::Future: Send + 'static,
{
type Response = Response;
type Error = S::Error;
type Future = BoxResponseFuture<Self::Error>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, request: Request<Body>) -> Self::Future {
let detected_language = detect_language(&request, &self.default_language);
crate::i18n::activate(&detected_language);
let guard = LocaleGuard;
let future = self.inner.call(request);
Box::pin(async move {
let _guard = guard;
let mut response = future.await?;
response.headers_mut().insert(
CONTENT_LANGUAGE,
HeaderValue::from_str(&detected_language)
.expect("content-language header should be valid"),
);
append_vary(response.headers_mut(), "Accept-Language");
Ok(response)
})
}
}
#[cfg(test)]
mod tests {
use std::convert::Infallible;
use axum::http::{Request, StatusCode, header};
use tower::{ServiceExt, service_fn};
use super::*;
fn build_service() -> impl Service<
Request<Body>,
Response = Response,
Error = Infallible,
Future = impl Future<Output = Result<Response, Infallible>> + Send,
> + Clone {
service_fn(|_request: Request<Body>| async move {
let current_language = crate::i18n::get_language();
Ok::<_, Infallible>(
Response::builder()
.status(StatusCode::OK)
.header("x-active-language", current_language)
.body(Body::from("ok"))
.expect("response should build"),
)
})
}
#[tokio::test]
async fn test_locale_from_accept_language() {
crate::i18n::deactivate();
let service = LocaleMiddlewareLayer {
default_language: "en".to_string(),
}
.layer(build_service());
let response = service
.oneshot(
Request::builder()
.uri("/articles")
.header(ACCEPT_LANGUAGE, "fr;q=0.8,en-US;q=0.9")
.body(Body::empty())
.expect("request should build"),
)
.await
.expect("service should respond");
assert_eq!(
response
.headers()
.get("x-active-language")
.expect("active language header should be present"),
"en-us"
);
}
#[tokio::test]
async fn test_locale_fallback_to_default() {
crate::i18n::deactivate();
let service = LocaleMiddlewareLayer {
default_language: "de".to_string(),
}
.layer(build_service());
let response = service
.oneshot(
Request::builder()
.uri("/articles")
.body(Body::empty())
.expect("request should build"),
)
.await
.expect("service should respond");
assert_eq!(
response
.headers()
.get("x-active-language")
.expect("active language header should be present"),
"de"
);
}
#[tokio::test]
async fn test_locale_sets_content_language() {
crate::i18n::deactivate();
let service = LocaleMiddlewareLayer {
default_language: "en".to_string(),
}
.layer(build_service());
let response = service
.oneshot(
Request::builder()
.uri("/es/about")
.body(Body::empty())
.expect("request should build"),
)
.await
.expect("service should respond");
assert_eq!(
response
.headers()
.get(CONTENT_LANGUAGE)
.expect("content-language header should be present"),
"es"
);
}
#[tokio::test]
async fn test_locale_vary_header_added() {
crate::i18n::deactivate();
let service = LocaleMiddlewareLayer::default().layer(build_service());
let response = service
.oneshot(
Request::builder()
.uri("/articles")
.body(Body::empty())
.expect("request should build"),
)
.await
.expect("service should respond");
assert_eq!(
response
.headers()
.get(VARY)
.expect("vary header should be present"),
"Accept-Language"
);
}
#[tokio::test]
async fn test_locale_url_prefix_takes_priority() {
crate::i18n::deactivate();
let service = LocaleMiddlewareLayer {
default_language: "en".to_string(),
}
.layer(build_service());
let response = service
.oneshot(
Request::builder()
.uri("/de/about")
.header(ACCEPT_LANGUAGE, "fr;q=1.0")
.body(Body::empty())
.expect("request should build"),
)
.await
.expect("service should respond");
assert_eq!(
response
.headers()
.get("x-active-language")
.expect("active language header should be present"),
"de"
);
}
#[tokio::test]
async fn test_locale_deactivates_after_response() {
crate::i18n::deactivate();
let service = LocaleMiddlewareLayer {
default_language: "fr".to_string(),
}
.layer(build_service());
let _ = service
.oneshot(
Request::builder()
.uri("/articles")
.header(header::ACCEPT_LANGUAGE, "fr")
.body(Body::empty())
.expect("request should build"),
)
.await
.expect("service should respond");
assert_eq!(crate::i18n::get_language(), "en-us");
}
}