Skip to main content

qm_keycloak/
session.rs

1use keycloak::KeycloakError;
2use keycloak::KeycloakTokenSupplier;
3use std::{sync::Arc, time::Duration};
4use tokio::runtime::Builder;
5use tokio::sync::RwLock;
6use tokio::task::LocalSet;
7
8/// Errors for Keycloak session operations.
9#[derive(Debug, Clone)]
10pub enum KeycloakSessionError {
11    /// Request failure.
12    ReqwestFailure(Arc<reqwest::Error>),
13    /// HTTP failure with status and text.
14    HttpFailure {
15        /// HTTP status code.
16        status: u16,
17        /// Response text.
18        text: Arc<str>,
19    },
20    /// Decode failure.
21    Decode(Arc<serde_json::Error>),
22}
23
24impl From<reqwest::Error> for KeycloakSessionError {
25    fn from(value: reqwest::Error) -> Self {
26        KeycloakSessionError::ReqwestFailure(Arc::new(value))
27    }
28}
29
30impl std::error::Error for KeycloakSessionError {}
31impl std::fmt::Display for KeycloakSessionError {
32    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
33        match self {
34            KeycloakSessionError::HttpFailure { text, .. } => {
35                writeln!(f, "keycloak error: {}", text.as_ref())
36            }
37            KeycloakSessionError::ReqwestFailure(e) => e.fmt(f),
38            KeycloakSessionError::Decode(e) => e.fmt(f),
39        }
40    }
41}
42
43async fn error(response: reqwest::Response) -> Result<reqwest::Response, KeycloakSessionError> {
44    if !response.status().is_success() {
45        let status = response.status();
46        let text = response.text().await;
47        return match text {
48            Ok(text) => Err(KeycloakSessionError::HttpFailure {
49                status: status.as_u16(),
50                text: Arc::from(text),
51            }),
52            Err(e) => Err(KeycloakSessionError::ReqwestFailure(Arc::new(e))),
53        };
54    }
55
56    Ok(response)
57}
58
59/// Parsed access token from Keycloak (equivalent to KeycloakAccessTokenResponse).
60#[derive(Debug, serde::Deserialize, serde::Serialize)]
61pub struct ParsedAccessToken {
62    /// Expiration time (unix seconds).
63    exp: usize,
64    /// Issued at time (unix seconds).
65    iat: usize,
66    /// JWT ID.
67    jti: Option<String>,
68    /// Issuer.
69    iss: Option<String>,
70    /// Subject (user ID).
71    sub: Option<String>,
72    /// Token type.
73    typ: Option<String>,
74    /// Authorized party (client ID).
75    azp: Option<String>,
76    /// Nonce.
77    nonce: Option<String>,
78    /// Session state.
79    session_state: Option<String>,
80    /// Authentication context class reference.
81    acr: Option<String>,
82    /// Allowed actions.
83    allowed: Option<Vec<String>>,
84    /// Scope.
85    scope: Option<String>,
86    /// Session ID.
87    sid: Option<String>,
88    /// Whether email is verified.
89    #[serde(default)]
90    email_verified: bool,
91    /// Preferred username.
92    preferred_username: Option<String>,
93}
94
95/// Session token from Keycloak.
96#[derive(Debug, serde::Deserialize, serde::Serialize)]
97pub struct KeycloakSessionToken {
98    /// Access token.
99    access_token: Arc<str>,
100    /// Time until expiration.
101    expires_in: usize,
102    /// Not before policy.
103    #[serde(rename = "not-before-policy")]
104    not_before_policy: Option<usize>,
105    /// Time until refresh token expires.
106    refresh_expires_in: Option<usize>,
107    /// Refresh token.
108    refresh_token: Arc<str>,
109    /// Scope.
110    scope: String,
111    /// Session state.
112    session_state: Option<String>,
113    /// Token type.
114    token_type: String,
115    /// Parsed access token.
116    #[serde(skip)]
117    parsed_access_token: Option<ParsedAccessToken>,
118    /// Client token (type + access_token).
119    #[serde(skip)]
120    client_token: Option<Arc<str>>,
121}
122
123impl KeycloakSessionToken {
124    fn parse_access_token(mut token: Self) -> Self {
125        use base64::engine::{general_purpose::STANDARD_NO_PAD, Engine};
126        if let Some(parsed_access_token) = token
127            .access_token
128            .split('.')
129            .nth(1)
130            .and_then(|s| {
131                STANDARD_NO_PAD
132                    .decode(s)
133                    .map_err(|e| {
134                        tracing::error!("{e:#?}");
135                        e
136                    })
137                    .ok()
138            })
139            .and_then(|b| {
140                serde_json::from_slice::<ParsedAccessToken>(&b)
141                    .map_err(|e| {
142                        tracing::error!("{e:#?}");
143                        e
144                    })
145                    .ok()
146            })
147        {
148            token.parsed_access_token = Some(parsed_access_token);
149        }
150        token.client_token = Some(Arc::from(format!(
151            "{} {}",
152            &token.token_type, &token.access_token
153        )));
154        token
155    }
156}
157
158struct KeycloakSessionClientInner {
159    url: Arc<str>,
160    realm: Arc<str>,
161    client_id: Arc<str>,
162    client: reqwest::Client,
163}
164
165#[derive(Clone)]
166/// Keycloak session client.
167pub struct KeycloakSessionClient {
168    inner: Arc<KeycloakSessionClientInner>,
169}
170
171impl KeycloakSessionClient {
172    /// Creates a new KeycloakSessionClient.
173    pub fn new<T>(url: T, realm: T, client_id: T) -> Self
174    where
175        T: Into<String>,
176    {
177        Self {
178            inner: Arc::new(KeycloakSessionClientInner {
179                url: Arc::from(url.into()),
180                realm: Arc::from(realm.into()),
181                client_id: Arc::from(client_id.into()),
182                client: reqwest::Client::default(),
183            }),
184        }
185    }
186
187    async fn acquire(
188        &self,
189        username: &str,
190        password: &str,
191    ) -> Result<KeycloakSessionToken, KeycloakSessionError> {
192        let url = self.inner.url.as_ref();
193        let realm = self.inner.realm.as_ref();
194        let client_id = self.inner.client_id.as_ref();
195        let result = error(
196            self.inner
197                .client
198                .post(format!(
199                    "{url}/realms/{realm}/protocol/openid-connect/token",
200                ))
201                .form(&serde_json::json!({
202                    "username": username,
203                    "password": password,
204                    "client_id": client_id,
205                    "grant_type": "password"
206                }))
207                .send()
208                .await?,
209        )
210        .await?
211        .json::<serde_json::Value>()
212        .await?;
213        tracing::debug!(
214            "Acquire result: {}",
215            serde_json::to_string_pretty(&result).unwrap()
216        );
217        serde_json::from_value(result).map_err(|err| KeycloakSessionError::Decode(Arc::new(err)))
218    }
219
220    async fn acquire_with_secret(
221        &self,
222        secret: &str,
223    ) -> Result<KeycloakSessionToken, KeycloakSessionError> {
224        let url = self.inner.url.as_ref();
225        let realm = self.inner.realm.as_ref();
226        let client_id = self.inner.client_id.as_ref();
227
228        // curl \
229        // -d "client_id=R09219E08" \
230        // -d "client_secret=wBdk1Z3GXm2YXRrtbgcEMLrVsbL8jjwn" \
231        // -d "grant_type=client_credentials" \
232        // "https://id.shapth.homenet/realms/shapth/protocol/openid-connect/token"
233        let result = error(
234            self.inner
235                .client
236                .post(format!(
237                    "{url}/realms/{realm}/protocol/openid-connect/token",
238                ))
239                .form(&serde_json::json!({
240                    "client_id": client_id,
241                    "client_secret": secret,
242                    "grant_type": "client_credentials"
243                }))
244                .send()
245                .await?,
246        )
247        .await?
248        .json::<serde_json::Value>()
249        .await?;
250        tracing::debug!(
251            "Acquire result: {}",
252            serde_json::to_string_pretty(&result).unwrap()
253        );
254        serde_json::from_value(result).map_err(|err| KeycloakSessionError::Decode(Arc::new(err)))
255    }
256
257    async fn refresh(
258        &self,
259        refresh_token: &str,
260    ) -> Result<KeycloakSessionToken, KeycloakSessionError> {
261        let url = self.inner.url.as_ref();
262        let realm = self.inner.realm.as_ref();
263        let client_id = self.inner.client_id.as_ref();
264        let result = error(
265            self.inner
266                .client
267                .post(format!(
268                    "{url}/realms/{realm}/protocol/openid-connect/token",
269                ))
270                .form(&serde_json::json!({
271                    "grant_type": "refresh_token",
272                    "refresh_token": refresh_token,
273                    "client_id": client_id,
274                }))
275                .send()
276                .await?,
277        )
278        .await?
279        .json::<serde_json::Value>()
280        .await?;
281        tracing::debug!(
282            "Refresh result: {}",
283            serde_json::to_string_pretty(&result).unwrap()
284        );
285        serde_json::from_value(result).map_err(|err| KeycloakSessionError::Decode(Arc::new(err)))
286    }
287}
288
289async fn try_refresh(
290    keycloak: &KeycloakSessionClient,
291    refresh_token: &str,
292    username: &str,
293    password: &str,
294) -> Result<KeycloakSessionToken, KeycloakSessionError> {
295    tracing::debug!("refresh session for user {username}");
296    match keycloak.refresh(refresh_token).await {
297        Ok(token) => Ok(KeycloakSessionToken::parse_access_token(token)),
298        Err(err) => {
299            if let KeycloakSessionError::HttpFailure { status, .. } = &err {
300                if *status == 400 {
301                    tracing::error!(
302                        "refresh token expired try to acquire new token with credentials"
303                    );
304                    tracing::error!("{:#?}", err);
305                    keycloak
306                        .acquire(username, password)
307                        .await
308                        .map(KeycloakSessionToken::parse_access_token)
309                } else {
310                    Err(err)
311                }
312            } else {
313                Err(err)
314            }
315        }
316    }
317}
318
319async fn try_refresh_with_secret(
320    keycloak: &KeycloakSessionClient,
321    refresh_token: &str,
322    secret: &str,
323) -> Result<KeycloakSessionToken, KeycloakSessionError> {
324    tracing::debug!("refresh session for api client");
325    match keycloak.refresh(refresh_token).await {
326        Ok(token) => Ok(KeycloakSessionToken::parse_access_token(token)),
327        Err(err) => {
328            if let KeycloakSessionError::HttpFailure { status, .. } = &err {
329                if *status == 400 {
330                    tracing::error!(
331                        "refresh token expired try to acquire new token with credentials"
332                    );
333                    tracing::error!("{:#?}", err);
334                    keycloak
335                        .acquire_with_secret(secret)
336                        .await
337                        .map(KeycloakSessionToken::parse_access_token)
338                } else {
339                    Err(err)
340                }
341            } else {
342                Err(err)
343            }
344        }
345    }
346}
347
348struct KeycloakSessionInner {
349    username: Arc<str>,
350    password: Arc<str>,
351    token: RwLock<KeycloakSessionToken>,
352    stop_tx: tokio::sync::watch::Sender<bool>,
353}
354
355#[derive(Clone)]
356/// Keycloak session for user authentication.
357pub struct KeycloakSession {
358    inner: Arc<KeycloakSessionInner>,
359}
360
361impl Drop for KeycloakSession {
362    fn drop(&mut self) {
363        self.inner.stop_tx.send(false).ok();
364    }
365}
366
367impl KeycloakSession {
368    /// Creates a new Keycloak session.
369    pub async fn new(
370        keycloak: KeycloakSessionClient,
371        username: &str,
372        password: &str,
373        refresh_enabled: bool,
374    ) -> anyhow::Result<Self> {
375        let token = keycloak
376            .acquire(username, password)
377            .await
378            .map(KeycloakSessionToken::parse_access_token)?;
379        let username: Arc<str> = Arc::from(username.to_string());
380        let password: Arc<str> = Arc::from(password.to_string());
381        let (stop_tx, stop_signal) = tokio::sync::watch::channel(true);
382        let result = KeycloakSession {
383            inner: Arc::new(KeycloakSessionInner {
384                username,
385                password,
386                token: RwLock::new(token),
387                stop_tx,
388            }),
389        };
390        if refresh_enabled {
391            let keycloak = keycloak.clone();
392            let session = result.clone();
393            std::thread::spawn(move || {
394                let rt = Builder::new_current_thread().enable_all().build().unwrap();
395                let local = LocalSet::new();
396                local.spawn_local(async move {
397                    let username = &session.inner.username;
398                    let password = &session.inner.password;
399                    loop {
400                        let (expires_in, refresh_expires_in) = async {
401                            let r = session.inner.token.read().await;
402                            (r.expires_in, r.refresh_expires_in)
403                        }
404                        .await;
405                        tracing::debug!("{expires_in} -> {refresh_expires_in:#?}");
406                        let refresh_future = async {
407                            tokio::time::sleep(Duration::from_secs(
408                                expires_in
409                                    .checked_sub(30)
410                                    .ok_or(anyhow::anyhow!("unable to calculate refresh timeout"))?
411                                    as u64,
412                            ))
413                            .await;
414                            let next_token = async {
415                                try_refresh(
416                                    &keycloak,
417                                    &session.inner.token.read().await.refresh_token,
418                                    username,
419                                    password,
420                                )
421                                .await
422                            }
423                            .await;
424                            match next_token {
425                                Ok(next_token) => {
426                                    *session.inner.token.write().await = next_token;
427                                }
428                                Err(err) => {
429                                    tracing::error!("{err:#?}");
430                                    std::process::exit(1)
431                                }
432                            }
433                            anyhow::Ok(true)
434                        };
435                        let stop_future = async {
436                            let mut stop_signal = stop_signal.clone();
437                            stop_signal.changed().await?;
438                            let result = *stop_signal.borrow_and_update();
439                            anyhow::Ok(result)
440                        };
441                        tokio::select! {
442                            result = refresh_future => {
443                                match result {
444                                    Ok(_) => {},
445                                    Err(_) => {
446                                        tracing::debug!("acquire new session");
447                                        match keycloak
448                                            .acquire(username, password)
449                                            .await
450                                            .map(KeycloakSessionToken::parse_access_token) {
451                                            Ok(next_token) => {
452                                                *session.inner.token.write().await = next_token;
453                                            },
454                                            Err(err) => {
455                                                tracing::error!("{err:#?}");
456                                                std::process::exit(1)
457                                            }
458                                        }
459                                    }
460                                }
461                            }
462                            is_logged_in = stop_future => {
463                                if !is_logged_in.unwrap_or(false) {
464                                    break
465                                }
466                            }
467                        }
468                    }
469                    tracing::debug!("session ends for user {username}");
470                    anyhow::Ok(())
471                });
472                rt.block_on(local);
473            });
474        }
475        Ok(result)
476    }
477
478    /// Stops the session.
479    pub fn stop(&self) -> anyhow::Result<()> {
480        tracing::debug!("stop session for {}", self.inner.username);
481        self.inner.stop_tx.send(false)?;
482        Ok(())
483    }
484
485    /// Gets the access token.
486    pub async fn access_token(&self) -> Arc<str> {
487        self.inner.token.read().await.access_token.clone()
488    }
489
490    /// Gets the token.
491    pub async fn token(&self) -> Arc<str> {
492        self.inner
493            .token
494            .read()
495            .await
496            .client_token
497            .as_ref()
498            .unwrap()
499            .clone()
500    }
501}
502
503#[async_trait::async_trait]
504impl KeycloakTokenSupplier for KeycloakSession {
505    async fn get(&self, _url: &str) -> Result<String, KeycloakError> {
506        Ok(self.inner.token.read().await.access_token.to_string())
507    }
508}
509
510struct KeycloakApiClientSessionInner {
511    secret: Arc<str>,
512    token: RwLock<KeycloakSessionToken>,
513    stop_tx: tokio::sync::watch::Sender<bool>,
514}
515
516#[derive(Clone)]
517/// Keycloak API client session for service accounts.
518pub struct KeycloakApiClientSession {
519    inner: Arc<KeycloakApiClientSessionInner>,
520}
521
522impl Drop for KeycloakApiClientSession {
523    fn drop(&mut self) {
524        self.inner.stop_tx.send(false).ok();
525    }
526}
527
528impl KeycloakApiClientSession {
529    /// Creates a new KeycloakApiClientSession.
530    pub async fn new(
531        keycloak: KeycloakSessionClient,
532        secret: &str,
533        refresh_enabled: bool,
534    ) -> anyhow::Result<Self> {
535        let token = keycloak
536            .acquire_with_secret(secret)
537            .await
538            .map(KeycloakSessionToken::parse_access_token)?;
539        let secret: Arc<str> = Arc::from(secret.to_string());
540        let (stop_tx, stop_signal) = tokio::sync::watch::channel(true);
541        let result = KeycloakApiClientSession {
542            inner: Arc::new(KeycloakApiClientSessionInner {
543                secret,
544                token: RwLock::new(token),
545                stop_tx,
546            }),
547        };
548        if refresh_enabled {
549            let keycloak = keycloak.clone();
550            let session = result.clone();
551            std::thread::spawn(move || {
552                let rt = Builder::new_current_thread().enable_all().build().unwrap();
553                let local = LocalSet::new();
554                local.spawn_local(async move {
555                    let secret = &session.inner.secret;
556                    loop {
557                        let expires_in = session.inner.token.read().await.expires_in;
558                        let refresh_future = async {
559                            tokio::time::sleep(Duration::from_secs(
560                                expires_in
561                                    .checked_sub(30)
562                                    .ok_or(anyhow::anyhow!("unable to calculate refresh timeout"))?
563                                    as u64,
564                            ))
565                            .await;
566                            let next_token = async {
567                                try_refresh_with_secret(
568                                    &keycloak,
569                                    &session.inner.token.read().await.refresh_token,
570                                    secret,
571                                )
572                                .await
573                            }
574                            .await;
575                            match next_token {
576                                Ok(next_token) => {
577                                    *session.inner.token.write().await = next_token;
578                                }
579                                Err(err) => {
580                                    tracing::error!("{err:#?}");
581                                    std::process::exit(1)
582                                }
583                            }
584                            anyhow::Ok(true)
585                        };
586                        let stop_future = async {
587                            let mut stop_signal = stop_signal.clone();
588                            stop_signal.changed().await?;
589                            let result = *stop_signal.borrow_and_update();
590                            anyhow::Ok(result)
591                        };
592                        tokio::select! {
593                            result = refresh_future => {
594                                match result {
595                                    Ok(_) => {},
596                                    Err(_) => {
597                                        tracing::debug!("acquire new session");
598                                        match keycloak
599                                            .acquire_with_secret(secret)
600                                            .await
601                                            .map(KeycloakSessionToken::parse_access_token) {
602                                            Ok(next_token) => {
603                                                *session.inner.token.write().await = next_token;
604                                            },
605                                            Err(err) => {
606                                                tracing::error!("{err:#?}");
607                                                std::process::exit(1)
608                                            }
609                                        }
610                                    }
611                                }
612                            }
613                            is_logged_in = stop_future => {
614                                if !is_logged_in.unwrap_or(false) {
615                                    break
616                                }
617                            }
618                        }
619                    }
620                    tracing::debug!("session ends for api client");
621                    anyhow::Ok(())
622                });
623                rt.block_on(local);
624            });
625        }
626        Ok(result)
627    }
628
629    /// Stops the session.
630    pub fn stop(&self) -> anyhow::Result<()> {
631        tracing::debug!("stop session for {}", self.inner.secret);
632        self.inner.stop_tx.send(false)?;
633        Ok(())
634    }
635
636    /// Gets the access token.
637    pub async fn access_token(&self) -> Arc<str> {
638        self.inner.token.read().await.access_token.clone()
639    }
640
641    /// Gets the token.
642    pub async fn token(&self) -> Arc<str> {
643        self.inner
644            .token
645            .read()
646            .await
647            .client_token
648            .as_ref()
649            .unwrap()
650            .clone()
651    }
652}
653
654#[async_trait::async_trait]
655impl KeycloakTokenSupplier for KeycloakApiClientSession {
656    async fn get(&self, _url: &str) -> Result<String, KeycloakError> {
657        Ok(self.inner.token.read().await.access_token.to_string())
658    }
659}