1use crate::credential::Credential;
4use crate::error::{Error, Result};
5use crate::handlers::callback::CallbackHandler;
6use crate::handlers::chatbot::{AsyncChatbotHandler, ChatbotReplier, async_raw_process};
7use crate::handlers::event::EventHandler;
8use crate::handlers::system::{DefaultSystemHandler, SystemHandler};
9use crate::messages::frames::{AckMessage, StreamMessage, SystemMessage};
10use crate::transport::http::HttpClient;
11use crate::transport::token::TokenManager;
12use futures_util::{SinkExt, StreamExt};
13use std::collections::HashMap;
14use std::sync::Arc;
15use tokio_tungstenite::connect_async;
16use tokio_tungstenite::tungstenite::Message;
17use url::form_urlencoded;
18
19enum CallbackEntry {
21 Sync(Arc<dyn CallbackHandler>),
22 Async(Arc<dyn AsyncChatbotHandler>),
23}
24
25pub struct DingTalkStreamClient {
27 credential: Credential,
28 event_handler: Option<Arc<dyn EventHandler>>,
29 callback_handlers: HashMap<String, CallbackEntry>,
30 system_handler: Arc<dyn SystemHandler>,
31 http_client: HttpClient,
32 token_manager: Arc<TokenManager>,
33 is_event_required: bool,
34 pre_started: bool,
35}
36
37impl DingTalkStreamClient {
38 pub fn builder(credential: Credential) -> ClientBuilder {
40 ClientBuilder::new(credential)
41 }
42
43 pub async fn get_access_token(&self) -> Result<String> {
45 self.token_manager.get_access_token().await
46 }
47
48 pub async fn reset_access_token(&self) {
50 self.token_manager.reset().await;
51 }
52
53 pub fn chatbot_replier(&self) -> ChatbotReplier {
55 ChatbotReplier::new(
56 self.http_client.clone(),
57 Arc::clone(&self.token_manager),
58 self.credential.client_id.clone(),
59 )
60 }
61
62 pub async fn upload_to_dingtalk(
64 &self,
65 image_content: &[u8],
66 filetype: &str,
67 filename: &str,
68 mimetype: &str,
69 ) -> Result<String> {
70 let access_token = self.token_manager.get_access_token().await?;
71 let result = self
72 .http_client
73 .upload_file(&access_token, image_content, filetype, filename, mimetype)
74 .await;
75
76 if let Err(Error::Auth(_)) = &result {
77 self.token_manager.reset().await;
78 }
79
80 result
81 }
82
83 pub async fn start(&mut self) -> Result<()> {
85 self.pre_start();
86
87 loop {
88 match self.run_once().await {
89 Ok(()) => {
90 tracing::info!("connection closed, reconnecting in 3s...");
91 }
92 Err(e) => {
93 tracing::error!(error = %e, "connection error, reconnecting in 10s...");
94 tokio::time::sleep(std::time::Duration::from_secs(10)).await;
95 continue;
96 }
97 }
98 tokio::time::sleep(std::time::Duration::from_secs(3)).await;
99 }
100 }
101
102 pub fn start_forever(&mut self) -> Result<()> {
104 let rt = tokio::runtime::Runtime::new()
105 .map_err(|e| Error::Connection(format!("failed to create runtime: {e}")))?;
106 rt.block_on(self.start())
107 }
108
109 fn pre_start(&mut self) {
110 if self.pre_started {
111 return;
112 }
113 self.pre_started = true;
114
115 if let Some(ref handler) = self.event_handler {
116 handler.pre_start();
117 }
118 self.system_handler.pre_start();
119 for entry in self.callback_handlers.values() {
120 match entry {
121 CallbackEntry::Sync(h) => h.pre_start(),
122 CallbackEntry::Async(h) => h.pre_start(),
123 }
124 }
125 }
126
127 async fn run_once(&self) -> Result<()> {
129 let connection = self.open_connection().await?;
130
131 let endpoint = connection
132 .get("endpoint")
133 .and_then(|v| v.as_str())
134 .ok_or_else(|| Error::Connection("endpoint not found".to_owned()))?;
135 let ticket = connection
136 .get("ticket")
137 .and_then(|v| v.as_str())
138 .ok_or_else(|| Error::Connection("ticket not found".to_owned()))?;
139
140 let encoded_ticket: String = form_urlencoded::Serializer::new(String::new())
141 .append_pair("ticket", ticket)
142 .finish();
143 let uri = format!("{}?{}", endpoint, encoded_ticket);
144
145 tracing::info!(endpoint = %endpoint, "connecting to WebSocket");
146
147 let (ws_stream, _) =
148 tokio::time::timeout(std::time::Duration::from_secs(30), connect_async(&uri))
149 .await
150 .map_err(|_| Error::Connection("WebSocket connect timeout".to_string()))?
151 .map_err(Error::WebSocket)?;
152 let (write, read) = ws_stream.split();
153
154 let write = Arc::new(tokio::sync::Mutex::new(write));
155 let write_keepalive = Arc::clone(&write);
156
157 let keepalive_handle = tokio::spawn(async move {
159 loop {
160 tokio::time::sleep(std::time::Duration::from_secs(60)).await;
161 let mut w = write_keepalive.lock().await;
162 if w.send(Message::Ping(Vec::new().into())).await.is_err() {
163 break;
164 }
165 }
166 });
167
168 let mut read = read;
170 while let Some(msg_result) = read.next().await {
171 match msg_result {
172 Ok(Message::Text(text)) => {
173 let route_result = self.route_message(&text).await;
174 match route_result {
175 Ok((ack_opt, should_disconnect)) => {
176 if let Some(ack) = ack_opt {
177 let ack_json = serde_json::to_string(&ack).unwrap_or_default();
178 let mut w = write.lock().await;
179 if let Err(e) = w.send(Message::Text(ack_json.into())).await {
180 tracing::error!(error = %e, "failed to send ack");
181 break;
182 }
183 }
184 if should_disconnect {
185 tracing::info!("received disconnect, closing connection");
186 let mut w = write.lock().await;
187 let _ = w.close().await;
188 break;
189 }
190 }
191 Err(e) => {
192 tracing::error!(error = %e, "route message failed");
193 }
194 }
195 }
196 Ok(Message::Pong(_)) => {}
197 Ok(Message::Close(_)) => {
198 tracing::info!("WebSocket closed by server");
199 break;
200 }
201 Err(e) => {
202 tracing::error!(error = %e, "WebSocket read error");
203 break;
204 }
205 _ => {}
206 }
207 }
208
209 keepalive_handle.abort();
210 Ok(())
211 }
212
213 async fn route_message(&self, raw: &str) -> Result<(Option<AckMessage>, bool)> {
215 let msg: StreamMessage = serde_json::from_str(raw)?;
216 let mut should_disconnect = false;
217
218 let ack = match msg {
219 StreamMessage::System(body) => {
220 let ack = self.system_handler.raw_process(&body).await;
221 if body.headers.topic.as_deref() == Some(SystemMessage::TOPIC_DISCONNECT) {
222 should_disconnect = true;
223 tracing::info!(
224 topic = ?body.headers.topic,
225 "received disconnect"
226 );
227 } else {
228 tracing::warn!(
229 topic = ?body.headers.topic,
230 "unknown system message topic"
231 );
232 }
233 Some(ack)
234 }
235 StreamMessage::Event(body) => {
236 if let Some(ref handler) = self.event_handler {
237 Some(handler.raw_process(&body).await)
238 } else {
239 tracing::warn!("no event handler registered");
240 None
241 }
242 }
243 StreamMessage::Callback(body) => {
244 let topic = body.headers.topic.as_deref().unwrap_or("");
245 if let Some(entry) = self.callback_handlers.get(topic) {
246 match entry {
247 CallbackEntry::Sync(handler) => Some(handler.raw_process(&body).await),
248 CallbackEntry::Async(handler) => {
249 Some(async_raw_process(Arc::clone(handler), body).await)
250 }
251 }
252 } else {
253 tracing::warn!(topic = %topic, "unknown callback topic");
254 None
255 }
256 }
257 };
258
259 Ok((ack, should_disconnect))
260 }
261
262 async fn open_connection(&self) -> Result<serde_json::Value> {
264 let url = format!(
265 "{}/v1.0/gateway/connections/open",
266 self.http_client.openapi_endpoint()
267 );
268
269 tracing::info!(url = %url, "opening connection");
270
271 let mut topics: Vec<serde_json::Value> = Vec::new();
272 if self.is_event_required {
273 topics.push(serde_json::json!({"type": "EVENT", "topic": "*"}));
274 }
275 for topic in self.callback_handlers.keys() {
276 topics.push(serde_json::json!({"type": "CALLBACK", "topic": topic}));
277 }
278
279 let body = serde_json::json!({
280 "clientId": self.credential.client_id,
281 "clientSecret": self.credential.client_secret,
282 "subscriptions": topics,
283 "ua": format!("dingtalk-sdk-rust/v{}-union", env!("CARGO_PKG_VERSION")),
284 "localIp": get_host_ip(),
285 });
286
287 self.http_client.post_raw(&url, &body).await
288 }
289}
290
291pub struct ClientBuilder {
293 credential: Credential,
294 event_handler: Option<Arc<dyn EventHandler>>,
295 callback_handlers: HashMap<String, CallbackEntry>,
296 system_handler: Option<Arc<dyn SystemHandler>>,
297 connect_timeout_secs: Option<u64>,
298 request_timeout_secs: Option<u64>,
299}
300
301impl ClientBuilder {
302 pub fn new(credential: Credential) -> Self {
304 Self {
305 credential,
306 event_handler: None,
307 callback_handlers: HashMap::new(),
308 system_handler: None,
309 connect_timeout_secs: None,
310 request_timeout_secs: None,
311 }
312 }
313
314 pub fn register_all_event_handler(mut self, handler: impl EventHandler + 'static) -> Self {
316 self.event_handler = Some(Arc::new(handler));
317 self
318 }
319
320 pub fn register_callback_handler(
322 mut self,
323 topic: &str,
324 handler: impl CallbackHandler + 'static,
325 ) -> Self {
326 self.callback_handlers
327 .insert(topic.to_owned(), CallbackEntry::Sync(Arc::new(handler)));
328 self
329 }
330
331 pub fn register_async_chatbot_handler(
333 mut self,
334 topic: &str,
335 handler: impl AsyncChatbotHandler + 'static,
336 ) -> Self {
337 self.callback_handlers
338 .insert(topic.to_owned(), CallbackEntry::Async(Arc::new(handler)));
339 self
340 }
341
342 pub fn register_system_handler(mut self, handler: impl SystemHandler + 'static) -> Self {
344 self.system_handler = Some(Arc::new(handler));
345 self
346 }
347
348 pub fn connect_timeout_secs(mut self, secs: u64) -> Self {
350 self.connect_timeout_secs = Some(secs);
351 self
352 }
353
354 pub fn request_timeout_secs(mut self, secs: u64) -> Self {
356 self.request_timeout_secs = Some(secs);
357 self
358 }
359
360 pub fn build(self) -> DingTalkStreamClient {
362 let http_client = match (self.connect_timeout_secs, self.request_timeout_secs) {
363 (None, None) => HttpClient::new(),
364 (ct, rt) => HttpClient::with_timeout(ct.unwrap_or(10), rt.unwrap_or(30)),
365 };
366 let token_manager = Arc::new(TokenManager::new(
367 self.credential.clone(),
368 http_client.clone(),
369 ));
370
371 let is_event_required = self.event_handler.is_some();
372
373 DingTalkStreamClient {
374 credential: self.credential,
375 event_handler: self.event_handler,
376 callback_handlers: self.callback_handlers,
377 system_handler: self
378 .system_handler
379 .unwrap_or_else(|| Arc::new(DefaultSystemHandler)),
380 http_client,
381 token_manager,
382 is_event_required,
383 pre_started: false,
384 }
385 }
386}
387
388fn get_host_ip() -> String {
390 use std::net::UdpSocket;
391 UdpSocket::bind("0.0.0.0:0")
392 .and_then(|socket| {
393 socket.connect("8.8.8.8:80")?;
394 socket.local_addr()
395 })
396 .map(|addr| addr.ip().to_string())
397 .unwrap_or_default()
398}