use crate::ids::types::{IdsError, IdsResult, IdsUri, SecurityProfile};
use chrono::{DateTime, Duration, Utc};
use jsonwebtoken::{
decode, encode, Algorithm, DecodingKey, EncodingKey, Header, TokenData, Validation,
};
use ring::rand::SystemRandom;
use ring::signature::{Ed25519KeyPair, KeyPair};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use tokio::sync::RwLock;
use uuid::Uuid;
pub struct DapsClient {
daps_url: String,
credentials: Option<DapsCredentials>,
daps_public_key: Arc<RwLock<Option<Vec<u8>>>>,
}
pub struct DapsCredentials {
connector_id: IdsUri,
key_pair: Ed25519KeyPair,
}
impl DapsCredentials {
pub fn new(connector_id: IdsUri) -> IdsResult<Self> {
let rng = SystemRandom::new();
let pkcs8_bytes = Ed25519KeyPair::generate_pkcs8(&rng)
.map_err(|e| IdsError::InternalError(format!("Failed to generate key pair: {}", e)))?;
let key_pair = Ed25519KeyPair::from_pkcs8(pkcs8_bytes.as_ref())
.map_err(|e| IdsError::InternalError(format!("Failed to parse key pair: {}", e)))?;
Ok(Self {
connector_id,
key_pair,
})
}
pub fn from_pkcs8(connector_id: IdsUri, pkcs8_bytes: &[u8]) -> IdsResult<Self> {
let key_pair = Ed25519KeyPair::from_pkcs8(pkcs8_bytes)
.map_err(|e| IdsError::InternalError(format!("Failed to parse key pair: {}", e)))?;
Ok(Self {
connector_id,
key_pair,
})
}
pub fn public_key(&self) -> &[u8] {
self.key_pair.public_key().as_ref()
}
}
impl DapsClient {
pub fn new(daps_url: impl Into<String>) -> Self {
Self {
daps_url: daps_url.into(),
credentials: None,
daps_public_key: Arc::new(RwLock::new(None)),
}
}
pub fn with_credentials(daps_url: impl Into<String>, credentials: DapsCredentials) -> Self {
Self {
daps_url: daps_url.into(),
credentials: Some(credentials),
daps_public_key: Arc::new(RwLock::new(None)),
}
}
pub fn daps_url(&self) -> &str {
&self.daps_url
}
pub async fn get_token(&self, connector_id: &IdsUri) -> IdsResult<DapsToken> {
let client_assertion = self.create_client_assertion(connector_id)?;
let request_body = serde_json::json!({
"grant_type": "client_credentials",
"client_assertion_type": "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
"client_assertion": client_assertion,
"scope": "idsc:IDS_CONNECTOR_ATTRIBUTES_ALL"
});
let client = reqwest::Client::new();
let response = client
.post(format!("{}/token", self.daps_url))
.json(&request_body)
.send()
.await
.map_err(|e| IdsError::DapsAuthFailed(format!("DAPS request failed: {}", e)))?;
if !response.status().is_success() {
let status = response.status();
let error_text = response
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
return Err(IdsError::DapsAuthFailed(format!(
"DAPS returned {}: {}",
status, error_text
)));
}
let token_response: DapsTokenResponse = response.json().await.map_err(|e| {
IdsError::DapsAuthFailed(format!("Failed to parse DAPS response: {}", e))
})?;
Ok(DapsToken {
access_token: token_response.access_token,
token_type: token_response.token_type,
expires_at: Utc::now() + Duration::seconds(token_response.expires_in as i64),
scope: token_response
.scope
.split_whitespace()
.map(String::from)
.collect(),
})
}
fn create_client_assertion(&self, connector_id: &IdsUri) -> IdsResult<String> {
let now = Utc::now();
let exp = now + Duration::minutes(5);
let claims = ClientAssertionClaims {
iss: connector_id.as_str().to_string(),
sub: connector_id.as_str().to_string(),
aud: self.daps_url.clone(),
jti: Uuid::new_v4().to_string(),
iat: now.timestamp(),
exp: exp.timestamp(),
nbf: now.timestamp(),
};
if let Some(ref credentials) = self.credentials {
let header = Header {
alg: Algorithm::EdDSA,
..Default::default()
};
let encoding_key = EncodingKey::from_ed_der(credentials.key_pair.public_key().as_ref());
encode(&header, &claims, &encoding_key).map_err(|e| {
IdsError::DapsAuthFailed(format!("Failed to create client assertion: {}", e))
})
} else {
let header = Header {
alg: Algorithm::HS256,
..Default::default()
};
let encoding_key = EncodingKey::from_secret(b"development-only-key");
encode(&header, &claims, &encoding_key).map_err(|e| {
IdsError::DapsAuthFailed(format!("Failed to create client assertion: {}", e))
})
}
}
pub fn validate_token(&self, token: &str) -> IdsResult<DapsTokenClaims> {
self.validate_token_with_options(token, &TokenValidationOptions::default())
}
pub fn validate_token_with_options(
&self,
token: &str,
options: &TokenValidationOptions,
) -> IdsResult<DapsTokenClaims> {
let mut validation = Validation::new(Algorithm::RS256);
validation.validate_exp = options.check_expiration;
validation.validate_nbf = true;
if let Some(ref expected_issuer) = options.expected_issuer {
validation.set_issuer(&[expected_issuer.as_str()]);
} else {
validation.set_issuer(&[&self.daps_url]);
}
if let Some(ref expected_audience) = options.expected_audience {
validation.set_audience(&[expected_audience.as_str()]);
}
let token_data: TokenData<DapsTokenClaims> = if options.skip_signature_verification {
let mut unsafe_validation = Validation::new(Algorithm::RS256);
#[allow(deprecated)]
unsafe_validation.insecure_disable_signature_validation();
unsafe_validation.validate_exp = options.check_expiration;
decode(
token,
&DecodingKey::from_secret(b"unused"),
&unsafe_validation,
)
.map_err(|e| IdsError::InvalidToken(format!("Failed to decode token: {}", e)))?
} else {
let public_key = options.daps_public_key.as_ref().ok_or_else(|| {
IdsError::InvalidToken("DAPS public key required for validation".to_string())
})?;
let decoding_key = DecodingKey::from_rsa_pem(public_key)
.map_err(|e| IdsError::InvalidToken(format!("Invalid DAPS public key: {}", e)))?;
decode(token, &decoding_key, &validation)
.map_err(|e| IdsError::InvalidToken(format!("Token validation failed: {}", e)))?
};
let claims = token_data.claims;
if options.check_expiration {
let now = Utc::now().timestamp();
if claims.exp < now {
return Err(IdsError::InvalidToken("Token has expired".to_string()));
}
}
if options.validate_issuer && claims.iss != self.daps_url {
return Err(IdsError::InvalidToken(format!(
"Invalid issuer: expected {}, got {}",
self.daps_url, claims.iss
)));
}
Ok(claims)
}
pub fn decode_token_unverified(&self, token: &str) -> IdsResult<DapsTokenClaims> {
let mut validation = Validation::new(Algorithm::RS256);
#[allow(deprecated)]
validation.insecure_disable_signature_validation();
validation.validate_exp = false;
let token_data: TokenData<DapsTokenClaims> =
decode(token, &DecodingKey::from_secret(b"unused"), &validation)
.map_err(|e| IdsError::InvalidToken(format!("Failed to decode token: {}", e)))?;
Ok(token_data.claims)
}
pub async fn fetch_daps_public_key(&self) -> IdsResult<Vec<u8>> {
let client = reqwest::Client::new();
let response = client
.get(format!("{}/.well-known/jwks.json", self.daps_url))
.send()
.await
.map_err(|e| {
IdsError::DapsAuthFailed(format!("Failed to fetch DAPS public key: {}", e))
})?;
if !response.status().is_success() {
return Err(IdsError::DapsAuthFailed(
"Failed to fetch DAPS public key".to_string(),
));
}
let jwks: serde_json::Value = response.json().await.map_err(|e| {
IdsError::DapsAuthFailed(format!("Failed to parse JWKS response: {}", e))
})?;
let key = jwks
.get("keys")
.and_then(|k| k.get(0))
.ok_or_else(|| IdsError::DapsAuthFailed("No keys in JWKS".to_string()))?;
let key_bytes = serde_json::to_vec(key)
.map_err(|e| IdsError::DapsAuthFailed(format!("Failed to serialize key: {}", e)))?;
let mut cached = self.daps_public_key.write().await;
*cached = Some(key_bytes.clone());
Ok(key_bytes)
}
}
#[derive(Debug, Clone)]
pub struct TokenValidationOptions {
pub check_expiration: bool,
pub validate_issuer: bool,
pub expected_issuer: Option<String>,
pub expected_audience: Option<String>,
pub skip_signature_verification: bool,
pub daps_public_key: Option<Vec<u8>>,
}
impl Default for TokenValidationOptions {
fn default() -> Self {
Self {
check_expiration: true,
validate_issuer: true,
expected_issuer: None,
expected_audience: None,
skip_signature_verification: false,
daps_public_key: None,
}
}
}
impl TokenValidationOptions {
pub fn development() -> Self {
Self {
check_expiration: false,
validate_issuer: false,
skip_signature_verification: true,
..Default::default()
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct ClientAssertionClaims {
iss: String,
sub: String,
aud: String,
jti: String,
iat: i64,
exp: i64,
nbf: i64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DapsTokenResponse {
pub access_token: String,
pub token_type: String,
pub expires_in: u64,
pub scope: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DapsToken {
pub access_token: String,
pub token_type: String,
pub expires_at: DateTime<Utc>,
pub scope: Vec<String>,
}
impl DapsToken {
pub fn is_expired(&self) -> bool {
Utc::now() > self.expires_at
}
pub fn time_until_expiry(&self) -> Duration {
self.expires_at - Utc::now()
}
pub fn expires_within(&self, duration: Duration) -> bool {
Utc::now() + duration > self.expires_at
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DapsTokenClaims {
pub sub: String,
pub iss: String,
#[serde(default)]
pub aud: Vec<String>,
pub exp: i64,
pub iat: i64,
#[serde(default)]
pub nbf: Option<i64>,
#[serde(default)]
pub jti: Option<String>,
#[serde(default)]
pub scope: Vec<String>,
#[serde(rename = "securityProfile", default)]
pub security_profile: String,
#[serde(rename = "@type", default)]
pub connector_type: Option<String>,
#[serde(rename = "extendedGuarantee", default)]
pub extended_guarantee: Option<String>,
#[serde(rename = "transportCertsSha256", default)]
pub transport_certs_sha256: Option<Vec<String>>,
#[serde(rename = "referringConnector", default)]
pub referring_connector: Option<String>,
}
impl DapsTokenClaims {
pub fn is_expired(&self) -> bool {
Utc::now().timestamp() > self.exp
}
pub fn get_security_profile(&self) -> SecurityProfile {
match self.security_profile.as_str() {
"idsc:BASE_SECURITY_PROFILE" | "BASE_SECURITY_PROFILE" => {
SecurityProfile::BaseSecurityProfile
}
"idsc:TRUST_SECURITY_PROFILE" | "TRUST_SECURITY_PROFILE" => {
SecurityProfile::TrustSecurityProfile
}
"idsc:TRUST_PLUS_SECURITY_PROFILE" | "TRUST_PLUS_SECURITY_PROFILE" => {
SecurityProfile::TrustPlusSecurityProfile
}
_ => SecurityProfile::BaseSecurityProfile,
}
}
pub fn has_scope(&self, required_scope: &str) -> bool {
self.scope.iter().any(|s| s == required_scope)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_daps_client_creation() {
let client = DapsClient::new("https://daps.example.org");
assert_eq!(client.daps_url(), "https://daps.example.org");
}
#[test]
fn test_daps_token_expiration() {
let token = DapsToken {
access_token: "test_token".to_string(),
token_type: "Bearer".to_string(),
expires_at: Utc::now() + Duration::hours(1),
scope: vec!["idsc:IDS_CONNECTOR_ATTRIBUTES_ALL".to_string()],
};
assert!(!token.is_expired());
assert!(!token.expires_within(Duration::minutes(30)));
assert!(token.expires_within(Duration::hours(2)));
let expired_token = DapsToken {
access_token: "test_token".to_string(),
token_type: "Bearer".to_string(),
expires_at: Utc::now() - Duration::hours(1),
scope: vec![],
};
assert!(expired_token.is_expired());
}
#[test]
fn test_daps_claims_security_profile() {
let claims = DapsTokenClaims {
sub: "urn:ids:connector:example".to_string(),
iss: "https://daps.example.org".to_string(),
aud: vec!["idsc:IDS_CONNECTORS_ALL".to_string()],
exp: (Utc::now() + Duration::hours(1)).timestamp(),
iat: Utc::now().timestamp(),
nbf: None,
jti: None,
scope: vec!["idsc:IDS_CONNECTOR_ATTRIBUTES_ALL".to_string()],
security_profile: "idsc:TRUST_SECURITY_PROFILE".to_string(),
connector_type: None,
extended_guarantee: None,
transport_certs_sha256: None,
referring_connector: None,
};
assert_eq!(
claims.get_security_profile(),
SecurityProfile::TrustSecurityProfile
);
assert!(claims.has_scope("idsc:IDS_CONNECTOR_ATTRIBUTES_ALL"));
assert!(!claims.has_scope("idsc:SOME_OTHER_SCOPE"));
}
#[test]
fn test_client_assertion_creation() {
let client = DapsClient::new("https://daps.example.org");
let connector_id = IdsUri::new("urn:ids:connector:test").expect("valid URI");
let assertion = client.create_client_assertion(&connector_id);
assert!(assertion.is_ok());
let token = assertion.expect("assertion");
let parts: Vec<&str> = token.split('.').collect();
assert_eq!(parts.len(), 3);
}
#[test]
fn test_token_validation_options() {
let default_opts = TokenValidationOptions::default();
assert!(default_opts.check_expiration);
assert!(default_opts.validate_issuer);
assert!(!default_opts.skip_signature_verification);
let dev_opts = TokenValidationOptions::development();
assert!(!dev_opts.check_expiration);
assert!(!dev_opts.validate_issuer);
assert!(dev_opts.skip_signature_verification);
}
#[test]
fn test_credentials_creation() {
let connector_id = IdsUri::new("urn:ids:connector:test").expect("valid URI");
let credentials = DapsCredentials::new(connector_id);
assert!(credentials.is_ok());
let creds = credentials.expect("credentials");
assert!(!creds.public_key().is_empty());
}
}