ocpi 0.3.5

Unofficial, in progress, OCPI implementation
Documentation
use crate::{types, Error};

#[derive(Clone, Debug)]
pub struct Context {
    pub request_id: String,
    pub correlation_id: String,
    pub credentials_token: types::CredentialsToken,
}

pub struct ContextResult<T> {
    pub result: crate::Result<T>,
    pub context: Context,
}

pub trait IntoContextResult<T> {
    fn with_context(self, context: Context) -> ContextResult<T>;
}

impl<T, E> IntoContextResult<T> for Result<T, E>
where
    Error: From<E>,
{
    fn with_context(self, context: Context) -> ContextResult<T> {
        ContextResult {
            result: self.map_err(Error::from),
            context,
        }
    }
}

#[derive(Clone, Copy)]
pub struct ExtendedContext<'a> {
    pub request_id: &'a str,
    pub correlation_id: &'a str,
    pub credentials_token: &'a types::CredentialsToken,
}

impl Context {
    /// Extends this Context to use the same request and Correlation Ids but a different
    /// token.
    ///
    /// Mainly used in callbacks and similar.
    pub fn extend<'a>(&'a self, token: &'a types::CredentialsToken) -> ExtendedContext<'a> {
        ExtendedContext {
            request_id: &self.request_id,
            correlation_id: &self.correlation_id,
            credentials_token: token,
        }
    }

    pub fn as_extended(&self) -> ExtendedContext<'_> {
        ExtendedContext {
            request_id: self.request_id.as_str(),
            correlation_id: self.correlation_id.as_str(),
            credentials_token: &self.credentials_token,
        }
    }
}

#[cfg(feature = "warp")]
pub mod warp_extensions {
    use super::Context;
    use http::HeaderValue;
    use warp::{filters::header, reject::Rejection, Filter};

    pub fn context() -> impl Filter<Extract = (Context,), Error = Rejection> + Clone {
        header::optional("x-request-id")
            .and(header::optional("x-correlation-id"))
            .and(header::optional("authorization"))
            .and_then(
                |req_id: Option<HeaderValue>,
                 corr_id: Option<HeaderValue>,
                 auth: Option<HeaderValue>| async move {
                    super::extensions::extract_context(
                        req_id.as_ref(),
                        corr_id.as_ref(),
                        auth.as_ref(),
                    )
                    .map_err(warp::reject::custom)
                },
            )
    }
}

#[cfg(feature = "axum")]
pub mod axum_extensions {
    use crate::{Context, Error};
    use axum::extract::{FromRequest, RequestParts};

    #[async_trait::async_trait]
    impl<B> FromRequest<B> for Context
    where
        B: Send, // required by `async_trait`
    {
        type Rejection = Error;

        async fn from_request(req: &mut RequestParts<B>) -> Result<Self, Self::Rejection> {
            let headers = req.headers();

            super::extensions::extract_context(
                headers.get("x-request-id"),
                headers.get("x-correlation-id"),
                headers.get("authorization"),
            )
        }
    }
}

#[cfg(any(feature = "axum", feature = "warp"))]
mod extensions {

    use crate::{types, Context, Error};
    use http::header::HeaderValue;

    pub fn extract_context(
        request_id_header: Option<&HeaderValue>,
        correlation_id_header: Option<&HeaderValue>,
        authorization_header: Option<&HeaderValue>,
    ) -> Result<Context, Error> {
        let request_id = extract_str(request_id_header, "x-request-id")?.into();
        let correlation_id = extract_str(correlation_id_header, "x-correlation-id")?.into();
        let authorization = extract_str(authorization_header, "authorization")?;

        if !authorization.starts_with("Token ") {
            return Err(Error::unauthorized(
                "Malformed Authorization header. Must start with `Token `",
            ));
        }

        let credentials_token = base64::decode(authorization.trim_start_matches("Token "))
            .map_err(|_| Error::unauthorized("Authorization header contained invalid base64"))
            .and_then(|bs| {
                String::from_utf8(bs).map_err(|_| {
                    Error::unauthorized("Decoded Authorzation token contained Invalid UTF-8")
                })
            })
            .and_then(|s| {
                types::CredentialsToken::try_from(s).map_err(|err| {
                    Error::unauthorized(format!("Invalid CredentialsToken provided: {err}"))
                })
            })?;

        Ok(Context {
            request_id,
            correlation_id,
            credentials_token,
        })
    }

    #[cfg(any(feature = "axum", feature = "warp"))]
    fn extract_str<'a>(
        header_value: Option<&'a HeaderValue>,
        header: &str,
    ) -> Result<&'a str, Error> {
        header_value
            .ok_or_else(|| Error::client_generic(format!("Missing header `{header}`")))
            .and_then(|s| {
                s.to_str().map_err(|err| {
                    Error::client_generic(format!("Header `{header}` contains non-ASCII: {err}"))
                })
            })
    }
}

#[cfg(all(test, any(feature = "axum", feature = "warp")))]
mod tests {

    #[cfg(feature = "axum")]
    #[test]
    fn test_extract_context() {
        use http::HeaderValue;

        let request_id = HeaderValue::from_static("123");
        let correlation_id = HeaderValue::from_static("456");
        let b64_auth = base64::encode("IMATOKEN");
        let authorization =
            HeaderValue::from_str(&format!("Token {b64_auth}")).expect("Authorization header");

        let ctx = super::extensions::extract_context(
            Some(&request_id),
            Some(&correlation_id),
            Some(&authorization),
        )
        .expect("Extracting context");

        assert_eq!(ctx.request_id, "123");
        assert_eq!(ctx.correlation_id, "456");
        assert_eq!(ctx.credentials_token, "IMATOKEN");
    }
}

#[cfg(test)]
pub fn test_ctx() -> Context {
    use once_cell::sync::OnceCell;
    use std::sync::atomic::{AtomicUsize, Ordering};

    static COUNTER: OnceCell<AtomicUsize> = OnceCell::new();

    let c = COUNTER.get_or_init(|| AtomicUsize::new(1));

    let req_id = c.fetch_add(1, Ordering::Relaxed);
    let corr_id = c.fetch_add(1, Ordering::Relaxed);

    Context {
        request_id: format!("{req_id}"),
        correlation_id: format!("{corr_id}"),
        credentials_token: "TOKENBYTESTCTX".parse().unwrap(),
    }
}