1use crate::config::HttpConfig;
8use crate::error::HttpError;
9use crate::model::types::AuthToken;
10use base64::Engine;
11use hmac::{Hmac, Mac};
12use pretty_simple_display::{DebugPretty, DisplaySimple};
13use reqwest::Client;
14use serde::{Deserialize, Serialize};
15use sha2::Sha256;
16use std::time::{Duration, SystemTime, UNIX_EPOCH};
17use tracing::{debug, error};
18use urlencoding;
19
20type HmacSha256 = Hmac<Sha256>;
21
22#[derive(DebugPretty, DisplaySimple, Clone, Serialize, Deserialize)]
24pub struct AuthRequest {
25 pub grant_type: String,
27 pub client_id: String,
29 pub client_secret: String,
31 pub scope: Option<String>,
33}
34
35#[derive(DebugPretty, DisplaySimple, Clone, Serialize, Deserialize)]
37pub struct ApiKeyAuth {
38 pub key: String,
40 pub secret: String,
42 pub timestamp: u64,
44 pub nonce: String,
46}
47
48#[derive(Debug, Clone)]
50pub struct AuthManager {
51 client: Client,
52 config: HttpConfig,
53 token: Option<AuthToken>,
54 token_expires_at: Option<SystemTime>,
55}
56
57impl AuthManager {
58 pub fn new(client: Client, config: HttpConfig) -> Self {
60 Self {
61 client,
62 config,
63 token: None,
64 token_expires_at: None,
65 }
66 }
67
68 pub async fn authenticate_oauth2(&mut self) -> Result<AuthToken, HttpError> {
70 let credentials = match self.config.credentials.clone() {
71 Some(creds) => match creds.is_valid() {
72 true => creds,
73 false => {
74 return Err(HttpError::AuthenticationFailed(
75 "Invalid credentials for OAuth2".to_string(),
76 ));
77 }
78 },
79 None => {
80 return Err(HttpError::AuthenticationFailed(
81 "No credentials configured".to_string(),
82 ));
83 }
84 };
85 let (client_id, client_secret) = credentials.get_client_credentials()?;
86 let url = format!(
88 "{}/public/auth?grant_type=client_credentials&client_id={}&client_secret={}",
89 self.config.base_url,
90 urlencoding::encode(client_id.as_str()),
91 urlencoding::encode(client_secret.as_str())
92 );
93
94 debug!("Authentication URL: {}", url);
96
97 let response = self
98 .client
99 .get(&url)
100 .header("Content-Type", "application/json")
101 .send()
102 .await
103 .map_err(|e| HttpError::NetworkError(e.to_string()))?;
104
105 if !response.status().is_success() {
106 let error_text = response
107 .text()
108 .await
109 .unwrap_or_else(|_| "Unknown error".to_string());
110 return Err(HttpError::AuthenticationFailed(format!(
111 "OAuth2 authentication failed: {}",
112 error_text
113 )));
114 }
115
116 let json_response: serde_json::Value = response
118 .json()
119 .await
120 .map_err(|e| HttpError::InvalidResponse(e.to_string()))?;
121
122 if let Some(error) = json_response.get("error") {
124 let _code = error.get("code").and_then(|c| c.as_i64()).unwrap_or(-1);
125 let _message = error
126 .get("message")
127 .and_then(|m| m.as_str())
128 .unwrap_or("Unknown error");
129 return Err(HttpError::AuthenticationFailed(format!(
130 "OAuth2 authentication failed: {}",
131 json_response
132 )));
133 }
134
135 let result = json_response
137 .get("result")
138 .ok_or_else(|| HttpError::InvalidResponse("No result in response".to_string()))?;
139
140 let token: AuthToken = serde_json::from_value(result.clone())
141 .map_err(|e| HttpError::InvalidResponse(format!("Failed to parse token: {}", e)))?;
142
143 let expires_at = SystemTime::now() + Duration::from_secs(token.expires_in);
145
146 self.token = Some(token.clone());
147 self.token_expires_at = Some(expires_at);
148
149 Ok(token)
150 }
151
152 pub fn generate_api_key_signature(
154 &self,
155 api_secret: &str,
156 timestamp: u64,
157 nonce: &str,
158 method: &str,
159 uri: &str,
160 body: &str,
161 ) -> Result<String, HttpError> {
162 let data = format!(
163 "{}{}{}{}{}",
164 timestamp,
165 nonce,
166 method.to_uppercase(),
167 uri,
168 body
169 );
170
171 let mut mac = HmacSha256::new_from_slice(api_secret.as_bytes())
172 .map_err(|e| HttpError::AuthenticationFailed(format!("Invalid API secret: {}", e)))?;
173
174 mac.update(data.as_bytes());
175 let result = mac.finalize();
176
177 Ok(base64::engine::general_purpose::STANDARD.encode(result.into_bytes()))
178 }
179
180 pub fn get_token(&self) -> Option<&AuthToken> {
182 if !self.is_token_expired() {
183 self.token.as_ref()
184 } else {
185 None
186 }
187 }
188
189 fn is_token_expired(&self) -> bool {
191 match self.token_expires_at {
192 Some(expires_at) => {
193 let buffer = Duration::from_secs(60);
195 SystemTime::now() + buffer >= expires_at
196 }
197 None => true,
198 }
199 }
200
201 fn is_token_valid(&self) -> bool {
213 match self.token {
214 Some(_) => !self.is_token_expired(),
215 None => false,
216 }
217 }
218
219 pub async fn get_authorization_header(&mut self) -> Option<String> {
221 match self.is_token_valid() {
222 true => {
223 let token = self.token.as_ref().unwrap();
224 Some(format!("{} {}", token.token_type, token.access_token))
225 }
226 false => match self.config.credentials.as_ref() {
227 Some(credentials) => match credentials.is_valid() {
228 true => match self.authenticate_oauth2().await {
229 Ok(token) => Some(format!("{} {}", token.token_type, token.access_token)),
230 Err(e) => {
231 error!("Failed to authenticate: {}", e);
232 None
233 }
234 },
235 false => None,
236 },
237 None => None,
238 },
239 }
240 }
241
242 pub fn generate_nonce() -> String {
244 use rand::Rng;
245 let mut rng = rand::rng();
246 let chars: String = (0..16)
247 .map(|_| {
248 let idx = rng.random_range(0..62);
249 match idx {
250 0..=25 => (b'a' + idx) as char,
251 26..=51 => (b'A' + (idx - 26)) as char,
252 _ => (b'0' + (idx - 52)) as char,
253 }
254 })
255 .collect();
256 chars
257 }
258
259 pub fn get_timestamp() -> u64 {
261 SystemTime::now()
262 .duration_since(UNIX_EPOCH)
263 .unwrap_or_default()
264 .as_millis() as u64
265 }
266
267 pub fn update_token(&mut self, token: AuthToken) {
294 self.token_expires_at = Some(SystemTime::now() + Duration::from_secs(token.expires_in));
295 self.token = Some(token);
296 }
297}
298
299#[cfg(test)]
300mod tests {
301 use super::*;
302
303 #[test]
304 fn test_auth_request_creation() {
305 let auth_request = AuthRequest {
306 grant_type: "client_credentials".to_string(),
307 client_id: "test_client".to_string(),
308 client_secret: "test_secret".to_string(),
309 scope: Some("read write".to_string()),
310 };
311
312 assert_eq!(auth_request.grant_type, "client_credentials");
313 assert_eq!(auth_request.client_id, "test_client");
314 }
315
316 #[test]
317 fn test_nonce_generation() {
318 let nonce1 = AuthManager::generate_nonce();
319 let nonce2 = AuthManager::generate_nonce();
320
321 assert_eq!(nonce1.len(), 16);
322 assert_eq!(nonce2.len(), 16);
323 assert_ne!(nonce1, nonce2);
324 }
325
326 #[test]
327 fn test_timestamp_generation() {
328 let timestamp1 = AuthManager::get_timestamp();
329 std::thread::sleep(std::time::Duration::from_millis(1));
330 let timestamp2 = AuthManager::get_timestamp();
331
332 assert!(timestamp2 > timestamp1);
333 }
334}