openauth-core 0.0.4

Core types and primitives for OpenAuth.
Documentation
use std::collections::HashMap;

use openauth_core::auth::session::{GetSessionInput, SessionAuth};
use openauth_core::context::{create_auth_context, AuthContext};
use openauth_core::cookies::{
    get_cookies, parse_cookies, set_session_cookie, Cookie, SessionCookieOptions,
};
use openauth_core::db::{
    run_transaction_without_native_support, AdapterFuture, Count, Create, DbAdapter, DbRecord,
    DbValue, Delete, DeleteMany, FindMany, FindOne, Session, TransactionCallback, Update,
    UpdateMany, User, Where, WhereOperator,
};
use openauth_core::error::OpenAuthError;
use openauth_core::options::{CookieCacheOptions, OpenAuthOptions, SessionOptions};
use time::{Duration, OffsetDateTime};
use tokio::sync::Mutex;

#[derive(Default)]
struct SessionAuthAdapter {
    users: Mutex<HashMap<String, DbRecord>>,
    sessions: Mutex<HashMap<String, DbRecord>>,
    updates: Mutex<Vec<Update>>,
    deletes: Mutex<Vec<Delete>>,
}

impl SessionAuthAdapter {
    async fn insert_user(&self, user: User) {
        self.users
            .lock()
            .await
            .insert(user.id.clone(), user_record(user));
    }

    async fn insert_session(&self, session: Session) {
        self.sessions
            .lock()
            .await
            .insert(session.token.clone(), session_record(session));
    }
}

impl DbAdapter for SessionAuthAdapter {
    fn id(&self) -> &str {
        "session-auth-memory"
    }

    fn create<'a>(&'a self, _query: Create) -> AdapterFuture<'a, DbRecord> {
        Box::pin(async {
            Err(OpenAuthError::Adapter(
                "create should not be called".to_owned(),
            ))
        })
    }

    fn find_one<'a>(&'a self, query: FindOne) -> AdapterFuture<'a, Option<DbRecord>> {
        Box::pin(async move {
            match query.model.as_str() {
                "session" => {
                    let token = string_filter(&query.where_clauses, "token")?;
                    Ok(self.sessions.lock().await.get(token).cloned())
                }
                "user" => {
                    let id = string_filter(&query.where_clauses, "id")?;
                    Ok(self.users.lock().await.get(id).cloned())
                }
                model => Err(OpenAuthError::Adapter(format!(
                    "unexpected find_one model `{model}`"
                ))),
            }
        })
    }

    fn find_many<'a>(&'a self, _query: FindMany) -> AdapterFuture<'a, Vec<DbRecord>> {
        Box::pin(async { Ok(Vec::new()) })
    }

    fn count<'a>(&'a self, _query: Count) -> AdapterFuture<'a, u64> {
        Box::pin(async { Ok(0) })
    }

    fn update<'a>(&'a self, query: Update) -> AdapterFuture<'a, Option<DbRecord>> {
        Box::pin(async move {
            self.updates.lock().await.push(query.clone());
            let token = string_filter(&query.where_clauses, "token")?;
            let mut sessions = self.sessions.lock().await;
            let Some(record) = sessions.get_mut(token) else {
                return Ok(None);
            };
            for (key, value) in query.data {
                record.insert(key, value);
            }
            Ok(Some(record.clone()))
        })
    }

    fn update_many<'a>(&'a self, _query: UpdateMany) -> AdapterFuture<'a, u64> {
        Box::pin(async { Ok(0) })
    }

    fn delete<'a>(&'a self, query: Delete) -> AdapterFuture<'a, ()> {
        Box::pin(async move {
            self.deletes.lock().await.push(query.clone());
            let token = string_filter(&query.where_clauses, "token")?;
            self.sessions.lock().await.remove(token);
            Ok(())
        })
    }

    fn delete_many<'a>(&'a self, _query: DeleteMany) -> AdapterFuture<'a, u64> {
        Box::pin(async { Ok(0) })
    }

    fn transaction<'a>(&'a self, callback: TransactionCallback<'a>) -> AdapterFuture<'a, ()> {
        run_transaction_without_native_support(self, callback)
    }
}

