batata_client/auth/
login.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3use std::time::Duration;
4
5use parking_lot::RwLock;
6use tracing::{debug, warn};
7
8use crate::auth::{AccessToken, Credentials};
9use crate::error::{BatataError, Result};
10use crate::remote::ServerAddress;
11
12/// Default token TTL in milliseconds (5 hours, matching Nacos default)
13const DEFAULT_TOKEN_TTL_MS: i64 = 18000000;
14
15/// Auth manager for handling authentication
16pub struct AuthManager {
17    credentials: Credentials,
18    token: Arc<RwLock<AccessToken>>,
19    server_addresses: Vec<ServerAddress>,
20}
21
22impl AuthManager {
23    /// Create a new auth manager
24    pub fn new(credentials: Credentials, server_addresses: Vec<ServerAddress>) -> Self {
25        Self {
26            credentials,
27            token: Arc::new(RwLock::new(AccessToken::default())),
28            server_addresses,
29        }
30    }
31
32    /// Check if authentication is required
33    pub fn is_auth_enabled(&self) -> bool {
34        self.credentials.is_configured()
35    }
36
37    /// Get current valid token, refreshing if necessary
38    pub async fn get_token(&self) -> Result<Option<String>> {
39        if !self.is_auth_enabled() {
40            return Ok(None);
41        }
42
43        // Check if token is still valid
44        {
45            let token = self.token.read();
46            if token.is_valid() {
47                return Ok(Some(token.token.clone()));
48            }
49        }
50
51        // Refresh token
52        self.refresh_token().await?;
53
54        let token = self.token.read();
55        if token.is_valid() {
56            Ok(Some(token.token.clone()))
57        } else {
58            Err(BatataError::AuthError {
59                message: "Failed to obtain valid token".to_string(),
60            })
61        }
62    }
63
64    /// Refresh the access token
65    pub async fn refresh_token(&self) -> Result<()> {
66        if !self.credentials.is_configured() {
67            return Ok(());
68        }
69
70        for server in &self.server_addresses {
71            match self.login_to_server(server).await {
72                Ok(token) => {
73                    *self.token.write() = token;
74                    debug!("Token refreshed successfully from {}", server.address());
75                    return Ok(());
76                }
77                Err(e) => {
78                    warn!("Failed to login to {}: {}", server.address(), e);
79                    continue;
80                }
81            }
82        }
83
84        Err(BatataError::AuthError {
85            message: "Failed to login to any server".to_string(),
86        })
87    }
88
89    /// Login to a specific server
90    async fn login_to_server(&self, server: &ServerAddress) -> Result<AccessToken> {
91        let url = format!(
92            "http://{}:{}/nacos/v1/auth/login",
93            server.host(),
94            server.port()
95        );
96
97        let client = reqwest::Client::builder()
98            .timeout(Duration::from_secs(5))
99            .build()
100            .map_err(|e| BatataError::connection_error(format!("Failed to create HTTP client: {}", e)))?;
101
102        let mut params = HashMap::new();
103
104        if let (Some(username), Some(password)) = (&self.credentials.username, &self.credentials.password) {
105            params.insert("username".to_string(), username.clone());
106            params.insert("password".to_string(), password.clone());
107        } else if self.credentials.has_ak_sk_auth() {
108            // For AK/SK auth, we might need different endpoint or headers
109            // This is a simplified version
110            if let Some(sig) = self.credentials.generate_signature(&server.address()) {
111                params.insert("accessKey".to_string(), sig.access_key.clone());
112                // Note: actual implementation may vary based on Nacos version
113            }
114        }
115
116        let response = client
117            .post(&url)
118            .form(&params)
119            .send()
120            .await
121            .map_err(|e| BatataError::connection_error(format!("Login request failed: {}", e)))?;
122
123        if !response.status().is_success() {
124            let status = response.status();
125            let body = response.text().await.unwrap_or_default();
126            return Err(BatataError::AuthError {
127                message: format!("Login failed with status {}: {}", status, body),
128            });
129        }
130
131        let body: serde_json::Value = response
132            .json()
133            .await
134            .map_err(|e| BatataError::AuthError {
135                message: format!("Failed to parse login response: {}", e),
136            })?;
137
138        let token = body["accessToken"]
139            .as_str()
140            .ok_or_else(|| BatataError::AuthError {
141                message: "No accessToken in response".to_string(),
142            })?
143            .to_string();
144
145        let ttl = body["tokenTtl"]
146            .as_i64()
147            .unwrap_or(DEFAULT_TOKEN_TTL_MS / 1000) * 1000;
148
149        let global_admin = body["globalAdmin"].as_bool().unwrap_or(false);
150
151        Ok(AccessToken {
152            token,
153            expire_time: crate::common::current_time_millis() + ttl,
154            global_admin,
155        })
156    }
157
158    /// Get auth headers to include in requests
159    pub async fn get_auth_headers(&self) -> Result<HashMap<String, String>> {
160        let mut headers = HashMap::new();
161
162        if !self.is_auth_enabled() {
163            return Ok(headers);
164        }
165
166        // Get token
167        if let Some(token) = self.get_token().await? {
168            headers.insert("accessToken".to_string(), token);
169        }
170
171        // Add AK/SK signature if configured
172        if self.credentials.has_ak_sk_auth()
173            && let Some(sig) = self.credentials.generate_signature("+")
174        {
175            headers.insert("ak".to_string(), sig.access_key);
176            headers.insert("data".to_string(), sig.timestamp);
177            headers.insert("signature".to_string(), sig.signature);
178        }
179
180        Ok(headers)
181    }
182}
183
184#[cfg(test)]
185mod tests {
186    use super::*;
187
188    #[test]
189    fn test_auth_manager_no_auth() {
190        let manager = AuthManager::new(Credentials::new(), vec![]);
191        assert!(!manager.is_auth_enabled());
192    }
193
194    #[test]
195    fn test_auth_manager_with_credentials() {
196        let creds = Credentials::with_username_password("admin", "password");
197        let servers = vec![ServerAddress::new("localhost", 8848)];
198        let manager = AuthManager::new(creds, servers);
199        assert!(manager.is_auth_enabled());
200    }
201}