arkflow_plugin/input/
http.rs1use arkflow_core::input::{register_input_builder, Ack, Input, InputBuilder, NoopAck};
6use arkflow_core::{Error, MessageBatch};
7use async_trait::async_trait;
8use axum::{extract::State, http::StatusCode, routing::post, Router};
9use serde::{Deserialize, Serialize};
10use std::collections::VecDeque;
11use std::net::SocketAddr;
12use std::sync::atomic::{AtomicBool, Ordering};
13use std::sync::Arc;
14use tokio::sync::Mutex;
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct HttpInputConfig {
19 pub address: String,
21 pub path: String,
23 pub cors_enabled: Option<bool>,
25}
26
27pub struct HttpInput {
29 config: HttpInputConfig,
30 queue: Arc<Mutex<VecDeque<MessageBatch>>>,
31 server_handle: Arc<Mutex<Option<tokio::task::JoinHandle<Result<(), Error>>>>>,
32 connected: AtomicBool,
33}
34
35type AppState = Arc<Mutex<VecDeque<MessageBatch>>>;
36
37impl HttpInput {
38 pub fn new(config: HttpInputConfig) -> Result<Self, Error> {
39 Ok(Self {
40 config,
41 queue: Arc::new(Mutex::new(VecDeque::new())),
42 server_handle: Arc::new(Mutex::new(None)),
43 connected: AtomicBool::new(false),
44 })
45 }
46
47 async fn handle_request(
48 State(state): State<AppState>,
49 body: axum::extract::Json<serde_json::Value>,
50 ) -> StatusCode {
51 let msg = match MessageBatch::from_json(&body.0) {
52 Ok(msg) => msg,
53 Err(_) => return StatusCode::BAD_REQUEST,
54 };
55
56 let mut queue = state.lock().await;
57 queue.push_back(msg);
58 StatusCode::OK
59 }
60}
61
62#[async_trait]
63impl Input for HttpInput {
64 async fn connect(&self) -> Result<(), Error> {
65 if self.connected.load(Ordering::SeqCst) {
66 return Ok(());
67 }
68
69 let queue = self.queue.clone();
70 let path = self.config.path.clone();
71 let address = self.config.address.clone();
72
73 let app = Router::new()
74 .route(&path, post(Self::handle_request))
75 .with_state(queue);
76
77 let addr: SocketAddr = address
78 .parse()
79 .map_err(|e| Error::Config(format!("Invalid address {}: {}", address, e)))?;
80
81 let server_handle = tokio::spawn(async move {
82 axum::Server::bind(&addr)
83 .serve(app.into_make_service())
84 .await
85 .map_err(|e| Error::Connection(format!("HTTP server error: {}", e)))
86 });
87
88 let server_handle_arc = self.server_handle.clone();
89 let mut server_handle_arc_mutex = server_handle_arc.lock().await;
90 *server_handle_arc_mutex = Some(server_handle);
91 self.connected.store(true, Ordering::SeqCst);
92
93 Ok(())
94 }
95
96 async fn read(&self) -> Result<(MessageBatch, Arc<dyn Ack>), Error> {
97 if !self.connected.load(Ordering::SeqCst) {
98 return Err(Error::Connection("The input is not connected".to_string()));
99 }
100
101 let msg_option;
103 {
104 let mut queue = self.queue.lock().await;
105 msg_option = queue.pop_front();
106 }
107
108 if let Some(msg) = msg_option {
109 Ok((msg, Arc::new(NoopAck)))
110 } else {
111 tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
113 Err(Error::Process("The queue is empty".to_string()))
114 }
115 }
116
117 async fn close(&self) -> Result<(), Error> {
118 let mut server_handle_guard = self.server_handle.lock().await;
119 if let Some(handle) = server_handle_guard.take() {
120 handle.abort();
121 }
122
123 self.connected.store(false, Ordering::SeqCst);
124 Ok(())
125 }
126}
127
128pub(crate) struct HttpInputBuilder;
129impl InputBuilder for HttpInputBuilder {
130 fn build(&self, config: &Option<serde_json::Value>) -> Result<Arc<dyn Input>, Error> {
131 if config.is_none() {
132 return Err(Error::Config(
133 "Http input configuration is missing".to_string(),
134 ));
135 }
136
137 let config: HttpInputConfig = serde_json::from_value(config.clone().unwrap())?;
138 Ok(Arc::new(HttpInput::new(config)?))
139 }
140}
141
142pub fn init() {
143 register_input_builder("http", Arc::new(HttpInputBuilder));
144}
145
146#[cfg(test)]
147mod tests {
148 use super::*;
149 use reqwest::Client;
150 use serde_json::json;
151
152 #[tokio::test]
153 async fn test_http_input_new() {
154 let config = HttpInputConfig {
155 address: "127.0.0.1:0".to_string(), path: "/test".to_string(),
157 cors_enabled: Some(false),
158 };
159 let input = HttpInput::new(config);
160 assert!(input.is_ok());
161 }
162
163 #[tokio::test]
164 async fn test_http_input_connect() {
165 let config = HttpInputConfig {
166 address: "127.0.0.1:0".to_string(), path: "/test".to_string(),
168 cors_enabled: Some(false),
169 };
170 let input = HttpInput::new(config).unwrap();
171 let result = input.connect().await;
172 assert!(result.is_ok());
173
174 let result = input.connect().await;
176 assert!(result.is_ok());
177
178 assert!(input.close().await.is_ok());
180 }
181
182 #[tokio::test]
183 async fn test_http_input_read_without_connect() {
184 let config = HttpInputConfig {
185 address: "127.0.0.1:0".to_string(),
186 path: "/test".to_string(),
187 cors_enabled: Some(false),
188 };
189 let input = HttpInput::new(config).unwrap();
190 let result = input.read().await;
191 assert!(result.is_err());
192 match result {
193 Err(Error::Connection(_)) => {} _ => panic!("Expected Connection error"),
195 }
196 }
197
198 #[tokio::test]
199 async fn test_http_input_read_empty_queue() {
200 let config = HttpInputConfig {
201 address: "127.0.0.1:0".to_string(),
202 path: "/test".to_string(),
203 cors_enabled: Some(false),
204 };
205 let input = HttpInput::new(config).unwrap();
206 assert!(input.connect().await.is_ok());
207
208 let result = input.read().await;
210 assert!(result.is_err());
211 match result {
212 Err(Error::Process(_)) => {} _ => panic!("Expected Processing error"),
214 }
215
216 assert!(input.close().await.is_ok());
218 }
219
220 #[tokio::test]
221 async fn test_http_input_invalid_address() {
222 let config = HttpInputConfig {
223 address: "invalid-address".to_string(), path: "/test".to_string(),
225 cors_enabled: Some(false),
226 };
227 let input = HttpInput::new(config).unwrap();
228 let result = input.connect().await;
229 assert!(result.is_err());
230 match result {
231 Err(Error::Config(_)) => {} _ => panic!("Expected Config error"),
233 }
234 }
235
236 #[tokio::test]
237 async fn test_http_input_receive_message() {
238 let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
240 let port = listener.local_addr().unwrap().port();
241 drop(listener);
243
244 let config = HttpInputConfig {
246 address: format!("127.0.0.1:{}", port),
247 path: "/test".to_string(),
248 cors_enabled: Some(false),
249 };
250 let input = HttpInput::new(config.clone()).unwrap();
251 assert!(input.connect().await.is_ok());
252
253 tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
255
256 let client = Client::new();
258 let test_message = json!({"data": "test message"});
259
260 let response = client
262 .post(format!("http://127.0.0.1:{}{}", port, config.path))
263 .json(&test_message)
264 .send()
265 .await;
266
267 assert!(
268 response.is_ok(),
269 "HTTP request failed: {:?}",
270 response.err()
271 );
272 let response = response.unwrap();
273 assert!(
274 response.status().is_success(),
275 "HTTP response status is not success: {}",
276 response.status()
277 );
278
279 let read_result = input.read().await;
281 assert!(
282 read_result.is_ok(),
283 "Failed to read message: {:?}",
284 read_result.err()
285 );
286
287 let (msg, ack) = read_result.unwrap();
288 let content = msg.as_string().unwrap();
289 assert_eq!(content, vec![test_message.to_string()]);
290 ack.ack().await;
291
292 assert!(input.close().await.is_ok());
294 }
295}