use std::sync::Arc;
use axum::{
extract::State,
http::StatusCode,
response::{IntoResponse, Json, Response},
routing::{get, post},
Form, Router,
};
use base64::{engine::general_purpose::STANDARD as BASE64, Engine as _};
use serde::{Deserialize, Serialize};
use tower_http::cors::{Any, CorsLayer};
use altcha::{
create_challenge, verify_server_signature, verify_solution, CreateChallengeOptions, Payload,
ServerSignaturePayload, ServerSignatureVerificationData, VerifySolutionOptions,
};
use rand::Rng;
#[derive(Clone)]
struct AppState {
hmac_secret: Arc<String>,
hmac_key_secret: Arc<String>,
}
async fn get_challenge(State(state): State<AppState>) -> Response {
let options = CreateChallengeOptions {
algorithm: "PBKDF2/SHA-256".to_string(),
cost: 5_000,
counter: Some(rand::thread_rng().gen_range(5_000..=10_000)),
expires_at: Some(
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs()
+ 600,
),
hmac_signature_secret: Some((*state.hmac_secret).clone()),
hmac_key_signature_secret: Some((*state.hmac_key_secret).clone()),
..Default::default()
};
match create_challenge(options) {
Ok(challenge) => Json(challenge).into_response(),
Err(err) => (
StatusCode::INTERNAL_SERVER_ERROR,
format!("failed to create challenge: {err}"),
)
.into_response(),
}
}
#[derive(Deserialize)]
struct SubmitForm {
altcha: String,
}
#[derive(Serialize)]
struct AltchaResult {
verified: bool,
expired: bool,
invalid_signature: Option<bool>,
invalid_solution: Option<bool>,
time: f64,
#[serde(skip_serializing_if = "Option::is_none")]
verification_data: Option<ServerSignatureVerificationData>,
}
#[derive(Serialize)]
struct SubmitResponse {
ok: bool,
#[serde(skip_serializing_if = "Option::is_none")]
altcha: Option<AltchaResult>,
#[serde(skip_serializing_if = "Option::is_none")]
error: Option<String>,
}
impl SubmitResponse {
fn ok(altcha: AltchaResult) -> Self {
Self { ok: true, altcha: Some(altcha), error: None }
}
fn err(msg: impl Into<String>) -> (StatusCode, Json<Self>) {
(
StatusCode::BAD_REQUEST,
Json(Self { ok: false, altcha: None, error: Some(msg.into()) }),
)
}
}
#[derive(Deserialize)]
#[serde(untagged)]
enum AltchaPayload {
ServerSignature(ServerSignaturePayload),
Client(Payload),
}
async fn post_submit(
State(state): State<AppState>,
Form(form): Form<SubmitForm>,
) -> Response {
let secret = (*state.hmac_secret).as_str();
let bytes = match BASE64.decode(&form.altcha) {
Ok(b) => b,
Err(_) => return SubmitResponse::err("altcha: base64 decode failed").into_response(),
};
let altcha_result = match serde_json::from_slice::<AltchaPayload>(&bytes) {
Ok(AltchaPayload::Client(payload)) => {
match verify_solution(VerifySolutionOptions {
hmac_key_signature_secret: Some((*state.hmac_key_secret).clone()),
..VerifySolutionOptions::new(&payload.challenge, &payload.solution, secret)
}) {
Ok(r) => AltchaResult {
verified: r.verified,
expired: r.expired,
invalid_signature: r.invalid_signature,
invalid_solution: r.invalid_solution,
time: r.time,
verification_data: None,
},
Err(err) => {
return SubmitResponse::err(format!("altcha: {err}")).into_response()
}
}
}
Ok(AltchaPayload::ServerSignature(payload)) => {
match verify_server_signature(&payload, secret) {
Ok(r) => AltchaResult {
verified: r.verified,
expired: r.expired,
invalid_signature: Some(r.invalid_signature),
invalid_solution: Some(r.invalid_solution),
time: r.time,
verification_data: r.verification_data,
},
Err(err) => {
return SubmitResponse::err(format!("altcha: {err}")).into_response()
}
}
}
Err(_) => {
return SubmitResponse::err("altcha: unrecognised payload format").into_response()
}
};
if !altcha_result.verified {
let reason = if altcha_result.expired {
"challenge has expired"
} else if altcha_result.invalid_signature == Some(true) {
"invalid signature"
} else {
"invalid solution"
};
return SubmitResponse::err(format!("altcha: {reason}")).into_response();
}
Json(SubmitResponse::ok(altcha_result)).into_response()
}
#[tokio::main]
async fn main() {
let state = AppState {
hmac_secret: Arc::new("change-me-in-production".to_string()),
hmac_key_secret: Arc::new("change-me-in-production".to_string()),
};
let cors = CorsLayer::new()
.allow_origin(Any)
.allow_methods(Any)
.allow_headers(Any);
let app = Router::new()
.route("/challenge", get(get_challenge))
.route("/submit", post(post_submit))
.layer(cors)
.with_state(state);
let addr = "0.0.0.0:3000";
let listener = tokio::net::TcpListener::bind(addr).await.unwrap();
println!("listening on http://{addr}");
axum::serve(listener, app).await.unwrap();
}