cipherstash-client 0.34.1-alpha.3

The official CipherStash SDK
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
pub mod user_token;

mod auth0;
mod okta;

use std::{
    path::Path,
    time::{Duration, SystemTime, UNIX_EPOCH},
};

use async_mutex::Mutex as AsyncMutex;
use async_trait::async_trait;
use auth0::Auth0UserCredentials;
use miette::Diagnostic;
use okta::OktaUserCredentials;
use serde::Deserialize;
use thiserror::Error;
use tracing::{debug, error};
use url::Url;

use crate::{
    config::idp_provider::IdpProvider,
    credentials::{
        token_store::TokenStore, AutoRefreshable, ClearTokenError, Credentials, GetTokenError,
        TokenExpiry,
    },
};

pub use user_token::UserToken;

// The offline_access scope is used for requesting a refresh token
// The cipherstash:admin scope is used by self hosted CTS to allow access to
// management endpoints
pub const DEFAULT_REQUESTED_SCOPES: &str = "offline_access cipherstash:admin";

#[derive(Deserialize)]
pub(crate) struct PollingInfo {
    pub user_code: String,
    pub device_code: String,
    pub verification_uri_complete: String,
}

#[derive(Deserialize)]
pub(crate) struct AccessTokenResponse {
    pub refresh_token: String,
    pub access_token: String,
    pub expires_in: u64,
}

impl From<AccessTokenResponse> for UserToken {
    fn from(value: AccessTokenResponse) -> Self {
        Self {
            access_token: value.access_token,
            refresh_token: value.refresh_token,
            expiry: value.expires_in + now_secs(),
        }
    }
}

pub fn now_secs() -> u64 {
    SystemTime::now()
        .duration_since(UNIX_EPOCH)
        .expect("Expected system time to be greater than UNIX_EPOCH")
        .as_secs()
}

/// Show a prompt in the terminal and open a browser (if available) with an authentication link for generating
/// an access token.
pub(crate) fn prompt_user(polling_info: &PollingInfo) {
    if open::that(&polling_info.verification_uri_complete).is_err() {
        println!(
            "Failed to open web browser. Please manually click the link in the following message."
        )
    }

    let user_code = &polling_info.user_code;
    let code_len = user_code.len();

    println!();
    println!("### ACTION REQUIRED ###");
    println!();
    println!(
        "Visit {} to complete authentication by following the below steps:",
        polling_info.verification_uri_complete
    );
    println!();
    println!("1. Verify that this code matches the code in your browser");
    println!();
    println!("             +------{}------+", "-".repeat(code_len));
    println!("             |      {}      |", " ".repeat(code_len));
    println!("             |      {user_code}      |");
    println!("             |      {}      |", " ".repeat(code_len));
    println!("             +------{}------+", "-".repeat(code_len));
    println!();
    println!("2. If the codes match, click on the confirm button in the browser");
    println!();
    println!("Waiting for authentication...");
}

#[derive(Diagnostic, Error, Debug)]
pub enum RefreshTokenError {
    #[error("Failed to redeem refresh token: {0}")]
    RequestFailed(reqwest::Error),

    #[error("Failed to parse json response: {0}")]
    BadResponse(reqwest::Error),
}

#[derive(Diagnostic, Error, Debug)]
pub enum NewTokenError {
    #[error("Failed to parse Url: {0}")]
    UrlParse(#[from] url::ParseError),
    #[error("Failed to get device code: {0}")]
    DeviceCodeRequestFailed(reqwest::Error),

    #[error("Failed to parse polling info json response: {0}")]
    DeviceCodeBadResponse(reqwest::Error),

    #[error("Failed to poll for new token: {0}")]
    PollTokenRequestFailed(reqwest::Error),

    #[error("Failed to parse access token response: {0}")]
    PollTokenBadResponse(reqwest::Error),

    #[error("Failed to parse pending auth response: {0}")]
    PollTokenBadPendingResponse(reqwest::Error),

    #[error("Device code authentication failed: {0}")]
    PollTokenAuthFailed(String),

    #[error("Unexpected error code in response body: {0}")]
    PollTokenUnexpected(String),
}

pub struct UserCredentials {
    token_store: AsyncMutex<TokenStore<UserToken>>,
    provider: UserCredentialsProvider,
}

enum UserCredentialsProvider {
    Auth0(Auth0UserCredentials),
    Okta(OktaUserCredentials),
}

impl UserCredentials {
    pub fn new(
        idp_token_path: &Path,
        idp_base_url: &Url,
        idp_audience: &str,
        idp_client_id: &str,
        idp_provider: IdpProvider,
    ) -> Self {
        let provider = match idp_provider {
            IdpProvider::Auth0 => UserCredentialsProvider::Auth0(Auth0UserCredentials::new(
                idp_base_url,
                idp_audience,
                idp_client_id,
            )),

            IdpProvider::Okta => {
                UserCredentialsProvider::Okta(OktaUserCredentials::new(idp_base_url, idp_client_id))
            }
        };

        Self {
            token_store: AsyncMutex::new(TokenStore::new(idp_token_path)),
            provider,
        }
    }

