1#![deny(unsafe_code, rust_2018_idioms, clippy::unwrap_used)]
63#![warn(rust_2024_compatibility, clippy::pedantic)]
64#![allow(clippy::missing_errors_doc, clippy::missing_panics_doc)]
65
66mod cache;
67pub mod cert;
68pub mod config;
69mod http_client;
70
71#[derive(Debug, serde::Deserialize, Clone)]
73#[serde(untagged)]
74#[allow(dead_code)]
75enum Audience {
76 Single(String),
77 Multiple(Vec<String>),
78}
79
80#[derive(Debug, serde::Serialize, Clone)]
81struct TokenClaims {
82 #[serde(rename = "@type")]
83 type_: String,
84 #[serde(rename = "@context")]
85 context_: String,
86 iss: String,
87 sub: String,
88 id: String,
89 jti: String,
90 aud: String,
91 iat: i64,
92 exp: i64,
93 nbf: i64,
94}
95
96#[derive(Debug, serde::Deserialize, Clone)]
98pub struct TokenResponse {
99 pub access_token: String,
100 pub token_type: String,
101 pub expires_in: u64,
102 pub scope: Option<String>,
103}
104
105#[derive(Debug, serde::Deserialize, Clone)]
107#[serde(rename_all = "camelCase")]
108#[allow(dead_code)]
109pub struct DatClaims {
110 #[serde(rename = "@type")]
111 type_: String,
112 #[serde(rename = "@context")]
113 context_: String,
114 referring_connector: String,
115 security_profile: String,
116 #[serde(rename = "iat")]
117 issued_at: i64,
118 #[serde(rename = "exp")]
119 expires_at: i64,
120 #[serde(rename = "nbf")]
121 not_before: i64,
122 #[serde(rename = "sub")]
123 subject: String,
124 #[serde(rename = "aud")]
125 audience: Audience,
126 #[serde(rename = "iss")]
127 issuer: String,
128 #[serde(rename = "jti")]
129 jwt_id: String,
130}
131
132#[derive(thiserror::Error, Debug)]
133pub enum DapsError {
134 #[error("http client error: {0}")]
135 DapsHttpClient(#[from] http_client::DapsHttpClientError),
136 #[error("jwt error")]
137 InvalidToken,
138 #[error("cache error: {0}")]
139 CacheError(#[from] cache::CertificatesCacheError),
140}
141
142pub type ReqwestDapsClient = DapsClient<http_client::reqwest_client::ReqwestDapsClient>;
144
145pub struct DapsClient<C> {
148 client: C,
150 sub: String,
152 certs_url: String,
154 token_url: String,
156 scope: String,
158 encoding_key: jsonwebtoken::EncodingKey,
160 uuid_context: uuid::ContextV7,
162 certs_cache: cache::CertificatesCache,
164}
165
166impl<C> DapsClient<C>
167where
168 C: http_client::DapsClientRequest,
169{
170 #[must_use]
172 pub fn new(config: &config::DapsConfig<'_>) -> Self {
173 let (ski_aki, private_key) = cert::ski_aki_and_private_key_from_file(
175 config.private_key.as_ref(),
176 config.private_key_password.as_deref().unwrap_or(""),
177 )
178 .expect("Reading SKI:AKI failed");
179
180 let encoding_key = jsonwebtoken::EncodingKey::from_rsa_der(private_key.as_ref());
182
183 Self {
184 client: C::default(),
185 sub: ski_aki.to_string(),
186 scope: config.scope.to_string(),
187 certs_url: config.certs_url.to_string(),
188 token_url: config.token_url.to_string(),
189 encoding_key,
190 uuid_context: uuid::ContextV7::new(),
191 certs_cache: cache::CertificatesCache::new(std::time::Duration::from_secs(
192 config.certs_cache_ttl,
193 )),
194 }
195 }
196
197 pub async fn validate_dat(
199 &self,
200 token: &str,
201 ) -> Result<jsonwebtoken::TokenData<DatClaims>, DapsError> {
202 let jwks = self.get_certs().await?;
204
205 let mut validation = jsonwebtoken::Validation::new(jsonwebtoken::Algorithm::RS256);
207 validation.sub = Some(self.sub.to_string());
208 validation.set_audience(&["idsc:IDS_CONNECTORS_ALL"]);
209 validation.set_required_spec_claims(&["exp", "nbf", "aud", "iss", "sub"]);
210
211 let validation_results: Vec<jsonwebtoken::TokenData<_>> = jwks
213 .keys
214 .iter()
215 .filter_map(|jwk| {
216 if let Ok(jwk) = jsonwebtoken::DecodingKey::from_jwk(jwk) {
217 let result = jsonwebtoken::decode(token, &jwk, &validation);
218 tracing::debug!("Validation result: {:?}", result);
219 result.ok()
220 } else {
221 None
222 }
223 })
224 .collect();
225
226 validation_results
228 .first()
229 .ok_or(DapsError::InvalidToken)
230 .cloned()
231 }
232
233 pub async fn request_dat(&self) -> Result<String, DapsError> {
235 let now = chrono::Utc::now();
237 let now_secs = now.timestamp();
238 let now_subsec_nanos = now.timestamp_subsec_nanos();
239 #[allow(clippy::cast_sign_loss)]
240 let uuid_timestamp =
241 uuid::Timestamp::from_unix(&self.uuid_context, now_secs as u64, now_subsec_nanos);
242
243 let header = jsonwebtoken::Header::new(jsonwebtoken::Algorithm::RS256);
245 let claims = TokenClaims {
246 context_: "https://w3id.org/idsa/contexts/context.jsonld".to_string(),
247 type_: "ids:DatRequestToken".to_string(),
248 jti: uuid::Uuid::new_v7(uuid_timestamp).hyphenated().to_string(),
249 iss: self.sub.to_string(),
250 sub: self.sub.to_string(),
251 id: self.sub.to_string(),
252 aud: self.scope.to_string(),
253 iat: now_secs,
254 exp: now_secs + 3600,
255 nbf: now_secs,
256 };
257 let token = jsonwebtoken::encode(&header, &claims, &self.encoding_key)
259 .expect("Token signing failed. There must be something wrong with the private key.");
260
261 tracing::debug!("Issued TokenRequest (requestDAT): {}", token);
262
263 let response = self
264 .client
265 .request_token(
266 self.token_url.as_ref(),
267 &[
268 ("grant_type", "client_credentials"),
269 ("scope", "idsc:IDS_CONNECTOR_ATTRIBUTES_ALL"),
270 (
271 "client_assertion_type",
272 "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
273 ),
274 ("client_assertion", &token),
275 ],
276 )
277 .await?;
278
279 Ok(response.access_token)
280 }
281
282 pub async fn get_jwks(&self) -> Result<jsonwebtoken::jwk::JwkSet, DapsError> {
284 self.get_certs().await
285 }
286
287 async fn update_cert_cache(&self) -> Result<jsonwebtoken::jwk::JwkSet, DapsError> {
289 let jwks = self.client.get_certs(self.certs_url.as_ref()).await?;
290 self.certs_cache
291 .update(jwks.clone())
292 .await
293 .map_err(DapsError::from)
294 }
295
296 async fn get_certs(&self) -> Result<jsonwebtoken::jwk::JwkSet, DapsError> {
298 tracing::debug!("Checking cache...");
299
300 match self.certs_cache.get().await {
301 Ok(cert) => {
302 tracing::debug!("Cache is up-to-date");
303 Ok(cert)
304 }
305 Err(cache::CertificatesCacheError::Outdated) => {
306 tracing::info!("Cache is outdated, updating...");
307 self.update_cert_cache().await
308 }
309 Err(cache::CertificatesCacheError::Empty) => {
310 tracing::info!("Cache is empty, updating...");
311 self.update_cert_cache().await
312 }
313 }
314 }
315}
316
317#[cfg(test)]
318mod test {
319 use super::*;
320
321 #[tokio::test]
322 async fn integration_test() {
323 use testcontainers::runners::AsyncRunner;
324
325 tracing_subscriber::fmt()
327 .with_env_filter(tracing_subscriber::EnvFilter::new("ids_daps=DEBUG"))
328 .init();
329
330 let image = testcontainers::GenericImage::new("ghcr.io/ids-basecamp/daps", "test");
332 let container = image
333 .with_exposed_port(4567.into()) .with_wait_for(testcontainers::core::WaitFor::message_on_stdout(
335 "Listening on 0.0.0.0:4567, CTRL+C to stop",
336 ))
337 .start()
338 .await
339 .expect("Failed to start DAPS container. Is Docker running?");
340
341 let host = container.get_host().await.expect("Failed to get host");
343 let host_port = container
344 .get_host_port_ipv4(4567)
345 .await
346 .expect("Failed to get port");
347
348 let certs_url = format!("http://{host}:{host_port}/jwks.json");
350 let token_url = format!("http://{host}:{host_port}/token");
351
352 let config = config::DapsConfigBuilder::default()
354 .certs_url(certs_url)
355 .token_url(token_url)
356 .private_key(std::path::Path::new("./testdata/connector-certificate.p12"))
357 .private_key_password(Some(std::borrow::Cow::from("Password1")))
358 .scope(std::borrow::Cow::from("idsc:IDS_CONNECTORS_ALL"))
359 .certs_cache_ttl(1_u64)
360 .build()
361 .expect("Failed to build DAPS-Config");
362
363 let client: ReqwestDapsClient = DapsClient::new(&config);
365
366 let dat = client.request_dat().await.unwrap();
369 tracing::info!("DAT Token: {:?}", dat);
370
371 let cache1_start = std::time::Instant::now();
373 if let Err(err) = client.validate_dat(&dat).await {
374 tracing::error!("Validation failed: {:?}", err);
375 panic!("Validation failed");
376 } else {
377 assert!(client.validate_dat(&dat).await.is_ok());
378 }
379 tracing::debug!("First validation took {:?}", cache1_start.elapsed());
380
381 let cache2_start = std::time::Instant::now();
383 assert!(client.validate_dat(&dat).await.is_ok());
384 tracing::debug!("Second validation took {:?}", cache2_start.elapsed());
385
386 tokio::time::sleep(std::time::Duration::from_secs(2)).await;
388 let cache3_start = std::time::Instant::now();
390 assert!(client.validate_dat(&dat).await.is_ok());
391 tracing::debug!("Third validation took {:?}", cache3_start.elapsed());
392 }
393}