Skip to main content

comdirect_rest_api/oauth2/
session.rs

1use chrono::Utc;
2use reqwest::header::{ACCEPT, AUTHORIZATION, CONTENT_TYPE};
3use serde_json::Value;
4use std::collections::HashMap;
5use std::sync::Arc;
6use tokio::sync::RwLock;
7use tokio::sync::mpsc;
8use tracing::{error, info, warn};
9use uuid::Uuid;
10
11use crate::oauth2::config::ComdirectConfig;
12use crate::oauth2::errors::AuthError;
13use crate::oauth2::types::{AuthResult, TokenResponse};
14
15/// A session with the Comdirect API, managing tokens and automatic background refreshes.
16pub struct Session {
17    pub(crate) client: reqwest::Client,
18    pub(crate) state: Arc<SessionState>,
19    pub session_id: String,
20    pub(crate) config: ComdirectConfig,
21    pub(crate) refresh_tx: mpsc::UnboundedSender<u64>,
22}
23
24/// Internal shared state for token management, allowing background refreshes.
25pub(crate) struct SessionState {
26    pub(crate) access_token: RwLock<String>,
27    pub(crate) refresh_token: RwLock<String>,
28}
29
30fn timestamp() -> String {
31    Utc::now().format("%Y%m%d%H%M%S%6f").to_string()
32}
33
34impl Session {
35    /// Initializes a new authenticated session with Comdirect.
36    ///
37    /// This method orchestrates the complete OAuth2 handshake. Depending on whether a
38    /// `refresh_token` is provided, it will either attempt to resume an existing session
39    /// or perform a full "Interactive" login sequence.
40    ///
41    /// # Authentication Flow
42    /// 1. **Resumption**: If `refresh_token` is `Some`, it attempts a `refresh_token` grant.
43    ///    - If successful, it proceeds to obtain a new session ID and sets up the auto-refresh worker.
44    ///    - If unsuccessful (e.g., token expired), it returns an error instead of falling back.
45    /// 2. **Full Authentication**: If no token is provided or resumption fails:
46    ///    - Performs a `password` grant using `user` and `password` from the config.
47    ///    - Triggers a **Push-TAN** challenge for session validation.
48    ///    - Performs a `cd_secondary` grant to finalize the session.
49    ///
50    /// # Background Worker
51    /// Upon successful creation, a background task is spawned that monitors token expiration.
52    /// It proactively refreshes the access token 60 seconds before it expires (or at 10% of
53    /// its life if shorter than 120s) to ensure uninterrupted API access.
54    ///
55    /// # Token Persistence
56    /// To persist authentication across application restarts, provide an `on_refresh_token`
57    /// callback in the [`ComdirectConfig`]. This callback is triggered whenever a new
58    /// refresh token is successfully obtained (both during initial login and background refreshes).
59    ///
60    /// # Example
61    /// ```rust,no_run
62    /// use comdirect_rest_api::oauth2::{ComdirectConfig, Session};
63    /// use std::sync::Arc;
64    ///
65    /// #[tokio::main]
66    /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
67    ///     let config = ComdirectConfig {
68    ///         user: "1234567".to_string(),
69    ///         password: "your_password".to_string(),
70    ///         client_id: "your_client_id".to_string(),
71    ///         client_secret: "your_client_secret".to_string(),
72    ///         on_refresh_token: Some(Arc::new(|new_token| {
73    ///             println!("New refresh token received: {}", new_token);
74    ///             // Save this token to a database or secure file for next time
75    ///         })),
76    ///     };
77    ///
78    ///     // Try to resume with a previously saved token
79    ///     let saved_token = Some("...token from database...".to_string());
80    ///     let session = Session::new(&config, saved_token).await?;
81    ///
82    ///     println!("Session established! ID: {}", session.session_id);
83    ///     Ok(())
84    /// }
85    /// ```
86    pub async fn new(
87        config: &ComdirectConfig,
88        refresh_token: Option<String>,
89    ) -> Result<Self, AuthError> {
90        let client = reqwest::Client::new();
91
92        info!("Starting Comdirect Authentication...");
93
94        let auth_result = if let Some(rt) = refresh_token {
95            info!("Trying to authenticate with provided refresh token...");
96            Self::fetch_tokens_refresh(&client, config, &rt).await?
97        } else {
98            Self::fetch_tokens_full(&client, config).await?
99        };
100
101        Ok(Self::create_session(auth_result, config.clone(), client))
102    }
103
104    /// Internal helper to initialize the session struct and spawn the background refresh worker.
105    fn create_session(auth: AuthResult, config: ComdirectConfig, client: reqwest::Client) -> Self {
106        let (tx, mut rx) = mpsc::unbounded_channel::<u64>();
107
108        let state = Arc::new(SessionState {
109            access_token: RwLock::new(auth.access_token),
110            refresh_token: RwLock::new(auth.refresh_token),
111        });
112
113        let state_clone = Arc::clone(&state);
114        let client_clone = client.clone();
115        let config_clone = config.clone();
116        let tx_clone = tx.clone();
117
118        tokio::spawn(async move {
119            let mut expires_in = auth.expires_in;
120            loop {
121                // Refresh 60 seconds before expiry, or at 90% of life if shorter than 60s
122                let buffer = if expires_in > 120 {
123                    60
124                } else {
125                    expires_in / 10
126                };
127                let wait_secs = expires_in.saturating_sub(buffer);
128
129                tokio::select! {
130                    Some(new_expires) = rx.recv() => {
131                        expires_in = new_expires;
132                    }
133                    _ = tokio::time::sleep(std::time::Duration::from_secs(wait_secs)) => {
134                        info!("Auto-refresh timer triggered (expires in {}s, waiting {}s)", expires_in, wait_secs);
135                        if let Err(e) = Self::perform_refresh_internal(&client_clone, &config_clone, &state_clone, &tx_clone).await {
136                            error!("Auto-refresh failed: {}. Will retry in 10 seconds...", e);
137                            tokio::time::sleep(std::time::Duration::from_secs(10)).await;
138                        }
139                    }
140                }
141            }
142        });
143
144        Session {
145            client,
146            state,
147            session_id: auth.session_id,
148            config,
149            refresh_tx: tx,
150        }
151    }
152
153    /// Performs OAuth2 refresh flow.
154    async fn fetch_tokens_refresh(
155        client: &reqwest::Client,
156        config: &ComdirectConfig,
157        refresh_token: &str,
158    ) -> Result<AuthResult, AuthError> {
159        let mut form = HashMap::new();
160        form.insert("client_id", config.client_id.as_str());
161        form.insert("client_secret", config.client_secret.as_str());
162        form.insert("grant_type", "refresh_token");
163        form.insert("refresh_token", refresh_token.trim());
164
165        let res = client
166            .post("https://api.comdirect.de/oauth/token")
167            .header(ACCEPT, "application/json")
168            .header(CONTENT_TYPE, "application/x-www-form-urlencoded")
169            .form(&form)
170            .send()
171            .await
172            .map_err(|e| AuthError::Network(e.to_string()))?;
173
174        if res.status() == 400 || res.status() == 401 {
175            return Err(AuthError::InvalidCredentials(format!(
176                "Token refresh rejected: {}",
177                res.status()
178            )));
179        }
180
181        if res.status() != 200 {
182            return Err(AuthError::Other(format!(
183                "Refresh token POST failed with status {}",
184                res.status()
185            )));
186        }
187
188        let token_res: TokenResponse = res
189            .json()
190            .await
191            .map_err(|e| AuthError::Other(e.to_string()))?;
192
193        let session_id = Self::get_session_id(client, &token_res.access_token).await?;
194
195        if let Some(callback) = &config.on_refresh_token {
196            callback(token_res.refresh_token.clone());
197        }
198
199        Ok(AuthResult {
200            access_token: token_res.access_token,
201            refresh_token: token_res.refresh_token,
202            expires_in: token_res.expires_in,
203            session_id,
204        })
205    }
206
207    /// Performs the full authentication sequence including password grant and TAN validation.
208    async fn fetch_tokens_full(
209        client: &reqwest::Client,
210        config: &ComdirectConfig,
211    ) -> Result<AuthResult, AuthError> {
212        info!("Performing full authentication...");
213
214        let mut form = HashMap::new();
215        form.insert("client_id", config.client_id.as_str());
216        form.insert("client_secret", config.client_secret.as_str());
217        form.insert("username", config.user.as_str());
218        form.insert("password", config.password.as_str());
219        form.insert("grant_type", "password");
220
221        let res = client
222            .post("https://api.comdirect.de/oauth/token")
223            .header(ACCEPT, "application/json")
224            .header(CONTENT_TYPE, "application/x-www-form-urlencoded")
225            .form(&form)
226            .send()
227            .await
228            .map_err(|e| AuthError::Network(e.to_string()))?;
229
230        if res.status() != 200 {
231            return Err(AuthError::InvalidCredentials(format!(
232                "Initial OAuth POST failed with status {}",
233                res.status()
234            )));
235        }
236
237        let token_res: Value = res
238            .json()
239            .await
240            .map_err(|e| AuthError::Other(e.to_string()))?;
241        let access_token = token_res["access_token"]
242            .as_str()
243            .ok_or_else(|| AuthError::Other("Missing access_token".to_string()))?
244            .to_string();
245
246        let session_id = Self::get_session_id(client, &access_token).await?;
247
248        // Validate session (triggers Push-TAN)
249        let req_info = serde_json::json!({
250            "clientRequestId": {
251                "sessionId": session_id,
252                "requestId": timestamp()
253            }
254        })
255        .to_string();
256        let body = serde_json::json!({
257            "identifier": session_id,
258            "sessionTanActive": true,
259            "activated2FA": true
260        })
261        .to_string();
262
263        let url = &format!(
264            "https://api.comdirect.de/api/session/clients/user/v1/sessions/{}/validate",
265            session_id
266        );
267        let res = client
268            .post(url)
269            .header(ACCEPT, "application/json")
270            .header(AUTHORIZATION, format!("Bearer {}", access_token))
271            .header("x-http-request-info", req_info)
272            .header(CONTENT_TYPE, "application/json")
273            .body(body)
274            .send()
275            .await
276            .map_err(|e| AuthError::Network(e.to_string()))?;
277
278        if res.status() != 201 {
279            return Err(AuthError::Other(format!(
280                "Session validation failed with status {}",
281                res.status()
282            )));
283        }
284
285        let once_auth_info_str = res
286            .headers()
287            .get("x-once-authentication-info")
288            .ok_or_else(|| {
289                AuthError::Other("Missing x-once-authentication-info header".to_string())
290            })?
291            .to_str()
292            .map_err(|e| AuthError::Other(e.to_string()))?;
293
294        let once_auth_info: Value = serde_json::from_str(once_auth_info_str)
295            .map_err(|e| AuthError::Other(e.to_string()))?;
296        let challenge_id = once_auth_info["id"]
297            .as_str()
298            .ok_or_else(|| AuthError::Other("Missing challenge id".to_string()))?;
299
300        info!(
301            "Auth challenge (Push-TAN) triggered. Please confirm on your device. Waiting for 1 minute..."
302        );
303
304        (config.on_awaits_user_confirm)().await;
305
306        let tan = "123456";
307
308        let req_info = serde_json::json!({
309            "clientRequestId": {
310                "sessionId": session_id,
311                "requestId": timestamp()
312            }
313        })
314        .to_string();
315        let body = serde_json::json!({
316            "identifier": session_id,
317            "sessionTanActive": true,
318            "activated2FA": true
319        })
320        .to_string();
321
322        let url1 = &format!(
323            "https://api.comdirect.de/api/session/clients/user/v1/sessions/{}",
324            session_id
325        );
326        let res = client
327            .patch(url1)
328            .header(ACCEPT, "application/json")
329            .header(AUTHORIZATION, format!("Bearer {}", access_token))
330            .header("x-http-request-info", req_info)
331            .header(CONTENT_TYPE, "application/json")
332            .header(
333                "x-once-authentication-info",
334                serde_json::json!({"id": challenge_id}).to_string(),
335            )
336            .header("x-once-authentication", tan)
337            .body(body)
338            .send()
339            .await
340            .map_err(|e| AuthError::Network(e.to_string()))?;
341
342        if res.status() != 200 {
343            return Err(AuthError::Other(format!(
344                "Session PATCH (TAN completion) failed with status {}",
345                res.status()
346            )));
347        }
348
349        // Secondary OAuth (final step)
350        let mut form = HashMap::new();
351        form.insert("client_id", config.client_id.as_str());
352        form.insert("client_secret", config.client_secret.as_str());
353        form.insert("grant_type", "cd_secondary");
354        form.insert("token", access_token.as_str());
355
356        let res = client
357            .post("https://api.comdirect.de/oauth/token")
358            .header(ACCEPT, "application/json")
359            .header(CONTENT_TYPE, "application/x-www-form-urlencoded")
360            .form(&form)
361            .send()
362            .await
363            .map_err(|e| AuthError::Network(e.to_string()))?;
364
365        if res.status() != 200 {
366            return Err(AuthError::Other(format!(
367                "Secondary OAuth token POST failed with status {}",
368                res.status()
369            )));
370        }
371
372        let token_res: TokenResponse = res
373            .json()
374            .await
375            .map_err(|e| AuthError::Other(e.to_string()))?;
376
377        if let Some(callback) = &config.on_refresh_token {
378            callback(token_res.refresh_token.clone());
379        }
380
381        info!("Authentication sequence completed successfully!");
382
383        Ok(AuthResult {
384            access_token: token_res.access_token,
385            refresh_token: token_res.refresh_token,
386            expires_in: token_res.expires_in,
387            session_id,
388        })
389    }
390
391    /// Obtains the session ID required for API requests.
392    async fn get_session_id(
393        client: &reqwest::Client,
394        access_token: &str,
395    ) -> Result<String, AuthError> {
396        let generated_session_id = Uuid::new_v4().to_string();
397        let req_info = serde_json::json!({
398            "clientRequestId": {
399                "sessionId": generated_session_id,
400                "requestId": timestamp()
401            }
402        })
403        .to_string();
404
405        let res = client
406            .get("https://api.comdirect.de/api/session/clients/user/v1/sessions")
407            .header(ACCEPT, "application/json")
408            .header(AUTHORIZATION, format!("Bearer {}", access_token))
409            .header("x-http-request-info", req_info)
410            .send()
411            .await
412            .map_err(|e| AuthError::Network(e.to_string()))?;
413
414        if res.status() != 200 {
415            return Err(AuthError::Other(format!(
416                "GET sessions failed with status {}",
417                res.status()
418            )));
419        }
420
421        let sessions_res: Value = res
422            .json()
423            .await
424            .map_err(|e| AuthError::Other(e.to_string()))?;
425        sessions_res[0]["identifier"]
426            .as_str()
427            .map(|s| s.to_string())
428            .ok_or_else(|| AuthError::Other("Missing session identifier".to_string()))
429    }
430
431    /// Performs an authorized GET request, automatically handling token expiration.
432    pub(crate) async fn get_authorized(&self, url: &str) -> Result<reqwest::Response, AuthError> {
433        let mut retry = true;
434        loop {
435            let access_token = self.state.access_token.read().await.clone();
436
437            let req_info = serde_json::json!({
438                "clientRequestId": {
439                    "sessionId": self.session_id,
440                    "requestId": timestamp()
441                }
442            })
443            .to_string();
444
445            let res = self
446                .client
447                .get(url)
448                .header(ACCEPT, "application/json")
449                .header(AUTHORIZATION, format!("Bearer {}", access_token))
450                .header("x-http-request-info", req_info)
451                .send()
452                .await
453                .map_err(|e| AuthError::Network(e.to_string()))?;
454
455            if res.status() == 401 && retry {
456                warn!("Access token expired. Attempting refresh...");
457                retry = false;
458                match self.refresh_token_in_place().await {
459                    Ok(_) => {
460                        info!("Token refreshed successfully. Retrying request...");
461                        continue;
462                    }
463                    Err(e) => {
464                        error!("Token refresh failed during runtime: {}", e);
465                        return Err(e);
466                    }
467                }
468            }
469
470            if res.status() != 200 {
471                return Err(AuthError::Other(format!(
472                    "Authorized GET to {} failed with status {}",
473                    url,
474                    res.status()
475                )));
476            }
477            return Ok(res);
478        }
479    }
480
481    /// Performs a token refresh in-place using the current refresh token.
482    async fn refresh_token_in_place(&self) -> Result<(), AuthError> {
483        Self::perform_refresh_internal(&self.client, &self.config, &self.state, &self.refresh_tx)
484            .await
485    }
486
487    /// Internal logic for token refresh, shared between background task and on-demand refresh.
488    async fn perform_refresh_internal(
489        client: &reqwest::Client,
490        config: &ComdirectConfig,
491        state: &SessionState,
492        refresh_tx: &mpsc::UnboundedSender<u64>,
493    ) -> Result<(), AuthError> {
494        let current_refresh = state.refresh_token.read().await.clone();
495
496        let mut form = HashMap::new();
497        form.insert("client_id", config.client_id.as_str());
498        form.insert("client_secret", config.client_secret.as_str());
499        form.insert("grant_type", "refresh_token");
500        form.insert("refresh_token", current_refresh.trim());
501
502        let res = client
503            .post("https://api.comdirect.de/oauth/token")
504            .header(ACCEPT, "application/json")
505            .header(CONTENT_TYPE, "application/x-www-form-urlencoded")
506            .form(&form)
507            .send()
508            .await
509            .map_err(|e| AuthError::Network(e.to_string()))?;
510
511        if res.status() != 200 {
512            return Err(AuthError::InvalidCredentials(format!(
513                "Refresh failed with status {}",
514                res.status()
515            )));
516        }
517
518        let token_res: TokenResponse = res
519            .json()
520            .await
521            .map_err(|e| AuthError::Other(e.to_string()))?;
522
523        let mut token_lock = state.access_token.write().await;
524        *token_lock = token_res.access_token;
525        drop(token_lock);
526
527        let mut refresh_lock = state.refresh_token.write().await;
528        *refresh_lock = token_res.refresh_token.clone();
529        drop(refresh_lock);
530
531        if let Some(callback) = &config.on_refresh_token {
532            callback(token_res.refresh_token.clone());
533        }
534
535        // Notify worker of the new expiration time
536        let _ = refresh_tx.send(token_res.expires_in);
537
538        Ok(())
539    }
540}