1use crate::UtilsError;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::fmt;
10use std::sync::{Arc, Mutex};
11use std::time::{Duration, Instant};
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct ApiConfig {
16 pub base_url: String,
17 pub timeout: Duration,
18 pub max_retries: u32,
19 pub retry_delay: Duration,
20 pub headers: HashMap<String, String>,
21 pub authentication: Option<Authentication>,
22 pub user_agent: String,
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
26pub enum Authentication {
27 Bearer(String),
28 ApiKey { key: String, header: String },
29 Basic { username: String, password: String },
30 Custom { headers: HashMap<String, String> },
31}
32
33impl Default for ApiConfig {
34 fn default() -> Self {
35 Self {
36 base_url: "https://api.example.com".to_string(),
37 timeout: Duration::from_secs(30),
38 max_retries: 3,
39 retry_delay: Duration::from_millis(1000),
40 headers: HashMap::new(),
41 authentication: None,
42 user_agent: "sklears-utils/1.0".to_string(),
43 }
44 }
45}
46
47impl ApiConfig {
48 pub fn new(base_url: String) -> Self {
49 Self {
50 base_url,
51 ..Default::default()
52 }
53 }
54
55 pub fn with_timeout(mut self, timeout: Duration) -> Self {
56 self.timeout = timeout;
57 self
58 }
59
60 pub fn with_retries(mut self, max_retries: u32, retry_delay: Duration) -> Self {
61 self.max_retries = max_retries;
62 self.retry_delay = retry_delay;
63 self
64 }
65
66 pub fn with_authentication(mut self, auth: Authentication) -> Self {
67 self.authentication = Some(auth);
68 self
69 }
70
71 pub fn with_header(mut self, key: String, value: String) -> Self {
72 self.headers.insert(key, value);
73 self
74 }
75
76 pub fn with_user_agent(mut self, user_agent: String) -> Self {
77 self.user_agent = user_agent;
78 self
79 }
80}
81
82#[derive(thiserror::Error, Debug, Clone)]
84pub enum ApiError {
85 #[error("HTTP error {status}: {message}")]
86 HttpError { status: u16, message: String },
87 #[error("Network error: {0}")]
88 NetworkError(String),
89 #[error("Timeout error: request took longer than {0:?}")]
90 TimeoutError(Duration),
91 #[error("Serialization error: {0}")]
92 SerializationError(String),
93 #[error("Authentication error: {0}")]
94 AuthenticationError(String),
95 #[error("Rate limit exceeded: {retry_after:?}")]
96 RateLimitError { retry_after: Option<Duration> },
97 #[error("Invalid request: {0}")]
98 InvalidRequest(String),
99}
100
101impl From<ApiError> for UtilsError {
102 fn from(err: ApiError) -> Self {
103 UtilsError::InvalidParameter(err.to_string())
104 }
105}
106
107#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
109pub enum HttpMethod {
110 Get,
111 Post,
112 Put,
113 Delete,
114 Patch,
115 Head,
116 Options,
117}
118
119impl fmt::Display for HttpMethod {
120 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
121 match self {
122 HttpMethod::Get => write!(f, "GET"),
123 HttpMethod::Post => write!(f, "POST"),
124 HttpMethod::Put => write!(f, "PUT"),
125 HttpMethod::Delete => write!(f, "DELETE"),
126 HttpMethod::Patch => write!(f, "PATCH"),
127 HttpMethod::Head => write!(f, "HEAD"),
128 HttpMethod::Options => write!(f, "OPTIONS"),
129 }
130 }
131}
132
133#[derive(Debug, Clone)]
135pub struct ApiRequest {
136 pub method: HttpMethod,
137 pub url: String,
138 pub headers: HashMap<String, String>,
139 pub body: Option<Vec<u8>>,
140 pub query_params: HashMap<String, String>,
141}
142
143impl ApiRequest {
144 pub fn new(method: HttpMethod, url: String) -> Self {
145 Self {
146 method,
147 url,
148 headers: HashMap::new(),
149 body: None,
150 query_params: HashMap::new(),
151 }
152 }
153
154 pub fn with_header(mut self, key: String, value: String) -> Self {
155 self.headers.insert(key, value);
156 self
157 }
158
159 pub fn with_body<T: Serialize>(mut self, body: &T) -> Result<Self, ApiError> {
160 let serialized =
161 serde_json::to_vec(body).map_err(|e| ApiError::SerializationError(e.to_string()))?;
162 self.body = Some(serialized);
163 self.headers
164 .insert("Content-Type".to_string(), "application/json".to_string());
165 Ok(self)
166 }
167
168 pub fn with_json_body(mut self, json: String) -> Self {
169 self.body = Some(json.into_bytes());
170 self.headers
171 .insert("Content-Type".to_string(), "application/json".to_string());
172 self
173 }
174
175 pub fn with_query_param(mut self, key: String, value: String) -> Self {
176 self.query_params.insert(key, value);
177 self
178 }
179
180 pub fn with_query_params(mut self, params: HashMap<String, String>) -> Self {
181 self.query_params.extend(params);
182 self
183 }
184
185 pub fn build_url(&self) -> String {
186 if self.query_params.is_empty() {
187 return self.url.clone();
188 }
189
190 let query_string: String = self
191 .query_params
192 .iter()
193 .map(|(k, v)| format!("{k}={v}"))
194 .collect::<Vec<_>>()
195 .join("&");
196
197 if self.url.contains('?') {
198 format!("{}&{}", self.url, query_string)
199 } else {
200 format!("{}?{}", self.url, query_string)
201 }
202 }
203}
204
205#[derive(Debug, Clone)]
207pub struct ApiResponse {
208 pub status_code: u16,
209 pub headers: HashMap<String, String>,
210 pub body: Vec<u8>,
211 pub execution_time: Duration,
212}
213
214impl ApiResponse {
215 pub fn new(status_code: u16, body: Vec<u8>, execution_time: Duration) -> Self {
216 Self {
217 status_code,
218 headers: HashMap::new(),
219 body,
220 execution_time,
221 }
222 }
223
224 pub fn with_headers(mut self, headers: HashMap<String, String>) -> Self {
225 self.headers = headers;
226 self
227 }
228
229 pub fn is_success(&self) -> bool {
230 self.status_code >= 200 && self.status_code < 300
231 }
232
233 pub fn text(&self) -> Result<String, ApiError> {
234 String::from_utf8(self.body.clone())
235 .map_err(|e| ApiError::SerializationError(e.to_string()))
236 }
237
238 pub fn json<T: for<'de> Deserialize<'de>>(&self) -> Result<T, ApiError> {
239 serde_json::from_slice(&self.body).map_err(|e| ApiError::SerializationError(e.to_string()))
240 }
241
242 pub fn get_header(&self, name: &str) -> Option<&String> {
243 self.headers.get(name)
244 }
245}
246
247pub trait ApiClient {
249 fn execute(&self, request: ApiRequest) -> Result<ApiResponse, ApiError>;
250}
251
252pub struct MockApiClient {
254 responses: Arc<Mutex<Vec<ApiResponse>>>,
255 current_index: Arc<Mutex<usize>>,
256}
257
258impl Default for MockApiClient {
259 fn default() -> Self {
260 Self::new()
261 }
262}
263
264impl MockApiClient {
265 pub fn new() -> Self {
266 Self {
267 responses: Arc::new(Mutex::new(Vec::new())),
268 current_index: Arc::new(Mutex::new(0)),
269 }
270 }
271
272 pub fn add_response(&self, response: ApiResponse) {
273 if let Ok(mut responses) = self.responses.lock() {
274 responses.push(response);
275 }
276 }
277
278 pub fn reset(&self) {
279 if let Ok(mut responses) = self.responses.lock() {
280 responses.clear();
281 }
282 if let Ok(mut index) = self.current_index.lock() {
283 *index = 0;
284 }
285 }
286}
287
288impl ApiClient for MockApiClient {
289 fn execute(&self, _request: ApiRequest) -> Result<ApiResponse, ApiError> {
290 let responses = self
291 .responses
292 .lock()
293 .map_err(|_| ApiError::NetworkError("Failed to lock responses".to_string()))?;
294
295 let mut index = self
296 .current_index
297 .lock()
298 .map_err(|_| ApiError::NetworkError("Failed to lock index".to_string()))?;
299
300 if *index >= responses.len() {
301 return Err(ApiError::NetworkError(
302 "No more mock responses available".to_string(),
303 ));
304 }
305
306 let response = responses[*index].clone();
307 *index += 1;
308 Ok(response)
309 }
310}
311
312pub struct RequestBuilder {
314 method: HttpMethod,
315 url: String,
316 headers: HashMap<String, String>,
317 query_params: HashMap<String, String>,
318 body: Option<Vec<u8>>,
319}
320
321impl RequestBuilder {
322 pub fn new(method: HttpMethod, url: String) -> Self {
323 Self {
324 method,
325 url,
326 headers: HashMap::new(),
327 query_params: HashMap::new(),
328 body: None,
329 }
330 }
331
332 pub fn get(url: String) -> Self {
333 Self::new(HttpMethod::Get, url)
334 }
335
336 pub fn post(url: String) -> Self {
337 Self::new(HttpMethod::Post, url)
338 }
339
340 pub fn put(url: String) -> Self {
341 Self::new(HttpMethod::Put, url)
342 }
343
344 pub fn delete(url: String) -> Self {
345 Self::new(HttpMethod::Delete, url)
346 }
347
348 pub fn header(mut self, key: String, value: String) -> Self {
349 self.headers.insert(key, value);
350 self
351 }
352
353 pub fn headers(mut self, headers: HashMap<String, String>) -> Self {
354 self.headers.extend(headers);
355 self
356 }
357
358 pub fn query(mut self, key: String, value: String) -> Self {
359 self.query_params.insert(key, value);
360 self
361 }
362
363 pub fn json<T: Serialize + ?Sized>(mut self, body: &T) -> Result<Self, ApiError> {
364 let serialized =
365 serde_json::to_vec(body).map_err(|e| ApiError::SerializationError(e.to_string()))?;
366 self.body = Some(serialized);
367 self.headers
368 .insert("Content-Type".to_string(), "application/json".to_string());
369 Ok(self)
370 }
371
372 pub fn text(mut self, body: String) -> Self {
373 self.body = Some(body.into_bytes());
374 self.headers
375 .insert("Content-Type".to_string(), "text/plain".to_string());
376 self
377 }
378
379 pub fn build(self) -> ApiRequest {
380 ApiRequest {
381 method: self.method,
382 url: self.url,
383 headers: self.headers,
384 body: self.body,
385 query_params: self.query_params,
386 }
387 }
388}
389
390pub struct ApiService {
392 client: Box<dyn ApiClient + Send + Sync>,
393 config: ApiConfig,
394 metrics: Arc<Mutex<ApiMetrics>>,
395}
396
397impl ApiService {
398 pub fn new(client: Box<dyn ApiClient + Send + Sync>, config: ApiConfig) -> Self {
399 Self {
400 client,
401 config,
402 metrics: Arc::new(Mutex::new(ApiMetrics::default())),
403 }
404 }
405
406 pub fn with_mock() -> Self {
407 Self::new(Box::new(MockApiClient::new()), ApiConfig::default())
408 }
409
410 pub fn get(&self, endpoint: &str) -> RequestBuilder {
411 let url = format!(
412 "{}/{}",
413 self.config.base_url.trim_end_matches('/'),
414 endpoint.trim_start_matches('/')
415 );
416 RequestBuilder::get(url).headers(self.build_default_headers())
417 }
418
419 pub fn post(&self, endpoint: &str) -> RequestBuilder {
420 let url = format!(
421 "{}/{}",
422 self.config.base_url.trim_end_matches('/'),
423 endpoint.trim_start_matches('/')
424 );
425 RequestBuilder::post(url).headers(self.build_default_headers())
426 }
427
428 pub fn put(&self, endpoint: &str) -> RequestBuilder {
429 let url = format!(
430 "{}/{}",
431 self.config.base_url.trim_end_matches('/'),
432 endpoint.trim_start_matches('/')
433 );
434 RequestBuilder::put(url).headers(self.build_default_headers())
435 }
436
437 pub fn delete(&self, endpoint: &str) -> RequestBuilder {
438 let url = format!(
439 "{}/{}",
440 self.config.base_url.trim_end_matches('/'),
441 endpoint.trim_start_matches('/')
442 );
443 RequestBuilder::delete(url).headers(self.build_default_headers())
444 }
445
446 pub fn execute(&self, request: ApiRequest) -> Result<ApiResponse, ApiError> {
447 let start_time = Instant::now();
448
449 let request = self.apply_authentication(request)?;
451
452 let mut last_error = None;
454 for attempt in 0..=self.config.max_retries {
455 if attempt > 0 {
456 std::thread::sleep(self.config.retry_delay);
457 }
458
459 match self.client.execute(request.clone()) {
460 Ok(response) => {
461 let execution_time = start_time.elapsed();
462 self.record_metrics(&request, &response, execution_time);
463 return Ok(response);
464 }
465 Err(e) => {
466 last_error = Some(e);
467 if !self.should_retry(last_error.as_ref().unwrap()) {
468 break;
469 }
470 }
471 }
472 }
473
474 Err(last_error.unwrap())
475 }
476
477 pub fn get_metrics(&self) -> Option<ApiMetrics> {
478 self.metrics.lock().ok().map(|m| m.clone())
479 }
480
481 pub fn reset_metrics(&self) {
482 if let Ok(mut metrics) = self.metrics.lock() {
483 *metrics = ApiMetrics::default();
484 }
485 }
486
487 fn build_default_headers(&self) -> HashMap<String, String> {
488 let mut headers = self.config.headers.clone();
489 headers.insert("User-Agent".to_string(), self.config.user_agent.clone());
490 headers
491 }
492
493 fn apply_authentication(&self, mut request: ApiRequest) -> Result<ApiRequest, ApiError> {
494 if let Some(auth) = &self.config.authentication {
495 match auth {
496 Authentication::Bearer(token) => {
497 request =
498 request.with_header("Authorization".to_string(), format!("Bearer {token}"));
499 }
500 Authentication::ApiKey { key, header } => {
501 request = request.with_header(header.clone(), key.clone());
502 }
503 Authentication::Basic { username, password } => {
504 let credentials = base64::encode(format!("{username}:{password}"));
505 request = request
506 .with_header("Authorization".to_string(), format!("Basic {credentials}"));
507 }
508 Authentication::Custom { headers } => {
509 for (key, value) in headers {
510 request = request.with_header(key.clone(), value.clone());
511 }
512 }
513 }
514 }
515 Ok(request)
516 }
517
518 fn should_retry(&self, error: &ApiError) -> bool {
519 match error {
520 ApiError::NetworkError(_) => true,
521 ApiError::TimeoutError(_) => true,
522 ApiError::HttpError { status, .. } => {
523 *status >= 500 || *status == 429
525 }
526 _ => false,
527 }
528 }
529
530 fn record_metrics(
531 &self,
532 request: &ApiRequest,
533 response: &ApiResponse,
534 execution_time: Duration,
535 ) {
536 if let Ok(mut metrics) = self.metrics.lock() {
537 metrics.total_requests += 1;
538 if response.is_success() {
539 metrics.successful_requests += 1;
540 } else {
541 metrics.failed_requests += 1;
542 }
543 metrics.total_execution_time += execution_time;
544 metrics.average_response_time = Duration::from_nanos(
545 (metrics.total_execution_time.as_nanos() / metrics.total_requests as u128) as u64,
546 );
547
548 let method_stats = metrics
549 .method_stats
550 .entry(request.method)
551 .or_insert(MethodStats::default());
552 method_stats.requests += 1;
553 method_stats.total_time += execution_time;
554 method_stats.average_time = Duration::from_nanos(
555 (method_stats.total_time.as_nanos() / method_stats.requests as u128) as u64,
556 );
557 }
558 }
559}
560
561#[derive(Debug, Clone, Default)]
563pub struct ApiMetrics {
564 pub total_requests: u64,
565 pub successful_requests: u64,
566 pub failed_requests: u64,
567 pub total_execution_time: Duration,
568 pub average_response_time: Duration,
569 pub method_stats: HashMap<HttpMethod, MethodStats>,
570}
571
572#[derive(Debug, Clone, Default)]
573pub struct MethodStats {
574 pub requests: u64,
575 pub total_time: Duration,
576 pub average_time: Duration,
577}
578
579impl ApiMetrics {
580 pub fn success_rate(&self) -> f64 {
581 if self.total_requests == 0 {
582 0.0
583 } else {
584 self.successful_requests as f64 / self.total_requests as f64
585 }
586 }
587}
588
589impl fmt::Display for ApiMetrics {
590 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
591 writeln!(f, "API Metrics:")?;
592 writeln!(f, " Total Requests: {}", self.total_requests)?;
593 writeln!(f, " Success Rate: {:.2}%", self.success_rate() * 100.0)?;
594 writeln!(
595 f,
596 " Average Response Time: {:?}",
597 self.average_response_time
598 )?;
599
600 if !self.method_stats.is_empty() {
601 writeln!(f, " Method Statistics:")?;
602 for (method, stats) in &self.method_stats {
603 writeln!(
604 f,
605 " {}: {} requests, avg {:?}",
606 method, stats.requests, stats.average_time
607 )?;
608 }
609 }
610
611 Ok(())
612 }
613}
614
615pub struct MLApiPatterns;
617
618impl MLApiPatterns {
619 pub fn prediction_request<T: Serialize>(
621 service: &ApiService,
622 endpoint: &str,
623 features: &T,
624 ) -> Result<RequestBuilder, ApiError> {
625 service.post(endpoint).json(features)
626 }
627
628 pub fn batch_prediction_request<T: Serialize>(
630 service: &ApiService,
631 endpoint: &str,
632 batch_features: &[T],
633 ) -> Result<RequestBuilder, ApiError> {
634 service.post(endpoint).json(batch_features)
635 }
636
637 pub fn training_request<T: Serialize>(
639 service: &ApiService,
640 endpoint: &str,
641 training_data: &T,
642 model_config: &HashMap<String, serde_json::Value>,
643 ) -> Result<RequestBuilder, ApiError> {
644 let payload = serde_json::json!({
645 "data": training_data,
646 "config": model_config
647 });
648 service.post(endpoint).json(&payload)
649 }
650
651 pub fn model_status_request(service: &ApiService, model_id: &str) -> RequestBuilder {
653 service.get(&format!("models/{model_id}/status"))
654 }
655
656 pub fn health_check_request(service: &ApiService) -> RequestBuilder {
658 service.get("health")
659 }
660}
661
662mod base64 {
664 pub fn encode(input: String) -> String {
665 format!("base64({input})")
668 }
669}
670
671#[allow(non_snake_case)]
672#[cfg(test)]
673mod tests {
674 use super::*;
675 use serde_json::json;
676
677 #[test]
678 fn test_api_config() {
679 let config = ApiConfig::new("https://api.example.com".to_string())
680 .with_timeout(Duration::from_secs(60))
681 .with_retries(5, Duration::from_millis(500))
682 .with_header("Custom-Header".to_string(), "value".to_string());
683
684 assert_eq!(config.base_url, "https://api.example.com");
685 assert_eq!(config.timeout, Duration::from_secs(60));
686 assert_eq!(config.max_retries, 5);
687 assert_eq!(
688 config.headers.get("Custom-Header"),
689 Some(&"value".to_string())
690 );
691 }
692
693 #[test]
694 fn test_request_builder() {
695 let request = RequestBuilder::get("https://api.example.com/test".to_string())
696 .header("Content-Type".to_string(), "application/json".to_string())
697 .query("param".to_string(), "value".to_string())
698 .build();
699
700 assert_eq!(request.method, HttpMethod::Get);
701 assert_eq!(request.url, "https://api.example.com/test");
702 assert_eq!(
703 request.headers.get("Content-Type"),
704 Some(&"application/json".to_string())
705 );
706 assert_eq!(
707 request.query_params.get("param"),
708 Some(&"value".to_string())
709 );
710 assert_eq!(
711 request.build_url(),
712 "https://api.example.com/test?param=value"
713 );
714 }
715
716 #[test]
717 fn test_request_with_json_body() {
718 let data = json!({"name": "test", "value": 42});
719 let request = RequestBuilder::post("https://api.example.com/data".to_string())
720 .json(&data)
721 .unwrap()
722 .build();
723
724 assert_eq!(request.method, HttpMethod::Post);
725 assert!(request.body.is_some());
726 assert_eq!(
727 request.headers.get("Content-Type"),
728 Some(&"application/json".to_string())
729 );
730 }
731
732 #[test]
733 fn test_api_response() {
734 let body = b"{\"result\": \"success\"}".to_vec();
735 let response = ApiResponse::new(200, body, Duration::from_millis(100));
736
737 assert!(response.is_success());
738 assert_eq!(response.status_code, 200);
739 assert_eq!(response.execution_time, Duration::from_millis(100));
740
741 let text = response.text().unwrap();
742 assert_eq!(text, "{\"result\": \"success\"}");
743
744 let json: serde_json::Value = response.json().unwrap();
745 assert_eq!(json["result"], "success");
746 }
747
748 #[test]
749 fn test_mock_api_client() {
750 let client = MockApiClient::new();
751
752 let mock_response = ApiResponse::new(
753 200,
754 b"{\"data\": \"test\"}".to_vec(),
755 Duration::from_millis(50),
756 );
757 client.add_response(mock_response);
758
759 let request = ApiRequest::new(HttpMethod::Get, "https://api.example.com/test".to_string());
760 let response = client.execute(request).unwrap();
761
762 assert_eq!(response.status_code, 200);
763 assert_eq!(response.text().unwrap(), "{\"data\": \"test\"}");
764 }
765
766 #[test]
767 fn test_api_service() {
768 let service = ApiService::with_mock();
769
770 let request_builder = service.get("test");
772 let request = request_builder.build();
773
774 assert_eq!(request.method, HttpMethod::Get);
777 assert!(request.url.contains("test"));
778 }
779
780 #[test]
781 fn test_ml_api_patterns() {
782 let service = ApiService::with_mock();
783
784 let features = json!({"feature1": 1.0, "feature2": 2.0});
785 let request_builder =
786 MLApiPatterns::prediction_request(&service, "predict", &features).unwrap();
787 let request = request_builder.build();
788
789 assert_eq!(request.method, HttpMethod::Post);
790 assert!(request.url.contains("predict"));
791 assert!(request.body.is_some());
792 }
793
794 #[test]
795 fn test_authentication() {
796 let auth = Authentication::Bearer("test-token".to_string());
797 let config =
798 ApiConfig::new("https://api.example.com".to_string()).with_authentication(auth);
799
800 match config.authentication {
801 Some(Authentication::Bearer(token)) => assert_eq!(token, "test-token"),
802 _ => panic!("Expected Bearer authentication"),
803 }
804 }
805
806 #[test]
807 fn test_api_metrics() {
808 let mut metrics = ApiMetrics::default();
809 metrics.total_requests = 10;
810 metrics.successful_requests = 8;
811 metrics.failed_requests = 2;
812
813 assert_eq!(metrics.success_rate(), 0.8);
814
815 let display = metrics.to_string();
816 assert!(display.contains("Total Requests: 10"));
817 assert!(display.contains("Success Rate: 80.00%"));
818 }
819
820 #[test]
821 fn test_query_param_building() {
822 let request = ApiRequest::new(HttpMethod::Get, "https://api.example.com".to_string())
823 .with_query_param("param1".to_string(), "value1".to_string())
824 .with_query_param("param2".to_string(), "value2".to_string());
825
826 let url = request.build_url();
827 assert!(url.contains("param1=value1"));
828 assert!(url.contains("param2=value2"));
829 assert!(url.contains("?"));
830 assert!(url.contains("&"));
831 }
832}