axum-session-manager 0.1.1

Session manager crate for Rust web framework 'Axum'
Documentation
use std::{
    collections::HashMap,
    fmt::Debug,
    marker::PhantomData,
    sync::{Arc, RwLock},
};

use axum::{
    async_trait,
    extract::State,
    response::IntoResponse,
    routing::{get, post},
    Extension, Json, Router,
};
use axum_extra::extract::{cookie::Cookie, CookieJar};
use axum_session_manager::{SessionManage, SessionManagerLayer, UserData, UserState};
use http::StatusCode;
use serde::{Deserialize, Serialize};
use thiserror::Error;
use uuid::Uuid;

const COOKIE_KEY: &str = "test-id";
#[tokio::main]
async fn main() {
    let service = app();

    let listener = tokio::net::TcpListener::bind("127.0.0.1:3000")
        .await
        .unwrap();
    axum::serve(listener, service).await.unwrap();
}

fn app() -> Router {
    let session_pool = SessionPool::new();
    let phantom = PhantomData::default();
    let layer = SessionManagerLayer::new(session_pool.clone(), COOKIE_KEY, phantom);

    Router::new()
        .route("/", get(root))
        .route("/login", post(login))
        .route("/session/data", get(get_session_data))
        .layer(layer)
        .with_state(session_pool)
}

async fn root() -> impl IntoResponse {
    "Hello"
}

async fn get_session_data(
    Extension(user_data): Extension<UserData<Credential>>,
) -> Result<impl IntoResponse, impl IntoResponse> {
    let user_state = user_data.0;
    match user_state {
        UserState::HaveSession(a) => Ok((StatusCode::OK, Json(a))),
        UserState::NoCookie => Err((StatusCode::UNAUTHORIZED, "no cookie")),
        UserState::NoSession => Err((StatusCode::UNAUTHORIZED, "you need login")),
    }
}

async fn login(
    State(session_pool): State<SessionPool>,
    jar: CookieJar,
    Json(payload): Json<Credential>,
) -> Result<impl IntoResponse, StatusCode> {
    let session_id = session_pool
        .add_session(payload)
        .await
        .map_err(|_e| StatusCode::INTERNAL_SERVER_ERROR)?;
    let cookie = Cookie::new(COOKIE_KEY, session_id);

    Ok((StatusCode::OK, jar.add(cookie)))
}

#[derive(Debug, Clone, Serialize, Deserialize)]
struct Credential {
    id: i32,
    name: String,
    password: String,
}

#[derive(Debug, Clone)]
struct UserPool {
    pool: Arc<RwLock<HashMap<i32, Credential>>>,
}

impl UserPool {
    fn new() -> Self {
        Self {
            pool: Arc::default(),
        }
    }
}

#[derive(Debug, Clone)]
struct SessionPool {
    pool: Arc<RwLock<HashMap<String, Credential>>>,
}

impl SessionPool {
    fn new() -> Self {
        Self {
            pool: Arc::default(),
        }
    }
}

#[derive(Debug, Clone, Error)]
enum ServerError {
    #[error("unexpected error")]
    Unexpect,
}

#[async_trait]
impl SessionManage<Credential> for SessionPool {
    type SessionID = String;
    type UserInfo = Credential;
    type Error = ServerError;

    async fn add_session(&self, session_data: Credential) -> Result<Self::SessionID, Self::Error> {
        let session_id = Uuid::new_v4().to_string();
        {
            let _ = &self
                .pool
                .write()
                .map_err(|_e| ServerError::Unexpect)?
                .insert(session_id.clone(), session_data);
        }
        Ok(session_id)
    }

    async fn verify_session(
        &self,
        session_id: &str,
    ) -> Result<Option<Self::UserInfo>, Self::Error> {
        let data = self.pool.read().map_err(|_e| ServerError::Unexpect)?.to_owned();

        match data.get(session_id) {
            Some(user) => {
                let user = user.to_owned();
                Ok(Some(user))
            }
            None => Ok(None),
        }
    }

    async fn delete_session(&self, session_id: &str) -> Result<(), Self::Error> {
        self.pool
            .write()
            .map_err(|_e| ServerError::Unexpect)?
            .remove(session_id);
        Ok(())
    }
}