use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use http::{header, Method, Request, StatusCode};
use rustauth_core::api::{core_auth_async_endpoints, ApiErrorResponse, AuthRouter};
use rustauth_core::context::{create_auth_context, create_auth_context_with_adapter};
use rustauth_core::cookies::Cookie;
use rustauth_core::db::{
AdapterFuture, Create, DbAdapter, DbRecord, DbValue, FindOne, MemoryAdapter, Session, User,
Where,
};
use rustauth_core::error::RustAuthError;
use rustauth_core::options::{AdvancedOptions, RustAuthOptions};
use rustauth_core::plugin::{AuthPlugin, PluginPasswordValidationRejection};
use serde_json::Value;
use time::OffsetDateTime;
type RouteAdapter = MemoryAdapter;
type UnitFuture<'a> = Pin<Box<dyn Future<Output = ()> + Send + 'a>>;
trait RouteAdapterSeed {
fn insert_user(&self, user: User) -> UnitFuture<'_>;
fn insert_account(&self, record: DbRecord) -> AdapterFuture<'_, ()>;
fn insert_session(&self, session: Session) -> UnitFuture<'_>;
}
impl RouteAdapterSeed for RouteAdapter {
fn insert_user(&self, user: User) -> UnitFuture<'_> {
Box::pin(async move {
let _ = self.create(create_query("user", user_record(user))).await;
})
}
fn insert_account(&self, record: DbRecord) -> AdapterFuture<'_, ()> {
Box::pin(async move {
self.create(create_query("account", record)).await?;
Ok(())
})
}
fn insert_session(&self, session: Session) -> UnitFuture<'_> {
Box::pin(async move {
let _ = self
.create(create_query("session", session_record(session)))
.await;
})
}
}
#[cfg(feature = "oauth")]
mod account_tokens;
mod change_email;
mod change_password;
mod delete_user;
mod delete_user_callback;
mod email_verification;
mod error_page;
mod get_session;
mod http_json_serde;
mod list_accounts;
mod list_sessions;
mod openapi;
mod password_validators;
mod request_password_reset;
mod reset_password;
mod revoke_other_sessions;
mod revoke_session;
mod revoke_sessions;
mod session_ip_metadata;
mod set_password;
mod sign_in_email;
mod sign_out;
mod sign_up_email;
#[cfg(feature = "oauth")]
mod social_oauth;
mod unlink_account;
mod update_session;
mod update_user;
mod verify_password;
use crate::common::{fast_hash_password, real_password_options, with_test_defaults};
fn router(adapter: Arc<RouteAdapter>) -> Result<AuthRouter, RustAuthError> {
router_with_options(adapter, RustAuthOptions::default())
}
fn router_with_real_password(adapter: Arc<RouteAdapter>) -> Result<AuthRouter, RustAuthError> {
router_with_options(
adapter,
RustAuthOptions {
password: real_password_options(),
..RustAuthOptions::default()
},
)
}
fn router_with_options(
adapter: Arc<RouteAdapter>,
options: RustAuthOptions,
) -> Result<AuthRouter, RustAuthError> {
let context = create_auth_context_with_adapter(
with_test_defaults(RustAuthOptions {
secret: Some(secret().to_owned()),
advanced: AdvancedOptions {
disable_csrf_check: true,
disable_origin_check: true,
..AdvancedOptions::default()
},
..options
}),
adapter.clone(),
)?;
AuthRouter::with_async_endpoints(context, Vec::new(), core_auth_async_endpoints())
}
fn router_with_bare_options(
adapter: Arc<RouteAdapter>,
options: RustAuthOptions,
) -> Result<AuthRouter, RustAuthError> {
let mut options = RustAuthOptions {
secret: Some(secret().to_owned()),
advanced: AdvancedOptions {
disable_csrf_check: true,
disable_origin_check: true,
..AdvancedOptions::default()
},
..options
};
if options.mode != rustauth_core::options::DeploymentMode::Production {
options.mode = rustauth_core::options::DeploymentMode::Development;
}
let context = create_auth_context_with_adapter(options, adapter.clone())?;
AuthRouter::with_async_endpoints(context, Vec::new(), core_auth_async_endpoints())
}
fn router_with_advanced(
adapter: Arc<RouteAdapter>,
options: RustAuthOptions,
mut advanced: AdvancedOptions,
) -> Result<AuthRouter, RustAuthError> {
advanced.disable_csrf_check = true;
advanced.disable_origin_check = true;
let context = create_auth_context_with_adapter(
with_test_defaults(RustAuthOptions {
secret: Some(secret().to_owned()),
advanced,
..options
}),
adapter.clone(),
)?;
AuthRouter::with_async_endpoints(context, Vec::new(), core_auth_async_endpoints())
}
fn json_request(
method: Method,
path: &str,
body: &str,
cookie: Option<&str>,
) -> Result<Request<Vec<u8>>, http::Error> {
let mut builder = Request::builder()
.method(method)
.uri(format!("http://localhost:3000{path}"));
if !body.is_empty() {
builder = builder.header(header::CONTENT_TYPE, "application/json");
}
if let Some(cookie) = cookie {
builder = builder.header(header::COOKIE, cookie);
}
builder.body(body.as_bytes().to_vec())
}
fn secret() -> &'static str {
"test-secret-123456789012345678901234"
}
fn signed_session_cookie(token: &str) -> Result<String, RustAuthError> {
let context = create_auth_context(with_test_defaults(RustAuthOptions {
secret: Some(secret().to_owned()),
..RustAuthOptions::default()
}))?;
let cookies = rustauth_core::cookies::set_session_cookie(
&context.auth_cookies,
&context.secret,
token,
rustauth_core::cookies::SessionCookieOptions::default(),
)?;
Ok(cookie_header(&cookies))
}
fn signed_dont_remember_session_cookie(token: &str) -> Result<String, RustAuthError> {
let context = create_auth_context(with_test_defaults(RustAuthOptions {
secret: Some(secret().to_owned()),
..RustAuthOptions::default()
}))?;
let cookies = rustauth_core::cookies::set_session_cookie(
&context.auth_cookies,
&context.secret,
token,
rustauth_core::cookies::SessionCookieOptions {
dont_remember: true,
..rustauth_core::cookies::SessionCookieOptions::default()
},
)?;
Ok(cookie_header(&cookies))
}
fn cookie_header(cookies: &[Cookie]) -> String {
cookies
.iter()
.map(|cookie| format!("{}={}", cookie.name, cookie.value))
.collect::<Vec<_>>()
.join("; ")
}
fn set_cookie_values(response: &http::Response<Vec<u8>>) -> Vec<String> {
response
.headers()
.get_all(header::SET_COOKIE)
.iter()
.filter_map(|value| value.to_str().ok().map(str::to_owned))
.collect()
}
fn user(now: OffsetDateTime) -> User {
User {
id: "user_1".to_owned(),
name: "Ada".to_owned(),
email: "ada@example.com".to_owned(),
email_verified: true,
image: None,
username: None,
display_username: None,
created_at: now,
updated_at: now,
}
}
fn session(now: OffsetDateTime, expires_at: OffsetDateTime) -> Session {
Session {
id: "session_1".to_owned(),
user_id: "user_1".to_owned(),
expires_at,
token: "token_1".to_owned(),
ip_address: None,
user_agent: None,
created_at: now,
updated_at: now,
}
}
fn user_record(user: User) -> DbRecord {
let mut record = DbRecord::new();
record.insert("id".to_owned(), DbValue::String(user.id));
record.insert("name".to_owned(), DbValue::String(user.name));
record.insert("email".to_owned(), DbValue::String(user.email));
record.insert(
"email_verified".to_owned(),
DbValue::Boolean(user.email_verified),
);
record.insert(
"image".to_owned(),
user.image.map(DbValue::String).unwrap_or(DbValue::Null),
);
record.insert("created_at".to_owned(), DbValue::Timestamp(user.created_at));
record.insert("updated_at".to_owned(), DbValue::Timestamp(user.updated_at));
record
}
fn create_query(model: &str, record: DbRecord) -> Create {
record
.into_iter()
.fold(Create::new(model), |query, (field, value)| {
query.data(field, value)
})
}
async fn record_by_string(
adapter: &RouteAdapter,
model: &str,
field: &str,
value: &str,
) -> Result<Option<DbRecord>, RustAuthError> {
adapter
.find_one(
FindOne::new(model).where_clause(Where::new(field, DbValue::String(value.to_owned()))),
)
.await
}
async fn contains_record_string(
adapter: &RouteAdapter,
model: &str,
field: &str,
value: &str,
) -> Result<bool, RustAuthError> {
Ok(record_by_string(adapter, model, field, value)
.await?
.is_some())
}
fn session_record(session: Session) -> DbRecord {
let mut record = DbRecord::new();
record.insert("id".to_owned(), DbValue::String(session.id));
record.insert("user_id".to_owned(), DbValue::String(session.user_id));
record.insert(
"expires_at".to_owned(),
DbValue::Timestamp(session.expires_at),
);
record.insert("token".to_owned(), DbValue::String(session.token));
record.insert("ip_address".to_owned(), DbValue::Null);
record.insert("user_agent".to_owned(), DbValue::Null);
record.insert(
"created_at".to_owned(),
DbValue::Timestamp(session.created_at),
);
record.insert(
"updated_at".to_owned(),
DbValue::Timestamp(session.updated_at),
);
record
}
fn credential_account_record(user_id: &str, password_hash: &str, now: OffsetDateTime) -> DbRecord {
let mut record = linked_account_record("account_1", "credential", user_id, user_id, None, now);
record.insert(
"password".to_owned(),
DbValue::String(password_hash.to_owned()),
);
record
}
fn linked_account_record(
id: &str,
provider_id: &str,
account_id: &str,
user_id: &str,
scope: Option<&str>,
now: OffsetDateTime,
) -> DbRecord {
let mut record = DbRecord::new();
record.insert("id".to_owned(), DbValue::String(id.to_owned()));
record.insert(
"provider_id".to_owned(),
DbValue::String(provider_id.to_owned()),
);
record.insert(
"account_id".to_owned(),
DbValue::String(account_id.to_owned()),
);
record.insert("user_id".to_owned(), DbValue::String(user_id.to_owned()));
record.insert("access_token".to_owned(), DbValue::Null);
record.insert("refresh_token".to_owned(), DbValue::Null);
record.insert("id_token".to_owned(), DbValue::Null);
record.insert("access_token_expires_at".to_owned(), DbValue::Null);
record.insert("refresh_token_expires_at".to_owned(), DbValue::Null);
record.insert(
"scope".to_owned(),
scope
.map(|scope| DbValue::String(scope.to_owned()))
.unwrap_or(DbValue::Null),
);
record.insert("password".to_owned(), DbValue::Null);
record.insert("created_at".to_owned(), DbValue::Timestamp(now));
record.insert("updated_at".to_owned(), DbValue::Timestamp(now));
record
}
fn string_field<'a>(record: &'a DbRecord, field: &str) -> Result<&'a str, RustAuthError> {
match record.get(field) {
Some(DbValue::String(value)) => Ok(value),
_ => Err(RustAuthError::Adapter(format!(
"missing string field `{field}`"
))),
}
}
fn rejecting_password_plugin(path: &'static str) -> AuthPlugin {
AuthPlugin::new("password-validator").with_password_validator(move |_context, input| {
Box::pin(async move {
if input.path == path {
return Err(PluginPasswordValidationRejection::bad_request(
"PASSWORD_COMPROMISED",
"compromised",
));
}
Ok(())
})
})
}