1use std::sync::Arc;
2
3use derive_builder::Builder;
4use serde::{Deserialize, Serialize};
5use serde_repr::Serialize_repr;
6use tokio::sync::RwLock;
7
8pub struct Client {
9 agentid: i64,
10 cli: reqwest::Client,
11 token_url: String,
12 token: Arc<RwLock<Option<TokenResponse>>>,
13}
14
15#[derive(Deserialize, Debug)]
16pub struct Response {
17 pub errcode: i32,
18 pub errmsg: String,
19 pub invaliduser: Option<String>,
20 pub invalidparty: Option<String>,
21 pub invalidtag: Option<String>,
22 pub msgid: String,
23 pub response_code: Option<String>,
24}
25
26impl Client {
27 const MAX_TOKEN: i32 = 1000;
28
29 pub async fn new(
30 client_id: &str,
31 client_secret: &str,
32 agentid: i64,
33 ) -> Result<Client, super::Error> {
34 let cli = reqwest::Client::builder()
35 .build()
36 .map_err(|e| super::InnerError::Http(e.to_string()))?;
37
38 let token_url = format!(
39 "https://qyapi.weixin.qq.com/cgi-bin/gettoken?corpid={}&corpsecret={}",
40 client_id, client_secret
41 );
42 let client = Client {
43 cli,
44 agentid,
45 token_url,
46 token: Default::default(),
47 };
48
49 client.request_token().await?;
50
51 Ok(client)
52 }
53
54 async fn request_token(&self) -> Result<TokenResponse, super::Error> {
55 let token = self
56 .cli
57 .get(&self.token_url)
58 .send()
59 .await
60 .map_err(|e| super::RetryError::Auth(e.to_string()))?
61 .json::<TokenResponse>()
62 .await
63 .map_err(|e| super::RetryError::Auth(e.to_string()))?;
64 if token.errcode != 0 {
65 return Err(super::RetryError::Auth(token.errmsg).into());
66 }
67 self.set_token(token.clone()).await;
68 Ok(token)
69 }
70
71 async fn set_token(&self, mut token: TokenResponse) {
72 let expires_in = chrono::Utc::now().timestamp() as i64 + token.expires_in;
73 token.expires_in = expires_in;
74 *(self.token.write().await) = Some(token);
75 }
76
77 fn valid_token(&self, token: &TokenResponse) -> bool {
81 if token.expires_in >= chrono::Utc::now().timestamp() as i64 {
82 return false;
83 }
84 true
85 }
86
87 pub fn agentid(&self) -> i64 {
92 self.agentid
93 }
94}
95
96#[async_trait::async_trait]
97impl<'b> super::Pusher<'b, Message<'b>, Response> for Client {
98 async fn push(&self, msg: &'b Message) -> Result<Response, crate::Error> {
99 let token = self.token.clone();
100
101 let token = token.read().await;
102
103 let token = match token.clone() {
104 Some(token) => token.clone(),
105 None => match self.request_token().await {
106 Ok(token) => token,
107 Err(_e) => return Err(super::RetryError::Auth("".to_string()).into()),
109 },
110 };
111
112 if self.valid_token(&token) {
113 return Err(super::RetryError::Auth("token expired or invalid".to_string()).into());
114 }
115
116 let text = serde_json::to_string(msg).unwrap();
117
118 let text = text.replace("\"agentid\":0", &format!("\"agentid\":{}", self.agentid));
119
120 let resp = self
121 .cli
122 .post("https://qyapi.weixin.qq.com/cgi-bin/message/send")
123 .query(&[("access_token", &token.access_token)])
124 .header("Content-Type", "application/json;encode=utf-8")
125 .body(text)
126 .send()
127 .await;
128
129 match resp {
130 Ok(resp) => match resp.error_for_status() {
131 Ok(resp) => Ok(resp.json::<Response>().await.unwrap()),
132 Err(e) => Err(super::InnerError::Http(e.to_string()).into()),
133 },
134 Err(e) => Err(super::InnerError::Http(e.to_string()).into()),
135 }
136 }
137}
138
139#[derive(Debug, Deserialize, Clone)]
140struct TokenResponse {
141 errcode: i32,
142 errmsg: String,
143 access_token: String,
144 expires_in: i64,
145}
146
147#[derive(Debug, Serialize_repr, Clone)]
148#[repr(u8)]
149pub enum Bool {
150 False = 0,
151 True = 1,
152}
153
154impl From<Bool> for bool {
155 fn from(b: Bool) -> Self {
156 match b {
157 Bool::False => false,
158 Bool::True => true,
159 }
160 }
161}
162
163impl From<bool> for Bool {
164 fn from(b: bool) -> Self {
165 match b {
166 true => Self::True,
167 false => Self::False,
168 }
169 }
170}
171
172#[derive(Debug, Serialize, Builder)]
173pub struct Message<'a> {
174 #[serde(flatten)]
175 pub to: To<'a>,
176 pub msgtype: MsgType,
177 #[builder(default)]
178 pub agentid: i64,
179 #[builder(default)]
180 #[serde(skip_serializing_if = "Option::is_none")]
181 pub safe: Option<Bool>,
182 #[builder(default)]
183 #[serde(skip_serializing_if = "Option::is_none")]
184 pub enable_duplicate_check: Option<Bool>,
185 #[builder(default)]
186 #[serde(skip_serializing_if = "Option::is_none")]
187 pub duplicate_check_interval: Option<Bool>,
188 #[builder(default)]
189 #[serde(skip_serializing_if = "Option::is_none")]
190 pub enable_id_trans: Option<Bool>,
191 #[serde(flatten, skip_serializing_if = "Option::is_none")]
192 #[builder(setter(custom), default)]
193 pub inner: Option<InnerMesssage<'a>>,
194}
195
196impl<'a> MessageBuilder<'a> {
197 pub fn inner(&mut self, value: InnerMesssage<'a>) -> &mut Self {
198 match value {
199 InnerMesssage::Text { .. } => self.msgtype = Some(MsgType::Text),
200 InnerMesssage::Markdown { .. } => self.msgtype = Some(MsgType::Markdown),
201 InnerMesssage::Textcard { .. } => self.msgtype = Some(MsgType::Textcard),
202 }
203 self.inner = Some(Some(value));
204 self
205 }
206}
207
208#[derive(Debug, Serialize, Clone)]
209#[serde(rename_all = "snake_case")]
210pub enum To<'a> {
211 Touser(String),
212 Toparty(&'a str),
213 Totag(&'a str),
214}
215
216#[derive(Debug, Serialize, Clone, Default)]
217#[serde(rename_all = "lowercase")]
218pub enum MsgType {
219 #[default]
220 Text,
221 Markdown,
222 Textcard,
223}
224
225#[derive(Debug, Serialize, Clone)]
226#[serde(rename_all = "snake_case")]
227pub enum InnerMesssage<'a> {
228 Text {
229 content: &'a str,
230 },
231 Markdown {
232 content: &'a str,
233 },
234 Textcard {
235 title: &'a str,
236 description: &'a str,
237 url: &'a str,
238 btntxt: Option<&'a str>,
239 },
240}
241
242mod tests {
243 use crate::Pusher;
244
245 #[test]
246 fn test_msg_builder() {
247 use super::*;
248 let msg = MessageBuilder::default()
249 .to(To::Totag("()"))
250 .agentid(123)
251 .inner(InnerMesssage::Markdown { content: "()" })
252 .build();
253 assert_eq!(msg.is_ok(), true);
254
255 let js = serde_json::to_string_pretty(&msg.unwrap()).unwrap();
256 println!("{js}");
257
258 let msg = MessageBuilder::default()
259 .to(To::Totag("()"))
260 .inner(InnerMesssage::Markdown { content: "()" })
261 .build();
262 assert_eq!(msg.is_err(), true);
263 }
264
265 #[tokio::test]
266 async fn test_client() {
267 use super::*;
268
269 let client_id = std::env::var("WECOM_CLIENT_ID").unwrap();
270 let client_secret = std::env::var("WECOM_CLIENT_SECRET").unwrap();
271 let agent_id = std::env::var("WECOM_AGENT_ID")
272 .unwrap()
273 .parse::<i64>()
274 .expect("");
275
276 let cli = Client::new(&client_id, &client_secret, agent_id)
277 .await
278 .unwrap();
279
280 let msg = MessageBuilder::default()
281 .to(To::Touser("nieaowei".to_string()))
282 .agentid(cli.agentid())
283 .inner(InnerMesssage::Text {
284 content: "hello harmongy",
285 })
286 .build()
287 .unwrap();
288 let resp = cli.push(&msg).await;
289 println!("{resp:?}");
290
291 assert_eq!(resp.is_ok(), true);
292 }
293}