arkflow_plugin/input/
http.rs

1//! HTTP input component
2//!
3//! Receive data from HTTP endpoints
4
5use arkflow_core::input::{register_input_builder, Ack, Input, InputBuilder, NoopAck};
6use arkflow_core::{Error, MessageBatch};
7use async_trait::async_trait;
8use axum::{extract::State, http::StatusCode, routing::post, Router};
9use serde::{Deserialize, Serialize};
10use std::collections::VecDeque;
11use std::net::SocketAddr;
12use std::sync::atomic::{AtomicBool, Ordering};
13use std::sync::Arc;
14use tokio::sync::Mutex;
15
16/// HTTP input configuration
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct HttpInputConfig {
19    /// Listening address
20    pub address: String,
21    /// Path
22    pub path: String,
23    /// Whether CORS is enabled
24    pub cors_enabled: Option<bool>,
25}
26
27/// HTTP input component
28pub struct HttpInput {
29    config: HttpInputConfig,
30    queue: Arc<Mutex<VecDeque<MessageBatch>>>,
31    server_handle: Arc<Mutex<Option<tokio::task::JoinHandle<Result<(), Error>>>>>,
32    connected: AtomicBool,
33}
34
35type AppState = Arc<Mutex<VecDeque<MessageBatch>>>;
36
37impl HttpInput {
38    pub fn new(config: HttpInputConfig) -> Result<Self, Error> {
39        Ok(Self {
40            config,
41            queue: Arc::new(Mutex::new(VecDeque::new())),
42            server_handle: Arc::new(Mutex::new(None)),
43            connected: AtomicBool::new(false),
44        })
45    }
46
47    async fn handle_request(
48        State(state): State<AppState>,
49        body: axum::extract::Json<serde_json::Value>,
50    ) -> StatusCode {
51        let msg = match MessageBatch::from_json(&body.0) {
52            Ok(msg) => msg,
53            Err(_) => return StatusCode::BAD_REQUEST,
54        };
55
56        let mut queue = state.lock().await;
57        queue.push_back(msg);
58        StatusCode::OK
59    }
60}
61
62#[async_trait]
63impl Input for HttpInput {
64    async fn connect(&self) -> Result<(), Error> {
65        if self.connected.load(Ordering::SeqCst) {
66            return Ok(());
67        }
68
69        let queue = self.queue.clone();
70        let path = self.config.path.clone();
71        let address = self.config.address.clone();
72
73        let app = Router::new()
74            .route(&path, post(Self::handle_request))
75            .with_state(queue);
76
77        let addr: SocketAddr = address
78            .parse()
79            .map_err(|e| Error::Config(format!("Invalid address {}: {}", address, e)))?;
80
81        let server_handle = tokio::spawn(async move {
82            axum::Server::bind(&addr)
83                .serve(app.into_make_service())
84                .await
85                .map_err(|e| Error::Connection(format!("HTTP server error: {}", e)))
86        });
87
88        let server_handle_arc = self.server_handle.clone();
89        let mut server_handle_arc_mutex = server_handle_arc.lock().await;
90        *server_handle_arc_mutex = Some(server_handle);
91        self.connected.store(true, Ordering::SeqCst);
92
93        Ok(())
94    }
95
96    async fn read(&self) -> Result<(MessageBatch, Arc<dyn Ack>), Error> {
97        if !self.connected.load(Ordering::SeqCst) {
98            return Err(Error::Connection("The input is not connected".to_string()));
99        }
100
101        // Try to get a message from the queue
102        let msg_option;
103        {
104            let mut queue = self.queue.lock().await;
105            msg_option = queue.pop_front();
106        }
107
108        if let Some(msg) = msg_option {
109            Ok((msg, Arc::new(NoopAck)))
110        } else {
111            // If the queue is empty, an error is returned after waiting for a while
112            tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
113            Err(Error::Process("The queue is empty".to_string()))
114        }
115    }
116
117    async fn close(&self) -> Result<(), Error> {
118        let mut server_handle_guard = self.server_handle.lock().await;
119        if let Some(handle) = server_handle_guard.take() {
120            handle.abort();
121        }
122
123        self.connected.store(false, Ordering::SeqCst);
124        Ok(())
125    }
126}
127
128pub(crate) struct HttpInputBuilder;
129impl InputBuilder for HttpInputBuilder {
130    fn build(&self, config: &Option<serde_json::Value>) -> Result<Arc<dyn Input>, Error> {
131        if config.is_none() {
132            return Err(Error::Config(
133                "Http input configuration is missing".to_string(),
134            ));
135        }
136
137        let config: HttpInputConfig = serde_json::from_value(config.clone().unwrap())?;
138        Ok(Arc::new(HttpInput::new(config)?))
139    }
140}
141
142pub fn init() {
143    register_input_builder("http", Arc::new(HttpInputBuilder));
144}
145
146#[cfg(test)]
147mod tests {
148    use super::*;
149    use reqwest::Client;
150    use serde_json::json;
151
152    #[tokio::test]
153    async fn test_http_input_new() {
154        let config = HttpInputConfig {
155            address: "127.0.0.1:0".to_string(), // Use random port
156            path: "/test".to_string(),
157            cors_enabled: Some(false),
158        };
159        let input = HttpInput::new(config);
160        assert!(input.is_ok());
161    }
162
163    #[tokio::test]
164    async fn test_http_input_connect() {
165        let config = HttpInputConfig {
166            address: "127.0.0.1:0".to_string(), // Use random port
167            path: "/test".to_string(),
168            cors_enabled: Some(false),
169        };
170        let input = HttpInput::new(config).unwrap();
171        let result = input.connect().await;
172        assert!(result.is_ok());
173
174        // Test repeated connection
175        let result = input.connect().await;
176        assert!(result.is_ok());
177
178        // Close connection
179        assert!(input.close().await.is_ok());
180    }
181
182    #[tokio::test]
183    async fn test_http_input_read_without_connect() {
184        let config = HttpInputConfig {
185            address: "127.0.0.1:0".to_string(),
186            path: "/test".to_string(),
187            cors_enabled: Some(false),
188        };
189        let input = HttpInput::new(config).unwrap();
190        let result = input.read().await;
191        assert!(result.is_err());
192        match result {
193            Err(Error::Connection(_)) => {} // Expected error type
194            _ => panic!("Expected Connection error"),
195        }
196    }
197
198    #[tokio::test]
199    async fn test_http_input_read_empty_queue() {
200        let config = HttpInputConfig {
201            address: "127.0.0.1:0".to_string(),
202            path: "/test".to_string(),
203            cors_enabled: Some(false),
204        };
205        let input = HttpInput::new(config).unwrap();
206        assert!(input.connect().await.is_ok());
207
208        // Queue is empty, should return Processing error
209        let result = input.read().await;
210        assert!(result.is_err());
211        match result {
212            Err(Error::Process(_)) => {} // Expected error type
213            _ => panic!("Expected Processing error"),
214        }
215
216        // Close connection
217        assert!(input.close().await.is_ok());
218    }
219
220    #[tokio::test]
221    async fn test_http_input_invalid_address() {
222        let config = HttpInputConfig {
223            address: "invalid-address".to_string(), // Invalid address
224            path: "/test".to_string(),
225            cors_enabled: Some(false),
226        };
227        let input = HttpInput::new(config).unwrap();
228        let result = input.connect().await;
229        assert!(result.is_err());
230        match result {
231            Err(Error::Config(_)) => {} // Expected error type
232            _ => panic!("Expected Config error"),
233        }
234    }
235
236    #[tokio::test]
237    async fn test_http_input_receive_message() {
238        // Create a TCP listener to get an available port
239        let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
240        let port = listener.local_addr().unwrap().port();
241        // Release the listener so the HTTP server can use this port
242        drop(listener);
243
244        // Create HTTP input component using the obtained port
245        let config = HttpInputConfig {
246            address: format!("127.0.0.1:{}", port),
247            path: "/test".to_string(),
248            cors_enabled: Some(false),
249        };
250        let input = HttpInput::new(config.clone()).unwrap();
251        assert!(input.connect().await.is_ok());
252
253        // Wait for server to start
254        tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
255
256        // Create an HTTP client and send request
257        let client = Client::new();
258        let test_message = json!({"data": "test message"});
259
260        // Send request and verify response
261        let response = client
262            .post(format!("http://127.0.0.1:{}{}", port, config.path))
263            .json(&test_message)
264            .send()
265            .await;
266
267        assert!(
268            response.is_ok(),
269            "HTTP request failed: {:?}",
270            response.err()
271        );
272        let response = response.unwrap();
273        assert!(
274            response.status().is_success(),
275            "HTTP response status is not success: {}",
276            response.status()
277        );
278
279        // Verify message was received correctly
280        let read_result = input.read().await;
281        assert!(
282            read_result.is_ok(),
283            "Failed to read message: {:?}",
284            read_result.err()
285        );
286
287        let (msg, ack) = read_result.unwrap();
288        let content = msg.as_string().unwrap();
289        assert_eq!(content, vec![test_message.to_string()]);
290        ack.ack().await;
291
292        // Close connection
293        assert!(input.close().await.is_ok());
294    }
295}