Skip to main content

comdirect_rest_api/oauth2/
session.rs

1use chrono::Utc;
2use reqwest::header::{ACCEPT, AUTHORIZATION, CONTENT_TYPE};
3use serde_json::Value;
4use std::collections::HashMap;
5use std::sync::Arc;
6use tokio::sync::RwLock;
7use tokio::sync::mpsc;
8use tracing::{error, info};
9use uuid::Uuid;
10
11use crate::oauth2::config::ComdirectConfig;
12use crate::oauth2::errors::AuthError;
13use crate::oauth2::types::{AuthResult, TokenResponse};
14
15/// A session with the Comdirect API, managing tokens and automatic background refreshes.
16pub struct Session {
17    pub(crate) client: reqwest::Client,
18    pub(crate) state: Arc<SessionState>,
19    pub(crate) refresh_tx: mpsc::UnboundedSender<u64>,
20}
21
22/// Internal shared state for token management, allowing background refreshes.
23pub(crate) struct SessionState {
24    pub(crate) access_token: RwLock<String>,
25    pub(crate) refresh_token: RwLock<String>,
26    pub(crate) session_id: RwLock<String>,
27}
28
29fn timestamp() -> String {
30    Utc::now().format("%Y%m%d%H%M%S%6f").to_string()
31}
32
33impl Session {
34    /// Initializes a new authenticated session with Comdirect.
35    ///
36    /// This method orchestrates the complete OAuth2 handshake. Depending on whether a
37    /// `refresh_token` is provided, it will either attempt to resume an existing session
38    /// or perform a full "Interactive" login sequence.
39    ///
40    /// # Authentication Flow
41    /// 1. **Resumption**: If `refresh_token` is `Some`, it attempts a `refresh_token` grant.
42    ///    - If successful, it proceeds to obtain a new session ID and sets up the auto-refresh worker.
43    ///    - If unsuccessful (e.g., token expired), it returns an error instead of falling back.
44    /// 2. **Full Authentication**: If no token is provided or resumption fails:
45    ///    - Performs a `password` grant using `user` and `password` from the config.
46    ///    - Triggers a **Push-TAN** challenge for session validation.
47    ///    - Performs a `cd_secondary` grant to finalize the session.
48    ///
49    /// # Background Worker
50    /// Upon successful creation, a background task is spawned that monitors token expiration.
51    /// It proactively refreshes the access token 60 seconds before it expires (or at 10% of
52    /// its life if shorter than 120s) to ensure uninterrupted API access.
53    ///
54    /// # Token Persistence
55    /// To persist authentication across application restarts, provide an `on_refresh_token`
56    /// callback in the [`ComdirectConfig`]. This callback is triggered whenever a new
57    /// refresh token is successfully obtained (both during initial login and background refreshes).
58    ///
59    /// # Example
60    /// ```rust,no_run
61    /// use comdirect_rest_api::oauth2::{ComdirectConfig, Session};
62    /// use std::sync::Arc;
63    ///
64    /// #[tokio::main]
65    /// async fn main() -> Result<(), Box<dyn std::error::Error>> {
66    ///     let config = ComdirectConfig {
67    ///         user: "1234567".to_string(),
68    ///         password: "your_password".to_string(),
69    ///         client_id: "your_client_id".to_string(),
70    ///         client_secret: "your_client_secret".to_string(),
71    ///         on_awaits_user_confirm: Arc::new(|| Box::pin(async {
72    ///             println!("Please confirm the TAN on your mobile device.");
73    ///         })),
74    ///         on_refresh_token: Some(Arc::new(|new_token| {
75    ///             println!("New refresh token received: {}", new_token);
76    ///             // Save this token to a database or secure file for next time
77    ///         })),
78    ///     };
79    ///
80    ///     // Try to resume with a previously saved token
81    ///     let saved_token = Some("...token from database...".to_string());
82    ///     let session = Session::new(&config, saved_token).await?;
83    ///
84    ///     println!("Session established! ID: {}", session.session_id().await);
85    ///     Ok(())
86    /// }
87    /// ```
88    pub async fn new(
89        config: &ComdirectConfig,
90        refresh_token: Option<String>,
91    ) -> Result<Self, AuthError> {
92        let client = reqwest::Client::new();
93
94        info!("Starting Comdirect Authentication...");
95
96        let auth_result = if let Some(rt) = refresh_token {
97            info!("Trying to authenticate with provided refresh token...");
98            Self::fetch_tokens_refresh(&client, config, &rt).await?
99        } else {
100            Self::fetch_tokens_full(&client, config).await?
101        };
102
103        Ok(Self::create_session(auth_result, config.clone(), client))
104    }
105
106    /// Internal helper to initialize the session struct and spawn the background refresh worker.
107    fn create_session(auth: AuthResult, config: ComdirectConfig, client: reqwest::Client) -> Self {
108        let (tx, mut rx) = mpsc::unbounded_channel::<u64>();
109
110        let state = Arc::new(SessionState {
111            access_token: RwLock::new(auth.access_token),
112            refresh_token: RwLock::new(auth.refresh_token),
113            session_id: RwLock::new(auth.session_id),
114        });
115
116        let state_clone = Arc::clone(&state);
117        let client_clone = client.clone();
118        let config_clone = config.clone();
119        let tx_clone = tx.clone();
120
121        tokio::spawn(async move {
122            let mut expires_in = auth.expires_in;
123            loop {
124                // Refresh 60 seconds before expiry, or at 90% of life if shorter than 60s
125                let buffer = if expires_in > 120 {
126                    60
127                } else {
128                    expires_in / 10
129                };
130                let wait_secs = expires_in.saturating_sub(buffer);
131
132                tokio::select! {
133                    Some(new_expires) = rx.recv() => {
134                        expires_in = new_expires;
135                    }
136                    _ = tokio::time::sleep(std::time::Duration::from_secs(wait_secs)) => {
137                        info!("Auto-refresh timer triggered (expires in {}s, waiting {}s)", expires_in, wait_secs);
138                        if let Err(e) = Self::perform_refresh_internal(&client_clone, &config_clone, &state_clone, &tx_clone).await {
139                            error!("Auto-refresh failed: {}. Will retry in 10 seconds...", e);
140                            tokio::time::sleep(std::time::Duration::from_secs(10)).await;
141                        }
142                    }
143                }
144            }
145        });
146
147        Session {
148            client,
149            state,
150            refresh_tx: tx,
151        }
152    }
153
154    /// Returns the currently active session ID.
155    pub async fn session_id(&self) -> String {
156        self.state.session_id.read().await.clone()
157    }
158
159    /// Performs OAuth2 refresh flow.
160    async fn fetch_tokens_refresh(
161        client: &reqwest::Client,
162        config: &ComdirectConfig,
163        refresh_token: &str,
164    ) -> Result<AuthResult, AuthError> {
165        let mut form = HashMap::new();
166        form.insert("client_id", config.client_id.as_str());
167        form.insert("client_secret", config.client_secret.as_str());
168        form.insert("grant_type", "refresh_token");
169        form.insert("refresh_token", refresh_token.trim());
170
171        let res = client
172            .post("https://api.comdirect.de/oauth/token")
173            .header(ACCEPT, "application/json")
174            .header(CONTENT_TYPE, "application/x-www-form-urlencoded")
175            .form(&form)
176            .send()
177            .await
178            .map_err(|e| AuthError::Network(e.to_string()))?;
179
180        if res.status() == 400 || res.status() == 401 {
181            return Err(AuthError::InvalidCredentials(format!(
182                "Token refresh rejected: {}",
183                res.status()
184            )));
185        }
186
187        if res.status() != 200 {
188            return Err(AuthError::Other(format!(
189                "Refresh token POST failed with status {}",
190                res.status()
191            )));
192        }
193
194        let token_res: TokenResponse = res
195            .json()
196            .await
197            .map_err(|e| AuthError::Other(e.to_string()))?;
198
199        let session_id = Self::get_session_id(client, &token_res.access_token).await?;
200
201        if let Some(callback) = &config.on_refresh_token {
202            callback(token_res.refresh_token.clone());
203        }
204
205        Ok(AuthResult {
206            access_token: token_res.access_token,
207            refresh_token: token_res.refresh_token,
208            expires_in: token_res.expires_in,
209            session_id,
210        })
211    }
212
213    /// Performs the full authentication sequence including password grant and TAN validation.
214    async fn fetch_tokens_full(
215        client: &reqwest::Client,
216        config: &ComdirectConfig,
217    ) -> Result<AuthResult, AuthError> {
218        info!("Performing full authentication...");
219
220        let mut form = HashMap::new();
221        form.insert("client_id", config.client_id.as_str());
222        form.insert("client_secret", config.client_secret.as_str());
223        form.insert("username", config.user.as_str());
224        form.insert("password", config.password.as_str());
225        form.insert("grant_type", "password");
226
227        let res = client
228            .post("https://api.comdirect.de/oauth/token")
229            .header(ACCEPT, "application/json")
230            .header(CONTENT_TYPE, "application/x-www-form-urlencoded")
231            .form(&form)
232            .send()
233            .await
234            .map_err(|e| AuthError::Network(e.to_string()))?;
235
236        if res.status() != 200 {
237            return Err(AuthError::InvalidCredentials(format!(
238                "Initial OAuth POST failed with status {}",
239                res.status()
240            )));
241        }
242
243        let token_res: Value = res
244            .json()
245            .await
246            .map_err(|e| AuthError::Other(e.to_string()))?;
247        let access_token = token_res["access_token"]
248            .as_str()
249            .ok_or_else(|| AuthError::Other("Missing access_token".to_string()))?
250            .to_string();
251
252        let session_id = Self::get_session_id(client, &access_token).await?;
253
254        // Validate session (triggers Push-TAN)
255        let req_info = serde_json::json!({
256            "clientRequestId": {
257                "sessionId": session_id,
258                "requestId": timestamp()
259            }
260        })
261        .to_string();
262        let body = serde_json::json!({
263            "identifier": session_id,
264            "sessionTanActive": true,
265            "activated2FA": true
266        })
267        .to_string();
268
269        let url = &format!(
270            "https://api.comdirect.de/api/session/clients/user/v1/sessions/{}/validate",
271            session_id
272        );
273        let res = client
274            .post(url)
275            .header(ACCEPT, "application/json")
276            .header(AUTHORIZATION, format!("Bearer {}", access_token))
277            .header("x-http-request-info", req_info)
278            .header(CONTENT_TYPE, "application/json")
279            .body(body)
280            .send()
281            .await
282            .map_err(|e| AuthError::Network(e.to_string()))?;
283
284        if res.status() != 201 {
285            return Err(AuthError::Other(format!(
286                "Session validation failed with status {}",
287                res.status()
288            )));
289        }
290
291        let once_auth_info_str = res
292            .headers()
293            .get("x-once-authentication-info")
294            .ok_or_else(|| {
295                AuthError::Other("Missing x-once-authentication-info header".to_string())
296            })?
297            .to_str()
298            .map_err(|e| AuthError::Other(e.to_string()))?;
299
300        let once_auth_info: Value = serde_json::from_str(once_auth_info_str)
301            .map_err(|e| AuthError::Other(e.to_string()))?;
302        let challenge_id = once_auth_info["id"]
303            .as_str()
304            .ok_or_else(|| AuthError::Other("Missing challenge id".to_string()))?;
305
306        info!(
307            "Auth challenge (Push-TAN) triggered. Please confirm on your device. Waiting for 1 minute..."
308        );
309
310        (config.on_awaits_user_confirm)().await;
311
312        let tan = "123456";
313
314        let req_info = serde_json::json!({
315            "clientRequestId": {
316                "sessionId": session_id,
317                "requestId": timestamp()
318            }
319        })
320        .to_string();
321        let body = serde_json::json!({
322            "identifier": session_id,
323            "sessionTanActive": true,
324            "activated2FA": true
325        })
326        .to_string();
327
328        let url1 = &format!(
329            "https://api.comdirect.de/api/session/clients/user/v1/sessions/{}",
330            session_id
331        );
332        let res = client
333            .patch(url1)
334            .header(ACCEPT, "application/json")
335            .header(AUTHORIZATION, format!("Bearer {}", access_token))
336            .header("x-http-request-info", req_info)
337            .header(CONTENT_TYPE, "application/json")
338            .header(
339                "x-once-authentication-info",
340                serde_json::json!({"id": challenge_id}).to_string(),
341            )
342            .header("x-once-authentication", tan)
343            .body(body)
344            .send()
345            .await
346            .map_err(|e| AuthError::Network(e.to_string()))?;
347
348        if res.status() != 200 {
349            return Err(AuthError::Other(format!(
350                "Session PATCH (TAN completion) failed with status {}",
351                res.status()
352            )));
353        }
354
355        // Secondary OAuth (final step)
356        let mut form = HashMap::new();
357        form.insert("client_id", config.client_id.as_str());
358        form.insert("client_secret", config.client_secret.as_str());
359        form.insert("grant_type", "cd_secondary");
360        form.insert("token", access_token.as_str());
361
362        let res = client
363            .post("https://api.comdirect.de/oauth/token")
364            .header(ACCEPT, "application/json")
365            .header(CONTENT_TYPE, "application/x-www-form-urlencoded")
366            .form(&form)
367            .send()
368            .await
369            .map_err(|e| AuthError::Network(e.to_string()))?;
370
371        if res.status() != 200 {
372            return Err(AuthError::Other(format!(
373                "Secondary OAuth token POST failed with status {}",
374                res.status()
375            )));
376        }
377
378        let token_res: TokenResponse = res
379            .json()
380            .await
381            .map_err(|e| AuthError::Other(e.to_string()))?;
382
383        if let Some(callback) = &config.on_refresh_token {
384            callback(token_res.refresh_token.clone());
385        }
386
387        info!("Authentication sequence completed successfully!");
388
389        Ok(AuthResult {
390            access_token: token_res.access_token,
391            refresh_token: token_res.refresh_token,
392            expires_in: token_res.expires_in,
393            session_id,
394        })
395    }
396
397    /// Obtains the session ID required for API requests.
398    async fn get_session_id(
399        client: &reqwest::Client,
400        access_token: &str,
401    ) -> Result<String, AuthError> {
402        let generated_session_id = Uuid::new_v4().to_string();
403        let req_info = serde_json::json!({
404            "clientRequestId": {
405                "sessionId": generated_session_id,
406                "requestId": timestamp()
407            }
408        })
409        .to_string();
410
411        let res = client
412            .get("https://api.comdirect.de/api/session/clients/user/v1/sessions")
413            .header(ACCEPT, "application/json")
414            .header(AUTHORIZATION, format!("Bearer {}", access_token))
415            .header("x-http-request-info", req_info)
416            .send()
417            .await
418            .map_err(|e| AuthError::Network(e.to_string()))?;
419
420        if res.status() != 200 {
421            return Err(AuthError::Other(format!(
422                "GET sessions failed with status {}",
423                res.status()
424            )));
425        }
426
427        let sessions_res: Value = res
428            .json()
429            .await
430            .map_err(|e| AuthError::Other(e.to_string()))?;
431        sessions_res[0]["identifier"]
432            .as_str()
433            .map(|s| s.to_string())
434            .ok_or_else(|| AuthError::Other("Missing session identifier".to_string()))
435    }
436
437    /// Performs an authorized GET request.
438    ///
439    /// This method only modifies the request with the session ID and timestamp and adds
440    /// the Bearer token. It does NOT automatically refresh the access token if it's expired.
441    pub(crate) async fn get_authorized(&self, url: &str) -> Result<reqwest::Response, AuthError> {
442        let access_token = self.state.access_token.read().await.clone();
443        let session_id = self.state.session_id.read().await.clone();
444
445        let req_info = serde_json::json!({
446            "clientRequestId": {
447                "sessionId": session_id,
448                "requestId": timestamp()
449            }
450        })
451        .to_string();
452
453        let res = self
454            .client
455            .get(url)
456            .header(ACCEPT, "application/json")
457            .header(AUTHORIZATION, format!("Bearer {}", access_token))
458            .header("x-http-request-info", req_info)
459            .send()
460            .await
461            .map_err(|e| AuthError::Network(e.to_string()))?;
462
463        if res.status() == 401 {
464            // Signal the background worker that a refresh is necessary.
465            let _ = self.refresh_tx.send(0);
466        }
467
468        if res.status() != 200 {
469            return Err(AuthError::Other(format!(
470                "Authorized GET to {} failed with status {}",
471                url,
472                res.status()
473            )));
474        }
475        Ok(res)
476    }
477
478    /// Internal logic for token refresh, shared between background task and on-demand refresh.
479    /// This also validates the new token by obtaining a new session ID.
480    async fn perform_refresh_internal(
481        client: &reqwest::Client,
482        config: &ComdirectConfig,
483        state: &SessionState,
484        refresh_tx: &mpsc::UnboundedSender<u64>,
485    ) -> Result<(), AuthError> {
486        let current_refresh = state.refresh_token.read().await.clone();
487
488        let mut form = HashMap::new();
489        form.insert("client_id", config.client_id.as_str());
490        form.insert("client_secret", config.client_secret.as_str());
491        form.insert("grant_type", "refresh_token");
492        form.insert("refresh_token", current_refresh.trim());
493
494        let res = client
495            .post("https://api.comdirect.de/oauth/token")
496            .header(ACCEPT, "application/json")
497            .header(CONTENT_TYPE, "application/x-www-form-urlencoded")
498            .form(&form)
499            .send()
500            .await
501            .map_err(|e| AuthError::Network(e.to_string()))?;
502
503        if res.status() != 200 {
504            return Err(AuthError::InvalidCredentials(format!(
505                "Refresh failed with status {}",
506                res.status()
507            )));
508        }
509
510        let token_res: TokenResponse = res
511            .json()
512            .await
513            .map_err(|e| AuthError::Other(e.to_string()))?;
514
515        // Validate the new token by getting a session ID before committing the changes.
516        let session_id = Self::get_session_id(client, &token_res.access_token).await?;
517
518        let mut token_lock = state.access_token.write().await;
519        *token_lock = token_res.access_token;
520        drop(token_lock);
521
522        let mut refresh_lock = state.refresh_token.write().await;
523        *refresh_lock = token_res.refresh_token.clone();
524        drop(refresh_lock);
525
526        let mut session_lock = state.session_id.write().await;
527        *session_lock = session_id;
528        drop(session_lock);
529
530        if let Some(callback) = &config.on_refresh_token {
531            callback(token_res.refresh_token.clone());
532        }
533
534        // Notify worker of the new expiration time
535        let _ = refresh_tx.send(token_res.expires_in);
536
537        Ok(())
538    }
539}