1use crate::config::HttpConfig;
8use crate::error::HttpError;
9use crate::model::types::AuthToken;
10use crate::time_compat::{SystemTime, UNIX_EPOCH};
11use base64::Engine;
12use hmac::KeyInit;
13use hmac::{Hmac, Mac};
14use pretty_simple_display::{DebugPretty, DisplaySimple};
15use reqwest::Client;
16use serde::{Deserialize, Serialize};
17use sha2::Sha256;
18use std::time::Duration;
19use tracing::{debug, error};
20use urlencoding;
21
22type HmacSha256 = Hmac<Sha256>;
23
24#[derive(DebugPretty, DisplaySimple, Clone, Serialize, Deserialize)]
26pub struct AuthRequest {
27 pub grant_type: String,
29 pub client_id: String,
31 pub client_secret: String,
33 pub scope: Option<String>,
35}
36
37#[derive(DebugPretty, DisplaySimple, Clone, Serialize, Deserialize)]
39pub struct ApiKeyAuth {
40 pub key: String,
42 pub secret: String,
44 pub timestamp: u64,
46 pub nonce: String,
48}
49
50#[derive(Debug, Clone)]
52pub struct AuthManager {
53 client: Client,
54 config: HttpConfig,
55 token: Option<AuthToken>,
56 token_expires_at: Option<SystemTime>,
57}
58
59impl AuthManager {
60 pub fn new(client: Client, config: HttpConfig) -> Self {
62 Self {
63 client,
64 config,
65 token: None,
66 token_expires_at: None,
67 }
68 }
69
70 pub async fn authenticate_oauth2(&mut self) -> Result<AuthToken, HttpError> {
72 let credentials = match self.config.credentials.clone() {
73 Some(creds) => match creds.is_valid() {
74 true => creds,
75 false => {
76 return Err(HttpError::AuthenticationFailed(
77 "Invalid credentials for OAuth2".to_string(),
78 ));
79 }
80 },
81 None => {
82 return Err(HttpError::AuthenticationFailed(
83 "No credentials configured".to_string(),
84 ));
85 }
86 };
87 let (client_id, client_secret) = credentials.get_client_credentials()?;
88 let url = format!(
90 "{}/public/auth?grant_type=client_credentials&client_id={}&client_secret={}",
91 self.config.base_url,
92 urlencoding::encode(client_id.as_str()),
93 urlencoding::encode(client_secret.as_str())
94 );
95
96 debug!("Authentication URL: {}", url);
98
99 let response = self
100 .client
101 .get(&url)
102 .header("Content-Type", "application/json")
103 .send()
104 .await
105 .map_err(|e| HttpError::NetworkError(e.to_string()))?;
106
107 if !response.status().is_success() {
108 let error_text = response
109 .text()
110 .await
111 .unwrap_or_else(|_| "Unknown error".to_string());
112 return Err(HttpError::AuthenticationFailed(format!(
113 "OAuth2 authentication failed: {}",
114 error_text
115 )));
116 }
117
118 let json_response: serde_json::Value = response
120 .json()
121 .await
122 .map_err(|e| HttpError::InvalidResponse(e.to_string()))?;
123
124 if let Some(error) = json_response.get("error") {
126 let _code = error.get("code").and_then(|c| c.as_i64()).unwrap_or(-1);
127 let _message = error
128 .get("message")
129 .and_then(|m| m.as_str())
130 .unwrap_or("Unknown error");
131 return Err(HttpError::AuthenticationFailed(format!(
132 "OAuth2 authentication failed: {}",
133 json_response
134 )));
135 }
136
137 let result = json_response
139 .get("result")
140 .ok_or_else(|| HttpError::InvalidResponse("No result in response".to_string()))?;
141
142 let token: AuthToken = serde_json::from_value(result.clone())
143 .map_err(|e| HttpError::InvalidResponse(format!("Failed to parse token: {}", e)))?;
144
145 let expires_at = SystemTime::now() + Duration::from_secs(token.expires_in);
147
148 self.token = Some(token.clone());
149 self.token_expires_at = Some(expires_at);
150
151 Ok(token)
152 }
153
154 pub fn generate_api_key_signature(
156 &self,
157 api_secret: &str,
158 timestamp: u64,
159 nonce: &str,
160 method: &str,
161 uri: &str,
162 body: &str,
163 ) -> Result<String, HttpError> {
164 let data = format!(
165 "{}{}{}{}{}",
166 timestamp,
167 nonce,
168 method.to_uppercase(),
169 uri,
170 body
171 );
172
173 let mut mac = HmacSha256::new_from_slice(api_secret.as_bytes())
174 .map_err(|e| HttpError::AuthenticationFailed(format!("Invalid API secret: {}", e)))?;
175
176 mac.update(data.as_bytes());
177 let result = mac.finalize();
178
179 Ok(base64::engine::general_purpose::STANDARD.encode(result.into_bytes()))
180 }
181
182 pub fn get_token(&self) -> Option<&AuthToken> {
184 if !self.is_token_expired() {
185 self.token.as_ref()
186 } else {
187 None
188 }
189 }
190
191 fn is_token_expired(&self) -> bool {
193 match self.token_expires_at {
194 Some(expires_at) => {
195 let buffer = Duration::from_secs(60);
197 SystemTime::now() + buffer >= expires_at
198 }
199 None => true,
200 }
201 }
202
203 fn is_token_valid(&self) -> bool {
215 match self.token {
216 Some(_) => !self.is_token_expired(),
217 None => false,
218 }
219 }
220
221 pub async fn get_authorization_header(&mut self) -> Option<String> {
223 match self.is_token_valid() {
224 true => {
225 let token = self.token.as_ref().unwrap();
226 Some(format!("{} {}", token.token_type, token.access_token))
227 }
228 false => match self.config.credentials.as_ref() {
229 Some(credentials) => match credentials.is_valid() {
230 true => match self.authenticate_oauth2().await {
231 Ok(token) => Some(format!("{} {}", token.token_type, token.access_token)),
232 Err(e) => {
233 error!("Failed to authenticate: {}", e);
234 None
235 }
236 },
237 false => None,
238 },
239 None => None,
240 },
241 }
242 }
243
244 pub fn generate_nonce() -> String {
246 use rand::RngExt;
247 let mut rng = rand::rng();
248 let chars: String = (0..16)
249 .map(|_| {
250 let idx = rng.random_range(0..62);
251 match idx {
252 0..=25 => (b'a' + idx) as char,
253 26..=51 => (b'A' + (idx - 26)) as char,
254 _ => (b'0' + (idx - 52)) as char,
255 }
256 })
257 .collect();
258 chars
259 }
260
261 pub fn get_timestamp() -> u64 {
263 SystemTime::now()
264 .duration_since(UNIX_EPOCH)
265 .unwrap_or_default()
266 .as_millis() as u64
267 }
268
269 pub fn update_token(&mut self, token: AuthToken) {
296 self.token_expires_at = Some(SystemTime::now() + Duration::from_secs(token.expires_in));
297 self.token = Some(token);
298 }
299}
300
301#[cfg(test)]
302mod tests {
303 use super::*;
304
305 #[test]
306 fn test_auth_request_creation() {
307 let auth_request = AuthRequest {
308 grant_type: "client_credentials".to_string(),
309 client_id: "test_client".to_string(),
310 client_secret: "test_secret".to_string(),
311 scope: Some("read write".to_string()),
312 };
313
314 assert_eq!(auth_request.grant_type, "client_credentials");
315 assert_eq!(auth_request.client_id, "test_client");
316 }
317
318 #[test]
319 fn test_nonce_generation() {
320 let nonce1 = AuthManager::generate_nonce();
321 let nonce2 = AuthManager::generate_nonce();
322
323 assert_eq!(nonce1.len(), 16);
324 assert_eq!(nonce2.len(), 16);
325 assert_ne!(nonce1, nonce2);
326 }
327
328 #[test]
329 fn test_timestamp_generation() {
330 let timestamp1 = AuthManager::get_timestamp();
331 std::thread::sleep(std::time::Duration::from_millis(1));
332 let timestamp2 = AuthManager::get_timestamp();
333
334 assert!(timestamp2 > timestamp1);
335 }
336}