1use crate::auth::{Auth, AuthError, RequestBuilderExt};
2use libpep::data::traits::{HasStructure, Pseudonymizable, Rekeyable, Transcryptable};
3use libpep::factors::{EncryptionContext, PseudonymizationDomain};
4use libpep::keys::distribution::SessionKeyShares;
5use paas_api::config::{PAASConfig, TranscryptorConfig};
6use paas_api::paths::ApiPath;
7use paas_api::sessions::{EndSessionRequest, SessionResponse, StartSessionResponse};
8use paas_api::status::{StatusResponse, VersionInfo};
9use paas_api::transcrypt::{
10 PseudonymizationBatchRequest, PseudonymizationBatchResponse, PseudonymizationRequest,
11 PseudonymizationResponse, RekeyBatchRequest, RekeyBatchResponse, RekeyRequest, RekeyResponse,
12 TranscryptionBatchRequest, TranscryptionBatchResponse, TranscryptionRequest,
13 TranscryptionResponse,
14};
15use serde::Serialize;
16use serde::de::DeserializeOwned;
17use std::sync::Arc;
18
19#[derive(Debug, thiserror::Error)]
20pub enum TranscryptorError {
21 #[error(transparent)]
22 AuthError(#[from] AuthError),
23 #[error(transparent)]
24 NetworkError(#[from] reqwest::Error),
25 #[error("Authentication required")]
26 Unauthorized,
27 #[error("Transcryption not allowed: {0}")]
28 NotAllowed(String),
29 #[error("Invalid or expired session: {0}")]
30 InvalidSession(String),
31 #[error("Bad request: {0}")]
32 BadRequest(String),
33 #[error("Server error: {0}")]
34 ServerError(String),
35
36 #[error("No active session to end")]
37 NoSessionToEnd,
38 #[error(
39 "Client version {client_version} is incompatible with server version {server_version} (min. supported version {server_min_supported_version})"
40 )]
41 IncompatibleClientVersionError {
42 client_version: String,
43 server_version: String,
44 server_min_supported_version: String,
45 },
46 #[error("Inconsistent system name (configured: {configured_name}, responded: {responded_name}")]
47 InconsistentSystemNameError {
48 configured_name: String,
49 responded_name: String,
50 },
51 #[error("Inconsistent system name ({name})")]
52 InvalidSystemNameError { name: String },
53 #[error("Inconsistent configuration (configured: {configured_url}, responded: {responded_url}")]
54 InconsistentUrlError {
55 configured_url: String,
56 responded_url: String,
57 },
58 #[error("URL must use HTTPS protocol, got: {scheme}")]
59 NonHttpsUrlError { scheme: String },
60 #[error("Unexpected encrypted data variant type in response")]
61 UnexpectedVariantType,
62}
63
64#[derive(Debug, serde::Deserialize)]
65pub struct ErrorResponse {
66 pub error: String,
67}
68
69#[derive(Clone)]
70pub struct TranscryptorClient {
72 pub(crate) config: TranscryptorConfig,
73 pub(crate) session_id: Option<EncryptionContext>,
74 pub(crate) sks: Option<SessionKeyShares>,
75 auth: Arc<dyn Auth>,
76}
77impl TranscryptorClient {
78 pub async fn new(
80 config: TranscryptorConfig,
81 auth: Arc<dyn Auth>,
82 ) -> Result<TranscryptorClient, TranscryptorError> {
83 if config.url.scheme() != "https" {
84 return Err(TranscryptorError::NonHttpsUrlError {
85 scheme: config.url.scheme().to_string(),
86 });
87 }
88
89 let mut client = Self {
90 config,
91 auth,
92 session_id: None,
93 sks: None,
94 };
95 client.check_status().await.and(Ok(client))
96 }
97
98 #[doc(hidden)]
101 pub async fn new_allow_http(
102 config: TranscryptorConfig,
103 auth: Arc<dyn Auth>,
104 ) -> Result<TranscryptorClient, TranscryptorError> {
105 let mut client = Self {
106 config,
107 auth,
108 session_id: None,
109 sks: None,
110 };
111 client.check_status().await.and(Ok(client))
112 }
113
114 pub async fn restore(
116 config: TranscryptorConfig,
117 auth: Arc<dyn Auth>,
118 session_id: EncryptionContext,
119 sks: SessionKeyShares,
120 ) -> Result<TranscryptorClient, TranscryptorError> {
121 if config.url.scheme() != "https" {
122 return Err(TranscryptorError::NonHttpsUrlError {
123 scheme: config.url.scheme().to_string(),
124 });
125 }
126
127 let mut client = Self {
128 config,
129 auth,
130 session_id: Some(session_id),
131 sks: Some(sks),
132 };
133 client.check_status().await.and(Ok(client))
134 }
135 pub fn dump(
136 &self,
137 ) -> (
138 TranscryptorConfig,
139 Option<EncryptionContext>,
140 Option<SessionKeyShares>,
141 ) {
142 (self.config.clone(), self.session_id.clone(), self.sks)
143 }
144 fn make_url(&self, path: &str) -> String {
145 format!(
146 "{}{}{}",
147 self.config.url.as_str().trim_end_matches('/'),
148 paas_api::paths::API_BASE,
149 path
150 )
151 }
152 async fn process_response<T>(&self, response: reqwest::Response) -> Result<T, TranscryptorError>
153 where
154 T: serde::de::DeserializeOwned,
155 {
156 if let Err(error) = response.error_for_status_ref() {
157 let status = response.status();
158 let body = response.text().await.unwrap_or_default();
159
160 let error_message = serde_json::from_str::<ErrorResponse>(&body)
161 .map(|r| r.error)
162 .unwrap_or(body);
163
164 let error = match status.as_u16() {
165 401 => TranscryptorError::Unauthorized,
166 403 => TranscryptorError::NotAllowed(error_message),
167 404 => TranscryptorError::InvalidSession(error_message),
168 400 => TranscryptorError::BadRequest(error_message),
169 500..=599 => TranscryptorError::ServerError(error_message),
170 _ => TranscryptorError::NetworkError(error),
171 };
172
173 return Err(error);
174 }
175
176 let data = response.json::<T>().await?;
177 Ok(data)
178 }
179
180 pub async fn check_status(&mut self) -> Result<(), TranscryptorError> {
182 let response = reqwest::Client::new()
183 .get(self.make_url(paas_api::paths::STATUS))
184 .send()
185 .await?;
186
187 let status = self.process_response::<StatusResponse>(response).await?;
188
189 let client_version = VersionInfo::default();
190 if !status.version_info.is_compatible_with(&client_version) {
191 return Err(TranscryptorError::IncompatibleClientVersionError {
192 client_version: client_version.protocol_version.to_string(),
193 server_version: status.version_info.protocol_version.to_string(),
194 server_min_supported_version: status.version_info.min_supported_version.to_string(),
195 });
196 };
197
198 if status.system_id != self.config.system_id {
199 return Err(TranscryptorError::InconsistentSystemNameError {
200 responded_name: status.system_id,
201 configured_name: self.config.system_id.clone(),
202 });
203 }
204
205 Ok(())
206 }
207
208 pub async fn check_config(&mut self) -> Result<PAASConfig, TranscryptorError> {
210 let response = reqwest::Client::new()
211 .get(self.make_url(paas_api::paths::CONFIG))
212 .with_auth(&self.auth)
213 .await?
214 .send()
215 .await?;
216
217 let config = self.process_response::<PAASConfig>(response).await?;
218
219 let ts_config = config
220 .transcryptors
221 .iter()
222 .find(|tc| tc.system_id == self.config.system_id)
223 .ok_or_else(|| TranscryptorError::InvalidSystemNameError {
224 name: self.config.system_id.clone(),
225 })?;
226
227 if ts_config.url != self.config.url {
228 return Err(TranscryptorError::InconsistentUrlError {
229 configured_url: self.config.url.to_string(),
230 responded_url: ts_config.url.to_string(),
231 });
232 }
233
234 Ok(config)
235 }
236
237 pub async fn start_session(
239 &mut self,
240 ) -> Result<(EncryptionContext, SessionKeyShares), TranscryptorError> {
241 let response = reqwest::Client::new()
242 .post(self.make_url(paas_api::paths::SESSIONS_START))
243 .with_auth(&self.auth)
244 .await?
245 .send()
246 .await?;
247
248 let session = self
249 .process_response::<StartSessionResponse>(response)
250 .await?;
251
252 self.session_id = Some(session.session_id.clone());
253 self.sks = Some(session.session_key_shares);
254 Ok((session.session_id, session.session_key_shares))
255 }
256
257 pub async fn get_sessions(&mut self) -> Result<Vec<EncryptionContext>, TranscryptorError> {
259 let response = reqwest::Client::new()
260 .get(self.make_url(paas_api::paths::SESSIONS_GET))
261 .with_auth(&self.auth)
262 .await?
263 .send()
264 .await?;
265
266 let sessions = self.process_response::<SessionResponse>(response).await?;
267 Ok(sessions.sessions)
268 }
269
270 pub async fn end_session(&mut self) -> Result<(), TranscryptorError> {
273 let request = EndSessionRequest {
274 session_id: self
275 .session_id
276 .clone()
277 .ok_or(TranscryptorError::NoSessionToEnd)?,
278 };
279 let response = reqwest::Client::new()
280 .post(self.make_url(paas_api::paths::SESSIONS_END))
281 .with_auth(&self.auth)
282 .await?
283 .json(&request)
284 .send()
285 .await?;
286
287 let _ = self
288 .process_response::<StartSessionResponse>(response)
289 .await?;
290
291 self.session_id = None;
292 self.sks = None;
293
294 Ok(())
295 }
296
297 pub async fn pseudonymize<T>(
299 &self,
300 encrypted: &T,
301 domain_from: &PseudonymizationDomain,
302 domain_to: &PseudonymizationDomain,
303 session_from: &EncryptionContext,
304 session_to: &EncryptionContext,
305 ) -> Result<T, TranscryptorError>
306 where
307 T: Pseudonymizable + DeserializeOwned + Serialize + Clone + ApiPath,
308 {
309 let request = PseudonymizationRequest {
310 encrypted: encrypted.clone(),
311 domain_from: domain_from.clone(),
312 domain_to: domain_to.clone(),
313 session_from: session_from.clone(),
314 session_to: session_to.clone(),
315 };
316 let response = reqwest::Client::new()
317 .post(self.make_url(paas_api::paths::pseudonymize_path::<T>().as_str()))
318 .with_auth(&self.auth)
319 .await?
320 .json(&request)
321 .send()
322 .await?;
323 let pseudo_response = self
324 .process_response::<PseudonymizationResponse<T>>(response)
325 .await?;
326 Ok(pseudo_response.result)
327 }
328
329 pub async fn pseudonymize_batch<T>(
331 &self,
332 encrypted: Vec<T>,
333 domain_from: &PseudonymizationDomain,
334 domain_to: &PseudonymizationDomain,
335 session_from: &EncryptionContext,
336 session_to: &EncryptionContext,
337 ) -> Result<Vec<T>, TranscryptorError>
338 where
339 T: Pseudonymizable + DeserializeOwned + Serialize + Clone + ApiPath + HasStructure,
340 {
341 if encrypted.is_empty() {
342 return Ok(vec![]);
343 }
344
345 let request = PseudonymizationBatchRequest {
346 encrypted,
347 domain_from: domain_from.clone(),
348 domain_to: domain_to.clone(),
349 session_from: session_from.clone(),
350 session_to: session_to.clone(),
351 };
352 let response = reqwest::Client::new()
353 .post(self.make_url(paas_api::paths::pseudonymize_batch_path::<T>().as_str()))
354 .with_auth(&self.auth)
355 .await?
356 .json(&request)
357 .send()
358 .await?;
359 let pseudo_response = self
360 .process_response::<PseudonymizationBatchResponse<T>>(response)
361 .await?;
362 Ok(pseudo_response.result)
363 }
364
365 pub async fn rekey<T>(
367 &self,
368 encrypted: &T,
369 session_from: &EncryptionContext,
370 session_to: &EncryptionContext,
371 ) -> Result<T, TranscryptorError>
372 where
373 T: Rekeyable + DeserializeOwned + Serialize + Clone + ApiPath,
374 {
375 let request = RekeyRequest {
376 encrypted: encrypted.clone(),
377 session_from: session_from.clone(),
378 session_to: session_to.clone(),
379 };
380 let response = reqwest::Client::new()
381 .post(self.make_url(paas_api::paths::rekey_path::<T>().as_str()))
382 .with_auth(&self.auth)
383 .await?
384 .json(&request)
385 .send()
386 .await?;
387 let rekey_response = self.process_response::<RekeyResponse<T>>(response).await?;
388 Ok(rekey_response.result)
389 }
390
391 pub async fn rekey_batch<T>(
393 &self,
394 encrypted: Vec<T>,
395 session_from: &EncryptionContext,
396 session_to: &EncryptionContext,
397 ) -> Result<Vec<T>, TranscryptorError>
398 where
399 T: Rekeyable + DeserializeOwned + Serialize + Clone + ApiPath + HasStructure,
400 {
401 if encrypted.is_empty() {
402 return Ok(vec![]);
403 }
404
405 let request = RekeyBatchRequest {
406 encrypted,
407 session_from: session_from.clone(),
408 session_to: session_to.clone(),
409 };
410 let response = reqwest::Client::new()
411 .post(self.make_url(paas_api::paths::rekey_batch_path::<T>().as_str()))
412 .with_auth(&self.auth)
413 .await?
414 .json(&request)
415 .send()
416 .await?;
417 let rekey_response = self
418 .process_response::<RekeyBatchResponse<T>>(response)
419 .await?;
420 Ok(rekey_response.result)
421 }
422
423 pub async fn transcrypt<T>(
425 &self,
426 encrypted: &T,
427 domain_from: &PseudonymizationDomain,
428 domain_to: &PseudonymizationDomain,
429 session_from: &EncryptionContext,
430 session_to: &EncryptionContext,
431 ) -> Result<T, TranscryptorError>
432 where
433 T: Transcryptable + DeserializeOwned + Serialize + Clone + ApiPath,
434 {
435 let request = TranscryptionRequest {
436 encrypted: encrypted.clone(),
437 domain_from: domain_from.clone(),
438 domain_to: domain_to.clone(),
439 session_from: session_from.clone(),
440 session_to: session_to.clone(),
441 };
442 let response = reqwest::Client::new()
443 .post(self.make_url(paas_api::paths::transcrypt_path::<T>().as_str()))
444 .with_auth(&self.auth)
445 .await?
446 .json(&request)
447 .send()
448 .await?;
449 let transcrypt_response = self
450 .process_response::<TranscryptionResponse<T>>(response)
451 .await?;
452 Ok(transcrypt_response.result)
453 }
454
455 pub async fn transcrypt_batch<T>(
457 &self,
458 encrypted: Vec<T>,
459 domain_from: &PseudonymizationDomain,
460 domain_to: &PseudonymizationDomain,
461 session_from: &EncryptionContext,
462 session_to: &EncryptionContext,
463 ) -> Result<Vec<T>, TranscryptorError>
464 where
465 T: Transcryptable + DeserializeOwned + Serialize + Clone + ApiPath + HasStructure,
466 {
467 if encrypted.is_empty() {
468 return Ok(vec![]);
469 }
470
471 let request = TranscryptionBatchRequest {
472 encrypted,
473 domain_from: domain_from.clone(),
474 domain_to: domain_to.clone(),
475 session_from: session_from.clone(),
476 session_to: session_to.clone(),
477 };
478 let response = reqwest::Client::new()
479 .post(self.make_url(paas_api::paths::transcrypt_batch_path::<T>().as_str()))
480 .with_auth(&self.auth)
481 .await?
482 .json(&request)
483 .send()
484 .await?;
485 let transcrypt_response = self
486 .process_response::<TranscryptionBatchResponse<T>>(response)
487 .await?;
488 Ok(transcrypt_response.result)
489 }
490}
491
492#[cfg(test)]
493mod tests {
494 use async_trait::async_trait;
495 use mockito::Server;
496
497 use super::*;
498
499 struct TestAuth;
500
501 #[async_trait]
502 impl Auth for TestAuth {
503 fn token_type(&self) -> &str {
504 "test"
505 }
506
507 async fn token(&self) -> Result<String, Box<dyn core::error::Error>> {
508 Ok("test".to_owned())
509 }
510 }
511
512 #[tokio::test]
513 async fn error_mapping() {
514 let config = TranscryptorConfig {
515 system_id: "test".to_owned(),
516 url: "https://example.com".parse().unwrap(),
517 };
518 let auth = Arc::new(TestAuth);
519 let t_c = TranscryptorClient {
520 config,
521 session_id: None,
522 sks: None,
523 auth,
524 };
525
526 let mut server = Server::new_async().await;
527
528 server
529 .mock("GET", "/")
530 .with_status(404)
531 .with_header("content-type", "application/json")
532 .with_body(
533 r#"{"error":"Unknown or expired session: Target session not owned by user"}"#,
534 )
535 .create_async()
536 .await;
537
538 let response = reqwest::get(server.url()).await;
539
540 let x = t_c
541 .process_response::<StatusResponse>(response.unwrap())
542 .await;
543
544 let err_str = "Unknown or expired session: Target session not owned by user";
545
546 #[allow(clippy::assertions_on_constants)]
547 match x {
548 Err(TranscryptorError::InvalidSession(err)) if err == err_str => assert!(true),
549 _ => assert!(false),
550 }
551 }
552}