    /// Authenticate the user interactively to get a new token.
    /// This starts the device code flow and waits for the user to authenticate.
    // TODO: This code should probably be in CLI itself so we can use the UI to prompt the user
    pub async fn authenticate_interactively(&self) -> Result<UserToken, GetTokenError> {
        let new_token = self
            .provider
            .acquire_new_token()
            .await
            .map_err(|err| GetTokenError::AcquireNewTokenFailed(Box::new(err)))?;

        // Saves the token to disk
        let mut token_store = self.token_store.lock().await;
        token_store
            .set(&new_token)
            .map_err(|e| GetTokenError::PersistTokenError(Box::new(e)))?;

        Ok(new_token)
    }
}

impl UserCredentialsProvider {
    async fn refresh_access_token(
        &self,
        cached_token: &UserToken,
    ) -> Result<Option<UserToken>, RefreshTokenError> {
        match self {
            Self::Auth0(creds) => creds.refresh_access_token(cached_token).await,
            Self::Okta(creds) => creds.refresh_access_token(cached_token).await,
        }
    }

    async fn acquire_new_token(&self) -> Result<UserToken, NewTokenError> {
        match self {
            Self::Auth0(creds) => creds.acquire_new_token().await,
            Self::Okta(creds) => creds.acquire_new_token().await,
        }
    }
}

#[async_trait]
impl Credentials for UserCredentials {
    type Token = UserToken;

    async fn get_token(&self) -> Result<Self::Token, GetTokenError> {
        // Fast path: lock briefly to check cache, then drop.
        // When tokens are cached and valid, concurrent callers return immediately.
        let cached_token = {
            let mut token_store = self.token_store.lock().await;
            token_store.get()
        };

        if let Some(cached_token) = &cached_token {
            if !cached_token.is_expired() {
                return Ok(cached_token.clone());
            }
        }

        // Slow path: token is expired or missing. Re-acquire lock and double-check
        // before refreshing. This serializes refresh-token exchange to prevent
        // multiple callers from consuming a single-use refresh token concurrently
        // (Auth0/Okta rotate refresh tokens, so only one exchange succeeds).
        let mut token_store = self.token_store.lock().await;

        // Double-check: another caller may have refreshed while we waited for the lock
        if let Some(cached_token) = token_store.get() {
            if !cached_token.is_expired() {
                return Ok(cached_token);
            }

            // Still expired — we hold the lock, so we're the sole refresher
            if let Some(new_token) = self
                .provider
                .refresh_access_token(&cached_token)
                .await
                .map_err(|e| {
                    error!("Failed to refresh token: {}", e);
                    GetTokenError::RefreshTokenFailed(Box::new(e))
                })?
            {
                token_store
                    .set(&new_token)
                    .map_err(|e| GetTokenError::PersistTokenError(Box::new(e)))?;

                return Ok(new_token);
            }
        }

        Err(GetTokenError::MissingOrExpired)
    }

