Skip to main content

dingtalk_stream/transport/
http.rs

1//! HTTP 客户端封装
2
3use crate::error::{Error, Result};
4use reqwest::header::{HeaderMap, HeaderValue};
5use url::form_urlencoded;
6
7/// 默认 OpenAPI 端点
8const DEFAULT_OPENAPI_ENDPOINT: &str = "https://api.dingtalk.com";
9/// 文件上传端点
10const UPLOAD_ENDPOINT: &str = "https://oapi.dingtalk.com";
11/// 默认连接超时(秒)
12const DEFAULT_CONNECT_TIMEOUT_SECS: u64 = 10;
13/// 默认请求超时(秒)
14const DEFAULT_REQUEST_TIMEOUT_SECS: u64 = 30;
15
16/// HTTP 客户端
17#[derive(Clone)]
18pub struct HttpClient {
19    client: reqwest::Client,
20    openapi_endpoint: String,
21}
22
23impl HttpClient {
24    /// 创建新的 HTTP 客户端(使用默认超时:连接 10s,请求 30s)
25    pub fn new() -> Self {
26        Self::with_timeout(DEFAULT_CONNECT_TIMEOUT_SECS, DEFAULT_REQUEST_TIMEOUT_SECS)
27    }
28
29    /// 创建带自定义超时的 HTTP 客户端
30    pub fn with_timeout(connect_timeout_secs: u64, request_timeout_secs: u64) -> Self {
31        let openapi_endpoint = std::env::var("DINGTALK_OPENAPI_ENDPOINT")
32            .unwrap_or_else(|_| DEFAULT_OPENAPI_ENDPOINT.to_owned());
33        let client = reqwest::Client::builder()
34            .connect_timeout(std::time::Duration::from_secs(connect_timeout_secs))
35            .timeout(std::time::Duration::from_secs(request_timeout_secs))
36            .pool_idle_timeout(std::time::Duration::from_secs(90))
37            .tcp_keepalive(std::time::Duration::from_secs(60))
38            .build()
39            .expect("Failed to build HTTP client");
40        Self {
41            client,
42            openapi_endpoint,
43        }
44    }
45
46    /// 获取 OpenAPI 端点
47    pub fn openapi_endpoint(&self) -> &str {
48        &self.openapi_endpoint
49    }
50
51    /// 获取 User-Agent 字符串
52    fn user_agent() -> String {
53        format!(
54            "DingTalkStream/1.0 SDK/{} Rust/{}",
55            env!("CARGO_PKG_VERSION"),
56            rustc_version()
57        )
58    }
59
60    /// 构建带 access_token 的请求头
61    fn build_headers(access_token: Option<&str>) -> HeaderMap {
62        let mut headers = HeaderMap::new();
63        headers.insert("Content-Type", HeaderValue::from_static("application/json"));
64        headers.insert("Accept", HeaderValue::from_static("*/*"));
65        headers.insert(
66            "User-Agent",
67            HeaderValue::from_str(&Self::user_agent())
68                .unwrap_or_else(|_| HeaderValue::from_static("DingTalkStream/1.0")),
69        );
70        if let Some(token) = access_token {
71            if let Ok(val) = HeaderValue::from_str(token) {
72                headers.insert("x-acs-dingtalk-access-token", val);
73            }
74        }
75        headers
76    }
77
78    /// POST JSON 请求,返回解析后的 JSON
79    pub async fn post_json(
80        &self,
81        url: &str,
82        body: &serde_json::Value,
83        access_token: Option<&str>,
84    ) -> Result<serde_json::Value> {
85        let resp = self
86            .client
87            .post(url)
88            .headers(Self::build_headers(access_token))
89            .json(body)
90            .send()
91            .await?;
92
93        let status = resp.status();
94        let text = resp.text().await?;
95
96        if !status.is_success() {
97            return Err(Error::Connection(format!(
98                "POST {} failed with status {}: {}",
99                url, status, text
100            )));
101        }
102
103        serde_json::from_str(&text).map_err(Error::Json)
104    }
105
106    /// PUT JSON 请求,返回解析后的 JSON
107    pub async fn put_json(
108        &self,
109        url: &str,
110        body: &serde_json::Value,
111        access_token: Option<&str>,
112    ) -> Result<serde_json::Value> {
113        let resp = self
114            .client
115            .put(url)
116            .headers(Self::build_headers(access_token))
117            .json(body)
118            .send()
119            .await?;
120
121        let status = resp.status();
122        let text = resp.text().await?;
123
124        if !status.is_success() {
125            return Err(Error::Connection(format!(
126                "PUT {} failed with status {}: {}",
127                url, status, text
128            )));
129        }
130
131        serde_json::from_str(&text).or_else(|_| Ok(serde_json::Value::Null))
132    }
133
134    /// POST JSON 请求,返回原始 (status_code, body_text)
135    pub async fn post_json_raw(
136        &self,
137        url: &str,
138        body: &serde_json::Value,
139    ) -> Result<(u16, String)> {
140        let resp = self
141            .client
142            .post(url)
143            .header("Content-Type", "application/json")
144            .json(body)
145            .send()
146            .await?;
147
148        let status = resp.status().as_u16();
149        let text = resp.text().await?;
150        Ok((status, text))
151    }
152
153    /// GET 请求,返回字节内容
154    pub async fn get_bytes(&self, url: &str) -> Result<Vec<u8>> {
155        let resp = self
156            .client
157            .get(url)
158            .timeout(std::time::Duration::from_secs(300))
159            .send()
160            .await?;
161        resp.error_for_status_ref()
162            .map_err(|e| Error::Http(e.without_url()))?;
163        let bytes = resp.bytes().await?;
164        Ok(bytes.to_vec())
165    }
166
167    /// GET 请求,返回字节内容(带大小限制)
168    ///
169    /// 先检查 `Content-Length` 响应头,超过 `max_size` 则直接拒绝;
170    /// 下载过程中累计检查已读字节数,超限则中止。
171    pub async fn get_bytes_with_limit(&self, url: &str, max_size: u64) -> Result<Vec<u8>> {
172        use futures_util::StreamExt;
173
174        let resp = self
175            .client
176            .get(url)
177            .timeout(std::time::Duration::from_secs(300))
178            .send()
179            .await?;
180        resp.error_for_status_ref()
181            .map_err(|e| Error::Http(e.without_url()))?;
182
183        if let Some(len) = resp.content_length() {
184            if len > max_size {
185                return Err(Error::Handler(format!(
186                    "file too large: {len} bytes (limit: {max_size})"
187                )));
188            }
189        }
190
191        let mut stream = resp.bytes_stream();
192        let mut buf = Vec::new();
193        while let Some(chunk) = stream.next().await {
194            let chunk = chunk?;
195            buf.extend_from_slice(&chunk);
196            if buf.len() as u64 > max_size {
197                return Err(Error::Handler(format!(
198                    "download exceeded limit: {} bytes (limit: {max_size})",
199                    buf.len()
200                )));
201            }
202        }
203        Ok(buf)
204    }
205
206    /// 上传文件到钉钉
207    pub async fn upload_file(
208        &self,
209        access_token: &str,
210        content: &[u8],
211        filetype: &str,
212        filename: &str,
213        mimetype: &str,
214    ) -> Result<String> {
215        let encoded_token: String = form_urlencoded::Serializer::new(String::new())
216            .append_pair("access_token", access_token)
217            .finish();
218        let url = format!("{}/media/upload?{}", UPLOAD_ENDPOINT, encoded_token);
219
220        let part = reqwest::multipart::Part::bytes(content.to_vec())
221            .file_name(filename.to_owned())
222            .mime_str(mimetype)
223            .map_err(|e| Error::Handler(format!("invalid mime type: {e}")))?;
224
225        let form = reqwest::multipart::Form::new()
226            .text("type", filetype.to_owned())
227            .part("media", part);
228
229        let resp = self
230            .client
231            .post(&url)
232            .timeout(std::time::Duration::from_secs(300))
233            .multipart(form)
234            .send()
235            .await?;
236
237        let status = resp.status();
238        let text = resp.text().await?;
239
240        if status.as_u16() == 401 {
241            return Err(Error::Auth("upload returned 401".to_owned()));
242        }
243
244        if !status.is_success() {
245            return Err(Error::Connection(format!(
246                "upload failed with status {}: {}",
247                status, text
248            )));
249        }
250
251        let json: serde_json::Value = serde_json::from_str(&text)?;
252        json.get("media_id")
253            .and_then(|v| v.as_str())
254            .map(String::from)
255            .ok_or_else(|| Error::Handler(format!("upload failed, response: {text}")))
256    }
257
258    /// 发送原始 POST 请求(用于 open_connection)
259    pub async fn post_raw(&self, url: &str, body: &serde_json::Value) -> Result<serde_json::Value> {
260        let headers = Self::build_headers(None);
261        let resp = self
262            .client
263            .post(url)
264            .headers(headers)
265            .json(body)
266            .send()
267            .await?;
268
269        let status = resp.status();
270        let text = resp.text().await?;
271
272        if !status.is_success() {
273            return Err(Error::Connection(format!(
274                "POST {} failed with status {}: {}",
275                url, status, text
276            )));
277        }
278
279        serde_json::from_str(&text).map_err(Error::Json)
280    }
281}
282
283impl Default for HttpClient {
284    fn default() -> Self {
285        Self::new()
286    }
287}
288
289/// 获取 Rust 版本字符串
290fn rustc_version() -> &'static str {
291    // 编译时确定的版本
292    env!("CARGO_PKG_RUST_VERSION")
293}