1use std::{net::SocketAddr, sync::Arc};
15
16use myko::server::CellServerCtx;
17use tokio::{io::AsyncWriteExt, net::TcpStream};
18
19use crate::{mcp, mcp::dispatch::ServerInfo, ws_handler::WsHandler};
20
21const MAX_HEADER_BYTES: usize = 8 * 1024;
23
24const MAX_HEADERS: usize = 64;
26
27#[derive(Debug, Clone)]
29pub struct HttpRequestHead {
30 pub method: String,
31 pub path: String,
32 pub headers: Vec<(String, String)>,
33 pub leftover_body: Vec<u8>,
37}
38
39impl HttpRequestHead {
40 pub fn header(&self, name: &str) -> Option<&str> {
42 self.headers
43 .iter()
44 .find(|(k, _)| k.eq_ignore_ascii_case(name))
45 .map(|(_, v)| v.as_str())
46 }
47
48 pub fn is_websocket_upgrade(&self) -> bool {
50 let upgrade = self
51 .header("Upgrade")
52 .map(|v| v.eq_ignore_ascii_case("websocket"))
53 .unwrap_or(false);
54 let connection_has_upgrade = self
55 .header("Connection")
56 .map(|v| {
57 v.split(',')
58 .any(|p| p.trim().eq_ignore_ascii_case("upgrade"))
59 })
60 .unwrap_or(false);
61 upgrade && connection_has_upgrade
62 }
63
64 pub fn wants_event_stream(&self) -> bool {
66 self.header("Accept")
67 .map(|v| {
68 v.split(',').any(|part| {
69 let media = part.split(';').next().unwrap_or("").trim();
71 media.eq_ignore_ascii_case("text/event-stream")
72 })
73 })
74 .unwrap_or(false)
75 }
76}
77
78pub async fn read_request_head(stream: &mut TcpStream) -> std::io::Result<Option<HttpRequestHead>> {
82 use tokio::io::AsyncReadExt;
83
84 let mut buffer = Vec::with_capacity(1024);
85 let mut chunk = [0u8; 1024];
86
87 let header_end = loop {
88 if buffer.len() > MAX_HEADER_BYTES {
89 return Err(std::io::Error::new(
90 std::io::ErrorKind::InvalidData,
91 "HTTP header section exceeded 8 KB",
92 ));
93 }
94 let n = stream.read(&mut chunk).await?;
95 if n == 0 {
96 return Ok(None);
97 }
98 buffer.extend_from_slice(&chunk[..n]);
99
100 if let Some(idx) = find_header_terminator(&buffer) {
101 break idx;
102 }
103 };
104
105 let mut headers_buf = [httparse::EMPTY_HEADER; MAX_HEADERS];
106 let mut req = httparse::Request::new(&mut headers_buf);
107 let status = req
108 .parse(&buffer[..header_end])
109 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e.to_string()))?;
110 if !status.is_complete() {
111 return Err(std::io::Error::new(
112 std::io::ErrorKind::InvalidData,
113 "incomplete HTTP request",
114 ));
115 }
116
117 let method = req.method.unwrap_or("").to_string();
118 let path = req.path.unwrap_or("").to_string();
119 let headers = req
120 .headers
121 .iter()
122 .map(|h| {
123 (
124 h.name.to_string(),
125 String::from_utf8_lossy(h.value).into_owned(),
126 )
127 })
128 .collect();
129
130 let leftover_body = buffer[header_end..].to_vec();
131
132 Ok(Some(HttpRequestHead {
133 method,
134 path,
135 headers,
136 leftover_body,
137 }))
138}
139
140fn find_header_terminator(buf: &[u8]) -> Option<usize> {
141 buf.windows(4).position(|w| w == b"\r\n\r\n").map(|i| i + 4)
142}
143
144pub async fn route_connection(
146 mut stream: TcpStream,
147 addr: SocketAddr,
148 ctx: Arc<CellServerCtx>,
149 server_info: Arc<ServerInfo>,
150) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
151 let head = match read_request_head(&mut stream).await {
152 Ok(Some(h)) => h,
153 Ok(None) => return Ok(()),
154 Err(e) => {
155 log::debug!("HTTP parse error from {}: {}", addr, e);
156 let _ = write_status(&mut stream, 400, "Bad Request").await;
157 shutdown_cleanly(stream).await;
158 return Ok(());
159 }
160 };
161
162 log::trace!("router accept {} {} from {}", head.method, head.path, addr,);
163
164 let path = head.path.split('?').next().unwrap_or(&head.path);
165
166 match (head.method.as_str(), path) {
167 ("GET", p) if p == "/myko" || p.starts_with("/myko?") => {
168 if !head.is_websocket_upgrade() {
169 let _ = write_status(&mut stream, 426, "Upgrade Required").await;
170 shutdown_cleanly(stream).await;
171 return Ok(());
172 }
173 handle_ws_upgrade(stream, addr, ctx, server_info, head, WsTarget::Myko).await
178 }
179 ("GET", "/myko/mcp") if head.is_websocket_upgrade() => {
180 handle_ws_upgrade(stream, addr, ctx, server_info, head, WsTarget::Mcp).await
181 }
182 ("GET", "/myko/mcp") if head.wants_event_stream() => {
183 mcp::http::handle_sse(stream, ctx, head).await
184 }
185 ("POST", "/myko/mcp") => mcp::http::handle_post(stream, ctx, server_info, head).await,
186 ("GET", "/myko/mcp") => {
187 let body = b"{\"status\":\"ok\",\"endpoint\":\"/myko/mcp\",\"transports\":[\"POST\",\"WebSocket\",\"SSE\"]}";
190 let result = write_full(
191 &mut stream,
192 200,
193 "OK",
194 &[("Content-Type", "application/json")],
195 body,
196 )
197 .await;
198 shutdown_cleanly(stream).await;
199 result.map_err(Into::into)
200 }
201 _ => {
202 let _ = write_status(&mut stream, 404, "Not Found").await;
203 shutdown_cleanly(stream).await;
204 Ok(())
205 }
206 }
207}
208
209enum WsTarget {
210 Myko,
211 Mcp,
212}
213
214async fn handle_ws_upgrade(
215 stream: TcpStream,
216 addr: SocketAddr,
217 ctx: Arc<CellServerCtx>,
218 server_info: Arc<ServerInfo>,
219 head: HttpRequestHead,
220 target: WsTarget,
221) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
222 if !head.leftover_body.is_empty() {
225 log::warn!(
226 "Rejecting WS upgrade from {} with {} leftover body bytes",
227 addr,
228 head.leftover_body.len()
229 );
230 let mut stream = stream;
231 let _ = write_status(&mut stream, 400, "Bad Request").await;
232 shutdown_cleanly(stream).await;
233 return Ok(());
234 }
235
236 match target {
237 WsTarget::Myko => mcp::ws::handle_myko_ws_upgrade(stream, addr, ctx, head).await,
238 WsTarget::Mcp => mcp::ws::handle_mcp_ws_upgrade(stream, ctx, server_info, head).await,
239 }
240}
241
242pub async fn write_status(stream: &mut TcpStream, code: u16, reason: &str) -> std::io::Result<()> {
244 write_full(stream, code, reason, &[("Content-Length", "0")], b"").await
245}
246
247pub async fn write_full(
249 stream: &mut TcpStream,
250 code: u16,
251 reason: &str,
252 extra_headers: &[(&str, &str)],
253 body: &[u8],
254) -> std::io::Result<()> {
255 let mut head = format!("HTTP/1.1 {} {}\r\n", code, reason);
256 head.push_str("Connection: close\r\n");
257 if !extra_headers
258 .iter()
259 .any(|(k, _)| k.eq_ignore_ascii_case("Content-Length"))
260 {
261 head.push_str(&format!("Content-Length: {}\r\n", body.len()));
262 }
263 for (k, v) in extra_headers {
264 head.push_str(&format!("{}: {}\r\n", k, v));
265 }
266 head.push_str("\r\n");
267 stream.write_all(head.as_bytes()).await?;
268 if !body.is_empty() {
269 stream.write_all(body).await?;
270 }
271 stream.flush().await?;
272 Ok(())
273}
274
275pub async fn shutdown_cleanly(mut stream: TcpStream) {
288 use tokio::io::AsyncReadExt;
289
290 let _ = stream.shutdown().await;
291 let mut buf = [0u8; 1024];
292 let _ = tokio::time::timeout(std::time::Duration::from_millis(250), async {
293 loop {
294 match stream.read(&mut buf).await {
295 Ok(0) | Err(_) => return,
296 Ok(_) => continue,
297 }
298 }
299 })
300 .await;
301}
302
303#[allow(dead_code)]
305fn _ws_handler_in_scope() -> WsHandler {
306 WsHandler
307}
308
309#[cfg(test)]
310mod tests {
311 use super::*;
312
313 fn make_head(headers: Vec<(&str, &str)>) -> HttpRequestHead {
314 HttpRequestHead {
315 method: "GET".to_string(),
316 path: "/myko/mcp".to_string(),
317 headers: headers
318 .into_iter()
319 .map(|(k, v)| (k.to_string(), v.to_string()))
320 .collect(),
321 leftover_body: Vec::new(),
322 }
323 }
324
325 #[test]
326 fn header_lookup_is_case_insensitive() {
327 let head = make_head(vec![("Content-Type", "application/json")]);
328 assert_eq!(head.header("content-type"), Some("application/json"));
329 assert_eq!(head.header("CONTENT-TYPE"), Some("application/json"));
330 }
331
332 #[test]
333 fn websocket_upgrade_requires_both_headers() {
334 let head = make_head(vec![("Upgrade", "websocket"), ("Connection", "Upgrade")]);
335 assert!(head.is_websocket_upgrade());
336
337 let head_no_conn = make_head(vec![("Upgrade", "websocket")]);
338 assert!(!head_no_conn.is_websocket_upgrade());
339
340 let head_no_upgrade = make_head(vec![("Connection", "Upgrade")]);
341 assert!(!head_no_upgrade.is_websocket_upgrade());
342 }
343
344 #[test]
345 fn connection_header_accepts_lists() {
346 let head = make_head(vec![
347 ("Upgrade", "websocket"),
348 ("Connection", "keep-alive, Upgrade"),
349 ]);
350 assert!(head.is_websocket_upgrade());
351 }
352
353 #[test]
354 fn accept_header_detects_sse() {
355 let head = make_head(vec![("Accept", "text/event-stream")]);
356 assert!(head.wants_event_stream());
357
358 let head_html = make_head(vec![("Accept", "text/html")]);
359 assert!(!head_html.wants_event_stream());
360
361 let head_mixed = make_head(vec![("Accept", "text/html, text/event-stream;q=0.9")]);
362 assert!(head_mixed.wants_event_stream());
363 }
364
365 #[test]
366 fn header_terminator_is_found() {
367 let req = b"GET / HTTP/1.1\r\nHost: x\r\n\r\nbody";
368 let idx = find_header_terminator(req).expect("terminator must be found");
369 assert_eq!(&req[idx..], b"body");
370 assert_eq!(find_header_terminator(b"no terminator here"), None);
371 }
372}