garudust_platforms/
webhook.rs1use std::pin::Pin;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use axum::{extract::State, routing::post, Json, Router};
6use futures::Stream;
7use garudust_core::{
8 error::PlatformError,
9 net_guard,
10 platform::{MessageHandler, PlatformAdapter},
11 types::{ChannelId, InboundMessage, OutboundMessage},
12};
13use serde::{Deserialize, Serialize};
14
15#[derive(Deserialize)]
16struct WebhookPayload {
17 text: String,
18 callback_url: String,
20 #[serde(default)]
21 user_id: String,
22 #[serde(default)]
23 user_name: String,
24 #[serde(default)]
25 session_key: String,
26}
27
28#[derive(Serialize)]
29struct CallbackPayload {
30 text: String,
31}
32
33async fn handle_webhook(
34 State(handler): State<Arc<dyn MessageHandler>>,
35 Json(payload): Json<WebhookPayload>,
36) -> axum::http::StatusCode {
37 let session_key = if payload.session_key.is_empty() {
38 format!("webhook:{}", payload.callback_url)
39 } else {
40 payload.session_key.clone()
41 };
42
43 let inbound = InboundMessage {
44 channel: ChannelId {
45 platform: "webhook".into(),
46 chat_id: payload.callback_url,
48 thread_id: None,
49 },
50 user_id: payload.user_id,
51 user_name: payload.user_name,
52 text: payload.text,
53 session_key,
54 is_group: false,
55 };
56
57 match handler.handle(inbound).await {
58 Ok(()) => axum::http::StatusCode::ACCEPTED,
59 Err(_) => axum::http::StatusCode::INTERNAL_SERVER_ERROR,
60 }
61}
62
63pub struct WebhookAdapter {
64 port: u16,
65}
66
67impl WebhookAdapter {
68 pub fn new(port: u16) -> Self {
69 Self { port }
70 }
71}
72
73#[async_trait]
74impl PlatformAdapter for WebhookAdapter {
75 fn name(&self) -> &'static str {
76 "webhook"
77 }
78
79 async fn start(&self, handler: Arc<dyn MessageHandler>) -> Result<(), PlatformError> {
80 let port = self.port;
81 let router = Router::new()
82 .route("/webhook", post(handle_webhook))
83 .with_state(handler);
84
85 let listener = tokio::net::TcpListener::bind(format!("0.0.0.0:{port}"))
86 .await
87 .map_err(|e| PlatformError::Connection(e.to_string()))?;
88
89 tracing::info!("webhook adapter listening on 0.0.0.0:{port}");
90 tokio::spawn(async move {
91 if let Err(e) = axum::serve(listener, router).await {
92 tracing::error!("webhook server error: {e}");
93 }
94 });
95 Ok(())
96 }
97
98 async fn send_message(
99 &self,
100 channel: &ChannelId,
101 message: OutboundMessage,
102 ) -> Result<(), PlatformError> {
103 net_guard::is_safe_url(&channel.chat_id).map_err(|e| PlatformError::Send(e.to_string()))?;
104
105 let client = reqwest::Client::new();
106 client
107 .post(&channel.chat_id)
108 .json(&CallbackPayload { text: message.text })
109 .send()
110 .await
111 .map_err(|e| PlatformError::Send(e.to_string()))?;
112 Ok(())
113 }
114
115 async fn send_stream(
116 &self,
117 channel: &ChannelId,
118 mut stream: Pin<Box<dyn Stream<Item = String> + Send>>,
119 ) -> Result<(), PlatformError> {
120 use futures::StreamExt;
121 let mut buf = String::new();
122 while let Some(chunk) = stream.next().await {
123 buf.push_str(&chunk);
124 }
125 self.send_message(channel, OutboundMessage::text(buf)).await
126 }
127}
128
129#[cfg(test)]
130mod tests {
131 use garudust_core::net_guard;
132
133 #[test]
134 fn send_message_rejects_private_callback_url() {
135 let result = net_guard::is_safe_url("http://192.168.1.1/callback");
138 assert!(result.is_err(), "private callback URL must be blocked");
139 }
140
141 #[test]
142 fn session_key_falls_back_to_callback_url() {
143 let session_key = "";
145 let callback_url = "https://example.com/reply";
146 let key = if session_key.is_empty() {
147 format!("webhook:{callback_url}")
148 } else {
149 session_key.to_string()
150 };
151 assert_eq!(key, "webhook:https://example.com/reply");
152 }
153
154 #[test]
155 fn session_key_used_when_provided() {
156 let session_key = "my-custom-key";
157 let callback_url = "https://example.com/reply";
158 let key = if session_key.is_empty() {
159 format!("webhook:{callback_url}")
160 } else {
161 session_key.to_string()
162 };
163 assert_eq!(key, "my-custom-key");
164 }
165}