use async_trait::async_trait;
use better_auth::adapters::DatabaseAdapter;
use better_auth::error::{AuthError, AuthResult};
use better_auth::plugins::EmailPasswordPlugin;
use better_auth::types::*;
use better_auth::{AuthConfig, BetterAuth};
use chrono::{DateTime, Utc};
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use uuid::Uuid;
pub struct CustomORMAdapter {
users: Arc<Mutex<HashMap<String, User>>>,
sessions: Arc<Mutex<HashMap<String, Session>>>,
credentials: Arc<Mutex<HashMap<String, String>>>, email_index: Arc<Mutex<HashMap<String, String>>>, }
impl CustomORMAdapter {
pub fn new() -> Self {
Self {
users: Arc::new(Mutex::new(HashMap::new())),
sessions: Arc::new(Mutex::new(HashMap::new())),
credentials: Arc::new(Mutex::new(HashMap::new())),
email_index: Arc::new(Mutex::new(HashMap::new())),
}
}
async fn orm_create_user(&self, user: User) -> Result<User, String> {
let mut users = self.users.lock().unwrap();
let mut email_index = self.email_index.lock().unwrap();
if let Some(ref email) = user.email {
if email_index.contains_key(email) {
return Err("Email already exists".to_string());
}
email_index.insert(email.clone(), user.id.clone());
}
users.insert(user.id.clone(), user.clone());
Ok(user)
}
async fn orm_find_user_by_email(&self, email: &str) -> Result<Option<User>, String> {
let email_index = self.email_index.lock().unwrap();
let users = self.users.lock().unwrap();
if let Some(user_id) = email_index.get(email) {
Ok(users.get(user_id).cloned())
} else {
Ok(None)
}
}
async fn orm_create_session(&self, session: Session) -> Result<Session, String> {
let mut sessions = self.sessions.lock().unwrap();
sessions.insert(session.token.clone(), session.clone());
Ok(session)
}
async fn orm_find_session(&self, token: &str) -> Result<Option<Session>, String> {
let sessions = self.sessions.lock().unwrap();
Ok(sessions.get(token).cloned())
}
}
#[async_trait]
impl DatabaseAdapter for CustomORMAdapter {
async fn create_user(&self, create_user: CreateUser) -> AuthResult<User> {
let id = create_user.id.unwrap_or_else(|| Uuid::new_v4().to_string());
let now = Utc::now();
let user = User {
id: id.clone(),
email: create_user.email.clone(),
name: create_user.name,
image: create_user.image,
email_verified: create_user.email_verified.unwrap_or(false),
created_at: now,
updated_at: now,
username: create_user.username,
display_username: create_user.display_username,
two_factor_enabled: false,
role: create_user.role,
banned: false,
ban_reason: None,
ban_expires: None,
metadata: create_user.metadata.unwrap_or_default(),
};
self.orm_create_user(user)
.await
.map_err(|e| AuthError::database(e))
}
async fn get_user_by_id(&self, id: &str) -> AuthResult<Option<User>> {
let users = self.users.lock().unwrap();
Ok(users.get(id).cloned())
}
async fn get_user_by_email(&self, email: &str) -> AuthResult<Option<User>> {
self.orm_find_user_by_email(email)
.await
.map_err(|e| AuthError::database(e))
}
async fn update_user(&self, id: &str, update: UpdateUser) -> AuthResult<User> {
let mut users = self.users.lock().unwrap();
let user = users
.get_mut(id)
.ok_or_else(|| AuthError::database("User not found"))?;
if let Some(email) = update.email {
user.email = Some(email);
}
if let Some(name) = update.name {
user.name = name;
}
if let Some(image) = update.image {
user.image = image;
}
if let Some(email_verified) = update.email_verified {
user.email_verified = email_verified;
}
user.updated_at = Utc::now();
Ok(user.clone())
}
async fn delete_user(&self, id: &str) -> AuthResult<()> {
let mut users = self.users.lock().unwrap();
let mut email_index = self.email_index.lock().unwrap();
if let Some(user) = users.remove(id) {
if let Some(email) = user.email {
email_index.remove(&email);
}
}
Ok(())
}
async fn create_session(&self, create_session: CreateSession) -> AuthResult<Session> {
let id = Uuid::new_v4().to_string();
let now = Utc::now();
let session = Session {
id: id.clone(),
token: create_session.token,
user_id: create_session.user_id,
expires_at: create_session.expires_at,
ip_address: create_session.ip_address,
user_agent: create_session.user_agent,
created_at: now,
updated_at: now,
};
self.orm_create_session(session)
.await
.map_err(|e| AuthError::database(e))
}
async fn get_session(&self, token: &str) -> AuthResult<Option<Session>> {
self.orm_find_session(token)
.await
.map_err(|e| AuthError::database(e))
}
async fn get_user_sessions(&self, user_id: &str) -> AuthResult<Vec<Session>> {
let sessions = self.sessions.lock().unwrap();
let user_sessions: Vec<Session> = sessions
.values()
.filter(|s| s.user_id == user_id)
.cloned()
.collect();
Ok(user_sessions)
}
async fn update_session_expiry(
&self,
token: &str,
expires_at: DateTime<Utc>,
) -> AuthResult<()> {
let mut sessions = self.sessions.lock().unwrap();
if let Some(session) = sessions.get_mut(token) {
session.expires_at = expires_at;
session.updated_at = Utc::now();
}
Ok(())
}
async fn delete_session(&self, token: &str) -> AuthResult<()> {
let mut sessions = self.sessions.lock().unwrap();
sessions.remove(token);
Ok(())
}
async fn delete_user_sessions(&self, user_id: &str) -> AuthResult<()> {
let mut sessions = self.sessions.lock().unwrap();
sessions.retain(|_, session| session.user_id != user_id);
Ok(())
}
async fn delete_expired_sessions(&self) -> AuthResult<usize> {
let mut sessions = self.sessions.lock().unwrap();
let now = Utc::now();
let initial_count = sessions.len();
sessions.retain(|_, session| session.expires_at > now);
Ok(initial_count - sessions.len())
}
async fn create_account(&self, _account: CreateAccount) -> AuthResult<Account> {
unimplemented!("OAuth accounts not implemented in this example")
}
async fn get_account(
&self,
_provider: &str,
_provider_account_id: &str,
) -> AuthResult<Option<Account>> {
Ok(None)
}
async fn get_user_accounts(&self, _user_id: &str) -> AuthResult<Vec<Account>> {
Ok(vec![])
}
async fn delete_account(&self, _id: &str) -> AuthResult<()> {
Ok(())
}
async fn create_verification(
&self,
_verification: CreateVerification,
) -> AuthResult<Verification> {
unimplemented!("Verifications not implemented in this example")
}
async fn get_verification(&self, _id: &str) -> AuthResult<Option<Verification>> {
Ok(None)
}
async fn get_verification_by_value(&self, _value: &str) -> AuthResult<Option<Verification>> {
Ok(None)
}
async fn delete_verification(&self, _id: &str) -> AuthResult<()> {
Ok(())
}
async fn delete_expired_verifications(&self) -> AuthResult<usize> {
Ok(0)
}
async fn create_credential(&self, user_id: String, password_hash: String) -> AuthResult<()> {
let mut credentials = self.credentials.lock().unwrap();
credentials.insert(user_id, password_hash);
Ok(())
}
async fn get_credential(&self, user_id: &str) -> AuthResult<Option<String>> {
let credentials = self.credentials.lock().unwrap();
Ok(credentials.get(user_id).cloned())
}
async fn update_credential(&self, user_id: &str, password_hash: String) -> AuthResult<()> {
let mut credentials = self.credentials.lock().unwrap();
credentials.insert(user_id.to_string(), password_hash);
Ok(())
}
async fn delete_credential(&self, user_id: &str) -> AuthResult<()> {
let mut credentials = self.credentials.lock().unwrap();
credentials.remove(user_id);
Ok(())
}
}
#[cfg(feature = "diesel-example")]
mod diesel_adapter {
use super::*;
use diesel::PgConnection;
use diesel::prelude::*;
use diesel::r2d2::{ConnectionManager, Pool};
pub struct DieselAdapter {
pool: Pool<ConnectionManager<PgConnection>>,
}
impl DieselAdapter {
pub fn new(database_url: &str) -> Result<Self, diesel::r2d2::Error> {
let manager = ConnectionManager::<PgConnection>::new(database_url);
let pool = Pool::builder().max_size(10).build(manager)?;
Ok(Self { pool })
}
}
#[async_trait]
impl DatabaseAdapter for DieselAdapter {
async fn create_user(&self, create_user: CreateUser) -> AuthResult<User> {
let pool = self.pool.clone();
tokio::task::spawn_blocking(move || {
let mut conn = pool.get().map_err(|e| AuthError::database(e.to_string()))?;
todo!("Implement actual Diesel query")
})
.await
.map_err(|e| AuthError::database(e.to_string()))?
}
async fn get_user_by_email(&self, _email: &str) -> AuthResult<Option<User>> {
todo!()
}
}
}
#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
println!("๐ง Custom ORM Adapter Example");
println!("{}", "=".repeat(50));
let custom_adapter = CustomORMAdapter::new();
println!("โ
Custom adapter created");
let config = AuthConfig::new("your-very-secure-secret-key-at-least-32-chars-long")
.base_url("http://localhost:3000")
.password_min_length(8);
let auth = BetterAuth::new(config)
.database(custom_adapter)
.plugin(EmailPasswordPlugin::new().enable_signup(true))
.build()
.await?;
println!("๐ Better-auth initialized with custom adapter");
println!("๐ Registered plugins: {:?}", auth.plugin_names());
println!("\n๐งช Testing custom adapter...");
let signup_body = serde_json::json!({
"email": "custom_adapter_user@example.com",
"password": "test_password_123",
"name": "Custom Adapter User"
});
let mut signup_req = AuthRequest::new(HttpMethod::Post, "/sign-up");
signup_req.body = Some(signup_body.to_string().into_bytes());
signup_req
.headers
.insert("content-type".to_string(), "application/json".to_string());
match auth.handle_request(signup_req).await {
Ok(response) => {
println!("โ
Registration successful with custom adapter");
let body_str = String::from_utf8(response.body)?;
let parsed: serde_json::Value = serde_json::from_str(&body_str)?;
println!("๐ค User: {}", parsed["user"]["email"]);
println!("๐ ID: {}", parsed["user"]["id"]);
}
Err(e) => {
println!("โ Registration failed: {}", e);
}
}
println!("\n๐งช Testing sign in...");
let signin_body = serde_json::json!({
"email": "custom_adapter_user@example.com",
"password": "test_password_123"
});
let mut signin_req = AuthRequest::new(HttpMethod::Post, "/sign-in");
signin_req.body = Some(signin_body.to_string().into_bytes());
signin_req
.headers
.insert("content-type".to_string(), "application/json".to_string());
match auth.handle_request(signin_req).await {
Ok(response) => {
println!("โ
Sign in successful with custom adapter");
let body_str = String::from_utf8(response.body)?;
let parsed: serde_json::Value = serde_json::from_str(&body_str)?;
if let Some(token) = parsed["session_token"].as_str() {
println!("๐ซ Session token: {}...", &token[..20.min(token.len())]);
}
}
Err(e) => {
println!("โ Sign in failed: {}", e);
}
}
println!("\n๐ Custom adapter example completed!");
println!("{}", "=".repeat(50));
println!("\n๐ก Key points for implementing custom adapters:");
println!(" 1. Implement the DatabaseAdapter trait");
println!(" 2. Map your ORM types to better-auth types");
println!(" 3. Handle blocking operations with tokio::task::spawn_blocking");
println!(" 4. Convert ORM errors to AuthError");
println!(" 5. Implement all required methods (users, sessions, credentials)");
println!(" 6. OAuth and verification methods can return unimplemented if not needed");
println!("\n๐ Adapter implementation tips:");
println!(" โข Diesel: Use spawn_blocking for all queries");
println!(" โข SeaORM: Direct async/await support");
println!(" โข MongoDB: Use the async driver");
println!(" โข Redis: Can be used for session storage");
println!(" โข Custom REST API: Use reqwest for HTTP calls");
Ok(())
}