dingtalk_stream_sdk_rust/
lib.rs1use anyhow::{bail, Result};
38use async_broadcast::{Receiver, Sender};
39use chrono::{DateTime, Duration, Local};
40use down::{ClientDownStream, EventData, RobotRecvMessage};
41use futures::{stream::SplitStream, Future, StreamExt};
42use log::{debug, error, info, trace, warn};
43use native_tls::TlsConnector;
44use reqwest::{header::ACCEPT, ClientBuilder};
45use serde::{Deserialize, Serialize};
46use std::sync::{
47 atomic::{AtomicBool, Ordering},
48 Arc, Mutex, RwLock,
49};
50use tokio::{net::TcpStream, sync::Notify, time::sleep};
51use tokio_tungstenite::{
52 connect_async_tls_with_config,
53 tungstenite::{Error, Message},
54 Connector, MaybeTlsStream, WebSocketStream,
55};
56use up::{EventAckData, Sink};
57
58pub mod down;
59pub mod group;
60pub mod up;
61pub use dingtalk_stream_sdk_rust_macro::action_card;
62
63#[derive(Debug)]
67pub struct Client {
68 pub config: Arc<Mutex<ClientConfig>>,
70 client: reqwest::Client,
71 rx: Receiver<Arc<ClientDownStream>>,
72 tx: Sender<Arc<ClientDownStream>>,
73 on_event_callback: EventCallback,
74 sink: tokio::sync::Mutex<Option<Sink>>,
75 alive: AtomicBool,
76 user_exit: AtomicBool,
77 aborting: Arc<Notify>,
78}
79
80struct EventCallback(RwLock<Box<dyn Fn(EventData) -> EventAckData + Send + Sync>>);
81
82impl std::fmt::Debug for EventCallback {
83 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
84 f.debug_tuple("EventCallback").finish()
85 }
86}
87
88impl Client {
89 pub fn new(
91 client_id: impl Into<String>,
92 client_secret: impl Into<String>,
93 ) -> Result<Arc<Self>> {
94 let client_id = client_id.into();
95 let client_secret = client_secret.into();
96 let (tx, rx) = async_broadcast::broadcast(32);
97 Ok(Arc::new(Self {
98 config: Arc::new(Mutex::new(ClientConfig {
99 client_id,
100 client_secret,
101 ..Default::default()
102 })),
103 client: ClientBuilder::new()
104 .no_proxy()
105 .danger_accept_invalid_certs(true)
106 .build()?,
107 tx,
108 rx,
109 sink: tokio::sync::Mutex::new(None),
110 on_event_callback: EventCallback(RwLock::new(Box::new(|p| {
111 info!("default event callback, event received: {:?}", p);
112 EventAckData::default()
113 }))),
114 alive: AtomicBool::new(false),
115 user_exit: AtomicBool::new(false),
116 aborting: Arc::new(Notify::new()),
117 }))
118 }
119
120 pub fn ua(self: Arc<Self>, value: impl Into<String>) -> Arc<Self> {
122 self.config.lock().unwrap().ua = value.into();
123 self
124 }
125
126 pub fn keep_alive(self: Arc<Self>, value: i64) -> Arc<Self> {
129 self.config.lock().unwrap().heartbeat_interval = value;
130 self
131 }
132
133 pub fn reconnect(self: Arc<Self>, value: i64) -> Arc<Self> {
136 self.config.lock().unwrap().reconnect_interval = value;
137 self
138 }
139
140 pub fn register_all_event_listener<P>(self: Arc<Self>, on_event_received: P) -> Arc<Self>
143 where
144 P: Fn(EventData) -> EventAckData + Send + Sync + 'static,
145 {
146 *self.on_event_callback.0.write().unwrap() = Box::new(on_event_received);
147 self
148 }
149
150 pub fn register_callback_listener<P, F>(
152 self: Arc<Self>,
153 event_id: impl AsRef<str>,
154 callback: P,
155 ) -> Arc<Self>
156 where
157 P: Fn(Arc<Self>, RobotRecvMessage) -> F + Send + 'static,
158 F: Future<Output = Result<()>> + Send,
159 {
160 let event_id = event_id.as_ref();
161 {
162 let mut config = self.config.lock().unwrap();
163 if !config
164 .subscriptions
165 .iter()
166 .any(|s| s.topic == event_id && s.r#type == "CALLBACK")
167 {
168 config.subscriptions.push(Subscription {
169 topic: event_id.to_owned(),
170 r#type: "CALLBACK".to_owned(),
171 });
172 }
173 }
174
175 tokio::spawn({
176 let mut rx = self.rx.clone();
177 let s = self.clone();
178 async move {
179 while let Ok(msg) = rx.recv().await {
180 match serde_json::from_str(&msg.data) {
181 Ok(msg) => {
182 if let Err(e) = callback(s.clone(), msg).await {
183 error!("callback error: {:?}", e);
184 }
185 }
186 Err(e) => {
187 error!("can not parse data: {:?}", e);
188 }
189 }
190 }
191 }
192 });
193
194 self
195 }
196
197 pub(crate) async fn token(&self) -> Result<String> {
198 let (access_token, token_expires_in) = {
199 let config = self.config.lock().unwrap();
200 (config.access_token.clone(), config.token_expires_in)
201 };
202
203 Ok(if Local::now() > token_expires_in {
204 debug!("token expired, get token again");
205 self.get_token().await?
206 } else {
207 access_token
208 })
209 }
210
211 async fn get_token(&self) -> Result<String> {
212 let url = {
213 let config = self.config.lock().unwrap();
214 debug!("get connect endpoint by config {:#?}", *config);
215 format!(
216 "{GET_TOKEN_URL}?appkey={}&appsecret={}",
217 config.client_id, config.client_secret
218 )
219 };
220 let response = self.client.get(url).send().await?;
221 if !response.status().is_success() {
222 bail!(
223 "get token http error: {} - {}",
224 response.status(),
225 response.text().await?
226 );
227 }
228
229 let token: TokenResponse = response.json().await?;
230 if token.errcode != 0 {
231 bail!(
232 "get token content error: {} - {}",
233 token.errcode,
234 token.errmsg
235 );
236 }
237
238 debug!("get token: {:?}", token);
239 let access_token = token.access_token;
240 let mut config = self.config.lock().unwrap();
241 config.access_token = access_token.clone();
242 config.token_expires_in = Local::now() + Duration::seconds(token.expires_in as i64);
243 Ok(access_token)
244 }
245
246 async fn get_endpoint(&self) -> Result<String> {
247 let token = self.get_token().await?;
248
249 let response = self
250 .client
251 .post(GATEWAY_URL)
252 .json(&*self.config)
253 .header(ACCEPT, "application/json")
254 .header("access-token", token)
255 .send()
256 .await?;
257 if !response.status().is_success() {
258 bail!(
259 "get endpoint http error: {} - {}",
260 response.status(),
261 response.text().await?
262 );
263 }
264
265 let endpoint: EndpointResponse = response.json().await?;
266 debug!("get endpoint: {:?}", endpoint);
267 let EndpointResponse { endpoint, ticket } = endpoint;
268
269 Ok(format!("{endpoint}?ticket={ticket}"))
270 }
271
272 async fn serve(self: &Arc<Self>, url: String) -> Result<()> {
273 let tls_connect = Connector::NativeTls({
274 TlsConnector::builder()
275 .danger_accept_invalid_certs(true)
276 .danger_accept_invalid_hostnames(true)
277 .build()?
278 });
279
280 let (stream, _) =
281 match connect_async_tls_with_config(&url, None, false, Some(tls_connect)).await {
282 Ok(x) => {
283 self.alive.store(true, Ordering::SeqCst);
284 x
285 }
286 Err(e) => {
287 if let Error::Http(ref h) = e {
288 bail!(
289 "connect websocket http error: {} - {}",
290 h.status(),
291 String::from_utf8_lossy(h.body().as_deref().unwrap_or_default())
292 );
293 } else {
294 bail!("connect websocket error: {:?}", e);
295 }
296 }
297 };
298
299 let (sink, stream) = stream.split();
300 *self.sink.lock().await = Some(sink);
301 let heartbeat_interval = self.config.lock().unwrap().heartbeat_interval;
302 if heartbeat_interval > 0 {
303 tokio::spawn({
304 let s = self.clone();
305 let aborting = self.aborting.clone();
306 async move {
307 loop {
308 if !s.alive.load(Ordering::SeqCst) {
309 aborting.notify_one();
310 break;
311 }
312
313 trace!("websocket ping");
314 s.alive.store(false, Ordering::SeqCst);
315 let _ = s.ping().await;
316 sleep(Duration::milliseconds(heartbeat_interval).to_std().unwrap()).await;
318 }
319 }
320 });
321 }
322
323 tokio::select! {
324 _ = self.aborting.notified() => { warn!("server aborting"); }
325 _ = self.process(stream) => { warn!("server error or closed"); }
326 }
327
328 self.alive.store(false, Ordering::SeqCst);
329 Ok(())
330 }
331
332 async fn process(
333 &self,
334 mut stream: SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
335 ) -> Result<()> {
336 while let Some(message) = stream.next().await {
337 let message = match message {
338 Ok(m) => m,
339 Err(e) => {
340 error!("recv websocket message error: {:?}", e);
341 break;
342 }
343 };
344
345 match message {
346 Message::Text(t) => {
347 debug!("recv websocket text: {t}");
348 match serde_json::from_str::<ClientDownStream>(&t) {
349 Ok(p) => self.on_down_stream(p).await?,
350 Err(e) => {
351 warn!("parse websocket text error: {:?}", e)
352 }
353 }
354 }
355 Message::Pong(_) => {
356 trace!("websocket pong");
357 self.alive.store(true, Ordering::SeqCst)
358 }
359 Message::Close(c) => {
360 warn!(
361 "Websocket closed: {}",
362 if let Some(c) = c {
363 c.to_string()
364 } else {
365 "Unknown reason".to_owned()
366 }
367 );
368
369 break;
370 }
371
372 _ => {
373 warn!("Unhandled websocket message: {:?}", message)
374 }
375 }
376 }
377
378 Ok(())
379 }
380
381 pub async fn connect(self: Arc<Self>) -> Result<()> {
383 loop {
384 let c = self.clone();
385 let reconnect_interval = c.config.lock().unwrap().reconnect_interval;
386 let url = c.get_endpoint().await?;
387 c.serve(url).await?;
388
389 if reconnect_interval > 0 && !self.user_exit.load(Ordering::SeqCst) {
390 info!("Reconnecting in {} seconds...", reconnect_interval / 1000);
391
392 sleep(Duration::milliseconds(reconnect_interval).to_std().unwrap()).await;
394 debug!("initial reconnecting...");
395 } else {
396 break;
397 }
398 }
399
400 Ok(())
401 }
402
403 pub fn exit(&self) {
404 self.user_exit.store(true, Ordering::SeqCst);
405 self.aborting.notify_waiters();
406 }
407}
408
409#[derive(Deserialize, Debug)]
410struct TokenResponse {
411 errcode: u32,
412 access_token: String,
413 errmsg: String,
414 expires_in: u32,
415}
416
417#[derive(Debug, Deserialize)]
418struct EndpointResponse {
419 endpoint: String,
420 ticket: String,
421}
422
423const GATEWAY_URL: &str = "https://api.dingtalk.com/v1.0/gateway/connections/open";
424const GET_TOKEN_URL: &str = "https://oapi.dingtalk.com/gettoken";
425pub const TOPIC_ROBOT: &str = "/v1.0/im/bot/messages/get";
427pub const TOPIC_CARD: &str = "/v1.0/card/instances/callback";
429
430#[derive(Debug, Serialize)]
432#[serde(rename_all = "camelCase")]
433pub struct ClientConfig {
434 pub client_id: String,
436 pub client_secret: String,
438 pub ua: String,
440 pub subscriptions: Vec<Subscription>,
442 #[serde(skip_serializing)]
443 access_token: String,
444 #[serde(skip_serializing)]
445 token_expires_in: DateTime<Local>,
446 #[serde(skip_serializing)]
447 reconnect_interval: i64,
448 #[serde(skip_serializing)]
449 heartbeat_interval: i64,
450}
451
452impl Default for ClientConfig {
453 fn default() -> Self {
454 Self {
455 client_id: Default::default(),
456 client_secret: Default::default(),
457 ua: Default::default(),
458 subscriptions: vec![
459 Subscription {
460 r#type: "EVENT".to_owned(),
461 topic: "*".to_owned(),
462 },
463 Subscription {
464 r#type: "SYSTEM".to_owned(),
465 topic: "*".to_owned(),
466 },
467 ],
468 access_token: String::new(),
469 token_expires_in: Local::now(),
470 reconnect_interval: 1000,
471 heartbeat_interval: 8000,
472 }
473 }
474}
475
476#[derive(Debug, Serialize)]
478pub struct Subscription {
479 pub r#type: String,
484 pub topic: String,
488}