    async fn clear_token(&self) -> Result<(), ClearTokenError> {
        let mut token_store = self.token_store.lock().await;
        token_store
            .clear()
            .map_err(|e| ClearTokenError(Box::new(e)))
    }
}

#[async_trait]
impl AutoRefreshable for UserCredentials {
    async fn refresh(&self) -> Duration {
        // Fast path: lock briefly to check if refresh is needed
        let token = {
            let mut token_store = self.token_store.lock().await;
            token_store.get()
        };

        if let Some(cached_token) = &token {
            if !cached_token.should_refresh() {
                debug!(target: "console_credentials", "Access token is still new");
                return cached_token.refresh_interval();
            }
        }

        // Slow path: re-lock and hold through refresh to serialize with
        // get_token(). Auth0/Okta rotate refresh tokens (single-use), so
        // concurrent exchanges between the background loop and foreground
        // callers cause auth failures.
        let mut token_store = self.token_store.lock().await;

        // Double-check: get_token() or a previous refresh() may have already refreshed
        if let Some(cached_token) = token_store.get() {
            if !cached_token.should_refresh() {
                debug!(target: "console_credentials", "Access token already refreshed by another caller");
                return cached_token.refresh_interval();
            }

            debug!(target: "console_credentials", "Access token close to expiry, refreshing");
            match self.provider.refresh_access_token(&cached_token).await {
                Ok(Some(new_token)) => {
                    if let Err(err) = token_store.set(&new_token) {
                        tracing::warn!(
                            target: "console_credentials",
                            error = %err,
                            "Failed to persist refreshed token"
                        );
                    } else {
                        debug!(target: "console_credentials", "Access token refreshed and saved to disk");
                        return new_token.refresh_interval();
                    }
                }
                Ok(None) => {
                    tracing::warn!(
                        target: "console_credentials",
                        "Token refresh returned no new token"
                    );
                }
                Err(err) => {
                    tracing::warn!(
                        target: "console_credentials",
                        error = %err,
                        "Failed to refresh user token"
                    );
                }
            }
        }

        Self::Token::min_refresh_interval()
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::config::idp_provider::IdpProvider;
    use crate::credentials::test_utils::CountingState;
    use std::sync::Arc;

    /// Simulates Auth0 /oauth/token refresh endpoint with latency
    async fn slow_refresh(
        axum::extract::State(state): axum::extract::State<CountingState>,
    ) -> axum::Json<serde_json::Value> {
        state.enter();
        tokio::time::sleep(std::time::Duration::from_millis(100)).await;
        state.exit();
        // AccessTokenResponse format (snake_case, not camelCase)
        axum::Json(serde_json::json!({
            "refresh_token": "new-refresh",
            "access_token": "new-access",
            "expires_in": 3600u64
        }))
    }

    /// Refresh-token exchange must be serialized because Auth0/Okta rotate
    /// refresh tokens (single-use). Only one caller should hit the IdP;
    /// others wait and pick up the refreshed token via double-check.
    #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
    async fn test_user_credentials_serializes_refresh_token_exchange() {
        let state = CountingState::new();
        let stats = state.clone();

        let app = axum::Router::new()
            .route("/oauth/token", axum::routing::post(slow_refresh))
            .with_state(state);

        let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
        let addr = listener.local_addr().unwrap();
        tokio::spawn(async move {
            axum::serve(listener, app).await.unwrap();
        });

        let tmp = tempfile::TempDir::new().unwrap();
        let token_path = tmp.path().join("idp_token.json");
        let base_url = Url::parse(&format!("http://{addr}")).unwrap();

        // Seed an expired token on disk so get_token() enters the refresh path.
        let expired_token = UserToken::new_from_raw("test-refresh", "expired-access", 0);
        std::fs::write(&token_path, serde_json::to_string(&expired_token).unwrap()).unwrap();

        let creds = Arc::new(UserCredentials::new(
            &token_path,
            &base_url,
            "test-audience",
            "test-client-id",
            IdpProvider::Auth0,
        ));

        let mut handles = vec![];
        for _ in 0..5 {
            let creds = creds.clone();
            handles.push(tokio::spawn(
                async move { creds.get_token().await.unwrap() },
            ));
        }

        for h in handles {
            h.await.unwrap();
        }

        let peak = stats.peak();
        let total = stats.total();
        assert_eq!(
            peak, 1,
            "Expected serialized refresh but peak concurrency was {peak}. \
             Concurrent refresh-token exchange can cause auth failures \
             with IdPs that rotate refresh tokens.",
        );
        assert_eq!(
            total, 1,
            "Expected exactly 1 refresh request but got {total}. \
             Double-check pattern should let waiters use the refreshed token.",
        );
    }
}