1use crate::config::HttpConfig;
8use crate::error::HttpError;
9use crate::model::types::AuthToken;
10use crate::time_compat::{SystemTime, UNIX_EPOCH};
11use base64::Engine;
12use hmac::{Hmac, Mac};
13use pretty_simple_display::{DebugPretty, DisplaySimple};
14use reqwest::Client;
15use serde::{Deserialize, Serialize};
16use sha2::Sha256;
17use std::time::Duration;
18use tracing::{debug, error};
19use urlencoding;
20
21type HmacSha256 = Hmac<Sha256>;
22
23#[derive(DebugPretty, DisplaySimple, Clone, Serialize, Deserialize)]
25pub struct AuthRequest {
26 pub grant_type: String,
28 pub client_id: String,
30 pub client_secret: String,
32 pub scope: Option<String>,
34}
35
36#[derive(DebugPretty, DisplaySimple, Clone, Serialize, Deserialize)]
38pub struct ApiKeyAuth {
39 pub key: String,
41 pub secret: String,
43 pub timestamp: u64,
45 pub nonce: String,
47}
48
49#[derive(Debug, Clone)]
51pub struct AuthManager {
52 client: Client,
53 config: HttpConfig,
54 token: Option<AuthToken>,
55 token_expires_at: Option<SystemTime>,
56}
57
58impl AuthManager {
59 pub fn new(client: Client, config: HttpConfig) -> Self {
61 Self {
62 client,
63 config,
64 token: None,
65 token_expires_at: None,
66 }
67 }
68
69 pub async fn authenticate_oauth2(&mut self) -> Result<AuthToken, HttpError> {
71 let credentials = match self.config.credentials.clone() {
72 Some(creds) => match creds.is_valid() {
73 true => creds,
74 false => {
75 return Err(HttpError::AuthenticationFailed(
76 "Invalid credentials for OAuth2".to_string(),
77 ));
78 }
79 },
80 None => {
81 return Err(HttpError::AuthenticationFailed(
82 "No credentials configured".to_string(),
83 ));
84 }
85 };
86 let (client_id, client_secret) = credentials.get_client_credentials()?;
87 let url = format!(
89 "{}/public/auth?grant_type=client_credentials&client_id={}&client_secret={}",
90 self.config.base_url,
91 urlencoding::encode(client_id.as_str()),
92 urlencoding::encode(client_secret.as_str())
93 );
94
95 debug!("Authentication URL: {}", url);
97
98 let response = self
99 .client
100 .get(&url)
101 .header("Content-Type", "application/json")
102 .send()
103 .await
104 .map_err(|e| HttpError::NetworkError(e.to_string()))?;
105
106 if !response.status().is_success() {
107 let error_text = response
108 .text()
109 .await
110 .unwrap_or_else(|_| "Unknown error".to_string());
111 return Err(HttpError::AuthenticationFailed(format!(
112 "OAuth2 authentication failed: {}",
113 error_text
114 )));
115 }
116
117 let json_response: serde_json::Value = response
119 .json()
120 .await
121 .map_err(|e| HttpError::InvalidResponse(e.to_string()))?;
122
123 if let Some(error) = json_response.get("error") {
125 let _code = error.get("code").and_then(|c| c.as_i64()).unwrap_or(-1);
126 let _message = error
127 .get("message")
128 .and_then(|m| m.as_str())
129 .unwrap_or("Unknown error");
130 return Err(HttpError::AuthenticationFailed(format!(
131 "OAuth2 authentication failed: {}",
132 json_response
133 )));
134 }
135
136 let result = json_response
138 .get("result")
139 .ok_or_else(|| HttpError::InvalidResponse("No result in response".to_string()))?;
140
141 let token: AuthToken = serde_json::from_value(result.clone())
142 .map_err(|e| HttpError::InvalidResponse(format!("Failed to parse token: {}", e)))?;
143
144 let expires_at = SystemTime::now() + Duration::from_secs(token.expires_in);
146
147 self.token = Some(token.clone());
148 self.token_expires_at = Some(expires_at);
149
150 Ok(token)
151 }
152
153 pub fn generate_api_key_signature(
155 &self,
156 api_secret: &str,
157 timestamp: u64,
158 nonce: &str,
159 method: &str,
160 uri: &str,
161 body: &str,
162 ) -> Result<String, HttpError> {
163 let data = format!(
164 "{}{}{}{}{}",
165 timestamp,
166 nonce,
167 method.to_uppercase(),
168 uri,
169 body
170 );
171
172 let mut mac = HmacSha256::new_from_slice(api_secret.as_bytes())
173 .map_err(|e| HttpError::AuthenticationFailed(format!("Invalid API secret: {}", e)))?;
174
175 mac.update(data.as_bytes());
176 let result = mac.finalize();
177
178 Ok(base64::engine::general_purpose::STANDARD.encode(result.into_bytes()))
179 }
180
181 pub fn get_token(&self) -> Option<&AuthToken> {
183 if !self.is_token_expired() {
184 self.token.as_ref()
185 } else {
186 None
187 }
188 }
189
190 fn is_token_expired(&self) -> bool {
192 match self.token_expires_at {
193 Some(expires_at) => {
194 let buffer = Duration::from_secs(60);
196 SystemTime::now() + buffer >= expires_at
197 }
198 None => true,
199 }
200 }
201
202 fn is_token_valid(&self) -> bool {
214 match self.token {
215 Some(_) => !self.is_token_expired(),
216 None => false,
217 }
218 }
219
220 pub async fn get_authorization_header(&mut self) -> Option<String> {
222 match self.is_token_valid() {
223 true => {
224 let token = self.token.as_ref().unwrap();
225 Some(format!("{} {}", token.token_type, token.access_token))
226 }
227 false => match self.config.credentials.as_ref() {
228 Some(credentials) => match credentials.is_valid() {
229 true => match self.authenticate_oauth2().await {
230 Ok(token) => Some(format!("{} {}", token.token_type, token.access_token)),
231 Err(e) => {
232 error!("Failed to authenticate: {}", e);
233 None
234 }
235 },
236 false => None,
237 },
238 None => None,
239 },
240 }
241 }
242
243 pub fn generate_nonce() -> String {
245 use rand::RngExt;
246 let mut rng = rand::rng();
247 let chars: String = (0..16)
248 .map(|_| {
249 let idx = rng.random_range(0..62);
250 match idx {
251 0..=25 => (b'a' + idx) as char,
252 26..=51 => (b'A' + (idx - 26)) as char,
253 _ => (b'0' + (idx - 52)) as char,
254 }
255 })
256 .collect();
257 chars
258 }
259
260 pub fn get_timestamp() -> u64 {
262 SystemTime::now()
263 .duration_since(UNIX_EPOCH)
264 .unwrap_or_default()
265 .as_millis() as u64
266 }
267
268 pub fn update_token(&mut self, token: AuthToken) {
295 self.token_expires_at = Some(SystemTime::now() + Duration::from_secs(token.expires_in));
296 self.token = Some(token);
297 }
298}
299
300#[cfg(test)]
301mod tests {
302 use super::*;
303
304 #[test]
305 fn test_auth_request_creation() {
306 let auth_request = AuthRequest {
307 grant_type: "client_credentials".to_string(),
308 client_id: "test_client".to_string(),
309 client_secret: "test_secret".to_string(),
310 scope: Some("read write".to_string()),
311 };
312
313 assert_eq!(auth_request.grant_type, "client_credentials");
314 assert_eq!(auth_request.client_id, "test_client");
315 }
316
317 #[test]
318 fn test_nonce_generation() {
319 let nonce1 = AuthManager::generate_nonce();
320 let nonce2 = AuthManager::generate_nonce();
321
322 assert_eq!(nonce1.len(), 16);
323 assert_eq!(nonce2.len(), 16);
324 assert_ne!(nonce1, nonce2);
325 }
326
327 #[test]
328 fn test_timestamp_generation() {
329 let timestamp1 = AuthManager::get_timestamp();
330 std::thread::sleep(std::time::Duration::from_millis(1));
331 let timestamp2 = AuthManager::get_timestamp();
332
333 assert!(timestamp2 > timestamp1);
334 }
335}