Skip to main content

dingtalk_stream/
client.rs

1//! DingTalk Stream 客户端,对齐 Python stream.py
2
3use crate::credential::Credential;
4use crate::error::{Error, Result};
5use crate::handlers::callback::CallbackHandler;
6use crate::handlers::chatbot::{AsyncChatbotHandler, ChatbotReplier, async_raw_process};
7use crate::handlers::event::EventHandler;
8use crate::handlers::system::{DefaultSystemHandler, SystemHandler};
9use crate::messages::frames::{AckMessage, StreamMessage, SystemMessage};
10use crate::transport::http::HttpClient;
11use crate::transport::token::TokenManager;
12use futures_util::{SinkExt, StreamExt};
13use std::collections::HashMap;
14use std::sync::Arc;
15use tokio_tungstenite::connect_async;
16use tokio_tungstenite::tungstenite::Message;
17use url::form_urlencoded;
18
19/// 回调处理器条目
20enum CallbackEntry {
21    Sync(Arc<dyn CallbackHandler>),
22    Async(Arc<dyn AsyncChatbotHandler>),
23}
24
25/// DingTalk Stream 客户端
26pub struct DingTalkStreamClient {
27    credential: Credential,
28    event_handler: Option<Arc<dyn EventHandler>>,
29    callback_handlers: HashMap<String, CallbackEntry>,
30    system_handler: Arc<dyn SystemHandler>,
31    http_client: HttpClient,
32    token_manager: Arc<TokenManager>,
33    is_event_required: bool,
34    pre_started: bool,
35}
36
37impl DingTalkStreamClient {
38    /// 创建 Builder
39    pub fn builder(credential: Credential) -> ClientBuilder {
40        ClientBuilder::new(credential)
41    }
42
43    /// 获取 access_token
44    pub async fn get_access_token(&self) -> Result<String> {
45        self.token_manager.get_access_token().await
46    }
47
48    /// 重置 access_token 缓存
49    pub async fn reset_access_token(&self) {
50        self.token_manager.reset().await;
51    }
52
53    /// 创建 `ChatbotReplier`
54    pub fn chatbot_replier(&self) -> ChatbotReplier {
55        ChatbotReplier::new(
56            self.http_client.clone(),
57            Arc::clone(&self.token_manager),
58            self.credential.client_id.clone(),
59        )
60    }
61
62    /// 上传文件到钉钉
63    pub async fn upload_to_dingtalk(
64        &self,
65        image_content: &[u8],
66        filetype: &str,
67        filename: &str,
68        mimetype: &str,
69    ) -> Result<String> {
70        let access_token = self.token_manager.get_access_token().await?;
71        let result = self
72            .http_client
73            .upload_file(&access_token, image_content, filetype, filename, mimetype)
74            .await;
75
76        if let Err(Error::Auth(_)) = &result {
77            self.token_manager.reset().await;
78        }
79
80        result
81    }
82
83    /// 启动客户端(异步,永久重连循环)
84    pub async fn start(&mut self) -> Result<()> {
85        self.pre_start();
86
87        loop {
88            match self.run_once().await {
89                Ok(()) => {
90                    tracing::info!("connection closed, reconnecting in 3s...");
91                }
92                Err(e) => {
93                    tracing::error!(error = %e, "connection error, reconnecting in 10s...");
94                    tokio::time::sleep(std::time::Duration::from_secs(10)).await;
95                    continue;
96                }
97            }
98            tokio::time::sleep(std::time::Duration::from_secs(3)).await;
99        }
100    }
101
102    /// 同步启动(阻塞当前线程)
103    pub fn start_forever(&mut self) -> Result<()> {
104        let rt = tokio::runtime::Runtime::new()
105            .map_err(|e| Error::Connection(format!("failed to create runtime: {e}")))?;
106        rt.block_on(self.start())
107    }
108
109    fn pre_start(&mut self) {
110        if self.pre_started {
111            return;
112        }
113        self.pre_started = true;
114
115        if let Some(ref handler) = self.event_handler {
116            handler.pre_start();
117        }
118        self.system_handler.pre_start();
119        for entry in self.callback_handlers.values() {
120            match entry {
121                CallbackEntry::Sync(h) => h.pre_start(),
122                CallbackEntry::Async(h) => h.pre_start(),
123            }
124        }
125    }
126
127    /// 单次连接运行
128    async fn run_once(&self) -> Result<()> {
129        let connection = self.open_connection().await?;
130
131        let endpoint = connection
132            .get("endpoint")
133            .and_then(|v| v.as_str())
134            .ok_or_else(|| Error::Connection("endpoint not found".to_owned()))?;
135        let ticket = connection
136            .get("ticket")
137            .and_then(|v| v.as_str())
138            .ok_or_else(|| Error::Connection("ticket not found".to_owned()))?;
139
140        let encoded_ticket: String = form_urlencoded::Serializer::new(String::new())
141            .append_pair("ticket", ticket)
142            .finish();
143        let uri = format!("{}?{}", endpoint, encoded_ticket);
144
145        tracing::info!(endpoint = %endpoint, "connecting to WebSocket");
146
147        let (ws_stream, _) = connect_async(&uri).await?;
148        let (write, read) = ws_stream.split();
149
150        let write = Arc::new(tokio::sync::Mutex::new(write));
151        let write_keepalive = Arc::clone(&write);
152
153        // Keepalive task
154        let keepalive_handle = tokio::spawn(async move {
155            loop {
156                tokio::time::sleep(std::time::Duration::from_secs(60)).await;
157                let mut w = write_keepalive.lock().await;
158                if w.send(Message::Ping(Vec::new().into())).await.is_err() {
159                    break;
160                }
161            }
162        });
163
164        // 消息处理
165        let mut read = read;
166        while let Some(msg_result) = read.next().await {
167            match msg_result {
168                Ok(Message::Text(text)) => {
169                    let route_result = self.route_message(&text).await;
170                    match route_result {
171                        Ok((ack_opt, should_disconnect)) => {
172                            if let Some(ack) = ack_opt {
173                                let ack_json = serde_json::to_string(&ack).unwrap_or_default();
174                                let mut w = write.lock().await;
175                                if let Err(e) = w.send(Message::Text(ack_json.into())).await {
176                                    tracing::error!(error = %e, "failed to send ack");
177                                    break;
178                                }
179                            }
180                            if should_disconnect {
181                                tracing::info!("received disconnect, closing connection");
182                                let mut w = write.lock().await;
183                                let _ = w.close().await;
184                                break;
185                            }
186                        }
187                        Err(e) => {
188                            tracing::error!(error = %e, "route message failed");
189                        }
190                    }
191                }
192                Ok(Message::Pong(_)) => {}
193                Ok(Message::Close(_)) => {
194                    tracing::info!("WebSocket closed by server");
195                    break;
196                }
197                Err(e) => {
198                    tracing::error!(error = %e, "WebSocket read error");
199                    break;
200                }
201                _ => {}
202            }
203        }
204
205        keepalive_handle.abort();
206        Ok(())
207    }
208
209    /// 路由消息
210    async fn route_message(&self, raw: &str) -> Result<(Option<AckMessage>, bool)> {
211        let msg: StreamMessage = serde_json::from_str(raw)?;
212        let mut should_disconnect = false;
213
214        let ack = match msg {
215            StreamMessage::System(body) => {
216                let ack = self.system_handler.raw_process(&body).await;
217                if body.headers.topic.as_deref() == Some(SystemMessage::TOPIC_DISCONNECT) {
218                    should_disconnect = true;
219                    tracing::info!(
220                        topic = ?body.headers.topic,
221                        "received disconnect"
222                    );
223                } else {
224                    tracing::warn!(
225                        topic = ?body.headers.topic,
226                        "unknown system message topic"
227                    );
228                }
229                Some(ack)
230            }
231            StreamMessage::Event(body) => {
232                if let Some(ref handler) = self.event_handler {
233                    Some(handler.raw_process(&body).await)
234                } else {
235                    tracing::warn!("no event handler registered");
236                    None
237                }
238            }
239            StreamMessage::Callback(body) => {
240                let topic = body.headers.topic.as_deref().unwrap_or("");
241                if let Some(entry) = self.callback_handlers.get(topic) {
242                    match entry {
243                        CallbackEntry::Sync(handler) => Some(handler.raw_process(&body).await),
244                        CallbackEntry::Async(handler) => {
245                            Some(async_raw_process(Arc::clone(handler), body).await)
246                        }
247                    }
248                } else {
249                    tracing::warn!(topic = %topic, "unknown callback topic");
250                    None
251                }
252            }
253        };
254
255        Ok((ack, should_disconnect))
256    }
257
258    /// 打开连接
259    async fn open_connection(&self) -> Result<serde_json::Value> {
260        let url = format!(
261            "{}/v1.0/gateway/connections/open",
262            self.http_client.openapi_endpoint()
263        );
264
265        tracing::info!(url = %url, "opening connection");
266
267        let mut topics: Vec<serde_json::Value> = Vec::new();
268        if self.is_event_required {
269            topics.push(serde_json::json!({"type": "EVENT", "topic": "*"}));
270        }
271        for topic in self.callback_handlers.keys() {
272            topics.push(serde_json::json!({"type": "CALLBACK", "topic": topic}));
273        }
274
275        let body = serde_json::json!({
276            "clientId": self.credential.client_id,
277            "clientSecret": self.credential.client_secret,
278            "subscriptions": topics,
279            "ua": format!("dingtalk-sdk-rust/v{}-union", env!("CARGO_PKG_VERSION")),
280            "localIp": get_host_ip(),
281        });
282
283        self.http_client.post_raw(&url, &body).await
284    }
285}
286
287/// 客户端构建器
288pub struct ClientBuilder {
289    credential: Credential,
290    event_handler: Option<Arc<dyn EventHandler>>,
291    callback_handlers: HashMap<String, CallbackEntry>,
292    system_handler: Option<Arc<dyn SystemHandler>>,
293}
294
295impl ClientBuilder {
296    /// 创建新的构建器
297    pub fn new(credential: Credential) -> Self {
298        Self {
299            credential,
300            event_handler: None,
301            callback_handlers: HashMap::new(),
302            system_handler: None,
303        }
304    }
305
306    /// 注册事件处理器
307    pub fn register_all_event_handler(mut self, handler: impl EventHandler + 'static) -> Self {
308        self.event_handler = Some(Arc::new(handler));
309        self
310    }
311
312    /// 注册回调处理器
313    pub fn register_callback_handler(
314        mut self,
315        topic: &str,
316        handler: impl CallbackHandler + 'static,
317    ) -> Self {
318        self.callback_handlers
319            .insert(topic.to_owned(), CallbackEntry::Sync(Arc::new(handler)));
320        self
321    }
322
323    /// 注册异步聊天机器人处理器
324    pub fn register_async_chatbot_handler(
325        mut self,
326        topic: &str,
327        handler: impl AsyncChatbotHandler + 'static,
328    ) -> Self {
329        self.callback_handlers
330            .insert(topic.to_owned(), CallbackEntry::Async(Arc::new(handler)));
331        self
332    }
333
334    /// 注册系统消息处理器
335    pub fn register_system_handler(mut self, handler: impl SystemHandler + 'static) -> Self {
336        self.system_handler = Some(Arc::new(handler));
337        self
338    }
339
340    /// 构建客户端
341    pub fn build(self) -> DingTalkStreamClient {
342        let http_client = HttpClient::new();
343        let token_manager = Arc::new(TokenManager::new(
344            self.credential.clone(),
345            http_client.clone(),
346        ));
347
348        let is_event_required = self.event_handler.is_some();
349
350        DingTalkStreamClient {
351            credential: self.credential,
352            event_handler: self.event_handler,
353            callback_handlers: self.callback_handlers,
354            system_handler: self
355                .system_handler
356                .unwrap_or_else(|| Arc::new(DefaultSystemHandler)),
357            http_client,
358            token_manager,
359            is_event_required,
360            pre_started: false,
361        }
362    }
363}
364
365/// 获取本机 IP 地址
366fn get_host_ip() -> String {
367    use std::net::UdpSocket;
368    UdpSocket::bind("0.0.0.0:0")
369        .and_then(|socket| {
370            socket.connect("8.8.8.8:80")?;
371            socket.local_addr()
372        })
373        .map(|addr| addr.ip().to_string())
374        .unwrap_or_default()
375}