use anyhow::{Result, anyhow};
use chrono::{Duration, Utc};
use jsonwebtoken::{Algorithm, EncodingKey, Header, encode};
use rand::distr::Alphanumeric;
use rand::{Rng, rng};
use serde::{Deserialize, Serialize};
use crate::models::JwtClaims;
use systemprompt_identifiers::{SessionId, UserId};
use systemprompt_models::Config;
use systemprompt_models::auth::{
AuthenticatedUser, JwtAudience, Permission, RateLimitTier, TokenType, UserType,
};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JwtConfig {
pub permissions: Vec<Permission>,
pub audience: Vec<JwtAudience>,
pub expires_in_hours: Option<i64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub resource: Option<String>,
}
#[derive(Debug, Clone)]
pub struct JwtSigningParams<'a> {
pub secret: &'a str,
pub issuer: &'a str,
}
impl Default for JwtConfig {
fn default() -> Self {
Self {
permissions: vec![Permission::User],
audience: JwtAudience::standard(),
expires_in_hours: Some(24),
resource: None,
}
}
}
pub fn generate_secure_token(prefix: &str) -> String {
let mut rng = rng();
let token: String = (0..32)
.map(|_| rng.sample(Alphanumeric))
.map(char::from)
.collect();
format!("{prefix}_{token}")
}
pub fn generate_jwt(
user: &AuthenticatedUser,
config: JwtConfig,
jti: String,
session_id: &SessionId,
signing: &JwtSigningParams<'_>,
) -> Result<String> {
let expires_in_hours = config.expires_in_hours.unwrap_or(24);
if expires_in_hours <= 0 || expires_in_hours > 8760 {
return Err(anyhow!(
"Invalid token expiry: {expires_in_hours} hours. Must be between 1 and 8760 (1 year)"
));
}
let expiration = Utc::now()
.checked_add_signed(Duration::hours(expires_in_hours))
.ok_or_else(|| anyhow!("Failed to calculate token expiration"))?
.timestamp();
let now = Utc::now().timestamp();
let user_type = user.user_type();
let mut audience = config.audience.clone();
if let Some(ref resource) = config.resource {
audience.push(JwtAudience::Resource(resource.clone()));
}
let claims = JwtClaims {
sub: user.id.to_string(),
iat: now,
exp: expiration,
iss: signing.issuer.to_string(),
aud: audience,
jti,
scope: config.permissions,
username: user.username.clone(),
email: user.email.clone(),
user_type,
roles: user.roles().to_vec(),
client_id: None,
token_type: TokenType::Bearer,
auth_time: now,
session_id: Some(session_id.to_string()),
rate_limit_tier: Some(user_type.rate_tier()),
};
let header = Header::new(Algorithm::HS256);
let token = encode(
&header,
&claims,
&EncodingKey::from_secret(signing.secret.as_bytes()),
)?;
Ok(token)
}
pub fn generate_client_secret() -> String {
let mut rng = rng();
let secret: String = (0..64)
.map(|_| rng.sample(Alphanumeric))
.map(char::from)
.collect();
format!("secret_{secret}")
}
pub fn generate_access_token_jti() -> String {
uuid::Uuid::new_v4().to_string()
}
pub fn hash_client_secret(secret: &str) -> Result<String> {
use bcrypt::{DEFAULT_COST, hash};
Ok(hash(secret, DEFAULT_COST)?)
}
pub fn verify_client_secret(secret: &str, hash: &str) -> Result<bool> {
use bcrypt::verify;
Ok(verify(secret, hash)?)
}
pub fn generate_anonymous_jwt(
user_id: &UserId,
session_id: &SessionId,
client_id: &systemprompt_identifiers::ClientId,
signing: &JwtSigningParams<'_>,
) -> Result<String> {
let expires_in_seconds = Config::get()?.jwt_access_token_expiration;
generate_anonymous_jwt_with_expiry(user_id, session_id, client_id, signing, expires_in_seconds)
}
pub fn generate_anonymous_jwt_with_expiry(
user_id: &UserId,
session_id: &SessionId,
client_id: &systemprompt_identifiers::ClientId,
signing: &JwtSigningParams<'_>,
expires_in_seconds: i64,
) -> Result<String> {
let expires_in_hours = expires_in_seconds / 3600;
let expiration = Utc::now()
.checked_add_signed(Duration::hours(expires_in_hours))
.ok_or_else(|| anyhow!("Failed to calculate token expiration"))?
.timestamp();
let now = Utc::now().timestamp();
let claims = JwtClaims {
sub: user_id.to_string(),
iat: now,
exp: expiration,
iss: signing.issuer.to_string(),
aud: JwtAudience::standard(),
jti: uuid::Uuid::new_v4().to_string(),
scope: vec![Permission::Anonymous],
username: user_id.to_string(),
email: user_id.to_string(),
user_type: UserType::Anon,
roles: vec!["anonymous".to_string()],
client_id: Some(client_id.to_string()),
token_type: TokenType::Bearer,
auth_time: now,
session_id: Some(session_id.to_string()),
rate_limit_tier: Some(RateLimitTier::Anon),
};
let header = Header::new(Algorithm::HS256);
let token = encode(
&header,
&claims,
&EncodingKey::from_secret(signing.secret.as_bytes()),
)?;
Ok(token)
}
pub fn generate_admin_jwt(
user_id: &UserId,
session_id: &SessionId,
email: &str,
client_id: &systemprompt_identifiers::ClientId,
signing: &JwtSigningParams<'_>,
) -> Result<String> {
let expires_in_seconds = Config::get()?.jwt_access_token_expiration;
generate_admin_jwt_with_expiry(
user_id,
session_id,
email,
client_id,
signing,
expires_in_seconds,
)
}
#[allow(clippy::too_many_arguments)]
pub fn generate_admin_jwt_with_expiry(
user_id: &UserId,
session_id: &SessionId,
email: &str,
client_id: &systemprompt_identifiers::ClientId,
signing: &JwtSigningParams<'_>,
expires_in_seconds: i64,
) -> Result<String> {
let expires_in_hours = expires_in_seconds / 3600;
let expiration = Utc::now()
.checked_add_signed(Duration::hours(expires_in_hours))
.ok_or_else(|| anyhow!("Failed to calculate token expiration"))?
.timestamp();
let now = Utc::now().timestamp();
let claims = JwtClaims {
sub: user_id.to_string(),
iat: now,
exp: expiration,
iss: signing.issuer.to_string(),
aud: JwtAudience::standard(),
jti: uuid::Uuid::new_v4().to_string(),
scope: vec![Permission::Admin],
username: email.to_string(),
email: email.to_string(),
user_type: UserType::Admin,
roles: vec!["admin".to_string(), "user".to_string()],
client_id: Some(client_id.to_string()),
token_type: TokenType::Bearer,
auth_time: now,
session_id: Some(session_id.to_string()),
rate_limit_tier: Some(RateLimitTier::Admin),
};
let header = Header::new(Algorithm::HS256);
let token = encode(
&header,
&claims,
&EncodingKey::from_secret(signing.secret.as_bytes()),
)?;
Ok(token)
}