lazy_sock/
lib.rs

1/* src/lib.rs */
2
3use std::collections::HashMap;
4use std::path::{Path, PathBuf};
5use std::sync::Arc;
6use tokio::fs;
7use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
8use tokio::net::{UnixListener, UnixStream};
9use tokio::signal;
10use tokio::sync::RwLock;
11use tokio::time::{Duration, sleep};
12
13mod request;
14mod response;
15mod router;
16
17pub use request::Request;
18pub use response::Response;
19pub use router::{Method, Router};
20
21/// 回调函数类型定义
22pub type HandlerFn = Arc<dyn Fn(Request) -> Response + Send + Sync>;
23
24/// 日志回调函数类型
25pub type LogCallbackFn = Arc<dyn Fn(&str) + Send + Sync>;
26
27/// 提示回调函数类型
28pub type PromptCallbackFn = Arc<dyn Fn(&str) + Send + Sync>;
29
30/// LazySock 服务器主结构
31pub struct LazySock {
32    socket_path: PathBuf,
33    router: Arc<RwLock<Router>>,
34    log_callback: Option<LogCallbackFn>,
35    prompt_callback: Option<PromptCallbackFn>,
36    cleanup_on_exit: bool,
37}
38
39impl LazySock {
40    /// 创建新的 LazySock 实例
41    pub fn new<P: AsRef<Path>>(socket_path: P) -> Self {
42        Self {
43            socket_path: socket_path.as_ref().to_path_buf(),
44            router: Arc::new(RwLock::new(Router::new())),
45            log_callback: None,
46            prompt_callback: None,
47            cleanup_on_exit: true,
48        }
49    }
50
51    /// 设置日志回调函数
52    pub fn with_log_callback<F>(mut self, callback: F) -> Self
53    where
54        F: Fn(&str) + Send + Sync + 'static,
55    {
56        self.log_callback = Some(Arc::new(callback));
57        self
58    }
59
60    /// 设置提示回调函数
61    pub fn with_prompt_callback<F>(mut self, callback: F) -> Self
62    where
63        F: Fn(&str) + Send + Sync + 'static,
64    {
65        self.prompt_callback = Some(Arc::new(callback));
66        self
67    }
68
69    /// 设置是否在退出时清理socket文件
70    pub fn with_cleanup_on_exit(mut self, cleanup: bool) -> Self {
71        self.cleanup_on_exit = cleanup;
72        self
73    }
74
75    /// 注册路由处理函数
76    pub async fn route<F>(&self, method: Method, path: &str, handler: F)
77    where
78        F: Fn(Request) -> Response + Send + Sync + 'static,
79    {
80        let mut router = self.router.write().await;
81        router.add_route(method, path, Arc::new(handler));
82    }
83
84    /// 启动服务器
85    pub async fn run(self) -> Result<(), Box<dyn std::error::Error>> {
86        // 检查socket文件是否存在
87        if let Err(e) = self.check_and_handle_existing_socket().await {
88            return Err(e);
89        }
90
91        // 创建Unix socket监听器
92        let listener = UnixListener::bind(&self.socket_path)?;
93        self.log(&format!("Server started on socket: {:?}", self.socket_path));
94
95        // 设置信号处理器用于优雅关闭
96        let socket_path_for_cleanup = self.socket_path.clone();
97        let cleanup_on_exit = self.cleanup_on_exit;
98        let mut cleanup_task = tokio::spawn(async move {
99            if let Ok(()) = signal::ctrl_c().await {
100                if cleanup_on_exit {
101                    let _ = fs::remove_file(&socket_path_for_cleanup).await;
102                }
103            }
104        });
105
106        // 主服务循环
107        loop {
108            tokio::select! {
109                result = listener.accept() => {
110                    match result {
111                        Ok((stream, _)) => {
112                            let router = Arc::clone(&self.router);
113                            let log_callback = self.log_callback.clone();
114                            tokio::spawn(async move {
115                                if let Err(e) = handle_connection(stream, router).await {
116                                    if let Some(logger) = log_callback {
117                                        logger(&format!("Error handling connection: {}", e));
118                                    }
119                                }
120                            });
121                        }
122                        Err(e) => {
123                            self.log(&format!("Error accepting connection: {}", e));
124                        }
125                    }
126                }
127                _ = &mut cleanup_task => {
128                    self.log("Server shutting down...");
129                    break;
130                }
131            }
132        }
133
134        Ok(())
135    }
136
137    /// 检查并处理已存在的socket文件
138    async fn check_and_handle_existing_socket(&self) -> Result<(), Box<dyn std::error::Error>> {
139        if self.socket_path.exists() {
140            self.prompt(
141                "Socket file already exists. Will override in 3 seconds... (Ctrl+C to abort now)",
142            );
143
144            // 等待3秒,期间可以被Ctrl+C中断
145            tokio::select! {
146                _ = sleep(Duration::from_secs(3)) => {
147                    fs::remove_file(&self.socket_path).await?;
148                    self.log("Removed existing socket file");
149                }
150                _ = signal::ctrl_c() => {
151                    self.prompt("Aborted by user");
152                    return Err("User aborted".into());
153                }
154            }
155        }
156
157        Ok(())
158    }
159
160    /// 记录日志
161    fn log(&self, message: &str) {
162        if let Some(callback) = &self.log_callback {
163            callback(message);
164        }
165    }
166
167    /// 显示提示信息
168    fn prompt(&self, message: &str) {
169        if let Some(callback) = &self.prompt_callback {
170            callback(message);
171        }
172    }
173}
174
175/// 处理单个连接
176async fn handle_connection(
177    mut stream: UnixStream,
178    router: Arc<RwLock<Router>>,
179) -> Result<(), Box<dyn std::error::Error>> {
180    let mut reader = BufReader::new(&mut stream);
181    let mut request_line = String::new();
182    reader.read_line(&mut request_line).await?;
183
184    // 解析HTTP请求行
185    let parts: Vec<&str> = request_line.trim().split_whitespace().collect();
186    if parts.len() < 2 {
187        return Err("Invalid request line".into());
188    }
189
190    let method = match parts[0] {
191        "GET" => Method::Get,
192        "POST" => Method::Post,
193        "PUT" => Method::Put,
194        "DELETE" => Method::Delete,
195        _ => return Err("Unsupported method".into()),
196    };
197
198    let path = parts[1].to_string();
199
200    // 读取剩余的头部(简单实现,跳过)
201    let headers = HashMap::new();
202    let mut line = String::new();
203    while reader.read_line(&mut line).await? > 0 {
204        if line.trim().is_empty() {
205            break;
206        }
207        // 这里可以解析头部,简单起见暂时跳过
208        line.clear();
209    }
210
211    // 创建请求对象
212    let request = Request::new(method.clone(), path.clone(), headers, Vec::new());
213
214    // 路由处理
215    let router_guard = router.read().await;
216    let response = if let Some(handler) = router_guard.find_handler(&method, &path) {
217        handler(request)
218    } else {
219        Response::not_found("Route not found")
220    };
221
222    // 发送响应
223    let response_data = response.to_http_response();
224    stream.write_all(response_data.as_bytes()).await?;
225    stream.flush().await?;
226
227    Ok(())
228}
229
230/// 便捷宏用于快速创建服务器
231#[macro_export]
232macro_rules! lazy_sock {
233    ($path:expr) => {
234        $crate::LazySock::new($path)
235            .with_log_callback(|msg| println!("{}", msg))
236            .with_prompt_callback(|msg| println!("{}", msg))
237    };
238}