Skip to main content

garudust_platforms/
webhook.rs

1use std::pin::Pin;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use axum::{extract::State, routing::post, Json, Router};
6use futures::Stream;
7use garudust_core::{
8    error::PlatformError,
9    net_guard,
10    platform::{MessageHandler, PlatformAdapter},
11    types::{ChannelId, InboundMessage, OutboundMessage},
12};
13use serde::{Deserialize, Serialize};
14
15#[derive(Deserialize)]
16struct WebhookPayload {
17    text: String,
18    /// URL to POST the response back to.
19    callback_url: String,
20    #[serde(default)]
21    user_id: String,
22    #[serde(default)]
23    user_name: String,
24    #[serde(default)]
25    session_key: String,
26}
27
28#[derive(Serialize)]
29struct CallbackPayload {
30    text: String,
31}
32
33async fn handle_webhook(
34    State(handler): State<Arc<dyn MessageHandler>>,
35    Json(payload): Json<WebhookPayload>,
36) -> axum::http::StatusCode {
37    let session_key = if payload.session_key.is_empty() {
38        format!("webhook:{}", payload.callback_url)
39    } else {
40        payload.session_key.clone()
41    };
42
43    let inbound = InboundMessage {
44        channel: ChannelId {
45            platform: "webhook".into(),
46            // chat_id holds the callback URL so send_message can POST back
47            chat_id: payload.callback_url,
48            thread_id: None,
49        },
50        user_id: payload.user_id,
51        user_name: payload.user_name,
52        text: payload.text,
53        session_key,
54        is_group: false,
55    };
56
57    match handler.handle(inbound).await {
58        Ok(()) => axum::http::StatusCode::ACCEPTED,
59        Err(_) => axum::http::StatusCode::INTERNAL_SERVER_ERROR,
60    }
61}
62
63pub struct WebhookAdapter {
64    port: u16,
65}
66
67impl WebhookAdapter {
68    pub fn new(port: u16) -> Self {
69        Self { port }
70    }
71}
72
73#[async_trait]
74impl PlatformAdapter for WebhookAdapter {
75    fn name(&self) -> &'static str {
76        "webhook"
77    }
78
79    async fn start(&self, handler: Arc<dyn MessageHandler>) -> Result<(), PlatformError> {
80        let port = self.port;
81        let router = Router::new()
82            .route("/webhook", post(handle_webhook))
83            .with_state(handler);
84
85        let listener = tokio::net::TcpListener::bind(format!("0.0.0.0:{port}"))
86            .await
87            .map_err(|e| PlatformError::Connection(e.to_string()))?;
88
89        tracing::info!("webhook adapter listening on 0.0.0.0:{port}");
90        tokio::spawn(async move {
91            if let Err(e) = axum::serve(listener, router).await {
92                tracing::error!("webhook server error: {e}");
93            }
94        });
95        Ok(())
96    }
97
98    async fn send_message(
99        &self,
100        channel: &ChannelId,
101        message: OutboundMessage,
102    ) -> Result<(), PlatformError> {
103        net_guard::is_safe_url(&channel.chat_id).map_err(|e| PlatformError::Send(e.to_string()))?;
104
105        let client = reqwest::Client::new();
106        client
107            .post(&channel.chat_id)
108            .json(&CallbackPayload { text: message.text })
109            .send()
110            .await
111            .map_err(|e| PlatformError::Send(e.to_string()))?;
112        Ok(())
113    }
114
115    async fn send_stream(
116        &self,
117        channel: &ChannelId,
118        mut stream: Pin<Box<dyn Stream<Item = String> + Send>>,
119    ) -> Result<(), PlatformError> {
120        use futures::StreamExt;
121        let mut buf = String::new();
122        while let Some(chunk) = stream.next().await {
123            buf.push_str(&chunk);
124        }
125        self.send_message(channel, OutboundMessage::text(buf)).await
126    }
127}
128
129#[cfg(test)]
130mod tests {
131    use garudust_core::net_guard;
132
133    #[test]
134    fn send_message_rejects_private_callback_url() {
135        // net_guard is called before the HTTP client, so a private URL must
136        // be rejected without making any network request.
137        let result = net_guard::is_safe_url("http://192.168.1.1/callback");
138        assert!(result.is_err(), "private callback URL must be blocked");
139    }
140
141    #[test]
142    fn session_key_falls_back_to_callback_url() {
143        // Mirrors the logic in handle_webhook: empty session_key → use callback_url.
144        let session_key = "";
145        let callback_url = "https://example.com/reply";
146        let key = if session_key.is_empty() {
147            format!("webhook:{callback_url}")
148        } else {
149            session_key.to_string()
150        };
151        assert_eq!(key, "webhook:https://example.com/reply");
152    }
153
154    #[test]
155    fn session_key_used_when_provided() {
156        let session_key = "my-custom-key";
157        let callback_url = "https://example.com/reply";
158        let key = if session_key.is_empty() {
159            format!("webhook:{callback_url}")
160        } else {
161            session_key.to_string()
162        };
163        assert_eq!(key, "my-custom-key");
164    }
165}