1use std::io::{self, BufRead, BufReader, Write};
9use std::net::{TcpListener, TcpStream};
10use std::sync::atomic::{AtomicBool, Ordering};
11use std::sync::{mpsc, Arc, Mutex};
12use std::thread;
13use std::time::Duration;
14
15#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
17pub enum Method {
18 Get,
19 Post,
20 Put,
21 Patch,
22 Delete,
23 Head,
24 Options,
25 Other,
26}
27
28impl Method {
29 pub fn parse(s: &str) -> Method {
30 match s {
31 "GET" => Method::Get,
32 "POST" => Method::Post,
33 "PUT" => Method::Put,
34 "PATCH" => Method::Patch,
35 "DELETE" => Method::Delete,
36 "HEAD" => Method::Head,
37 "OPTIONS" => Method::Options,
38 _ => Method::Other,
39 }
40 }
41
42 pub fn as_str(&self) -> &'static str {
43 match self {
44 Method::Get => "GET",
45 Method::Post => "POST",
46 Method::Put => "PUT",
47 Method::Patch => "PATCH",
48 Method::Delete => "DELETE",
49 Method::Head => "HEAD",
50 Method::Options => "OPTIONS",
51 Method::Other => "OTHER",
52 }
53 }
54}
55
56#[derive(Clone, Debug)]
58pub struct Request {
59 pub method: Method,
60 pub path: String,
62 pub query: String,
64 pub version: String,
65 pub headers: Vec<(String, String)>,
66 pub body: Vec<u8>,
67}
68
69impl Request {
70 pub fn header(&self, name: &str) -> Option<&str> {
72 self.headers
73 .iter()
74 .find(|(k, _)| k.eq_ignore_ascii_case(name))
75 .map(|(_, v)| v.as_str())
76 }
77
78 pub fn content_type(&self) -> Option<&str> {
80 self.header("content-type")
81 }
82
83 pub fn is_json(&self) -> bool {
85 self.content_type()
86 .map(|ct| ct.contains("application/json"))
87 .unwrap_or(false)
88 }
89
90 pub fn cookie(&self, name: &str) -> Option<String> {
92 let header = self.header("cookie")?;
93 for pair in header.split(';') {
94 let pair = pair.trim();
95 if let Some((k, v)) = pair.split_once('=') {
96 if k.trim() == name {
97 return Some(v.trim().to_string());
98 }
99 }
100 }
101 None
102 }
103}
104
105pub enum Body {
112 Full(Vec<u8>),
113 Stream(Box<dyn FnOnce(&mut dyn Write) -> io::Result<()> + Send + 'static>),
114}
115
116pub struct Response {
118 pub status: u16,
119 pub headers: Vec<(String, String)>,
120 pub body: Body,
121}
122
123impl Response {
124 pub fn new(status: u16) -> Response {
125 Response {
126 status,
127 headers: Vec::new(),
128 body: Body::Full(Vec::new()),
129 }
130 }
131
132 pub fn with_header(mut self, name: &str, value: &str) -> Response {
133 self.headers.push((name.to_string(), value.to_string()));
134 self
135 }
136
137 pub fn with_body(mut self, body: impl Into<Vec<u8>>) -> Response {
138 self.body = Body::Full(body.into());
139 self
140 }
141
142 pub fn with_stream(
146 mut self,
147 producer: impl FnOnce(&mut dyn Write) -> io::Result<()> + Send + 'static,
148 ) -> Response {
149 self.body = Body::Stream(Box::new(producer));
150 self
151 }
152
153 pub fn is_stream(&self) -> bool {
155 matches!(self.body, Body::Stream(_))
156 }
157}
158
159pub struct StreamSink<'a> {
162 w: &'a mut dyn Write,
163}
164
165impl<'a> StreamSink<'a> {
166 pub fn new(w: &'a mut dyn Write) -> StreamSink<'a> {
167 StreamSink { w }
168 }
169
170 pub fn write(&mut self, bytes: &[u8]) -> io::Result<()> {
171 self.w.write_all(bytes)?;
172 self.w.flush()
173 }
174
175 pub fn write_str(&mut self, s: &str) -> io::Result<()> {
176 self.write(s.as_bytes())
177 }
178}
179
180pub struct SseSink<'a> {
184 w: &'a mut dyn Write,
185}
186
187impl<'a> SseSink<'a> {
188 pub fn new(w: &'a mut dyn Write) -> SseSink<'a> {
189 SseSink { w }
190 }
191
192 pub fn data(&mut self, data: &str) -> io::Result<()> {
194 for line in data.split('\n') {
195 write!(self.w, "data: {}\n", line)?;
196 }
197 self.w.write_all(b"\n")?;
198 self.w.flush()
199 }
200
201 pub fn event(&mut self, event: &str, data: &str) -> io::Result<()> {
203 write!(self.w, "event: {}\n", event)?;
204 for line in data.split('\n') {
205 write!(self.w, "data: {}\n", line)?;
206 }
207 self.w.write_all(b"\n")?;
208 self.w.flush()
209 }
210
211 pub fn comment(&mut self, text: &str) -> io::Result<()> {
213 write!(self.w, ": {}\n\n", text)?;
214 self.w.flush()
215 }
216
217 pub fn retry(&mut self, millis: u64) -> io::Result<()> {
219 write!(self.w, "retry: {}\n\n", millis)?;
220 self.w.flush()
221 }
222}
223
224pub fn parse_request<R: BufRead>(reader: &mut R) -> io::Result<Option<Request>> {
227 let mut request_line = String::new();
228 if reader.read_line(&mut request_line)? == 0 {
229 return Ok(None);
230 }
231 let mut parts = request_line.trim_end().split_whitespace();
232 let method = Method::parse(parts.next().unwrap_or(""));
233 let target = parts.next().unwrap_or("/").to_string();
234 let version = parts.next().unwrap_or("HTTP/1.1").to_string();
235
236 let (path, query) = match target.split_once('?') {
237 Some((p, q)) => (p.to_string(), q.to_string()),
238 None => (target, String::new()),
239 };
240
241 let mut headers = Vec::new();
242 let mut content_length = 0usize;
243 loop {
244 let mut line = String::new();
245 if reader.read_line(&mut line)? == 0 {
246 break;
247 }
248 let line = line.trim_end_matches(['\r', '\n']);
249 if line.is_empty() {
250 break;
251 }
252 if let Some((k, v)) = line.split_once(':') {
253 let k = k.trim().to_string();
254 let v = v.trim().to_string();
255 if k.eq_ignore_ascii_case("content-length") {
256 content_length = v.parse().unwrap_or(0);
257 }
258 headers.push((k, v));
259 }
260 }
261
262 let mut body = vec![0u8; content_length];
263 if content_length > 0 {
264 reader.read_exact(&mut body)?;
265 }
266
267 Ok(Some(Request {
268 method,
269 path,
270 query,
271 version,
272 headers,
273 body,
274 }))
275}
276
277pub fn write_response<W: Write>(w: &mut W, resp: Response) -> io::Result<()> {
281 let reason = status_reason(resp.status);
282 let mut head = format!("HTTP/1.1 {} {}\r\n", resp.status, reason);
283 let mut has_content_type = false;
284 for (k, v) in &resp.headers {
285 if k.eq_ignore_ascii_case("content-type") {
286 has_content_type = true;
287 }
288 head.push_str(&format!("{}: {}\r\n", k, v));
289 }
290 if !has_content_type {
291 head.push_str("content-type: text/plain; charset=utf-8\r\n");
292 }
293
294 match resp.body {
295 Body::Full(bytes) => {
296 head.push_str(&format!("content-length: {}\r\n", bytes.len()));
297 head.push_str("connection: close\r\n\r\n");
298 w.write_all(head.as_bytes())?;
299 w.write_all(&bytes)?;
300 w.flush()
301 }
302 Body::Stream(producer) => {
303 head.push_str("connection: close\r\n\r\n");
305 w.write_all(head.as_bytes())?;
306 w.flush()?;
307 producer(w)
308 }
309 }
310}
311
312pub fn status_reason(status: u16) -> &'static str {
314 match status {
315 200 => "OK",
316 201 => "Created",
317 202 => "Accepted",
318 204 => "No Content",
319 301 => "Moved Permanently",
320 302 => "Found",
321 303 => "See Other",
322 304 => "Not Modified",
323 307 => "Temporary Redirect",
324 308 => "Permanent Redirect",
325 400 => "Bad Request",
326 401 => "Unauthorized",
327 403 => "Forbidden",
328 404 => "Not Found",
329 405 => "Method Not Allowed",
330 409 => "Conflict",
331 422 => "Unprocessable Entity",
332 429 => "Too Many Requests",
333 500 => "Internal Server Error",
334 501 => "Not Implemented",
335 502 => "Bad Gateway",
336 503 => "Service Unavailable",
337 504 => "Gateway Timeout",
338 s if (200..300).contains(&s) => "OK",
340 s if (300..400).contains(&s) => "Redirect",
341 s if (400..500).contains(&s) => "Client Error",
342 _ => "Server Error",
343 }
344}
345
346pub fn serve<H>(addr: &str, workers: usize, handler: H) -> io::Result<()>
349where
350 H: Fn(Request) -> Response + Send + Sync + 'static,
351{
352 let listener = TcpListener::bind(addr)?;
353 let handler = Arc::new(handler);
354 let pool = ThreadPool::new(workers.max(1));
355
356 for stream in listener.incoming() {
357 let stream = match stream {
358 Ok(s) => s,
359 Err(_) => continue,
360 };
361 let handler = Arc::clone(&handler);
362 pool.execute(move || {
363 let _ = handle_connection(stream, &*handler);
364 });
365 }
366 Ok(())
367}
368
369pub fn serve_until<H>(
375 addr: &str,
376 workers: usize,
377 handler: H,
378 shutdown: Arc<AtomicBool>,
379) -> io::Result<()>
380where
381 H: Fn(Request) -> Response + Send + Sync + 'static,
382{
383 let listener = TcpListener::bind(addr)?;
384 listener.set_nonblocking(true)?;
385 let handler = Arc::new(handler);
386 let pool = ThreadPool::new(workers.max(1));
387
388 while !shutdown.load(Ordering::Relaxed) {
389 match listener.accept() {
390 Ok((stream, _addr)) => {
391 let _ = stream.set_nonblocking(false);
393 let handler = Arc::clone(&handler);
394 pool.execute(move || {
395 let _ = handle_connection(stream, &*handler);
396 });
397 }
398 Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
399 thread::sleep(Duration::from_millis(50));
400 }
401 Err(_) => {}
402 }
403 }
404
405 drop(pool);
408 Ok(())
409}
410
411fn handle_connection<H>(stream: TcpStream, handler: &H) -> io::Result<()>
412where
413 H: Fn(Request) -> Response,
414{
415 let mut reader = BufReader::new(stream.try_clone()?);
416 if let Some(req) = parse_request(&mut reader)? {
417 let resp = handler(req);
418 let mut writer = stream;
419 write_response(&mut writer, resp)?;
420 }
421 Ok(())
422}
423
424type Job = Box<dyn FnOnce() + Send + 'static>;
427
428pub struct ThreadPool {
430 sender: Option<mpsc::Sender<Job>>,
431 workers: Vec<thread::JoinHandle<()>>,
432}
433
434impl ThreadPool {
435 pub fn new(size: usize) -> ThreadPool {
436 let (sender, receiver) = mpsc::channel::<Job>();
437 let receiver = Arc::new(Mutex::new(receiver));
438 let mut workers = Vec::with_capacity(size);
439 for _ in 0..size {
440 let receiver = Arc::clone(&receiver);
441 workers.push(thread::spawn(move || loop {
442 let job = receiver.lock().unwrap().recv();
443 match job {
444 Ok(job) => job(),
445 Err(_) => break, }
447 }));
448 }
449 ThreadPool {
450 sender: Some(sender),
451 workers,
452 }
453 }
454
455 pub fn execute<F>(&self, f: F)
456 where
457 F: FnOnce() + Send + 'static,
458 {
459 if let Some(sender) = &self.sender {
460 let _ = sender.send(Box::new(f));
461 }
462 }
463}
464
465impl Drop for ThreadPool {
466 fn drop(&mut self) {
467 drop(self.sender.take());
469 for worker in self.workers.drain(..) {
470 let _ = worker.join();
471 }
472 }
473}
474
475#[cfg(test)]
476mod tests {
477 use super::*;
478
479 #[test]
480 fn parses_request_with_body() {
481 let raw = "POST /todos?x=1 HTTP/1.1\r\nHost: localhost\r\nContent-Length: 5\r\n\r\nhello";
482 let mut reader = BufReader::new(raw.as_bytes());
483 let req = parse_request(&mut reader).unwrap().unwrap();
484 assert_eq!(req.method, Method::Post);
485 assert_eq!(req.path, "/todos");
486 assert_eq!(req.query, "x=1");
487 assert_eq!(req.body, b"hello");
488 assert_eq!(req.header("host"), Some("localhost"));
489 }
490
491 #[test]
492 fn writes_response_with_default_content_type() {
493 let resp = Response::new(200).with_body("hi");
494 let mut buf = Vec::new();
495 write_response(&mut buf, resp).unwrap();
496 let s = String::from_utf8(buf).unwrap();
497 assert!(s.starts_with("HTTP/1.1 200 OK\r\n"));
498 assert!(s.contains("content-length: 2\r\n"));
499 assert!(s.ends_with("\r\n\r\nhi"));
500 }
501
502 #[test]
503 fn request_content_type_and_cookies() {
504 let raw = "GET / HTTP/1.1\r\nContent-Type: application/json\r\nCookie: sid=abc; theme=dark\r\n\r\n";
505 let mut reader = BufReader::new(raw.as_bytes());
506 let req = parse_request(&mut reader).unwrap().unwrap();
507 assert!(req.is_json());
508 assert_eq!(req.cookie("sid").as_deref(), Some("abc"));
509 assert_eq!(req.cookie("theme").as_deref(), Some("dark"));
510 assert_eq!(req.cookie("missing"), None);
511 }
512
513 #[test]
514 fn streams_without_content_length() {
515 let resp = Response::new(200)
516 .with_header("content-type", "text/event-stream")
517 .with_stream(|w| {
518 let mut sink = SseSink::new(w);
519 sink.data("one")?;
520 sink.data("two")?;
521 Ok(())
522 });
523 let mut buf = Vec::new();
524 write_response(&mut buf, resp).unwrap();
525 let s = String::from_utf8(buf).unwrap();
526 assert!(s.contains("content-type: text/event-stream\r\n"));
527 assert!(!s.to_lowercase().contains("content-length"));
528 assert!(s.contains("data: one\n\n"));
529 assert!(s.contains("data: two\n\n"));
530 }
531}