amazon_spapi/client/
client.rs

1use anyhow::Result;
2use reqwest::header::HeaderMap;
3use reqwest::Client;
4use serde::Deserialize;
5use std::collections::HashMap;
6use std::sync::Arc;
7use std::time::Duration;
8use tokio::sync::Mutex;
9
10use crate::apis::configuration::Configuration;
11use crate::client::{ApiEndpoint, ApiMethod, AuthClient, RateLimiter, SpapiConfig};
12
13// use amazon_spapi_gen::apis::configuration::Configuration;
14
15pub struct SpapiClient {
16    client: Client,
17    auth_client: Arc<Mutex<AuthClient>>,
18    config: SpapiConfig,
19    rate_limiter: RateLimiter,
20}
21
22impl SpapiClient {
23    /// Create a new SP API client with the given configuration
24    pub fn new(config: SpapiConfig) -> Result<Self> {
25        let user_agent = if let Some(ua) = &config.user_agent {
26            ua.clone()
27        } else {
28            // Default user agent if not provided
29            Self::get_user_agent()
30        };
31
32        let client = Client::builder()
33            .timeout(std::time::Duration::from_secs(
34                config.timeout_sec.unwrap_or(30),
35            ))
36            .user_agent(&user_agent)
37            .build()?;
38
39        let auth_client = AuthClient::new(
40            config.client_id.clone(),
41            config.client_secret.clone(),
42            config.refresh_token.clone(),
43            &user_agent,
44        )?;
45
46        // Initialize rate limiter if enabled
47        let rate_limiter = RateLimiter::new();
48
49        Ok(Self {
50            client, //: Client::new(),
51            auth_client: Arc::new(Mutex::new(auth_client)),
52            config,
53            rate_limiter,
54        })
55    }
56
57    pub fn limiter(&self) -> &RateLimiter {
58        &self.rate_limiter
59    }
60
61    /// Get default user agent for the client
62    pub fn get_user_agent() -> String {
63        let platform = format!("{}/{}", std::env::consts::OS, std::env::consts::ARCH);
64        format!(
65            "amazon-spapi/v{} (Language=Rust; Platform={})",
66            env!("CARGO_PKG_VERSION"),
67            platform
68        )
69    }
70
71    /// Get the base URL for the client
72    pub fn get_base_url(&self) -> String {
73        if self.config.sandbox {
74            format!("https://sandbox.sellingpartnerapi-na.amazon.com")
75        } else {
76            format!("https://sellingpartnerapi-na.amazon.com")
77        }
78    }
79
80    /// Get the region for the client
81    pub fn get_marketplace_id(&self) -> &str {
82        &self.config.marketplace_id
83    }
84
85    /// Get access token from the auth client
86    pub async fn get_access_token(&self) -> Result<String> {
87        let mut auth_client = self.auth_client.lock().await;
88        auth_client.get_access_token().await
89    }
90
91    /// Check if the client is in sandbox mode
92    pub fn is_sandbox(&self) -> bool {
93        self.config.sandbox
94    }
95
96    /// Make a request to the SP API
97    pub async fn request(
98        &self,
99        endpoint: &ApiEndpoint,
100        query: Option<Vec<(String, String)>>,
101        header: Option<Vec<(&'static str, String)>>,
102        body: Option<&str>,
103    ) -> Result<String> {
104        // Get access token
105        let access_token = {
106            let mut auth_client = self.auth_client.lock().await;
107            auth_client.get_access_token().await?
108        };
109
110        let full_url = if query.is_none() {
111            format!("{}{}", self.get_base_url(), endpoint.get_path())
112        } else {
113            let query_str = serde_urlencoded::to_string(&query)?;
114            format!(
115                "{}{}?{}",
116                self.get_base_url(),
117                endpoint.get_path(),
118                query_str
119            )
120        };
121
122        log::debug!("Making {} request to: {}", endpoint.method, full_url);
123
124        // Create initial headers
125        let mut headers = HeaderMap::new();
126        headers.insert("Content-Type", "application/json; charset=utf-8".parse()?);
127        headers.insert("host", "sellingpartnerapi-na.amazon.com".parse()?);
128        headers.insert("x-amz-access-token", access_token.parse()?);
129        headers.insert(
130            "x-amz-date",
131            // 时间格式 YYYYMMDDTHHMMSSZ
132            chrono::Utc::now()
133                .format("%Y%m%dT%H%M%SZ")
134                .to_string()
135                .parse()?,
136        );
137        headers.insert("user-agent", Self::get_user_agent().parse()?);
138        if let Some(custom_headers) = header {
139            for (key, value) in custom_headers {
140                headers.insert(key, value.parse()?);
141            }
142        }
143
144        // Build the request
145        let mut request_builder = match endpoint.method {
146            ApiMethod::Get => self.client.get(&full_url),
147            ApiMethod::Post => self.client.post(&full_url),
148            ApiMethod::Put => self.client.put(&full_url),
149            ApiMethod::Delete => self.client.delete(&full_url),
150            ApiMethod::Patch => self.client.patch(&full_url),
151        };
152
153        // Add headers
154        request_builder = request_builder.headers(headers);
155
156        // Add query parameters if provided
157        if let Some(query_params) = query {
158            request_builder = request_builder.query(&query_params);
159        }
160
161        // Add body if provided
162        if let Some(body_content) = body {
163            request_builder = request_builder.body(body_content.to_string());
164        }
165
166        // Apply rate limiting if enabled
167        let limiter = self
168            .rate_limiter
169            .wait(&endpoint.rate_limit_key(), endpoint.rate, endpoint.burst)
170            .await?;
171
172        let response = request_builder.send().await;
173
174        limiter.mark_response().await;
175
176        // // Record the response time for rate limiting
177        // self.rate_limiter
178        //     .record_response(&endpoint.rate_limit_key())
179        //     .await?;
180
181        let response = response?;
182        log::debug!("Response status: {}", response.status());
183
184        let response_status = response.status();
185        if response_status.is_success() {
186            let text = response.text().await?;
187            Ok(text)
188        } else {
189            let error_text = response.text().await?;
190            Err(anyhow::anyhow!(
191                "Request {} failed with status {}: {}",
192                endpoint.get_path(),
193                response_status,
194                error_text
195            ))
196        }
197    }
198
199    /// Upload content to the feed document URL (direct S3 upload)
200    pub async fn upload(&self, url: &str, content: &str, content_type: &str) -> Result<()> {
201        let response = self
202            .client
203            .put(url)
204            .header("Content-Type", content_type)
205            .body(content.to_string())
206            .send()
207            .await?;
208
209        if response.status().is_success() {
210            log::info!("Feed document content uploaded successfully");
211            Ok(())
212        } else {
213            let status = response.status();
214            let error_text = response.text().await?;
215            Err(anyhow::anyhow!(
216                "Failed to upload feed document content: {} - Response: {}",
217                status,
218                error_text
219            ))
220        }
221    }
222
223    /// Download content from a feed document URL
224    pub async fn download(&self, url: &str) -> Result<String> {
225        let response = self.get_http_client().get(url).send().await?;
226
227        if response.status().is_success() {
228            let content = response.text().await?;
229            log::info!("Feed document content downloaded successfully");
230            Ok(content)
231        } else {
232            let status = response.status();
233            let error_text = response.text().await?;
234            Err(anyhow::anyhow!(
235                "Failed to download feed document content: {} - Response: {}",
236                status,
237                error_text
238            ))
239        }
240    }
241
242    /// Check if rate limiting is enabled and get token status
243    pub async fn get_rate_limit_status(&self) -> Result<HashMap<String, (f64, f64, u32)>> {
244        Ok(self.rate_limiter.get_token_status().await?)
245    }
246
247    /// Check if a token is available for a specific endpoint without consuming it
248    pub async fn check_rate_limit_availability(&self, endpoint_id: &String) -> Result<bool> {
249        Ok(self
250            .rate_limiter
251            .check_token_availability(endpoint_id)
252            .await?)
253    }
254
255    /// Refresh the access token if needed
256    pub async fn refresh_access_token_if_needed(&self) -> Result<()> {
257        let mut auth_client = self.auth_client.lock().await;
258        if !auth_client.is_token_valid() {
259            auth_client.refresh_access_token().await?;
260        }
261        Ok(())
262    }
263
264    /// Force refresh the access token
265    pub async fn force_refresh_token(&self) -> Result<()> {
266        let mut auth_client = self.auth_client.lock().await;
267        auth_client.refresh_access_token().await?;
268        Ok(())
269    }
270
271    /// Get access to the underlying HTTP client for direct requests
272    pub fn get_http_client(&self) -> &Client {
273        &self.client
274    }
275
276    /// Create a new configuration for the generated APIs
277    /// This function refreshes the access token and sets up the configuration
278    pub async fn create_configuration(&self) -> Result<Configuration> {
279        let mut headers = reqwest::header::HeaderMap::new();
280        headers.insert("Content-Type", "application/json; charset=utf-8".parse()?);
281        headers.insert("host", "sellingpartnerapi-na.amazon.com".parse()?);
282        headers.insert(
283            "x-amz-access-token",
284            self.get_access_token().await?.parse()?,
285        );
286        headers.insert(
287            "x-amz-date",
288            chrono::Utc::now()
289                .format("%Y%m%dT%H%M%SZ")
290                .to_string()
291                .parse()?,
292        );
293        headers.insert(
294            "user-agent",
295            self.config
296                .user_agent
297                .clone()
298                .unwrap_or_else(|| Self::get_user_agent())
299                .parse()?,
300        );
301
302        let http_client = reqwest::Client::builder()
303            .timeout(std::time::Duration::from_secs(
304                self.config.timeout_sec.unwrap_or(30),
305            ))
306            .default_headers(headers)
307            .build()?;
308
309        let configuration = Configuration {
310            base_path: self.get_base_url(),
311            client: http_client,
312            ..Default::default()
313        };
314        Ok(configuration)
315    }
316
317    pub fn from_json<'a, T>(s: &'a str) -> Result<T>
318    where
319        T: Deserialize<'a>,
320    {
321        serde_json::from_str(s).map_err(|e| anyhow::anyhow!("Failed to parse JSON: {}: {}", e, s))
322    }
323}