Skip to main content

dingtalk_stream/transport/
http.rs

1//! HTTP 客户端封装
2
3use crate::error::{Error, Result};
4use reqwest::header::{HeaderMap, HeaderValue};
5use url::form_urlencoded;
6
7/// 默认 OpenAPI 端点
8const DEFAULT_OPENAPI_ENDPOINT: &str = "https://api.dingtalk.com";
9/// 文件上传端点
10const UPLOAD_ENDPOINT: &str = "https://oapi.dingtalk.com";
11
12/// HTTP 客户端
13#[derive(Clone)]
14pub struct HttpClient {
15    client: reqwest::Client,
16    openapi_endpoint: String,
17}
18
19impl HttpClient {
20    /// 创建新的 HTTP 客户端
21    pub fn new() -> Self {
22        let openapi_endpoint = std::env::var("DINGTALK_OPENAPI_ENDPOINT")
23            .unwrap_or_else(|_| DEFAULT_OPENAPI_ENDPOINT.to_owned());
24
25        Self {
26            client: reqwest::Client::new(),
27            openapi_endpoint,
28        }
29    }
30
31    /// 获取 OpenAPI 端点
32    pub fn openapi_endpoint(&self) -> &str {
33        &self.openapi_endpoint
34    }
35
36    /// 获取 User-Agent 字符串
37    fn user_agent() -> String {
38        format!(
39            "DingTalkStream/1.0 SDK/{} Rust/{}",
40            env!("CARGO_PKG_VERSION"),
41            rustc_version()
42        )
43    }
44
45    /// 构建带 access_token 的请求头
46    fn build_headers(access_token: Option<&str>) -> HeaderMap {
47        let mut headers = HeaderMap::new();
48        headers.insert("Content-Type", HeaderValue::from_static("application/json"));
49        headers.insert("Accept", HeaderValue::from_static("*/*"));
50        headers.insert(
51            "User-Agent",
52            HeaderValue::from_str(&Self::user_agent())
53                .unwrap_or_else(|_| HeaderValue::from_static("DingTalkStream/1.0")),
54        );
55        if let Some(token) = access_token {
56            if let Ok(val) = HeaderValue::from_str(token) {
57                headers.insert("x-acs-dingtalk-access-token", val);
58            }
59        }
60        headers
61    }
62
63    /// POST JSON 请求,返回解析后的 JSON
64    pub async fn post_json(
65        &self,
66        url: &str,
67        body: &serde_json::Value,
68        access_token: Option<&str>,
69    ) -> Result<serde_json::Value> {
70        let resp = self
71            .client
72            .post(url)
73            .headers(Self::build_headers(access_token))
74            .json(body)
75            .send()
76            .await?;
77
78        let status = resp.status();
79        let text = resp.text().await?;
80
81        if !status.is_success() {
82            return Err(Error::Connection(format!(
83                "POST {} failed with status {}: {}",
84                url, status, text
85            )));
86        }
87
88        serde_json::from_str(&text).map_err(Error::Json)
89    }
90
91    /// PUT JSON 请求,返回解析后的 JSON
92    pub async fn put_json(
93        &self,
94        url: &str,
95        body: &serde_json::Value,
96        access_token: Option<&str>,
97    ) -> Result<serde_json::Value> {
98        let resp = self
99            .client
100            .put(url)
101            .headers(Self::build_headers(access_token))
102            .json(body)
103            .send()
104            .await?;
105
106        let status = resp.status();
107        let text = resp.text().await?;
108
109        if !status.is_success() {
110            return Err(Error::Connection(format!(
111                "PUT {} failed with status {}: {}",
112                url, status, text
113            )));
114        }
115
116        serde_json::from_str(&text).or_else(|_| Ok(serde_json::Value::Null))
117    }
118
119    /// POST JSON 请求,返回原始 (status_code, body_text)
120    pub async fn post_json_raw(
121        &self,
122        url: &str,
123        body: &serde_json::Value,
124    ) -> Result<(u16, String)> {
125        let resp = self
126            .client
127            .post(url)
128            .header("Content-Type", "application/json")
129            .json(body)
130            .send()
131            .await?;
132
133        let status = resp.status().as_u16();
134        let text = resp.text().await?;
135        Ok((status, text))
136    }
137
138    /// GET 请求,返回字节内容
139    pub async fn get_bytes(&self, url: &str) -> Result<Vec<u8>> {
140        let resp = self.client.get(url).send().await?;
141        resp.error_for_status_ref()
142            .map_err(|e| Error::Http(e.without_url()))?;
143        let bytes = resp.bytes().await?;
144        Ok(bytes.to_vec())
145    }
146
147    /// GET 请求,返回字节内容(带大小限制)
148    ///
149    /// 先检查 `Content-Length` 响应头,超过 `max_size` 则直接拒绝;
150    /// 下载过程中累计检查已读字节数,超限则中止。
151    pub async fn get_bytes_with_limit(&self, url: &str, max_size: u64) -> Result<Vec<u8>> {
152        use futures_util::StreamExt;
153
154        let resp = self.client.get(url).send().await?;
155        resp.error_for_status_ref()
156            .map_err(|e| Error::Http(e.without_url()))?;
157
158        if let Some(len) = resp.content_length() {
159            if len > max_size {
160                return Err(Error::Handler(format!(
161                    "file too large: {len} bytes (limit: {max_size})"
162                )));
163            }
164        }
165
166        let mut stream = resp.bytes_stream();
167        let mut buf = Vec::new();
168        while let Some(chunk) = stream.next().await {
169            let chunk = chunk?;
170            buf.extend_from_slice(&chunk);
171            if buf.len() as u64 > max_size {
172                return Err(Error::Handler(format!(
173                    "download exceeded limit: {} bytes (limit: {max_size})",
174                    buf.len()
175                )));
176            }
177        }
178        Ok(buf)
179    }
180
181    /// 上传文件到钉钉
182    pub async fn upload_file(
183        &self,
184        access_token: &str,
185        content: &[u8],
186        filetype: &str,
187        filename: &str,
188        mimetype: &str,
189    ) -> Result<String> {
190        let encoded_token: String = form_urlencoded::Serializer::new(String::new())
191            .append_pair("access_token", access_token)
192            .finish();
193        let url = format!("{}/media/upload?{}", UPLOAD_ENDPOINT, encoded_token);
194
195        let part = reqwest::multipart::Part::bytes(content.to_vec())
196            .file_name(filename.to_owned())
197            .mime_str(mimetype)
198            .map_err(|e| Error::Handler(format!("invalid mime type: {e}")))?;
199
200        let form = reqwest::multipart::Form::new()
201            .text("type", filetype.to_owned())
202            .part("media", part);
203
204        let resp = self.client.post(&url).multipart(form).send().await?;
205
206        let status = resp.status();
207        let text = resp.text().await?;
208
209        if status.as_u16() == 401 {
210            return Err(Error::Auth("upload returned 401".to_owned()));
211        }
212
213        if !status.is_success() {
214            return Err(Error::Connection(format!(
215                "upload failed with status {}: {}",
216                status, text
217            )));
218        }
219
220        let json: serde_json::Value = serde_json::from_str(&text)?;
221        json.get("media_id")
222            .and_then(|v| v.as_str())
223            .map(String::from)
224            .ok_or_else(|| Error::Handler(format!("upload failed, response: {text}")))
225    }
226
227    /// 发送原始 POST 请求(用于 open_connection)
228    pub async fn post_raw(&self, url: &str, body: &serde_json::Value) -> Result<serde_json::Value> {
229        let headers = Self::build_headers(None);
230        let resp = self
231            .client
232            .post(url)
233            .headers(headers)
234            .json(body)
235            .send()
236            .await?;
237
238        let status = resp.status();
239        let text = resp.text().await?;
240
241        if !status.is_success() {
242            return Err(Error::Connection(format!(
243                "POST {} failed with status {}: {}",
244                url, status, text
245            )));
246        }
247
248        serde_json::from_str(&text).map_err(Error::Json)
249    }
250}
251
252impl Default for HttpClient {
253    fn default() -> Self {
254        Self::new()
255    }
256}
257
258/// 获取 Rust 版本字符串
259fn rustc_version() -> &'static str {
260    // 编译时确定的版本
261    env!("CARGO_PKG_RUST_VERSION")
262}