#[tokio::test]
async fn get_session_returns_db_session_with_user_and_sets_cookie_cache(
) -> Result<(), Box<dyn std::error::Error>> {
    let adapter = SessionAuthAdapter::default();
    let context = context()?;
    let now = OffsetDateTime::now_utc();
    adapter.insert_user(user(now)).await;
    adapter
        .insert_session(session(now, now + Duration::hours(2)))
        .await;
    let cookie_header = session_cookie_header(&context, "token_1", false)?;

    let result = SessionAuth::new(&adapter, &context)
        .get_session(GetSessionInput::new(cookie_header))
        .await?;

    let Some(result) = result else {
        return Err("missing session result".into());
    };
    assert_eq!(
        result
            .session
            .as_ref()
            .map(|session| session.token.as_str()),
        Some("token_1")
    );
    assert_eq!(
        result.user.as_ref().map(|user| user.id.as_str()),
        Some("user_1")
    );
    assert!(result
        .cookies
        .iter()
        .any(|cookie| cookie.name == context.auth_cookies.session_data.name));
    assert!(adapter.updates.lock().await.is_empty());
    Ok(())
}

#[tokio::test]
async fn get_session_refreshes_due_session_and_session_cookie(
) -> Result<(), Box<dyn std::error::Error>> {
    let adapter = SessionAuthAdapter::default();
    let context = context()?;
    let now = OffsetDateTime::now_utc();
    let expires_at = now + Duration::seconds(10);
    adapter.insert_user(user(now)).await;
    adapter.insert_session(session(now, expires_at)).await;
    let cookie_header = session_cookie_header(&context, "token_1", false)?;

    let result = SessionAuth::new(&adapter, &context)
        .get_session(GetSessionInput::new(cookie_header))
        .await?;

    let Some(result) = result else {
        return Err("missing session result".into());
    };
    let Some(session) = result.session.as_ref() else {
        return Err("missing refreshed session".into());
    };
    assert!(session.expires_at > expires_at);
    assert!(result
        .cookies
        .iter()
        .any(|cookie| cookie.name == context.auth_cookies.session_token.name));
    assert_eq!(adapter.updates.lock().await.len(), 1);
    Ok(())
}

#[tokio::test]
async fn get_session_clears_cookies_when_signed_token_is_missing_from_store(
) -> Result<(), Box<dyn std::error::Error>> {
    let adapter = SessionAuthAdapter::default();
    let context = context()?;
    let cookie_header = session_cookie_header(&context, "missing_token", false)?;

    let result = SessionAuth::new(&adapter, &context)
        .get_session(GetSessionInput::new(cookie_header))
        .await?;

    let Some(result) = result else {
        return Err("missing anonymous result".into());
    };
    assert!(result.session.is_none());
    assert!(result.cookies.iter().any(|cookie| cookie.name
        == context.auth_cookies.session_token.name
        && cookie.attributes.max_age == Some(0)));
    Ok(())
}

#[tokio::test]
async fn sign_out_deletes_session_and_expires_session_cookies(
) -> Result<(), Box<dyn std::error::Error>> {
    let adapter = SessionAuthAdapter::default();
    let context = context()?;
    let now = OffsetDateTime::now_utc();
    adapter
        .insert_session(session(now, now + Duration::hours(1)))
        .await;
    let cookie_header = session_cookie_header(&context, "token_1", false)?;

    let result = SessionAuth::new(&adapter, &context)
        .sign_out(cookie_header)
        .await?;

    assert!(result.success);
    assert!(adapter.sessions.lock().await.is_empty());
    assert_eq!(adapter.deletes.lock().await.len(), 1);
    assert!(result.cookies.iter().any(|cookie| cookie.name
        == context.auth_cookies.session_token.name
        && cookie.attributes.max_age == Some(0)));
    Ok(())
}

#[tokio::test]
async fn sign_out_ignores_forged_session_cookie_but_still_expires_cookies(
) -> Result<(), Box<dyn std::error::Error>> {
    let adapter = SessionAuthAdapter::default();
    let context = context()?;
    let now = OffsetDateTime::now_utc();
    adapter
        .insert_session(session(now, now + Duration::hours(1)))
        .await;
    let cookie_header = format!("{}=token_1.bad", context.auth_cookies.session_token.name);

    let result = SessionAuth::new(&adapter, &context)
        .sign_out(cookie_header)
        .await?;

    assert!(result.success);
    assert_eq!(adapter.sessions.lock().await.len(), 1);
    assert!(adapter.deletes.lock().await.is_empty());
    assert!(result.cookies.iter().any(|cookie| cookie.name
        == context.auth_cookies.session_token.name
        && cookie.attributes.max_age == Some(0)));
    Ok(())
}

