use async_trait::async_trait;
use chrono::{DateTime, Duration, Utc};
use crate::{Error, JwtConfig, Session, SessionToken, UserId, error::SessionError};
use super::provider::SessionProvider;
pub struct JwtSessionProvider {
config: JwtConfig,
}
impl JwtSessionProvider {
pub fn new(config: JwtConfig) -> Self {
Self { config }
}
}
#[async_trait]
impl SessionProvider for JwtSessionProvider {
async fn create_session(
&self,
user_id: &UserId,
user_agent: Option<String>,
ip_address: Option<String>,
duration: Duration,
) -> Result<Session, Error> {
let now = Utc::now();
let expires_at = now + duration;
let session = Session::builder()
.user_id(user_id.clone())
.user_agent(user_agent)
.ip_address(ip_address)
.created_at(now)
.updated_at(now)
.expires_at(expires_at)
.build()?;
let claims =
session.to_jwt_claims(self.config.issuer.clone(), self.config.include_metadata);
let jwt_token = SessionToken::new_jwt(&claims, &self.config)?;
Ok(Session {
token: jwt_token,
..session
})
}
async fn get_session(&self, token: &SessionToken) -> Result<Session, Error> {
let claims = match token.verify_jwt(&self.config) {
Ok(claims) => claims,
Err(Error::Session(SessionError::InvalidToken(msg))) => {
if msg.contains("ExpiredSignature") {
return Err(Error::Session(SessionError::Expired));
}
return Err(Error::Session(SessionError::InvalidToken(msg)));
}
Err(e) => return Err(e),
};
let now = Utc::now();
let exp = DateTime::from_timestamp(claims.exp, 0).unwrap_or(now);
if now > exp {
return Err(Error::Session(SessionError::Expired));
}
let session = Session::from_jwt_claims(token.clone(), &claims);
Ok(session)
}
async fn delete_session(&self, _token: &SessionToken) -> Result<(), Error> {
Ok(())
}
async fn cleanup_expired_sessions(&self) -> Result<(), Error> {
Ok(())
}
async fn delete_sessions_for_user(&self, _user_id: &UserId) -> Result<(), Error> {
tracing::warn!(
"JwtSessionProvider doesn't support revoking all sessions for a user; tokens will remain valid until they expire"
);
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
const TEST_HS256_SECRET: &[u8] = b"test_secret_key_for_hs256_jwt_tokens_not_for_production_use";
#[tokio::test]
async fn test_jwt_session_provider_create_and_get() {
let config = JwtConfig::new_hs256(TEST_HS256_SECRET.to_vec())
.with_issuer("test-issuer")
.with_metadata(true);
let provider = JwtSessionProvider::new(config);
let user_id = UserId::new_random();
let user_agent = Some("test-agent".to_string());
let ip_address = Some("127.0.0.1".to_string());
let duration = Duration::hours(1);
let session = provider
.create_session(&user_id, user_agent.clone(), ip_address.clone(), duration)
.await
.unwrap();
assert_eq!(session.user_id, user_id);
assert_eq!(session.user_agent, user_agent);
assert_eq!(session.ip_address, ip_address);
let retrieved = provider.get_session(&session.token).await.unwrap();
assert_eq!(retrieved.user_id, user_id);
assert_eq!(retrieved.user_agent, user_agent);
assert_eq!(retrieved.ip_address, ip_address);
}
#[tokio::test]
async fn test_jwt_session_provider_expired_session() {
let config = JwtConfig::new_hs256(TEST_HS256_SECRET.to_vec());
let provider = JwtSessionProvider::new(config.clone());
let user_id = UserId::new_random();
let now = Utc::now();
let session = Session::builder()
.user_id(user_id.clone())
.expires_at(now - Duration::minutes(5))
.build()
.unwrap();
let claims = session.to_jwt_claims(None, false);
let token = SessionToken::new_jwt(&claims, &config).unwrap();
let result = provider.get_session(&token).await;
assert!(matches!(result, Err(Error::Session(SessionError::Expired))));
}
#[tokio::test]
async fn test_jwt_session_provider_invalid_token() {
let config = JwtConfig::new_hs256(TEST_HS256_SECRET.to_vec());
let provider = JwtSessionProvider::new(config);
let invalid_token = SessionToken::Jwt("invalid.jwt.token".to_string());
let result = provider.get_session(&invalid_token).await;
assert!(matches!(
result,
Err(Error::Session(SessionError::InvalidToken(_)))
));
}
}