ella-server 0.1.5

Client-server API for the ella datastore.
Documentation
use std::sync::{Arc, Mutex};

use dashmap::DashMap;
use ella_common::OffsetDateTime;
use ella_engine::{engine::EllaState, EllaConfig};
use hmac::{Hmac, Mac};
use jwt::{RegisteredClaims, SignWithKey, VerifyWithKey};
use sha2::Sha256;
use tonic::service::Interceptor;
use uuid::Uuid;

#[derive(Debug, Clone)]
pub(crate) struct ConnectionState {
    state: Arc<Mutex<EllaState>>,
}

impl ConnectionState {
    pub fn new(state: EllaState) -> Self {
        Self {
            state: Arc::new(Mutex::new(state)),
        }
    }

    pub fn read(&self) -> EllaState {
        self.state.lock().unwrap().clone()
    }

    pub fn set_config(&self, config: EllaConfig) {
        self.state.lock().unwrap().with_config(config);
    }
}

#[derive(Debug)]
pub(crate) struct AuthProvider {
    key: Hmac<Sha256>,
}

impl AuthProvider {
    pub fn from_secret(secret: &[u8]) -> crate::Result<Self> {
        let key = Hmac::new_from_slice(secret).map_err(|_| crate::ServerError::InvalidSecret)?;
        Ok(Self { key })
    }

    fn encode<T: serde::Serialize>(&self, token: &T) -> crate::Result<String> {
        token
            .sign_with_key(&self.key)
            .map_err(|err| crate::ServerError::Token(err.to_string()).into())
    }

    fn extract_payload<T: serde::de::DeserializeOwned>(
        &self,
        request: &tonic::Request<()>,
    ) -> Result<Option<T>, tonic::Status> {
        match request.metadata().get("authorization").map(|m| m.to_str()) {
            Some(Ok(auth)) => match auth.split_once(' ') {
                Some(("Bearer", token)) => {
                    let payload = token.verify_with_key(&self.key).map_err(|err| {
                        tonic::Status::unauthenticated(format!("invalid token: {}", err))
                    })?;
                    Ok(Some(payload))
                }
                _ => Ok(None),
            },
            Some(Err(_)) => Err(tonic::Status::unauthenticated(
                "unable to parse authorization header as ASCII",
            )),
            None => Ok(None),
        }
    }
}

#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
#[serde(transparent)]
pub(crate) struct ConnectionToken(RegisteredClaims);

impl ConnectionToken {
    fn new(user: Option<String>) -> Self {
        let id = Uuid::new_v4().simple().to_string();
        let issued = OffsetDateTime::now_utc().unix_timestamp().unsigned_abs();
        Self(RegisteredClaims {
            issuer: None,
            subject: user,
            audience: None,
            expiration: None,
            not_before: None,
            issued_at: Some(issued),
            json_web_token_id: Some(id),
        })
    }

    fn uuid(&self) -> Result<Uuid, uuid::Error> {
        Uuid::try_parse(self.0.json_web_token_id.as_deref().unwrap_or_default())
    }
}

#[derive(Debug, Clone)]
pub(crate) struct ConnectionManager {
    state: EllaState,
    auth: Arc<AuthProvider>,
    connections: Arc<DashMap<Uuid, ConnectionState>>,
}

impl ConnectionManager {
    pub fn new(auth: Arc<AuthProvider>, state: EllaState) -> Self {
        Self {
            auth,
            state,
            connections: Arc::new(DashMap::new()),
        }
    }

    pub fn handshake(&self) -> crate::Result<String> {
        let conn = ConnectionToken::new(None);
        let token = self.auth.encode(&conn)?;
        let state = ConnectionState::new(self.state.clone());
        self.connections.insert(
            conn.uuid()
                .expect("newly created UUID should always be valid"),
            state,
        );
        Ok(token)
    }
}

impl Interceptor for ConnectionManager {
    fn call(
        &mut self,
        mut request: tonic::Request<()>,
    ) -> Result<tonic::Request<()>, tonic::Status> {
        if let Some(token) = self.auth.extract_payload::<ConnectionToken>(&request)? {
            let key = token.uuid().map_err(|err| {
                tonic::Status::unauthenticated(format!("invalid connection id: {}", err))
            })?;
            let conn = match self.connections.get(&key) {
                Some(conn) => conn.value().clone(),
                None => {
                    return Err(tonic::Status::unauthenticated(
                        "no active connection found for connection id",
                    ))
                }
            };
            request.extensions_mut().insert(conn);
        }

        Ok(request)
    }
}

pub(crate) fn connection<T>(request: &tonic::Request<T>) -> Result<ConnectionState, tonic::Status> {
    request
        .extensions()
        .get()
        .cloned()
        .ok_or_else(|| tonic::Status::unauthenticated("missing connection token"))
}