use axum::{
Json, Router,
extract::State,
http::StatusCode,
response::{Html, IntoResponse, Response},
routing::{get, post},
};
use passki::{
AttestationConveyancePreference, AuthenticationChallenge, AuthenticationCredential,
AuthenticationState, ClientData, Passki, RegistrationChallenge, RegistrationCredential,
RegistrationState, ResidentKeyRequirement, StoredPasskey, UserVerificationRequirement,
};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fs;
use std::sync::{Arc, Mutex};
use tower_http::cors::CorsLayer;
use tower_http::trace::TraceLayer;
use uuid::Uuid;
struct AppError(String);
impl IntoResponse for AppError {
fn into_response(self) -> Response {
(StatusCode::BAD_REQUEST, self.0).into_response()
}
}
impl<T: std::fmt::Display> From<T> for AppError {
fn from(err: T) -> Self {
AppError(err.to_string())
}
}
type AppResult<T> = Result<Json<T>, AppError>;
#[derive(Clone, Default)]
struct Store {
users: Arc<Mutex<HashMap<String, User>>>,
pending_registrations: Arc<Mutex<HashMap<String, RegistrationState>>>,
pending_authentications: Arc<Mutex<HashMap<String, AuthenticationState>>>,
}
#[derive(Clone)]
#[allow(unused)]
struct User {
id: Uuid,
username: String,
display_name: String,
passkeys: Vec<StoredPasskey>,
}
#[derive(Deserialize)]
struct RegisterStartRequest {
username: String,
}
#[derive(Deserialize)]
struct RegisterFinishRequest {
credential_id: String,
public_key: String,
client_data_json: String,
}
#[derive(Deserialize, Default)]
struct AuthStartRequest {
#[serde(default)]
username: Option<String>,
}
#[derive(Deserialize)]
struct AuthFinishRequest {
credential_id: String,
authenticator_data: String,
client_data_json: String,
signature: String,
}
#[derive(Serialize)]
struct ApiResponse {
success: bool,
message: String,
#[serde(skip_serializing_if = "Option::is_none")]
username: Option<String>,
}
#[derive(Clone)]
struct AppState {
passki: Arc<Passki>,
store: Store,
}
async fn index() -> Result<Html<String>, AppError> {
let html = fs::read_to_string("examples/index.html")
.map_err(|e| AppError(e.to_string()))?;
Ok(Html(html))
}
async fn register_start(
State(state): State<AppState>,
Json(req): Json<RegisterStartRequest>,
) -> AppResult<RegistrationChallenge> {
let user_id = Uuid::new_v4().as_bytes().to_vec();
let existing = state.store.users.lock().unwrap()
.get(&req.username).map(|u| u.passkeys.clone());
let (challenge, reg_state) = state.passki.start_passkey_registration(
&user_id,
&req.username, &req.username, 60000, AttestationConveyancePreference::None, ResidentKeyRequirement::Preferred, UserVerificationRequirement::Preferred, existing.as_deref(), )?;
state.store.pending_registrations.lock().unwrap().insert(challenge.challenge.clone(), reg_state);
Ok(Json(challenge))
}
async fn register_finish(
State(state): State<AppState>,
Json(req): Json<RegisterFinishRequest>,
) -> AppResult<ApiResponse> {
let client_data = ClientData::from_base64(&req.client_data_json)?;
let reg_state = state.store.pending_registrations.lock().unwrap()
.remove(&client_data.challenge)
.ok_or(AppError("No pending registration".into()))?;
let credential = RegistrationCredential {
credential_id: req.credential_id,
public_key: req.public_key,
client_data_json: req.client_data_json,
};
let passkey = state.passki.finish_passkey_registration(&credential, ®_state)?;
let user_id_bytes = Passki::base64_decode(®_state.user.id)?;
let user_id = Uuid::from_slice(&user_id_bytes)?;
let mut users = state.store.users.lock().unwrap();
users
.entry(reg_state.user.name.clone())
.and_modify(|user| user.passkeys.push(passkey.clone()))
.or_insert(User {
id: user_id,
username: reg_state.user.name,
display_name: reg_state.user.display_name,
passkeys: vec![passkey],
});
Ok(Json(ApiResponse { success: true, message: "Registration successful".into(), username: None }))
}
async fn auth_start(
State(state): State<AppState>,
Json(req): Json<AuthStartRequest>,
) -> AppResult<AuthenticationChallenge> {
let passkeys = if let Some(ref username) = req.username {
let users = state.store.users.lock().unwrap();
let user = users.get(username).ok_or(AppError("User not found".into()))?;
user.passkeys.clone()
} else {
vec![]
};
let (challenge, auth_state) = state.passki.start_passkey_authentication(
&passkeys,
60000, UserVerificationRequirement::Preferred, );
state.store.pending_authentications.lock().unwrap().insert(challenge.challenge.clone(), auth_state);
Ok(Json(challenge))
}
async fn auth_finish(
State(state): State<AppState>,
Json(req): Json<AuthFinishRequest>,
) -> AppResult<ApiResponse> {
let client_data = ClientData::from_base64(&req.client_data_json)?;
let auth_state = state.store.pending_authentications.lock().unwrap()
.remove(&client_data.challenge)
.ok_or(AppError("No pending authentication".into()))?;
let credential_id = Passki::base64_decode(&req.credential_id)?;
let mut users = state.store.users.lock().unwrap();
let (username, passkey) = users.iter_mut()
.find_map(|(name, user)| {
user.passkeys.iter_mut()
.find(|pk| pk.credential_id == credential_id)
.map(|pk| (name.clone(), pk))
})
.ok_or(AppError("Unknown credential".into()))?;
let credential = AuthenticationCredential {
credential_id: req.credential_id,
authenticator_data: req.authenticator_data,
client_data_json: req.client_data_json,
signature: req.signature,
};
let result = state.passki.finish_passkey_authentication(&credential, &auth_state, passkey)?;
passkey.counter = result.counter;
Ok(Json(ApiResponse {
success: true,
message: format!("Welcome back, {}!", username),
username: Some(username),
}))
}
#[tokio::main]
async fn main() {
tracing_subscriber::fmt().with_target(false).init();
let state = AppState {
passki: Arc::new(Passki::new(
"localhost",
"http://localhost:3000",
"Passkeys Demo",
)),
store: Store::default(),
};
let app = Router::new()
.route("/", get(index))
.route("/register/start", post(register_start))
.route("/register/finish", post(register_finish))
.route("/auth/start", post(auth_start))
.route("/auth/finish", post(auth_finish))
.layer(CorsLayer::permissive())
.layer(TraceLayer::new_for_http())
.with_state(state);
let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap();
println!("Server starting on http://localhost:3000");
axum::serve(listener, app).await.unwrap();
}