fn context() -> Result<AuthContext, OpenAuthError> {
    create_auth_context(OpenAuthOptions {
        secret: Some("test-secret-123456789012345678901234".to_owned()),
        session: SessionOptions {
            expires_in: Some(60 * 60),
            update_age: Some(1),
            cookie_cache: CookieCacheOptions {
                enabled: true,
                max_age: Some(60 * 5),
                ..CookieCacheOptions::default()
            },
            ..SessionOptions::default()
        },
        ..OpenAuthOptions::default()
    })
}

fn session_cookie_header(
    context: &AuthContext,
    token: &str,
    dont_remember: bool,
) -> Result<String, OpenAuthError> {
    let cookies = set_session_cookie(
        &get_cookies(&context.options)?,
        &context.secret,
        token,
        SessionCookieOptions {
            dont_remember,
            ..SessionCookieOptions::default()
        },
    )?;
    Ok(cookie_header(&cookies))
}

fn cookie_header(cookies: &[Cookie]) -> String {
    cookies
        .iter()
        .map(|cookie| format!("{}={}", cookie.name, cookie.value))
        .collect::<Vec<_>>()
        .join("; ")
}

fn user(now: OffsetDateTime) -> User {
    User {
        id: "user_1".to_owned(),
        name: "Ada".to_owned(),
        email: "ada@example.com".to_owned(),
        email_verified: true,
        image: None,
        username: None,
        display_username: None,
        created_at: now,
        updated_at: now,
    }
}

fn session(now: OffsetDateTime, expires_at: OffsetDateTime) -> Session {
    Session {
        id: "session_1".to_owned(),
        user_id: "user_1".to_owned(),
        expires_at,
        token: "token_1".to_owned(),
        ip_address: Some("192.0.2.1".to_owned()),
        user_agent: Some("test-agent".to_owned()),
        created_at: now,
        updated_at: now,
    }
}

fn user_record(user: User) -> DbRecord {
    let mut record = DbRecord::new();
    record.insert("id".to_owned(), DbValue::String(user.id));
    record.insert("name".to_owned(), DbValue::String(user.name));
    record.insert("email".to_owned(), DbValue::String(user.email));
    record.insert(
        "email_verified".to_owned(),
        DbValue::Boolean(user.email_verified),
    );
    record.insert(
        "image".to_owned(),
        user.image.map(DbValue::String).unwrap_or(DbValue::Null),
    );
    record.insert("created_at".to_owned(), DbValue::Timestamp(user.created_at));
    record.insert("updated_at".to_owned(), DbValue::Timestamp(user.updated_at));
    record
}

fn session_record(session: Session) -> DbRecord {
    let mut record = DbRecord::new();
    record.insert("id".to_owned(), DbValue::String(session.id));
    record.insert("user_id".to_owned(), DbValue::String(session.user_id));
    record.insert(
        "expires_at".to_owned(),
        DbValue::Timestamp(session.expires_at),
    );
    record.insert("token".to_owned(), DbValue::String(session.token));
    record.insert(
        "ip_address".to_owned(),
        session
            .ip_address
            .map(DbValue::String)
            .unwrap_or(DbValue::Null),
    );
    record.insert(
        "user_agent".to_owned(),
        session
            .user_agent
            .map(DbValue::String)
            .unwrap_or(DbValue::Null),
    );
    record.insert(
        "created_at".to_owned(),
        DbValue::Timestamp(session.created_at),
    );
    record.insert(
        "updated_at".to_owned(),
        DbValue::Timestamp(session.updated_at),
    );
    record
}

fn string_filter<'a>(where_clauses: &'a [Where], field: &str) -> Result<&'a str, OpenAuthError> {
    where_clauses
        .iter()
        .find_map(|where_clause| {
            match (
                where_clause.field.as_str(),
                where_clause.operator,
                &where_clause.value,
            ) {
                (candidate, WhereOperator::Eq, DbValue::String(value)) if candidate == field => {
                    Some(value.as_str())
                }
                _ => None,
            }
        })
        .ok_or_else(|| OpenAuthError::Adapter(format!("missing {field} filter")))
}

#[allow(dead_code)]
fn parsed_cookie_value(cookies: &[Cookie], name: &str) -> Option<String> {
    parse_cookies(&cookie_header(cookies)).get(name).cloned()
}