Skip to main content

dingtalk_stream/
client.rs

1//! DingTalk Stream 客户端,对齐 Python stream.py
2
3use crate::credential::Credential;
4use crate::error::{Error, Result};
5use crate::handlers::callback::CallbackHandler;
6use crate::handlers::chatbot::{AsyncChatbotHandler, ChatbotReplier, async_raw_process};
7use crate::handlers::event::EventHandler;
8use crate::handlers::system::{DefaultSystemHandler, SystemHandler};
9use crate::messages::frames::{AckMessage, StreamMessage, SystemMessage};
10use crate::transport::http::HttpClient;
11use crate::transport::token::TokenManager;
12use futures_util::{SinkExt, StreamExt};
13use std::collections::HashMap;
14use std::sync::Arc;
15use tokio_tungstenite::connect_async;
16use tokio_tungstenite::tungstenite::Message;
17use url::form_urlencoded;
18
19/// 回调处理器条目
20enum CallbackEntry {
21    Sync(Arc<dyn CallbackHandler>),
22    Async(Arc<dyn AsyncChatbotHandler>),
23}
24
25/// DingTalk Stream 客户端
26pub struct DingTalkStreamClient {
27    credential: Credential,
28    event_handler: Option<Arc<dyn EventHandler>>,
29    callback_handlers: HashMap<String, CallbackEntry>,
30    system_handler: Arc<dyn SystemHandler>,
31    http_client: HttpClient,
32    token_manager: Arc<TokenManager>,
33    is_event_required: bool,
34    pre_started: bool,
35}
36
37impl DingTalkStreamClient {
38    /// 创建 Builder
39    pub fn builder(credential: Credential) -> ClientBuilder {
40        ClientBuilder::new(credential)
41    }
42
43    /// 获取 access_token
44    pub async fn get_access_token(&self) -> Result<String> {
45        self.token_manager.get_access_token().await
46    }
47
48    /// 重置 access_token 缓存
49    pub async fn reset_access_token(&self) {
50        self.token_manager.reset().await;
51    }
52
53    /// 创建 `ChatbotReplier`
54    pub fn chatbot_replier(&self) -> ChatbotReplier {
55        ChatbotReplier::new(
56            self.http_client.clone(),
57            Arc::clone(&self.token_manager),
58            self.credential.client_id.clone(),
59        )
60    }
61
62    /// 上传文件到钉钉
63    pub async fn upload_to_dingtalk(
64        &self,
65        image_content: &[u8],
66        filetype: &str,
67        filename: &str,
68        mimetype: &str,
69    ) -> Result<String> {
70        let access_token = self.token_manager.get_access_token().await?;
71        let result = self
72            .http_client
73            .upload_file(&access_token, image_content, filetype, filename, mimetype)
74            .await;
75
76        if let Err(Error::Auth(_)) = &result {
77            self.token_manager.reset().await;
78        }
79
80        result
81    }
82
83    /// 启动客户端(异步,永久重连循环)
84    pub async fn start(&mut self) -> Result<()> {
85        self.pre_start();
86
87        loop {
88            match self.run_once().await {
89                Ok(()) => {
90                    tracing::info!("connection closed, reconnecting in 3s...");
91                }
92                Err(e) => {
93                    tracing::error!(error = %e, "connection error, reconnecting in 10s...");
94                    tokio::time::sleep(std::time::Duration::from_secs(10)).await;
95                    continue;
96                }
97            }
98            tokio::time::sleep(std::time::Duration::from_secs(3)).await;
99        }
100    }
101
102    /// 同步启动(阻塞当前线程)
103    pub fn start_forever(&mut self) -> Result<()> {
104        let rt = tokio::runtime::Runtime::new()
105            .map_err(|e| Error::Connection(format!("failed to create runtime: {e}")))?;
106        rt.block_on(self.start())
107    }
108
109    fn pre_start(&mut self) {
110        if self.pre_started {
111            return;
112        }
113        self.pre_started = true;
114
115        if let Some(ref handler) = self.event_handler {
116            handler.pre_start();
117        }
118        self.system_handler.pre_start();
119        for entry in self.callback_handlers.values() {
120            match entry {
121                CallbackEntry::Sync(h) => h.pre_start(),
122                CallbackEntry::Async(h) => h.pre_start(),
123            }
124        }
125    }
126
127    /// 单次连接运行
128    async fn run_once(&self) -> Result<()> {
129        let connection = self.open_connection().await?;
130
131        let endpoint = connection
132            .get("endpoint")
133            .and_then(|v| v.as_str())
134            .ok_or_else(|| Error::Connection("endpoint not found".to_owned()))?;
135        let ticket = connection
136            .get("ticket")
137            .and_then(|v| v.as_str())
138            .ok_or_else(|| Error::Connection("ticket not found".to_owned()))?;
139
140        let encoded_ticket: String = form_urlencoded::Serializer::new(String::new())
141            .append_pair("ticket", ticket)
142            .finish();
143        let uri = format!("{}?{}", endpoint, encoded_ticket);
144
145        tracing::info!(endpoint = %endpoint, "connecting to WebSocket");
146
147        let (ws_stream, _) =
148            tokio::time::timeout(std::time::Duration::from_secs(30), connect_async(&uri))
149                .await
150                .map_err(|_| Error::Connection("WebSocket connect timeout".to_string()))?
151                .map_err(Error::WebSocket)?;
152        let (write, read) = ws_stream.split();
153
154        let write = Arc::new(tokio::sync::Mutex::new(write));
155        let write_keepalive = Arc::clone(&write);
156
157        // Keepalive task
158        let keepalive_handle = tokio::spawn(async move {
159            loop {
160                tokio::time::sleep(std::time::Duration::from_secs(60)).await;
161                let mut w = write_keepalive.lock().await;
162                if w.send(Message::Ping(Vec::new().into())).await.is_err() {
163                    break;
164                }
165            }
166        });
167
168        // 消息处理
169        let mut read = read;
170        while let Some(msg_result) = read.next().await {
171            match msg_result {
172                Ok(Message::Text(text)) => {
173                    let route_result = self.route_message(&text).await;
174                    match route_result {
175                        Ok((ack_opt, should_disconnect)) => {
176                            if let Some(ack) = ack_opt {
177                                let ack_json = serde_json::to_string(&ack).unwrap_or_default();
178                                let mut w = write.lock().await;
179                                if let Err(e) = w.send(Message::Text(ack_json.into())).await {
180                                    tracing::error!(error = %e, "failed to send ack");
181                                    break;
182                                }
183                            }
184                            if should_disconnect {
185                                tracing::info!("received disconnect, closing connection");
186                                let mut w = write.lock().await;
187                                let _ = w.close().await;
188                                break;
189                            }
190                        }
191                        Err(e) => {
192                            tracing::error!(error = %e, "route message failed");
193                        }
194                    }
195                }
196                Ok(Message::Pong(_)) => {}
197                Ok(Message::Close(_)) => {
198                    tracing::info!("WebSocket closed by server");
199                    break;
200                }
201                Err(e) => {
202                    tracing::error!(error = %e, "WebSocket read error");
203                    break;
204                }
205                _ => {}
206            }
207        }
208
209        keepalive_handle.abort();
210        Ok(())
211    }
212
213    /// 路由消息
214    async fn route_message(&self, raw: &str) -> Result<(Option<AckMessage>, bool)> {
215        let msg: StreamMessage = serde_json::from_str(raw)?;
216        let mut should_disconnect = false;
217
218        let ack = match msg {
219            StreamMessage::System(body) => {
220                let ack = self.system_handler.raw_process(&body).await;
221                if body.headers.topic.as_deref() == Some(SystemMessage::TOPIC_DISCONNECT) {
222                    should_disconnect = true;
223                    tracing::info!(
224                        topic = ?body.headers.topic,
225                        "received disconnect"
226                    );
227                } else {
228                    tracing::warn!(
229                        topic = ?body.headers.topic,
230                        "unknown system message topic"
231                    );
232                }
233                Some(ack)
234            }
235            StreamMessage::Event(body) => {
236                if let Some(ref handler) = self.event_handler {
237                    Some(handler.raw_process(&body).await)
238                } else {
239                    tracing::warn!("no event handler registered");
240                    None
241                }
242            }
243            StreamMessage::Callback(body) => {
244                let topic = body.headers.topic.as_deref().unwrap_or("");
245                if let Some(entry) = self.callback_handlers.get(topic) {
246                    match entry {
247                        CallbackEntry::Sync(handler) => Some(handler.raw_process(&body).await),
248                        CallbackEntry::Async(handler) => {
249                            Some(async_raw_process(Arc::clone(handler), body).await)
250                        }
251                    }
252                } else {
253                    tracing::warn!(topic = %topic, "unknown callback topic");
254                    None
255                }
256            }
257        };
258
259        Ok((ack, should_disconnect))
260    }
261
262    /// 打开连接
263    async fn open_connection(&self) -> Result<serde_json::Value> {
264        let url = format!(
265            "{}/v1.0/gateway/connections/open",
266            self.http_client.openapi_endpoint()
267        );
268
269        tracing::info!(url = %url, "opening connection");
270
271        let mut topics: Vec<serde_json::Value> = Vec::new();
272        if self.is_event_required {
273            topics.push(serde_json::json!({"type": "EVENT", "topic": "*"}));
274        }
275        for topic in self.callback_handlers.keys() {
276            topics.push(serde_json::json!({"type": "CALLBACK", "topic": topic}));
277        }
278
279        let body = serde_json::json!({
280            "clientId": self.credential.client_id,
281            "clientSecret": self.credential.client_secret,
282            "subscriptions": topics,
283            "ua": format!("dingtalk-sdk-rust/v{}-union", env!("CARGO_PKG_VERSION")),
284            "localIp": get_host_ip(),
285        });
286
287        self.http_client.post_raw(&url, &body).await
288    }
289}
290
291/// 客户端构建器
292pub struct ClientBuilder {
293    credential: Credential,
294    event_handler: Option<Arc<dyn EventHandler>>,
295    callback_handlers: HashMap<String, CallbackEntry>,
296    system_handler: Option<Arc<dyn SystemHandler>>,
297    connect_timeout_secs: Option<u64>,
298    request_timeout_secs: Option<u64>,
299}
300
301impl ClientBuilder {
302    /// 创建新的构建器
303    pub fn new(credential: Credential) -> Self {
304        Self {
305            credential,
306            event_handler: None,
307            callback_handlers: HashMap::new(),
308            system_handler: None,
309            connect_timeout_secs: None,
310            request_timeout_secs: None,
311        }
312    }
313
314    /// 注册事件处理器
315    pub fn register_all_event_handler(mut self, handler: impl EventHandler + 'static) -> Self {
316        self.event_handler = Some(Arc::new(handler));
317        self
318    }
319
320    /// 注册回调处理器
321    pub fn register_callback_handler(
322        mut self,
323        topic: &str,
324        handler: impl CallbackHandler + 'static,
325    ) -> Self {
326        self.callback_handlers
327            .insert(topic.to_owned(), CallbackEntry::Sync(Arc::new(handler)));
328        self
329    }
330
331    /// 注册异步聊天机器人处理器
332    pub fn register_async_chatbot_handler(
333        mut self,
334        topic: &str,
335        handler: impl AsyncChatbotHandler + 'static,
336    ) -> Self {
337        self.callback_handlers
338            .insert(topic.to_owned(), CallbackEntry::Async(Arc::new(handler)));
339        self
340    }
341
342    /// 注册系统消息处理器
343    pub fn register_system_handler(mut self, handler: impl SystemHandler + 'static) -> Self {
344        self.system_handler = Some(Arc::new(handler));
345        self
346    }
347
348    /// 设置 HTTP 连接超时(秒),默认 10s
349    pub fn connect_timeout_secs(mut self, secs: u64) -> Self {
350        self.connect_timeout_secs = Some(secs);
351        self
352    }
353
354    /// 设置 HTTP 请求超时(秒),默认 30s
355    pub fn request_timeout_secs(mut self, secs: u64) -> Self {
356        self.request_timeout_secs = Some(secs);
357        self
358    }
359
360    /// 构建客户端
361    pub fn build(self) -> DingTalkStreamClient {
362        let http_client = match (self.connect_timeout_secs, self.request_timeout_secs) {
363            (None, None) => HttpClient::new(),
364            (ct, rt) => HttpClient::with_timeout(ct.unwrap_or(10), rt.unwrap_or(30)),
365        };
366        let token_manager = Arc::new(TokenManager::new(
367            self.credential.clone(),
368            http_client.clone(),
369        ));
370
371        let is_event_required = self.event_handler.is_some();
372
373        DingTalkStreamClient {
374            credential: self.credential,
375            event_handler: self.event_handler,
376            callback_handlers: self.callback_handlers,
377            system_handler: self
378                .system_handler
379                .unwrap_or_else(|| Arc::new(DefaultSystemHandler)),
380            http_client,
381            token_manager,
382            is_event_required,
383            pre_started: false,
384        }
385    }
386}
387
388/// 获取本机 IP 地址
389fn get_host_ip() -> String {
390    use std::net::UdpSocket;
391    UdpSocket::bind("0.0.0.0:0")
392        .and_then(|socket| {
393            socket.connect("8.8.8.8:80")?;
394            socket.local_addr()
395        })
396        .map(|addr| addr.ip().to_string())
397        .unwrap_or_default()
398}