Skip to main content

garudust_platforms/
webhook.rs

1use std::pin::Pin;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use axum::{extract::State, routing::post, Json, Router};
6use futures::Stream;
7use garudust_core::{
8    error::PlatformError,
9    net_guard,
10    platform::{MessageHandler, PlatformAdapter},
11    types::{ChannelId, InboundMessage, OutboundMessage},
12};
13use serde::{Deserialize, Serialize};
14
15#[derive(Deserialize)]
16struct WebhookPayload {
17    text: String,
18    /// URL to POST the response back to.
19    callback_url: String,
20    #[serde(default)]
21    user_id: String,
22    #[serde(default)]
23    user_name: String,
24    #[serde(default)]
25    session_key: String,
26}
27
28#[derive(Serialize)]
29struct CallbackPayload {
30    text: String,
31}
32
33async fn handle_webhook(
34    State(handler): State<Arc<dyn MessageHandler>>,
35    Json(payload): Json<WebhookPayload>,
36) -> axum::http::StatusCode {
37    let session_key = if payload.session_key.is_empty() {
38        format!("webhook:{}", payload.callback_url)
39    } else {
40        payload.session_key.clone()
41    };
42
43    let inbound = InboundMessage {
44        channel: ChannelId {
45            platform: "webhook".into(),
46            // chat_id holds the callback URL so send_message can POST back
47            chat_id: payload.callback_url,
48            thread_id: None,
49        },
50        user_id: payload.user_id,
51        user_name: payload.user_name,
52        text: payload.text,
53        session_key,
54        is_group: false,
55        bot_mentioned: None,
56    };
57
58    match handler.handle(inbound).await {
59        Ok(()) => axum::http::StatusCode::ACCEPTED,
60        Err(_) => axum::http::StatusCode::INTERNAL_SERVER_ERROR,
61    }
62}
63
64pub struct WebhookAdapter {
65    port: u16,
66    webhook_path: String,
67}
68
69impl WebhookAdapter {
70    pub fn new(port: u16, webhook_path: String) -> Self {
71        Self { port, webhook_path }
72    }
73}
74
75#[async_trait]
76impl PlatformAdapter for WebhookAdapter {
77    fn name(&self) -> &'static str {
78        "webhook"
79    }
80
81    async fn start(&self, handler: Arc<dyn MessageHandler>) -> Result<(), PlatformError> {
82        let port = self.port;
83        let path = self.webhook_path.clone();
84        let router = Router::new()
85            .route(&self.webhook_path, post(handle_webhook))
86            .with_state(handler);
87
88        let listener = tokio::net::TcpListener::bind(format!("0.0.0.0:{port}"))
89            .await
90            .map_err(|e| PlatformError::Connection(e.to_string()))?;
91
92        tracing::info!("webhook adapter listening on 0.0.0.0:{port}{path}");
93        tokio::spawn(async move {
94            if let Err(e) = axum::serve(listener, router).await {
95                tracing::error!("webhook server error: {e}");
96            }
97        });
98        Ok(())
99    }
100
101    async fn send_message(
102        &self,
103        channel: &ChannelId,
104        message: OutboundMessage,
105    ) -> Result<(), PlatformError> {
106        net_guard::is_safe_url(&channel.chat_id).map_err(|e| PlatformError::Send(e.to_string()))?;
107
108        let client = reqwest::Client::new();
109        client
110            .post(&channel.chat_id)
111            .json(&CallbackPayload { text: message.text })
112            .send()
113            .await
114            .map_err(|e| PlatformError::Send(e.to_string()))?;
115        Ok(())
116    }
117
118    async fn send_stream(
119        &self,
120        channel: &ChannelId,
121        mut stream: Pin<Box<dyn Stream<Item = String> + Send>>,
122    ) -> Result<(), PlatformError> {
123        use futures::StreamExt;
124        let mut buf = String::new();
125        while let Some(chunk) = stream.next().await {
126            buf.push_str(&chunk);
127        }
128        self.send_message(channel, OutboundMessage::text(buf)).await
129    }
130}
131
132#[cfg(test)]
133mod tests {
134    use garudust_core::net_guard;
135
136    #[test]
137    fn send_message_rejects_private_callback_url() {
138        // net_guard is called before the HTTP client, so a private URL must
139        // be rejected without making any network request.
140        let result = net_guard::is_safe_url("http://192.168.1.1/callback");
141        assert!(result.is_err(), "private callback URL must be blocked");
142    }
143
144    #[test]
145    fn session_key_falls_back_to_callback_url() {
146        // Mirrors the logic in handle_webhook: empty session_key → use callback_url.
147        let session_key = "";
148        let callback_url = "https://example.com/reply";
149        let key = if session_key.is_empty() {
150            format!("webhook:{callback_url}")
151        } else {
152            session_key.to_string()
153        };
154        assert_eq!(key, "webhook:https://example.com/reply");
155    }
156
157    #[test]
158    fn session_key_used_when_provided() {
159        let session_key = "my-custom-key";
160        let callback_url = "https://example.com/reply";
161        let key = if session_key.is_empty() {
162            format!("webhook:{callback_url}")
163        } else {
164            session_key.to_string()
165        };
166        assert_eq!(key, "my-custom-key");
167    }
168}