use std::net::IpAddr;
use serde::{Deserialize, Serialize};
use thiserror::Error;
use solid_pod_rs::security::rate_limit::{
RateLimitDecision, RateLimitKey, RateLimitSubject, RateLimiter,
};
use crate::jwks::Jwks;
use crate::tokens::{issue_access_token, AccessToken};
use crate::user_store::{User, UserStore};
pub const MIN_PASSWORD_LENGTH: usize = 8;
pub const RATE_LIMIT_ROUTE: &str = "idp_credentials";
#[derive(Debug, Error)]
pub enum LoginError {
#[error("rate limited, retry after {retry_after_secs}s")]
RateLimited {
retry_after_secs: u64,
},
#[error("invalid credentials")]
InvalidGrant,
#[error("password must be at least {min_length} characters")]
PasswordTooShort {
min_length: usize,
},
#[error("invalid request: {0}")]
InvalidRequest(String),
#[error("user store: {0}")]
UserStore(String),
#[error("token issuance: {0}")]
Token(String),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CredentialsResponse {
pub access_token: String,
pub token_type: String,
pub expires_in: u64,
pub webid: String,
pub id: String,
}
#[allow(clippy::too_many_arguments)]
pub async fn login(
email: &str,
password: &str,
user_store: &dyn UserStore,
jwks: &Jwks,
issuer: &str,
dpop_jkt: Option<&str>,
limiter: &dyn RateLimiter,
ip: IpAddr,
now: u64,
ttl_secs: u64,
) -> Result<CredentialsResponse, LoginError> {
let key = RateLimitKey {
route: RATE_LIMIT_ROUTE,
subject: RateLimitSubject::Ip(ip),
};
match limiter.check(&key).await {
RateLimitDecision::Allow => {}
RateLimitDecision::Deny {
retry_after_secs, ..
} => return Err(LoginError::RateLimited { retry_after_secs }),
}
if email.is_empty() || password.is_empty() {
return Err(LoginError::InvalidRequest(
"email and password are required".into(),
));
}
let user: Option<User> = user_store
.find_by_email(email)
.await
.map_err(|e| LoginError::UserStore(e.to_string()))?;
let Some(user) = user else {
return Err(LoginError::InvalidGrant);
};
let ok = user_store
.verify_password(&user, password)
.await
.map_err(|e| LoginError::UserStore(e.to_string()))?;
if !ok {
return Err(LoginError::InvalidGrant);
}
let key = jwks.active_key();
let token: AccessToken = issue_access_token(
&key,
issuer,
&user.webid,
&user.id,
"credentials_client", "openid webid",
dpop_jkt,
now,
ttl_secs,
)
.map_err(|e| LoginError::Token(e.to_string()))?;
Ok(CredentialsResponse {
access_token: token.jwt,
token_type: if dpop_jkt.is_some() {
"DPoP".into()
} else {
"Bearer".into()
},
expires_in: ttl_secs,
webid: user.webid,
id: user.id,
})
}
pub fn validate_password_length(password: &str) -> Result<(), LoginError> {
if password.len() < MIN_PASSWORD_LENGTH {
return Err(LoginError::PasswordTooShort {
min_length: MIN_PASSWORD_LENGTH,
});
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use std::net::Ipv4Addr;
use std::time::Duration;
use solid_pod_rs::security::rate_limit::LruRateLimiter;
use crate::jwks::Jwks;
use crate::user_store::InMemoryUserStore;
fn ip() -> IpAddr {
IpAddr::V4(Ipv4Addr::new(203, 0, 113, 1))
}
fn seed() -> (InMemoryUserStore, Jwks, LruRateLimiter) {
let store = InMemoryUserStore::new();
store
.insert_user(
"acct-1",
"alice@example.com",
"https://alice.example/profile#me",
Some("Alice".into()),
"hunter2!",
)
.unwrap();
let jwks = Jwks::generate_es256().unwrap();
let limiter = LruRateLimiter::with_policy(vec![(
RATE_LIMIT_ROUTE.to_string(),
10,
Duration::from_secs(60),
)]);
(store, jwks, limiter)
}
#[tokio::test]
async fn login_succeeds_with_correct_password() {
let (store, jwks, limiter) = seed();
let resp = login(
"alice@example.com",
"hunter2!",
&store,
&jwks,
"https://pod.example/",
Some("JKT-OK"),
&limiter,
ip(),
1_700_000_000,
3600,
)
.await
.unwrap();
assert_eq!(resp.token_type, "DPoP");
assert_eq!(resp.webid, "https://alice.example/profile#me");
assert_eq!(resp.expires_in, 3600);
assert!(resp.access_token.contains('.'));
}
#[tokio::test]
async fn login_returns_bearer_when_no_dpop() {
let (store, jwks, limiter) = seed();
let resp = login(
"alice@example.com",
"hunter2!",
&store,
&jwks,
"https://pod.example/",
None,
&limiter,
ip(),
1_700_000_000,
3600,
)
.await
.unwrap();
assert_eq!(resp.token_type, "Bearer");
}
#[tokio::test]
async fn login_rejects_wrong_password() {
let (store, jwks, limiter) = seed();
let err = login(
"alice@example.com",
"nope",
&store,
&jwks,
"https://pod.example/",
None,
&limiter,
ip(),
1_700_000_000,
3600,
)
.await
.unwrap_err();
assert!(matches!(err, LoginError::InvalidGrant));
}
#[tokio::test]
async fn login_rejects_unknown_user() {
let (store, jwks, limiter) = seed();
let err = login(
"nobody@example.com",
"hunter2!",
&store,
&jwks,
"https://pod.example/",
None,
&limiter,
ip(),
1_700_000_000,
3600,
)
.await
.unwrap_err();
assert!(matches!(err, LoginError::InvalidGrant));
}
#[tokio::test]
async fn login_rate_limited_after_ten_attempts() {
let (store, jwks, limiter) = seed();
for _ in 0..10 {
let _ = login(
"alice@example.com",
"wrong",
&store,
&jwks,
"https://pod.example/",
None,
&limiter,
ip(),
1_700_000_000,
3600,
)
.await;
}
let err = login(
"alice@example.com",
"hunter2!",
&store,
&jwks,
"https://pod.example/",
None,
&limiter,
ip(),
1_700_000_000,
3600,
)
.await
.unwrap_err();
match err {
LoginError::RateLimited { retry_after_secs } => {
assert!(retry_after_secs >= 1);
}
other => panic!("expected RateLimited, got {other:?}"),
}
}
#[tokio::test]
async fn login_rejects_blank_input() {
let (store, jwks, limiter) = seed();
let err = login(
"",
"",
&store,
&jwks,
"https://pod.example/",
None,
&limiter,
ip(),
0,
3600,
)
.await
.unwrap_err();
assert!(matches!(err, LoginError::InvalidRequest(_)));
}
#[test]
fn password_too_short_7_chars_rejected() {
let err = validate_password_length("1234567").unwrap_err();
match err {
LoginError::PasswordTooShort { min_length } => {
assert_eq!(min_length, 8);
}
other => panic!("expected PasswordTooShort, got {other:?}"),
}
}
#[test]
fn password_exactly_8_chars_accepted() {
validate_password_length("12345678").unwrap();
}
#[test]
fn password_longer_than_8_chars_accepted() {
validate_password_length("a]9Kz!#mN@xP").unwrap();
}
#[test]
fn empty_password_rejected() {
let err = validate_password_length("").unwrap_err();
match err {
LoginError::PasswordTooShort { min_length } => {
assert_eq!(min_length, 8);
}
other => panic!("expected PasswordTooShort, got {other:?}"),
}
}
#[test]
fn min_password_length_constant_is_8() {
assert_eq!(MIN_PASSWORD_LENGTH, 8);
}
}