1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
//! # wecom-agent
//!
//! `wecom-agent`封装了企业微信API的消息发送功能。
//!
//! ## 使用方法
//! ```rust
//! use wecom_agent::{
//!     message::{MessageBuilder, Text},
//!     MsgSendResponse, WecomAgent,
//! };
//! async fn example() {
//!     let content = Text::new("Hello from Wandering AI!".to_string());
//!     let msg = MessageBuilder::default()
//!         .to_users(vec!["robin", "tom"])
//!         .from_agent(42)
//!         .build(content)
//!         .expect("Massage should be built");
//!     let handle = tokio::spawn(async move {
//!         let wecom_agent = WecomAgent::new("your_corpid", "your_secret");
//!         let response = wecom_agent.send(msg).await;
//!     });
//! }
//! ```

mod error;
pub mod message;

use log::{debug, info, warn};
use serde::{Deserialize, Serialize};
use std::error::Error as StdError;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use tokio::sync::RwLock;

// 企业微信鉴权凭据
#[derive(Debug)]
struct AccessToken {
    value: Option<String>,
    timestamp: SystemTime,
    lifetime: Duration,
}

impl AccessToken {
    /// 获取凭据内容
    pub fn value(&self) -> Option<&String> {
        self.value.as_ref()
    }

    /// 更新凭据
    pub fn update(&mut self, token: &str, timestamp: SystemTime, lifetime: Duration) {
        self.value = Some(token.to_owned());
        self.timestamp = timestamp;
        self.lifetime = lifetime;
    }

    /// 凭据是否已过期
    pub fn expired(&self) -> bool {
        match SystemTime::now().duration_since(self.timestamp) {
            Ok(duration) => duration >= self.lifetime,
            Err(_) => false,
        }
    }

    /// 凭据将在N秒后过期。注意,若凭据已过期,将返回false。必要时配合`expired()`使用。
    pub fn expire_in(&self, n: u64) -> bool {
        match SystemTime::now().duration_since(self.timestamp) {
            Ok(duration) => (duration - self.lifetime) < Duration::from_secs(n),
            Err(_) => false,
        }
    }

    /// 获取token上一次更新时刻
    pub fn timestamp(&self) -> SystemTime {
        self.timestamp
    }
}

impl Default for AccessToken {
    fn default() -> Self {
        Self {
            value: None,
            timestamp: UNIX_EPOCH,
            lifetime: Duration::from_secs(7200),
        }
    }
}

/// 企业微信API的轻量封装
#[derive(Debug)]
pub struct WecomAgent {
    corp_id: String,
    secret: String,
    access_token: RwLock<AccessToken>,
    client: reqwest::Client,
}

impl WecomAgent {
    /// 创建一个Agent。注意此过程不会自动初始化access token。
    pub fn new(corp_id: &str, secret: &str) -> Self {
        Self {
            corp_id: String::from(corp_id),
            secret: String::from(secret),
            access_token: RwLock::new(AccessToken::default()),
            client: reqwest::Client::new(),
        }
    }

    /// 更新access_token。使用`backoff_seconds`设定休止时段。若距离上次更新时间短于此时长,
    /// 将返回频繁更新错误。
    pub async fn update_token(
        &self,
        backoff_seconds: u64,
    ) -> Result<(), Box<dyn StdError + Send + Sync>> {
        // 获取token写权限
        let mut access_token = self.access_token.write().await;

        // 企业微信服务器对高频的接口调用存在风控措施。因此需要管制接口调用频率。
        let seconds_since_last_update = SystemTime::now()
            .duration_since(access_token.timestamp())?
            .as_secs();
        if seconds_since_last_update < backoff_seconds {
            return Err(Box::new(error::Error::new(
                -9,
                format!("Access token更新过于频繁。上次更新于{seconds_since_last_update}秒前。"),
            )));
        }

        // Fetch a new token
        let url = format!(
            "https://qyapi.weixin.qq.com/cgi-bin/gettoken?corpid={}&corpsecret={}",
            self.corp_id, self.secret,
        );
        let response = reqwest::get(url)
            .await?
            .json::<AccessTokenResponse>()
            .await?;
        if response.errcode != 0 {
            return Err(Box::<error::Error>::new(error::Error::new(
                response.errcode,
                response.errmsg,
            )));
        };

        // Update token with a write lock
        access_token.update(
            &response.access_token,
            SystemTime::now(),
            Duration::from_secs(response.expires_in),
        );
        Ok(())
    }

    /// 发送应用消息
    pub async fn send<T>(&self, msg: T) -> Result<MsgSendResponse, Box<dyn StdError + Send + Sync>>
    where
        T: Serialize,
    {
        // 需要更新Token?
        let token_should_update: bool = {
            let access_token = self.access_token.read().await;
            access_token.value().is_none() || access_token.expire_in(300) || access_token.expired()
        };
        if token_should_update {
            warn!("Token invalid. Updating...");
            let result = self.update_token(10).await;
            if let Err(e) = result {
                return Err(e);
            }
            info!("Token updated");
        }

        // API地址
        let url = {
            let access_token = self.access_token.read().await;
            format!(
                "https://qyapi.weixin.qq.com/cgi-bin/message/send?access_token={}",
                access_token
                    .value()
                    .expect("Access token should not be None.")
            )
        };

        // 第一次发送
        debug!("Sending [try 1]...");
        let mut response: MsgSendResponse = self
            .client
            .post(&url)
            .json(&msg)
            .send()
            .await?
            .json::<MsgSendResponse>()
            .await?;

        // 微信服务器主动弃用了当前token?
        if response.error_code() == 40014 {
            warn!("Token invalid. Updating...");
            let result = self.update_token(10).await;
            if let Err(e) = result {
                return Err(e);
            }

            // 第二次发送
            debug!("Sending [try 2]...");
            response = self
                .client
                .post(&url)
                .json(&msg)
                .send()
                .await?
                .json::<MsgSendResponse>()
                .await?;
        };

        debug!("Sending [Done]");
        Ok(response)
    }
}

// 应用消息发送结果
#[derive(Deserialize)]
pub struct MsgSendResponse {
    errcode: i64,
    errmsg: String,
    invaliduser: Option<String>,
    invalidparty: Option<String>,
    invalidtag: Option<String>,
    unlicenseduser: Option<String>,
    msgid: String,
    response_code: Option<String>,
}

impl MsgSendResponse {
    pub fn is_error(&self) -> bool {
        self.errcode != 0
    }

    pub fn error_code(&self) -> i64 {
        self.errcode
    }

    pub fn error_msg(&self) -> &str {
        &self.errmsg
    }
}

// 获取Access Token时的返回结果
// 示例
// {
//     "errcode": 0,
//     "errmsg": "ok",
//     "access_token": "accesstoken000001",
//     "expires_in": 7200
// }
#[derive(Deserialize)]
struct AccessTokenResponse {
    errcode: i64,
    errmsg: String,
    access_token: String,
    expires_in: u64,
}