use crate::{types, Error};
#[derive(Clone, Debug)]
pub struct Context {
pub request_id: String,
pub correlation_id: String,
pub credentials_token: types::CredentialsToken,
}
pub struct ContextResult<T> {
pub result: crate::Result<T>,
pub context: Context,
}
pub trait IntoContextResult<T> {
fn with_context(self, context: Context) -> ContextResult<T>;
}
impl<T, E> IntoContextResult<T> for Result<T, E>
where
Error: From<E>,
{
fn with_context(self, context: Context) -> ContextResult<T> {
ContextResult {
result: self.map_err(Error::from),
context,
}
}
}
#[derive(Clone, Copy)]
pub struct ExtendedContext<'a> {
pub request_id: &'a str,
pub correlation_id: &'a str,
pub credentials_token: &'a types::CredentialsToken,
}
impl Context {
pub fn extend<'a>(&'a self, token: &'a types::CredentialsToken) -> ExtendedContext<'a> {
ExtendedContext {
request_id: &self.request_id,
correlation_id: &self.correlation_id,
credentials_token: token,
}
}
pub fn as_extended(&self) -> ExtendedContext<'_> {
ExtendedContext {
request_id: self.request_id.as_str(),
correlation_id: self.correlation_id.as_str(),
credentials_token: &self.credentials_token,
}
}
}
#[cfg(feature = "warp")]
pub mod warp_extensions {
use super::Context;
use http::HeaderValue;
use warp::{filters::header, reject::Rejection, Filter};
pub fn context() -> impl Filter<Extract = (Context,), Error = Rejection> + Clone {
header::optional("x-request-id")
.and(header::optional("x-correlation-id"))
.and(header::optional("authorization"))
.and_then(
|req_id: Option<HeaderValue>,
corr_id: Option<HeaderValue>,
auth: Option<HeaderValue>| async move {
super::extensions::extract_context(
req_id.as_ref(),
corr_id.as_ref(),
auth.as_ref(),
)
.map_err(warp::reject::custom)
},
)
}
}
#[cfg(feature = "axum")]
pub mod axum_extensions {
use crate::{Context, Error};
use axum::extract::{FromRequest, RequestParts};
#[async_trait::async_trait]
impl<B> FromRequest<B> for Context
where
B: Send, {
type Rejection = Error;
async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
let headers = req.headers();
super::extensions::extract_context(
headers.get("x-request-id"),
headers.get("x-correlation-id"),
headers.get("authorization"),
)
}
}
}
#[cfg(any(feature = "axum", feature = "warp"))]
mod extensions {
use crate::{types, Context, Error};
use http::header::HeaderValue;
pub fn extract_context(
request_id_header: Option<&HeaderValue>,
correlation_id_header: Option<&HeaderValue>,
authorization_header: Option<&HeaderValue>,
) -> Result<Context, Error> {
let request_id = extract_str(request_id_header, "x-request-id")?.into();
let correlation_id = extract_str(correlation_id_header, "x-correlation-id")?.into();
let authorization = extract_str(authorization_header, "authorization")?;
if !authorization.starts_with("Token ") {
return Err(Error::unauthorized(
"Malformed Authorization header. Must start with `Token `",
));
}
let credentials_token = base64::decode(authorization.trim_start_matches("Token "))
.map_err(|_| Error::unauthorized("Authorization header contained invalid base64"))
.and_then(|bs| {
String::from_utf8(bs).map_err(|_| {
Error::unauthorized("Decoded Authorzation token contained Invalid UTF-8")
})
})
.and_then(|s| {
types::CredentialsToken::try_from(s).map_err(|err| {
Error::unauthorized(format!("Invalid CredentialsToken provided: {err}"))
})
})?;
Ok(Context {
request_id,
correlation_id,
credentials_token,
})
}
#[cfg(any(feature = "axum", feature = "warp"))]
fn extract_str<'a>(
header_value: Option<&'a HeaderValue>,
header: &str,
) -> Result<&'a str, Error> {
header_value
.ok_or_else(|| Error::client_generic(format!("Missing header `{header}`")))
.and_then(|s| {
s.to_str().map_err(|err| {
Error::client_generic(format!("Header `{header}` contains non-ASCII: {err}"))
})
})
}
}
#[cfg(all(test, any(feature = "axum", feature = "warp")))]
mod tests {
#[cfg(feature = "axum")]
#[test]
fn test_extract_context() {
use http::HeaderValue;
let request_id = HeaderValue::from_static("123");
let correlation_id = HeaderValue::from_static("456");
let b64_auth = base64::encode("IMATOKEN");
let authorization =
HeaderValue::from_str(&format!("Token {b64_auth}")).expect("Authorization header");
let ctx = super::extensions::extract_context(
Some(&request_id),
Some(&correlation_id),
Some(&authorization),
)
.expect("Extracting context");
assert_eq!(ctx.request_id, "123");
assert_eq!(ctx.correlation_id, "456");
assert_eq!(ctx.credentials_token, "IMATOKEN");
}
}
#[cfg(test)]
pub fn test_ctx() -> Context {
use once_cell::sync::OnceCell;
use std::sync::atomic::{AtomicUsize, Ordering};
static COUNTER: OnceCell<AtomicUsize> = OnceCell::new();
let c = COUNTER.get_or_init(|| AtomicUsize::new(1));
let req_id = c.fetch_add(1, Ordering::Relaxed);
let corr_id = c.fetch_add(1, Ordering::Relaxed);
Context {
request_id: format!("{req_id}"),
correlation_id: format!("{corr_id}"),
credentials_token: "TOKENBYTESTCTX".parse().unwrap(),
}
}