Skip to main content

construct/channels/
dingtalk.rs

1use super::traits::{Channel, ChannelMessage, SendMessage};
2use async_trait::async_trait;
3use futures_util::{SinkExt, StreamExt};
4use std::collections::HashMap;
5use std::sync::Arc;
6use tokio::sync::RwLock;
7use tokio_tungstenite::tungstenite::Message;
8use uuid::Uuid;
9
10const DINGTALK_BOT_CALLBACK_TOPIC: &str = "/v1.0/im/bot/messages/get";
11
12/// DingTalk channel — connects via Stream Mode WebSocket for real-time messages.
13/// Replies are sent through per-message session webhook URLs.
14pub struct DingTalkChannel {
15    client_id: String,
16    client_secret: String,
17    allowed_users: Vec<String>,
18    /// Per-chat session webhooks for sending replies (chatID -> webhook URL).
19    /// DingTalk provides a unique webhook URL with each incoming message.
20    session_webhooks: Arc<RwLock<HashMap<String, String>>>,
21    /// Per-channel proxy URL override.
22    proxy_url: Option<String>,
23}
24
25/// Response from DingTalk gateway connection registration.
26#[derive(serde::Deserialize)]
27struct GatewayResponse {
28    endpoint: String,
29    ticket: String,
30}
31
32impl DingTalkChannel {
33    pub fn new(client_id: String, client_secret: String, allowed_users: Vec<String>) -> Self {
34        Self {
35            client_id,
36            client_secret,
37            allowed_users,
38            session_webhooks: Arc::new(RwLock::new(HashMap::new())),
39            proxy_url: None,
40        }
41    }
42
43    /// Set a per-channel proxy URL that overrides the global proxy config.
44    pub fn with_proxy_url(mut self, proxy_url: Option<String>) -> Self {
45        self.proxy_url = proxy_url;
46        self
47    }
48
49    fn http_client(&self) -> reqwest::Client {
50        crate::config::build_channel_proxy_client("channel.dingtalk", self.proxy_url.as_deref())
51    }
52
53    fn is_user_allowed(&self, user_id: &str) -> bool {
54        self.allowed_users.iter().any(|u| u == "*" || u == user_id)
55    }
56
57    fn parse_stream_data(frame: &serde_json::Value) -> Option<serde_json::Value> {
58        match frame.get("data") {
59            Some(serde_json::Value::String(raw)) => serde_json::from_str(raw).ok(),
60            Some(serde_json::Value::Object(_)) => frame.get("data").cloned(),
61            _ => None,
62        }
63    }
64
65    fn resolve_chat_id(data: &serde_json::Value, sender_id: &str) -> String {
66        let is_private_chat = data
67            .get("conversationType")
68            .and_then(|value| {
69                value
70                    .as_str()
71                    .map(|v| v == "1")
72                    .or_else(|| value.as_i64().map(|v| v == 1))
73            })
74            .unwrap_or(true);
75
76        if is_private_chat {
77            sender_id.to_string()
78        } else {
79            data.get("conversationId")
80                .and_then(|c| c.as_str())
81                .unwrap_or(sender_id)
82                .to_string()
83        }
84    }
85
86    /// Register a connection with DingTalk's gateway to get a WebSocket endpoint.
87    async fn register_connection(&self) -> anyhow::Result<GatewayResponse> {
88        let body = serde_json::json!({
89            "clientId": self.client_id,
90            "clientSecret": self.client_secret,
91            "subscriptions": [
92                {
93                    "type": "CALLBACK",
94                    "topic": DINGTALK_BOT_CALLBACK_TOPIC,
95                }
96            ],
97        });
98
99        let resp = self
100            .http_client()
101            .post("https://api.dingtalk.com/v1.0/gateway/connections/open")
102            .json(&body)
103            .send()
104            .await?;
105
106        if !resp.status().is_success() {
107            let status = resp.status();
108            let err = resp.text().await.unwrap_or_default();
109            anyhow::bail!("DingTalk gateway registration failed ({status}): {err}");
110        }
111
112        let gw: GatewayResponse = resp.json().await?;
113        Ok(gw)
114    }
115}
116
117#[async_trait]
118impl Channel for DingTalkChannel {
119    fn name(&self) -> &str {
120        "dingtalk"
121    }
122
123    async fn send(&self, message: &SendMessage) -> anyhow::Result<()> {
124        let webhooks = self.session_webhooks.read().await;
125        let webhook_url = webhooks.get(&message.recipient).ok_or_else(|| {
126            anyhow::anyhow!(
127                "No session webhook found for chat {}. \
128                 The user must send a message first to establish a session.",
129                message.recipient
130            )
131        })?;
132
133        let title = message.subject.as_deref().unwrap_or("Construct");
134        let body = serde_json::json!({
135            "msgtype": "markdown",
136            "markdown": {
137                "title": title,
138                "text": message.content,
139            }
140        });
141
142        let resp = self
143            .http_client()
144            .post(webhook_url)
145            .json(&body)
146            .send()
147            .await?;
148
149        if !resp.status().is_success() {
150            let status = resp.status();
151            let err = resp.text().await.unwrap_or_default();
152            anyhow::bail!("DingTalk webhook reply failed ({status}): {err}");
153        }
154
155        Ok(())
156    }
157
158    async fn listen(&self, tx: tokio::sync::mpsc::Sender<ChannelMessage>) -> anyhow::Result<()> {
159        tracing::info!("DingTalk: registering gateway connection...");
160
161        let gw = self.register_connection().await?;
162        let ws_url = format!("{}?ticket={}", gw.endpoint, gw.ticket);
163
164        tracing::info!("DingTalk: connecting to stream WebSocket...");
165        let (ws_stream, _) = crate::config::ws_connect_with_proxy(
166            &ws_url,
167            "channel.dingtalk",
168            self.proxy_url.as_deref(),
169        )
170        .await?;
171        let (mut write, mut read) = ws_stream.split();
172
173        tracing::info!("DingTalk: connected and listening for messages...");
174
175        while let Some(msg) = read.next().await {
176            let msg = match msg {
177                Ok(Message::Text(t)) => t,
178                Ok(Message::Close(_)) => break,
179                Err(e) => {
180                    tracing::warn!("DingTalk WebSocket error: {e}");
181                    break;
182                }
183                _ => continue,
184            };
185
186            let frame: serde_json::Value = match serde_json::from_str(msg.as_ref()) {
187                Ok(v) => v,
188                Err(_) => continue,
189            };
190
191            let frame_type = frame.get("type").and_then(|t| t.as_str()).unwrap_or("");
192
193            match frame_type {
194                "SYSTEM" => {
195                    // Respond to system pings to keep the connection alive
196                    let message_id = frame
197                        .get("headers")
198                        .and_then(|h| h.get("messageId"))
199                        .and_then(|m| m.as_str())
200                        .unwrap_or("");
201
202                    let pong = serde_json::json!({
203                        "code": 200,
204                        "headers": {
205                            "contentType": "application/json",
206                            "messageId": message_id,
207                        },
208                        "message": "OK",
209                        "data": "",
210                    });
211
212                    if let Err(e) = write.send(Message::Text(pong.to_string().into())).await {
213                        tracing::warn!("DingTalk: failed to send pong: {e}");
214                        break;
215                    }
216                }
217                "EVENT" | "CALLBACK" => {
218                    // Parse the chatbot callback data from the frame.
219                    let data = match Self::parse_stream_data(&frame) {
220                        Some(v) => v,
221                        None => {
222                            tracing::debug!("DingTalk: frame has no parseable data payload");
223                            continue;
224                        }
225                    };
226
227                    // Extract message content
228                    let content = data
229                        .get("text")
230                        .and_then(|t| t.get("content"))
231                        .and_then(|c| c.as_str())
232                        .unwrap_or("")
233                        .trim();
234
235                    if content.is_empty() {
236                        continue;
237                    }
238
239                    let sender_id = data
240                        .get("senderStaffId")
241                        .and_then(|s| s.as_str())
242                        .unwrap_or("unknown");
243
244                    if !self.is_user_allowed(sender_id) {
245                        tracing::warn!(
246                            "DingTalk: ignoring message from unauthorized user: {sender_id}"
247                        );
248                        continue;
249                    }
250
251                    // Private chat uses sender ID, group chat uses conversation ID.
252                    let chat_id = Self::resolve_chat_id(&data, sender_id);
253
254                    // Store session webhook for later replies
255                    if let Some(webhook) = data.get("sessionWebhook").and_then(|w| w.as_str()) {
256                        let webhook = webhook.to_string();
257                        let mut webhooks = self.session_webhooks.write().await;
258                        // Use both keys so reply routing works for both group and private flows.
259                        webhooks.insert(chat_id.clone(), webhook.clone());
260                        webhooks.insert(sender_id.to_string(), webhook);
261                    }
262
263                    // Acknowledge the event
264                    let message_id = frame
265                        .get("headers")
266                        .and_then(|h| h.get("messageId"))
267                        .and_then(|m| m.as_str())
268                        .unwrap_or("");
269
270                    let ack = serde_json::json!({
271                        "code": 200,
272                        "headers": {
273                            "contentType": "application/json",
274                            "messageId": message_id,
275                        },
276                        "message": "OK",
277                        "data": "",
278                    });
279                    let _ = write.send(Message::Text(ack.to_string().into())).await;
280
281                    let channel_msg = ChannelMessage {
282                        id: Uuid::new_v4().to_string(),
283                        sender: sender_id.to_string(),
284                        reply_target: chat_id,
285                        content: content.to_string(),
286                        channel: "dingtalk".to_string(),
287                        timestamp: std::time::SystemTime::now()
288                            .duration_since(std::time::UNIX_EPOCH)
289                            .unwrap_or_default()
290                            .as_secs(),
291                        thread_ts: None,
292                        interruption_scope_id: None,
293                        attachments: vec![],
294                    };
295
296                    if tx.send(channel_msg).await.is_err() {
297                        tracing::warn!("DingTalk: message channel closed");
298                        break;
299                    }
300                }
301                _ => {}
302            }
303        }
304
305        anyhow::bail!("DingTalk WebSocket stream ended")
306    }
307
308    async fn health_check(&self) -> bool {
309        self.register_connection().await.is_ok()
310    }
311}
312
313#[cfg(test)]
314mod tests {
315    use super::*;
316
317    #[test]
318    fn test_name() {
319        let ch = DingTalkChannel::new("id".into(), "secret".into(), vec![]);
320        assert_eq!(ch.name(), "dingtalk");
321    }
322
323    #[test]
324    fn test_user_allowed_wildcard() {
325        let ch = DingTalkChannel::new("id".into(), "secret".into(), vec!["*".into()]);
326        assert!(ch.is_user_allowed("anyone"));
327    }
328
329    #[test]
330    fn test_user_allowed_specific() {
331        let ch = DingTalkChannel::new("id".into(), "secret".into(), vec!["user123".into()]);
332        assert!(ch.is_user_allowed("user123"));
333        assert!(!ch.is_user_allowed("other"));
334    }
335
336    #[test]
337    fn test_user_denied_empty() {
338        let ch = DingTalkChannel::new("id".into(), "secret".into(), vec![]);
339        assert!(!ch.is_user_allowed("anyone"));
340    }
341
342    #[test]
343    fn test_config_serde() {
344        let toml_str = r#"
345client_id = "app_id_123"
346client_secret = "secret_456"
347allowed_users = ["user1", "*"]
348"#;
349        let config: crate::config::schema::DingTalkConfig = toml::from_str(toml_str).unwrap();
350        assert_eq!(config.client_id, "app_id_123");
351        assert_eq!(config.client_secret, "secret_456");
352        assert_eq!(config.allowed_users, vec!["user1", "*"]);
353    }
354
355    #[test]
356    fn test_config_serde_defaults() {
357        let toml_str = r#"
358client_id = "id"
359client_secret = "secret"
360"#;
361        let config: crate::config::schema::DingTalkConfig = toml::from_str(toml_str).unwrap();
362        assert!(config.allowed_users.is_empty());
363    }
364
365    #[test]
366    fn parse_stream_data_supports_string_payload() {
367        let frame = serde_json::json!({
368            "data": "{\"text\":{\"content\":\"hello\"}}"
369        });
370        let parsed = DingTalkChannel::parse_stream_data(&frame).unwrap();
371        assert_eq!(
372            parsed.get("text").and_then(|v| v.get("content")),
373            Some(&serde_json::json!("hello"))
374        );
375    }
376
377    #[test]
378    fn parse_stream_data_supports_object_payload() {
379        let frame = serde_json::json!({
380            "data": {"text": {"content": "hello"}}
381        });
382        let parsed = DingTalkChannel::parse_stream_data(&frame).unwrap();
383        assert_eq!(
384            parsed.get("text").and_then(|v| v.get("content")),
385            Some(&serde_json::json!("hello"))
386        );
387    }
388
389    #[test]
390    fn resolve_chat_id_handles_numeric_group_conversation_type() {
391        let data = serde_json::json!({
392            "conversationType": 2,
393            "conversationId": "cid-group",
394        });
395        let chat_id = DingTalkChannel::resolve_chat_id(&data, "staff-1");
396        assert_eq!(chat_id, "cid-group");
397    }
398}