use super::loader::LocaleStore;
use std::{
future::Future,
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
use tower::{Layer, Service};
pub struct I18nContext {
pub locale: String,
pub default_locale: String,
pub store: LocaleStore,
}
tokio::task_local! {
pub static CURRENT_I18N: Arc<I18nContext>;
}
#[derive(Clone)]
pub struct I18nLayer {
store: LocaleStore,
default_locale: String,
}
impl I18nLayer {
pub fn new(store: LocaleStore, default_locale: impl Into<String>) -> Self {
Self {
store,
default_locale: default_locale.into(),
}
}
pub fn with_dir(dir: impl AsRef<std::path::Path>, default_locale: impl Into<String>) -> Self {
let default = default_locale.into();
let store = super::loader::load_dir(dir, &default);
Self {
store,
default_locale: default,
}
}
}
impl<S: Clone> Layer<S> for I18nLayer {
type Service = I18nService<S>;
fn layer(&self, inner: S) -> Self::Service {
I18nService {
inner,
store: Arc::clone(&self.store),
default_locale: self.default_locale.clone(),
}
}
}
#[derive(Clone)]
pub struct I18nService<S> {
inner: S,
store: LocaleStore,
default_locale: String,
}
impl<S, ReqBody> Service<http::Request<ReqBody>> for I18nService<S>
where
S: Service<http::Request<ReqBody>> + Clone + Send + 'static,
S::Future: Send + 'static,
ReqBody: Send + 'static,
{
type Response = S::Response;
type Error = S::Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: http::Request<ReqBody>) -> Self::Future {
let locale = detect_locale(
req.uri().query(),
req.headers()
.get("accept-language")
.and_then(|v| v.to_str().ok()),
&self.default_locale,
);
let ctx = Arc::new(I18nContext {
locale,
default_locale: self.default_locale.clone(),
store: Arc::clone(&self.store),
});
let fut = self.inner.call(req);
Box::pin(CURRENT_I18N.scope(ctx, fut))
}
}
fn detect_locale(query: Option<&str>, accept_lang: Option<&str>, default: &str) -> String {
if let Some(q) = query {
for pair in q.split('&') {
let mut kv = pair.splitn(2, '=');
if kv.next() == Some("lang") {
if let Some(val) = kv.next() {
let lang = val.split('-').next().unwrap_or(val).to_lowercase();
if !lang.is_empty() {
return lang;
}
}
}
}
}
if let Some(header) = accept_lang {
if let Some(first) = header.split(',').next() {
let tag = first.split(';').next().unwrap_or("").trim();
let primary = tag.split('-').next().unwrap_or("").to_lowercase();
if !primary.is_empty() {
return primary;
}
}
}
std::env::var("APP_LOCALE").unwrap_or_else(|_| default.to_string())
}