1use std::{net::SocketAddr, sync::Arc};
15
16use myko::server::CellServerCtx;
17use tokio::{io::AsyncWriteExt, net::TcpStream};
18
19use crate::{mcp, ws_handler::WsHandler};
20
21const MAX_HEADER_BYTES: usize = 8 * 1024;
23
24const MAX_HEADERS: usize = 64;
26
27#[derive(Debug, Clone)]
29pub struct HttpRequestHead {
30 pub method: String,
31 pub path: String,
32 pub headers: Vec<(String, String)>,
33 pub leftover_body: Vec<u8>,
37}
38
39impl HttpRequestHead {
40 pub fn header(&self, name: &str) -> Option<&str> {
42 self.headers
43 .iter()
44 .find(|(k, _)| k.eq_ignore_ascii_case(name))
45 .map(|(_, v)| v.as_str())
46 }
47
48 pub fn is_websocket_upgrade(&self) -> bool {
50 let upgrade = self
51 .header("Upgrade")
52 .map(|v| v.eq_ignore_ascii_case("websocket"))
53 .unwrap_or(false);
54 let connection_has_upgrade = self
55 .header("Connection")
56 .map(|v| {
57 v.split(',')
58 .any(|p| p.trim().eq_ignore_ascii_case("upgrade"))
59 })
60 .unwrap_or(false);
61 upgrade && connection_has_upgrade
62 }
63
64 pub fn wants_event_stream(&self) -> bool {
66 self.header("Accept")
67 .map(|v| {
68 v.split(',').any(|part| {
69 let media = part.split(';').next().unwrap_or("").trim();
71 media.eq_ignore_ascii_case("text/event-stream")
72 })
73 })
74 .unwrap_or(false)
75 }
76}
77
78pub async fn read_request_head(stream: &mut TcpStream) -> std::io::Result<Option<HttpRequestHead>> {
82 use tokio::io::AsyncReadExt;
83
84 let mut buffer = Vec::with_capacity(1024);
85 let mut chunk = [0u8; 1024];
86
87 let header_end = loop {
88 if buffer.len() > MAX_HEADER_BYTES {
89 return Err(std::io::Error::new(
90 std::io::ErrorKind::InvalidData,
91 "HTTP header section exceeded 8 KB",
92 ));
93 }
94 let n = stream.read(&mut chunk).await?;
95 if n == 0 {
96 return Ok(None);
97 }
98 buffer.extend_from_slice(&chunk[..n]);
99
100 if let Some(idx) = find_header_terminator(&buffer) {
101 break idx;
102 }
103 };
104
105 let mut headers_buf = [httparse::EMPTY_HEADER; MAX_HEADERS];
106 let mut req = httparse::Request::new(&mut headers_buf);
107 let status = req
108 .parse(&buffer[..header_end])
109 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e.to_string()))?;
110 if !status.is_complete() {
111 return Err(std::io::Error::new(
112 std::io::ErrorKind::InvalidData,
113 "incomplete HTTP request",
114 ));
115 }
116
117 let method = req.method.unwrap_or("").to_string();
118 let path = req.path.unwrap_or("").to_string();
119 let headers = req
120 .headers
121 .iter()
122 .map(|h| {
123 (
124 h.name.to_string(),
125 String::from_utf8_lossy(h.value).into_owned(),
126 )
127 })
128 .collect();
129
130 let leftover_body = buffer[header_end..].to_vec();
131
132 Ok(Some(HttpRequestHead {
133 method,
134 path,
135 headers,
136 leftover_body,
137 }))
138}
139
140fn find_header_terminator(buf: &[u8]) -> Option<usize> {
141 buf.windows(4).position(|w| w == b"\r\n\r\n").map(|i| i + 4)
142}
143
144pub async fn route_connection(
146 mut stream: TcpStream,
147 addr: SocketAddr,
148 ctx: Arc<CellServerCtx>,
149) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
150 let head = match read_request_head(&mut stream).await {
151 Ok(Some(h)) => h,
152 Ok(None) => return Ok(()),
153 Err(e) => {
154 log::debug!("HTTP parse error from {}: {}", addr, e);
155 let _ = write_status(&mut stream, 400, "Bad Request").await;
156 shutdown_cleanly(stream).await;
157 return Ok(());
158 }
159 };
160
161 log::trace!("router accept {} {} from {}", head.method, head.path, addr,);
162
163 let path = head.path.split('?').next().unwrap_or(&head.path);
164
165 match (head.method.as_str(), path) {
166 ("GET", p) if p == "/myko" || p.starts_with("/myko?") => {
167 if !head.is_websocket_upgrade() {
168 let _ = write_status(&mut stream, 426, "Upgrade Required").await;
169 shutdown_cleanly(stream).await;
170 return Ok(());
171 }
172 handle_ws_upgrade(stream, addr, ctx, head, WsTarget::Myko).await
177 }
178 ("GET", "/myko/mcp") if head.is_websocket_upgrade() => {
179 handle_ws_upgrade(stream, addr, ctx, head, WsTarget::Mcp).await
180 }
181 ("GET", "/myko/mcp") if head.wants_event_stream() => {
182 mcp::http::handle_sse(stream, ctx, head).await
183 }
184 ("POST", "/myko/mcp") => mcp::http::handle_post(stream, ctx, head).await,
185 ("GET", "/myko/mcp") => {
186 let body = b"{\"status\":\"ok\",\"endpoint\":\"/myko/mcp\",\"transports\":[\"POST\",\"WebSocket\",\"SSE\"]}";
189 let result = write_full(
190 &mut stream,
191 200,
192 "OK",
193 &[("Content-Type", "application/json")],
194 body,
195 )
196 .await;
197 shutdown_cleanly(stream).await;
198 result.map_err(Into::into)
199 }
200 _ => {
201 let _ = write_status(&mut stream, 404, "Not Found").await;
202 shutdown_cleanly(stream).await;
203 Ok(())
204 }
205 }
206}
207
208enum WsTarget {
209 Myko,
210 Mcp,
211}
212
213async fn handle_ws_upgrade(
214 stream: TcpStream,
215 addr: SocketAddr,
216 ctx: Arc<CellServerCtx>,
217 head: HttpRequestHead,
218 target: WsTarget,
219) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
220 if !head.leftover_body.is_empty() {
223 log::warn!(
224 "Rejecting WS upgrade from {} with {} leftover body bytes",
225 addr,
226 head.leftover_body.len()
227 );
228 let mut stream = stream;
229 let _ = write_status(&mut stream, 400, "Bad Request").await;
230 shutdown_cleanly(stream).await;
231 return Ok(());
232 }
233
234 match target {
235 WsTarget::Myko => mcp::ws::handle_myko_ws_upgrade(stream, addr, ctx, head).await,
236 WsTarget::Mcp => mcp::ws::handle_mcp_ws_upgrade(stream, ctx, head).await,
237 }
238}
239
240pub async fn write_status(stream: &mut TcpStream, code: u16, reason: &str) -> std::io::Result<()> {
242 write_full(stream, code, reason, &[("Content-Length", "0")], b"").await
243}
244
245pub async fn write_full(
247 stream: &mut TcpStream,
248 code: u16,
249 reason: &str,
250 extra_headers: &[(&str, &str)],
251 body: &[u8],
252) -> std::io::Result<()> {
253 let mut head = format!("HTTP/1.1 {} {}\r\n", code, reason);
254 head.push_str("Connection: close\r\n");
255 if !extra_headers
256 .iter()
257 .any(|(k, _)| k.eq_ignore_ascii_case("Content-Length"))
258 {
259 head.push_str(&format!("Content-Length: {}\r\n", body.len()));
260 }
261 for (k, v) in extra_headers {
262 head.push_str(&format!("{}: {}\r\n", k, v));
263 }
264 head.push_str("\r\n");
265 stream.write_all(head.as_bytes()).await?;
266 if !body.is_empty() {
267 stream.write_all(body).await?;
268 }
269 stream.flush().await?;
270 Ok(())
271}
272
273pub async fn shutdown_cleanly(mut stream: TcpStream) {
286 use tokio::io::AsyncReadExt;
287
288 let _ = stream.shutdown().await;
289 let mut buf = [0u8; 1024];
290 let _ = tokio::time::timeout(std::time::Duration::from_millis(250), async {
291 loop {
292 match stream.read(&mut buf).await {
293 Ok(0) | Err(_) => return,
294 Ok(_) => continue,
295 }
296 }
297 })
298 .await;
299}
300
301#[allow(dead_code)]
303fn _ws_handler_in_scope() -> WsHandler {
304 WsHandler
305}
306
307#[cfg(test)]
308mod tests {
309 use super::*;
310
311 fn make_head(headers: Vec<(&str, &str)>) -> HttpRequestHead {
312 HttpRequestHead {
313 method: "GET".to_string(),
314 path: "/myko/mcp".to_string(),
315 headers: headers
316 .into_iter()
317 .map(|(k, v)| (k.to_string(), v.to_string()))
318 .collect(),
319 leftover_body: Vec::new(),
320 }
321 }
322
323 #[test]
324 fn header_lookup_is_case_insensitive() {
325 let head = make_head(vec![("Content-Type", "application/json")]);
326 assert_eq!(head.header("content-type"), Some("application/json"));
327 assert_eq!(head.header("CONTENT-TYPE"), Some("application/json"));
328 }
329
330 #[test]
331 fn websocket_upgrade_requires_both_headers() {
332 let head = make_head(vec![("Upgrade", "websocket"), ("Connection", "Upgrade")]);
333 assert!(head.is_websocket_upgrade());
334
335 let head_no_conn = make_head(vec![("Upgrade", "websocket")]);
336 assert!(!head_no_conn.is_websocket_upgrade());
337
338 let head_no_upgrade = make_head(vec![("Connection", "Upgrade")]);
339 assert!(!head_no_upgrade.is_websocket_upgrade());
340 }
341
342 #[test]
343 fn connection_header_accepts_lists() {
344 let head = make_head(vec![
345 ("Upgrade", "websocket"),
346 ("Connection", "keep-alive, Upgrade"),
347 ]);
348 assert!(head.is_websocket_upgrade());
349 }
350
351 #[test]
352 fn accept_header_detects_sse() {
353 let head = make_head(vec![("Accept", "text/event-stream")]);
354 assert!(head.wants_event_stream());
355
356 let head_html = make_head(vec![("Accept", "text/html")]);
357 assert!(!head_html.wants_event_stream());
358
359 let head_mixed = make_head(vec![("Accept", "text/html, text/event-stream;q=0.9")]);
360 assert!(head_mixed.wants_event_stream());
361 }
362
363 #[test]
364 fn header_terminator_is_found() {
365 let req = b"GET / HTTP/1.1\r\nHost: x\r\n\r\nbody";
366 let idx = find_header_terminator(req).expect("terminator must be found");
367 assert_eq!(&req[idx..], b"body");
368 assert_eq!(find_header_terminator(b"no terminator here"), None);
369 }
370}