garudust_platforms/
webhook.rs1use std::pin::Pin;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use axum::{extract::State, routing::post, Json, Router};
6use futures::Stream;
7use garudust_core::{
8 error::PlatformError,
9 net_guard,
10 platform::{MessageHandler, PlatformAdapter},
11 types::{ChannelId, InboundMessage, OutboundMessage},
12};
13use serde::{Deserialize, Serialize};
14
15#[derive(Deserialize)]
16struct WebhookPayload {
17 text: String,
18 callback_url: String,
20 #[serde(default)]
21 user_id: String,
22 #[serde(default)]
23 user_name: String,
24 #[serde(default)]
25 session_key: String,
26}
27
28#[derive(Serialize)]
29struct CallbackPayload {
30 text: String,
31}
32
33async fn handle_webhook(
34 State(handler): State<Arc<dyn MessageHandler>>,
35 Json(payload): Json<WebhookPayload>,
36) -> axum::http::StatusCode {
37 let session_key = if payload.session_key.is_empty() {
38 format!("webhook:{}", payload.callback_url)
39 } else {
40 payload.session_key.clone()
41 };
42
43 let inbound = InboundMessage {
44 channel: ChannelId {
45 platform: "webhook".into(),
46 chat_id: payload.callback_url,
48 thread_id: None,
49 },
50 user_id: payload.user_id,
51 user_name: payload.user_name,
52 text: payload.text,
53 session_key,
54 is_group: false,
55 bot_mentioned: None,
56 };
57
58 match handler.handle(inbound).await {
59 Ok(()) => axum::http::StatusCode::ACCEPTED,
60 Err(_) => axum::http::StatusCode::INTERNAL_SERVER_ERROR,
61 }
62}
63
64pub struct WebhookAdapter {
65 port: u16,
66 webhook_path: String,
67}
68
69impl WebhookAdapter {
70 pub fn new(port: u16, webhook_path: String) -> Self {
71 Self { port, webhook_path }
72 }
73}
74
75#[async_trait]
76impl PlatformAdapter for WebhookAdapter {
77 fn name(&self) -> &'static str {
78 "webhook"
79 }
80
81 async fn start(&self, handler: Arc<dyn MessageHandler>) -> Result<(), PlatformError> {
82 let port = self.port;
83 let path = self.webhook_path.clone();
84 let router = Router::new()
85 .route(&self.webhook_path, post(handle_webhook))
86 .with_state(handler);
87
88 let listener = tokio::net::TcpListener::bind(format!("0.0.0.0:{port}"))
89 .await
90 .map_err(|e| PlatformError::Connection(e.to_string()))?;
91
92 tracing::info!("webhook adapter listening on 0.0.0.0:{port}{path}");
93 tokio::spawn(async move {
94 if let Err(e) = axum::serve(listener, router).await {
95 tracing::error!("webhook server error: {e}");
96 }
97 });
98 Ok(())
99 }
100
101 async fn send_message(
102 &self,
103 channel: &ChannelId,
104 message: OutboundMessage,
105 ) -> Result<(), PlatformError> {
106 net_guard::is_safe_url(&channel.chat_id).map_err(|e| PlatformError::Send(e.to_string()))?;
107
108 let client = reqwest::Client::new();
109 client
110 .post(&channel.chat_id)
111 .json(&CallbackPayload { text: message.text })
112 .send()
113 .await
114 .map_err(|e| PlatformError::Send(e.to_string()))?;
115 Ok(())
116 }
117
118 async fn send_stream(
119 &self,
120 channel: &ChannelId,
121 mut stream: Pin<Box<dyn Stream<Item = String> + Send>>,
122 ) -> Result<(), PlatformError> {
123 use futures::StreamExt;
124 let mut buf = String::new();
125 while let Some(chunk) = stream.next().await {
126 buf.push_str(&chunk);
127 }
128 self.send_message(channel, OutboundMessage::text(buf)).await
129 }
130}
131
132#[cfg(test)]
133mod tests {
134 use garudust_core::net_guard;
135
136 #[test]
137 fn send_message_rejects_private_callback_url() {
138 let result = net_guard::is_safe_url("http://192.168.1.1/callback");
141 assert!(result.is_err(), "private callback URL must be blocked");
142 }
143
144 #[test]
145 fn session_key_falls_back_to_callback_url() {
146 let session_key = "";
148 let callback_url = "https://example.com/reply";
149 let key = if session_key.is_empty() {
150 format!("webhook:{callback_url}")
151 } else {
152 session_key.to_string()
153 };
154 assert_eq!(key, "webhook:https://example.com/reply");
155 }
156
157 #[test]
158 fn session_key_used_when_provided() {
159 let session_key = "my-custom-key";
160 let callback_url = "https://example.com/reply";
161 let key = if session_key.is_empty() {
162 format!("webhook:{callback_url}")
163 } else {
164 session_key.to_string()
165 };
166 assert_eq!(key, "my-custom-key");
167 }
168}