use crate::persistence::sql::{
session::SessionRepository,
signup_code::{SignupCodeId, SignupCodeRepository},
uexecutor,
user::{UserEntity, UserRepository},
};
use crate::shared::{HttpError, HttpResult};
use crate::{client_server::AppState, SignupMode};
use axum::{
extract::{Query, State},
http::StatusCode,
http::{header, HeaderValue},
response::IntoResponse,
};
use axum_extra::extract::Host;
use bytes::Bytes;
use pubky_common::capabilities::Capabilities;
use pubky_common::crypto::PublicKey;
use pubky_common::session::SessionInfo;
use std::collections::HashMap;
use tower_cookies::{
cookie::time::{Duration, OffsetDateTime},
cookie::SameSite,
Cookie, Cookies,
};
use crate::shared::user_quota::UserQuota;
const SESSION_EXPIRY_DAYS: i64 = 365;
pub async fn signup(
State(state): State<AppState>,
cookies: Cookies,
Host(host): Host,
Query(params): Query<HashMap<String, String>>, body: Bytes,
) -> HttpResult<impl IntoResponse> {
let token = state.verifier.verify(&body)?;
let public_key = token.public_key();
let mut tx = state.sql_db.pool().begin().await?;
match UserRepository::get(public_key, uexecutor!(tx)).await {
Ok(_) => {
return Err(HttpError::new_with_message(
StatusCode::CONFLICT,
"User already exists",
));
}
Err(sqlx::Error::RowNotFound) => {
}
Err(e) => {
return Err(e.into());
}
}
let user = if state.signup_mode == SignupMode::TokenRequired {
let signup_token_param = params
.get("signup_token")
.ok_or(HttpError::new_with_message(
StatusCode::BAD_REQUEST,
"Token required",
))?;
let signup_code_id = SignupCodeId::new(signup_token_param.clone()).map_err(|e| {
HttpError::new_with_message(
StatusCode::BAD_REQUEST,
format!("Invalid signup token format: {}", e),
)
})?;
let code = match SignupCodeRepository::get(&signup_code_id, uexecutor!(tx)).await {
Ok(code) => code,
Err(sqlx::Error::RowNotFound) => {
return Err(HttpError::new_with_message(
StatusCode::UNAUTHORIZED,
"Invalid token",
));
}
Err(e) => {
return Err(e.into());
}
};
if code.used_by.is_some() {
return Err(HttpError::new_with_message(
StatusCode::UNAUTHORIZED,
"Token already used",
));
}
match SignupCodeRepository::mark_as_used(&signup_code_id, public_key, uexecutor!(tx)).await
{
Ok(_) => {}
Err(sqlx::Error::RowNotFound) => {
return Err(HttpError::new_with_message(
StatusCode::UNAUTHORIZED,
"Token already used",
));
}
Err(e) => return Err(e.into()),
}
let limits = code.quota();
state
.user_service
.create_user(public_key, &limits, tx)
.await?
} else {
state
.user_service
.create_user(public_key, &UserQuota::default(), tx)
.await?
};
state.metrics.record_signup();
create_session_and_cookie(&state, cookies, &host, &user, token.capabilities()).await
}
pub async fn signin(
State(state): State<AppState>,
cookies: Cookies,
Host(host): Host,
body: Bytes,
) -> HttpResult<impl IntoResponse> {
let token = state.verifier.verify(&body)?;
let public_key = token.public_key();
let user = state
.user_service
.get_or_http_error(public_key, false)
.await?;
create_session_and_cookie(&state, cookies, &host, &user, token.capabilities()).await
}
async fn create_session_and_cookie(
state: &AppState,
cookies: Cookies,
host: &str,
user: &UserEntity,
capabilities: &Capabilities,
) -> HttpResult<impl IntoResponse> {
let session_secret =
SessionRepository::create(user.id, capabilities, &mut state.sql_db.pool().into()).await?;
let mut cookie = Cookie::new(user.public_key.z32(), session_secret.to_string());
configure_session_cookie(&mut cookie, host);
let one_year = Duration::days(SESSION_EXPIRY_DAYS);
let expiry = OffsetDateTime::now_utc() + one_year;
cookie.set_max_age(one_year);
cookie.set_expires(expiry);
cookies.add(cookie);
let session = SessionInfo::new(&user.public_key, capabilities.clone(), None);
let mut resp = session.serialize().into_response();
resp.headers_mut().insert(
header::CONTENT_TYPE,
HeaderValue::from_static("application/octet-stream"),
);
Ok(resp)
}
pub(crate) fn configure_session_cookie(cookie: &mut Cookie<'static>, host: &str) {
cookie.set_path("/");
if is_secure(host) {
cookie.set_secure(true);
cookie.set_same_site(SameSite::None);
}
cookie.set_http_only(true);
}
fn is_secure(host: &str) -> bool {
if PublicKey::try_from_z32(host).is_ok() {
return true;
}
url::Host::parse(host)
.map(|host| match host {
url::Host::Domain(domain) => domain.contains('.'),
url::Host::Ipv4(_) | url::Host::Ipv6(_) => false,
})
.unwrap_or(false) }
#[cfg(test)]
mod tests {
use pubky_common::crypto::Keypair;
use super::*;
#[test]
fn test_is_secure() {
assert!(!is_secure(""));
assert!(!is_secure("127.0.0.1"));
assert!(!is_secure("homeserver"));
assert!(!is_secure("testnet"));
assert!(!is_secure("167.86.102.121"));
assert!(!is_secure("[2001:0db8:0000:0000:0000:ff00:0042:8329]"));
assert!(!is_secure("localhost"));
assert!(!is_secure("localhost:23423"));
assert!(is_secure(&Keypair::random().public_key().z32()));
assert!(is_secure("example.com"));
}
}