Skip to main content

fluss/client/
credentials.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::client::metadata::Metadata;
19use crate::error::{Error, Result};
20use crate::rpc::RpcClient;
21use crate::rpc::message::GetSecurityTokenRequest;
22use log::{debug, info, warn};
23use parking_lot::RwLock;
24use serde::Deserialize;
25use std::collections::HashMap;
26use std::sync::Arc;
27use std::time::{Duration, SystemTime, UNIX_EPOCH};
28use tokio::sync::{oneshot, watch};
29use tokio::task::JoinHandle;
30
31/// Default renewal time ratio - refresh at 80% of token lifetime
32const DEFAULT_TOKEN_RENEWAL_RATIO: f64 = 0.8;
33/// Default retry backoff when token fetch fails
34const DEFAULT_RENEWAL_RETRY_BACKOFF: Duration = Duration::from_secs(30);
35/// Minimum delay between refreshes
36const MIN_RENEWAL_DELAY: Duration = Duration::from_secs(1);
37/// Maximum delay between refreshes (7 days) - prevents overflow and ensures periodic refresh
38const MAX_RENEWAL_DELAY: Duration = Duration::from_secs(7 * 24 * 60 * 60);
39/// Default refresh interval for tokens without expiration (never expires)
40const DEFAULT_NON_EXPIRING_REFRESH_INTERVAL: Duration = Duration::from_secs(7 * 24 * 60 * 60); // 7 day
41
42/// Type alias for credentials properties receiver
43/// - `None` = not yet fetched, should wait
44/// - `Some(HashMap)` = fetched (may be empty if no auth needed)
45pub type CredentialsReceiver = watch::Receiver<Option<HashMap<String, String>>>;
46
47#[derive(Debug, Deserialize)]
48struct Credentials {
49    access_key_id: String,
50    access_key_secret: String,
51    security_token: Option<String>,
52}
53
54/// Returns (opendal_key, needs_inversion)
55/// needs_inversion is true for path_style_access -> enable_virtual_host_style conversion
56fn convert_hadoop_key_to_opendal(hadoop_key: &str) -> Option<(String, bool)> {
57    match hadoop_key {
58        // S3 specific configurations
59        "fs.s3a.endpoint" => Some(("endpoint".to_string(), false)),
60        "fs.s3a.endpoint.region" => Some(("region".to_string(), false)),
61        "fs.s3a.path.style.access" => Some(("enable_virtual_host_style".to_string(), true)),
62        "fs.s3a.connection.ssl.enabled" => None,
63        // OSS specific configurations
64        "fs.oss.endpoint" => Some(("endpoint".to_string(), false)),
65        "fs.oss.region" => Some(("region".to_string(), false)),
66        _ => None,
67    }
68}
69
70/// Build remote filesystem props from credentials and additional info
71fn build_remote_fs_props(
72    credentials: &Credentials,
73    addition_infos: &HashMap<String, String>,
74) -> HashMap<String, String> {
75    let mut props = HashMap::new();
76
77    props.insert(
78        "access_key_id".to_string(),
79        credentials.access_key_id.clone(),
80    );
81
82    // S3 specific configurations
83    props.insert(
84        "secret_access_key".to_string(),
85        credentials.access_key_secret.clone(),
86    );
87
88    // OSS specific configurations, todo: consider refactor it
89    // to handle different conversion for different scheme in different method
90    props.insert(
91        "access_key_secret".to_string(),
92        credentials.access_key_secret.clone(),
93    );
94
95    if let Some(token) = &credentials.security_token {
96        props.insert("security_token".to_string(), token.clone());
97    }
98
99    for (key, value) in addition_infos {
100        if let Some((opendal_key, transform)) = convert_hadoop_key_to_opendal(key) {
101            let final_value = if transform {
102                // Invert boolean value (path_style_access -> enable_virtual_host_style)
103                if value == "true" {
104                    "false".to_string()
105                } else {
106                    "true".to_string()
107                }
108            } else {
109                value.clone()
110            };
111            props.insert(opendal_key, final_value);
112        }
113    }
114
115    props
116}
117
118/// Manager for security tokens that refreshes tokens in a background task.
119///
120/// This follows the pattern from Java's `DefaultSecurityTokenManager`, where
121/// a background thread periodically refreshes tokens based on their expiration time.
122///
123/// Uses `tokio::sync::watch` channel to broadcast token updates to consumers.
124/// Consumers can subscribe by calling `subscribe()` to get a receiver.
125///
126/// The channel value is `Option<HashMap>`:
127/// - `None` = not yet fetched, consumers should wait
128/// - `Some(HashMap)` = fetched (may be empty if no auth needed)
129///
130/// # Example
131/// ```ignore
132/// let manager = SecurityTokenManager::new(rpc_client, metadata);
133/// let credentials_rx = manager.subscribe();
134/// manager.start();
135///
136/// // Consumer can get latest credentials via:
137/// let props = credentials_rx.borrow().clone();
138/// ```
139pub struct SecurityTokenManager {
140    rpc_client: Arc<RpcClient>,
141    metadata: Arc<Metadata>,
142    token_renewal_ratio: f64,
143    renewal_retry_backoff: Duration,
144    /// Watch channel sender for broadcasting token updates
145    credentials_tx: watch::Sender<Option<HashMap<String, String>>>,
146    /// Watch channel receiver (kept to allow cloning for new subscribers)
147    credentials_rx: watch::Receiver<Option<HashMap<String, String>>>,
148    /// Handle to the background refresh task
149    task_handle: RwLock<Option<JoinHandle<()>>>,
150    /// Sender to signal shutdown
151    shutdown_tx: RwLock<Option<oneshot::Sender<()>>>,
152}
153
154impl SecurityTokenManager {
155    pub fn new(rpc_client: Arc<RpcClient>, metadata: Arc<Metadata>) -> Self {
156        let (credentials_tx, credentials_rx) = watch::channel(None);
157        Self {
158            rpc_client,
159            metadata,
160            token_renewal_ratio: DEFAULT_TOKEN_RENEWAL_RATIO,
161            renewal_retry_backoff: DEFAULT_RENEWAL_RETRY_BACKOFF,
162            credentials_tx,
163            credentials_rx,
164            task_handle: RwLock::new(None),
165            shutdown_tx: RwLock::new(None),
166        }
167    }
168
169    /// Subscribe to credential updates.
170    /// Returns a receiver that always contains the latest credentials.
171    /// Consumers can call `receiver.borrow()` to get the current value.
172    pub fn subscribe(&self) -> CredentialsReceiver {
173        self.credentials_rx.clone()
174    }
175
176    /// Start the background token refresh task.
177    /// This should be called once after creating the manager.
178    pub fn start(&self) {
179        if self.task_handle.read().is_some() {
180            warn!("SecurityTokenManager is already started");
181            return;
182        }
183
184        let (shutdown_tx, shutdown_rx) = oneshot::channel();
185        *self.shutdown_tx.write() = Some(shutdown_tx);
186
187        let rpc_client = Arc::clone(&self.rpc_client);
188        let metadata = Arc::clone(&self.metadata);
189        let token_renewal_ratio = self.token_renewal_ratio;
190        let renewal_retry_backoff = self.renewal_retry_backoff;
191        let credentials_tx = self.credentials_tx.clone();
192
193        let handle = tokio::spawn(async move {
194            Self::token_refresh_loop(
195                rpc_client,
196                metadata,
197                token_renewal_ratio,
198                renewal_retry_backoff,
199                credentials_tx,
200                shutdown_rx,
201            )
202            .await;
203        });
204
205        *self.task_handle.write() = Some(handle);
206        info!("SecurityTokenManager started");
207    }
208
209    /// Stop the background token refresh task.
210    pub fn stop(&self) {
211        if let Some(tx) = self.shutdown_tx.write().take() {
212            let _ = tx.send(());
213        }
214        // Take and drop the task handle so the task can finish gracefully
215        let _ = self.task_handle.write().take();
216        info!("SecurityTokenManager stopped");
217    }
218
219    /// Background task that periodically refreshes tokens.
220    async fn token_refresh_loop(
221        rpc_client: Arc<RpcClient>,
222        metadata: Arc<Metadata>,
223        token_renewal_ratio: f64,
224        renewal_retry_backoff: Duration,
225        credentials_tx: watch::Sender<Option<HashMap<String, String>>>,
226        mut shutdown_rx: oneshot::Receiver<()>,
227    ) {
228        info!("Starting token refresh loop");
229
230        loop {
231            // Fetch token and send to channel
232            let result = Self::fetch_token(&rpc_client, &metadata).await;
233
234            let next_delay = match result {
235                Ok((props, expiration_time)) => {
236                    // Send credentials via watch channel (Some indicates fetched)
237                    if let Err(e) = credentials_tx.send(Some(props)) {
238                        debug!("No active subscribers for credentials update: {e:?}");
239                    }
240
241                    // Calculate next renewal delay based on expiration time
242                    if let Some(exp_time) = expiration_time {
243                        Self::calculate_renewal_delay(exp_time, token_renewal_ratio)
244                    } else {
245                        // No expiration time - token never expires, use long refresh interval
246                        info!(
247                            "Token has no expiration time (never expires), next refresh in {DEFAULT_NON_EXPIRING_REFRESH_INTERVAL:?}"
248                        );
249                        DEFAULT_NON_EXPIRING_REFRESH_INTERVAL
250                    }
251                }
252                Err(e) => {
253                    warn!(
254                        "Failed to obtain security token: {e:?}, will retry in {renewal_retry_backoff:?}"
255                    );
256                    renewal_retry_backoff
257                }
258            };
259
260            debug!("Next token refresh in {next_delay:?}");
261
262            // Wait for either the delay to elapse or shutdown signal
263            tokio::select! {
264                _ = tokio::time::sleep(next_delay) => {
265                    // Continue to next iteration to refresh
266                }
267                _ = &mut shutdown_rx => {
268                     info!("Token refresh loop received shutdown signal");
269                    break;
270                }
271            }
272        }
273    }
274
275    /// Fetch token from server.
276    /// Returns the props and expiration time if available.
277    async fn fetch_token(
278        rpc_client: &Arc<RpcClient>,
279        metadata: &Arc<Metadata>,
280    ) -> Result<(HashMap<String, String>, Option<i64>)> {
281        let cluster = metadata.get_cluster();
282        let server_node =
283            cluster
284                .get_one_available_server()
285                .ok_or_else(|| Error::UnexpectedError {
286                    message: "No tablet server available for token refresh".to_string(),
287                    source: None,
288                })?;
289
290        let conn = rpc_client.get_connection(server_node).await?;
291        let request = GetSecurityTokenRequest::new();
292        let response = conn.request(request).await?;
293
294        // The token may be empty if remote filesystem doesn't require authentication
295        if response.token.is_empty() {
296            info!("Empty token received, remote filesystem may not require authentication");
297            return Ok((HashMap::new(), response.expiration_time));
298        }
299
300        let credentials: Credentials =
301            serde_json::from_slice(&response.token).map_err(|e| Error::JsonSerdeError {
302                message: format!("Error when parsing token from server: {e}"),
303            })?;
304
305        let mut addition_infos = HashMap::new();
306        for kv in &response.addition_info {
307            addition_infos.insert(kv.key.clone(), kv.value.clone());
308        }
309
310        let props = build_remote_fs_props(&credentials, &addition_infos);
311        debug!("Security token fetched successfully");
312
313        Ok((props, response.expiration_time))
314    }
315
316    /// Calculate the delay before next token renewal.
317    /// Uses the renewal ratio to refresh before actual expiration.
318    /// Caps the delay to MAX_RENEWAL_DELAY to prevent overflow and ensure periodic refresh.
319    fn calculate_renewal_delay(expiration_time: i64, renewal_ratio: f64) -> Duration {
320        let now = SystemTime::now()
321            .duration_since(UNIX_EPOCH)
322            .unwrap()
323            .as_millis() as i64;
324
325        let time_until_expiry = expiration_time - now;
326        if time_until_expiry <= 0 {
327            // Token already expired, refresh immediately
328            return MIN_RENEWAL_DELAY;
329        }
330
331        // Cap time_until_expiry to prevent overflow when casting to f64 and back
332        let max_delay_ms = MAX_RENEWAL_DELAY.as_millis() as i64;
333        let capped_time = time_until_expiry.min(max_delay_ms);
334
335        let delay_ms = (capped_time as f64 * renewal_ratio) as u64;
336        let delay = Duration::from_millis(delay_ms);
337
338        debug!(
339            "Calculated renewal delay: {delay:?} (expiration: {expiration_time}, now: {now}, ratio: {renewal_ratio})"
340        );
341
342        delay.clamp(MIN_RENEWAL_DELAY, MAX_RENEWAL_DELAY)
343    }
344}
345
346impl Drop for SecurityTokenManager {
347    fn drop(&mut self) {
348        self.stop();
349    }
350}
351
352#[cfg(test)]
353mod tests {
354    use super::*;
355
356    #[test]
357    fn convert_hadoop_key_to_opendal_maps_known_keys() {
358        // S3 keys
359        let (key, invert) = convert_hadoop_key_to_opendal("fs.s3a.endpoint").expect("key");
360        assert_eq!(key, "endpoint");
361        assert!(!invert);
362
363        let (key, invert) = convert_hadoop_key_to_opendal("fs.s3a.path.style.access").expect("key");
364        assert_eq!(key, "enable_virtual_host_style");
365        assert!(invert);
366
367        assert!(convert_hadoop_key_to_opendal("fs.s3a.connection.ssl.enabled").is_none());
368
369        // OSS keys
370        let (key, invert) = convert_hadoop_key_to_opendal("fs.oss.endpoint").expect("key");
371        assert_eq!(key, "endpoint");
372        assert!(!invert);
373
374        let (key, invert) = convert_hadoop_key_to_opendal("fs.oss.region").expect("key");
375        assert_eq!(key, "region");
376        assert!(!invert);
377
378        // Unknown key
379        assert!(convert_hadoop_key_to_opendal("unknown.key").is_none());
380    }
381
382    #[test]
383    fn calculate_renewal_delay_returns_correct_delay() {
384        let now = SystemTime::now()
385            .duration_since(UNIX_EPOCH)
386            .unwrap()
387            .as_millis() as i64;
388
389        // Token expires in 1 hour
390        let expiration = now + 3600 * 1000;
391        let delay = SecurityTokenManager::calculate_renewal_delay(expiration, 0.8);
392
393        // Should be approximately 48 minutes (80% of 1 hour)
394        let expected_min = Duration::from_secs(2800); // ~46.7 minutes
395        let expected_max = Duration::from_secs(2900); // ~48.3 minutes
396        assert!(
397            delay >= expected_min && delay <= expected_max,
398            "Expected delay between {expected_min:?} and {expected_max:?}, got {delay:?}"
399        );
400    }
401
402    #[test]
403    fn calculate_renewal_delay_handles_expired_token() {
404        let now = SystemTime::now()
405            .duration_since(UNIX_EPOCH)
406            .unwrap()
407            .as_millis() as i64;
408
409        // Token already expired
410        let expiration = now - 1000;
411        let delay = SecurityTokenManager::calculate_renewal_delay(expiration, 0.8);
412
413        // Should return minimum delay
414        assert_eq!(delay, MIN_RENEWAL_DELAY);
415    }
416
417    #[test]
418    fn build_remote_fs_props_includes_all_fields() {
419        let credentials = Credentials {
420            access_key_id: "ak".to_string(),
421            access_key_secret: "sk".to_string(),
422            security_token: Some("token".to_string()),
423        };
424        let addition_infos =
425            HashMap::from([("fs.s3a.path.style.access".to_string(), "true".to_string())]);
426
427        let props = build_remote_fs_props(&credentials, &addition_infos);
428        assert_eq!(props.get("access_key_id"), Some(&"ak".to_string()));
429        assert_eq!(props.get("access_key_secret"), Some(&"sk".to_string()));
430        assert_eq!(props.get("access_key_secret"), Some(&"sk".to_string()));
431        assert_eq!(props.get("security_token"), Some(&"token".to_string()));
432        assert_eq!(
433            props.get("enable_virtual_host_style"),
434            Some(&"false".to_string())
435        );
436    }
437}