rsflow/input/
http.rs

1//! HTTP输入组件
2//!
3//! 从HTTP端点接收数据
4
5use std::net::SocketAddr;
6use std::sync::Arc;
7use tokio::sync::Mutex;
8use std::collections::VecDeque;
9use std::sync::atomic::{AtomicBool, Ordering};
10use async_trait::async_trait;
11use serde::{Deserialize, Serialize};
12use axum::{Router, routing::post, extract::State, http::StatusCode};
13
14use crate::{Error, Message, input::Input};
15
16/// HTTP输入配置
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct HttpInputConfig {
19    /// 监听地址
20    pub address: String,
21    /// 路径
22    pub path: String,
23    /// 是否启用CORS
24    pub cors_enabled: Option<bool>,
25}
26
27/// HTTP输入组件
28pub struct HttpInput {
29    config: HttpInputConfig,
30    queue: Arc<Mutex<VecDeque<Message>>>,
31    server_handle: Arc<Mutex<Option<tokio::task::JoinHandle<Result<(), Error>>>>>,
32    connected: AtomicBool,
33}
34
35/// 共享状态
36type AppState = Arc<Mutex<VecDeque<Message>>>;
37
38impl HttpInput {
39    /// 创建一个新的HTTP输入组件
40    pub fn new(config: &HttpInputConfig) -> Result<Self, Error> {
41        Ok(Self {
42            config: config.clone(),
43            queue: Arc::new(Mutex::new(VecDeque::new())),
44            server_handle: Arc::new(Mutex::new(None)),
45            connected: AtomicBool::new(false),
46        })
47    }
48
49    /// 处理HTTP请求
50    async fn handle_request(State(state): State<AppState>, body: axum::extract::Json<serde_json::Value>) -> StatusCode {
51        let msg = match Message::from_json(&body.0) {
52            Ok(msg) => msg,
53            Err(_) => return StatusCode::BAD_REQUEST,
54        };
55
56        let mut queue = state.lock().await;
57        queue.push_back(msg);
58        StatusCode::OK
59    }
60}
61
62#[async_trait]
63impl Input for HttpInput {
64    async fn connect(&self) -> Result<(), Error> {
65        if self.connected.load(Ordering::SeqCst) {
66            return Ok(());
67        }
68
69        let queue = self.queue.clone();
70        let path = self.config.path.clone();
71        let address = self.config.address.clone();
72
73        // 创建HTTP服务器
74        let app = Router::new()
75            .route(&path, post(Self::handle_request))
76            .with_state(queue);
77
78        // 解析地址
79        let addr: SocketAddr = address.parse().map_err(|e| {
80            Error::Config(format!("无效的地址 {}: {}", address, e))
81        })?;
82
83        // 启动服务器
84        let server_handle = tokio::spawn(async move {
85            axum::Server::bind(&addr)
86                .serve(app.into_make_service())
87                .await
88                .map_err(|e| Error::Connection(format!("HTTP服务器错误: {}", e)))
89        });
90
91        let server_handle_arc = self.server_handle.clone();
92        let mut server_handle_arc_mutex = server_handle_arc.lock().await;
93        *server_handle_arc_mutex = Some(server_handle);
94        self.connected.store(true, std::sync::atomic::Ordering::SeqCst);
95
96        Ok(())
97    }
98
99    async fn read(&self) -> Result<Message, Error> {
100        if !self.connected.load(Ordering::SeqCst) {
101            return Err(Error::Connection("输入未连接".to_string()));
102        }
103
104        // 尝试从队列中获取消息
105        let msg_option;
106        {
107            let mut queue = self.queue.lock().await;
108            msg_option = queue.pop_front();
109        }
110
111        if let Some(msg) = msg_option {
112            Ok(msg)
113        } else {
114            // 如果队列为空,则等待一段时间后返回错误
115            tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
116            Err(Error::Processing("队列为空".to_string()))
117        }
118    }
119
120    async fn acknowledge(&self, _msg: &Message) -> Result<(), Error> {
121        // HTTP输入不需要确认机制
122        Ok(())
123    }
124
125    async fn close(&self) -> Result<(), Error> {
126        let mut server_handle_guard = self.server_handle.lock().await;
127        if let Some(handle) = server_handle_guard.take() {
128            handle.abort();
129        }
130
131        self.connected.store(false, Ordering::SeqCst);
132        Ok(())
133    }
134}