Skip to main content

potato/
lib.rs

1#[cfg(feature = "acme")]
2pub mod acme;
3pub mod client;
4pub mod global_config;
5pub mod server;
6pub mod utils;
7#[cfg(feature = "webrtc")]
8pub mod webrtc;
9
10pub use client::*;
11pub use global_config::*;
12pub use hipstr;
13pub use inventory;
14pub use potato_macro::*;
15pub use regex;
16pub use rust_embed;
17pub use serde_json;
18pub use server::CorsConfig;
19pub use server::*;
20pub use utils::ai::*;
21pub use utils::refstr::Headers;
22
23#[cfg(all(feature = "jemalloc", not(target_os = "windows")))]
24pub use utils::jemalloc_helper::*;
25
26#[cfg(feature = "webrtc")]
27pub use webrtc::*;
28
29use anyhow::anyhow;
30use chrono::Utc;
31use core::str;
32use hipstr::{LocalHipByt, LocalHipStr};
33use http::Uri;
34use http::uri::Scheme;
35use rust_embed::Embed;
36use sha1::{Digest, Sha1};
37use std::any::{Any, TypeId};
38use std::borrow::Cow;
39use std::cell::RefCell;
40use std::fmt;
41use std::fs::{File, Metadata};
42use std::io::Read;
43use std::net::SocketAddr;
44use std::path::Path;
45use std::str::FromStr;
46use std::sync::{Arc, LazyLock};
47use std::time::UNIX_EPOCH;
48use std::{collections::HashMap, collections::HashSet, future::Future, pin::Pin};
49use strum::Display;
50use thread_local::ThreadLocal;
51use tokio::sync::Mutex;
52use tokio::sync::mpsc::Receiver;
53use utils::bytes::CompressExt;
54use utils::enums::{HttpConnection, HttpContentType};
55use utils::number::HttpCodeExt;
56use utils::refstr::{HeaderItem, HeaderOrHipStr};
57use utils::string::StringExt;
58use utils::tcp_stream::{HttpStream, VecU8Ext};
59
60/// 一次性缓存,用于单次请求的前处理、后处理及handler方法间传递参数
61#[derive(Debug)]
62pub struct OnceCache {
63    data: HashMap<(String, TypeId), Box<dyn Any + Send + Sync>>,
64}
65
66impl Default for OnceCache {
67    fn default() -> Self {
68        Self::new()
69    }
70}
71
72impl OnceCache {
73    pub fn new() -> Self {
74        Self {
75            data: HashMap::new(),
76        }
77    }
78
79    /// 获取不可变引用
80    pub fn get<T: Any + Send + Sync + 'static>(&self, name: &str) -> Option<&T> {
81        let key = (name.to_string(), TypeId::of::<T>());
82        self.data
83            .get(&key)
84            .and_then(|boxed| boxed.downcast_ref::<T>())
85    }
86
87    /// 获取值或返回默认值(需要类型实现Clone)
88    pub fn get_or_default<T: Any + Send + Sync + Clone + 'static>(
89        &self,
90        name: &str,
91        default: T,
92    ) -> T {
93        self.get::<T>(name).cloned().unwrap_or(default)
94    }
95
96    /// 获取可变引用
97    pub fn get_mut<T: Any + Send + Sync + 'static>(&mut self, name: &str) -> Option<&mut T> {
98        let key = (name.to_string(), TypeId::of::<T>());
99        self.data
100            .get_mut(&key)
101            .and_then(|boxed| boxed.downcast_mut::<T>())
102    }
103
104    /// 插入或更新值
105    pub fn set<T: Any + Send + Sync + 'static>(&mut self, name: &str, value: T) {
106        let key = (name.to_string(), TypeId::of::<T>());
107        self.data.insert(key, Box::new(value));
108    }
109
110    /// 移除并返回值
111    pub fn remove<T: Any + Send + Sync + 'static>(&mut self, name: &str) -> Option<T> {
112        let key = (name.to_string(), TypeId::of::<T>());
113        self.data
114            .remove(&key)
115            .and_then(|boxed| boxed.downcast::<T>().ok())
116            .map(|boxed| *boxed)
117    }
118
119    /// 检查是否包含指定的键
120    pub fn contains_key<T: Any + Send + Sync + 'static>(&self, name: &str) -> bool {
121        let key = (name.to_string(), TypeId::of::<T>());
122        self.data.contains_key(&key)
123    }
124
125    /// 清空所有缓存
126    pub fn clear(&mut self) {
127        self.data.clear();
128    }
129
130    /// 获取缓存项数量
131    pub fn len(&self) -> usize {
132        self.data.len()
133    }
134
135    /// 判断缓存是否为空
136    pub fn is_empty(&self) -> bool {
137        self.data.is_empty()
138    }
139}
140
141use dashmap::DashMap;
142use std::sync::RwLock;
143use std::time::{Duration, Instant};
144
145/// SessionCache 错误类型
146#[derive(Debug)]
147pub enum SessionCacheError {
148    /// Token 解析失败(格式错误、签名验证失败等)
149    InvalidToken(String),
150    /// Token 已过期
151    TokenExpired,
152    /// Session 已过期
153    SessionExpired,
154    /// 缺少 Authorization header
155    MissingAuthHeader,
156    /// 内部错误
157    InternalError(String),
158}
159
160impl std::fmt::Display for SessionCacheError {
161    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
162        match self {
163            SessionCacheError::InvalidToken(msg) => write!(f, "Invalid token: {msg}"),
164            SessionCacheError::TokenExpired => write!(f, "Token has expired"),
165            SessionCacheError::SessionExpired => write!(f, "Session has expired"),
166            SessionCacheError::MissingAuthHeader => write!(f, "Missing Authorization header"),
167            SessionCacheError::InternalError(msg) => write!(f, "Internal error: {msg}"),
168        }
169    }
170}
171
172impl std::error::Error for SessionCacheError {}
173
174// SessionCache 的内部管理器(私有)
175// 注意:JWT 密钥现在统一使用 global_config 中的 SERVER_JWT_SECRET
176static SESSION_CACHE_MANAGER: std::sync::LazyLock<DashMap<i64, (SessionCache, Instant)>> =
177    std::sync::LazyLock::new(DashMap::new);
178
179/// Cookie属性配置
180#[derive(Debug, Clone)]
181pub struct CookieBuilder {
182    /// Cookie名称
183    name: String,
184    /// Cookie值
185    value: String,
186    /// 路径(默认"/")
187    path: Option<String>,
188    /// 域名
189    domain: Option<String>,
190    /// 过期时间(UTC时间戳,秒)
191    expires: Option<i64>,
192    /// 最大存活时间(秒)
193    max_age: Option<i64>,
194    /// 是否仅HTTPS传输
195    secure: bool,
196    /// 是否禁止JavaScript访问
197    http_only: bool,
198    /// SameSite策略: Strict, Lax, None
199    same_site: Option<String>,
200}
201
202impl CookieBuilder {
203    /// 创建新的Cookie
204    pub fn new(name: &str, value: &str) -> Self {
205        Self {
206            name: name.to_string(),
207            value: value.to_string(),
208            path: Some("/".to_string()),
209            domain: None,
210            expires: None,
211            max_age: None,
212            secure: false,
213            http_only: false,
214            same_site: None,
215        }
216    }
217
218    /// 设置路径
219    pub fn path(mut self, path: &str) -> Self {
220        self.path = Some(path.to_string());
221        self
222    }
223
224    /// 设置域名
225    pub fn domain(mut self, domain: &str) -> Self {
226        self.domain = Some(domain.to_string());
227        self
228    }
229
230    /// 设置过期时间(Unix时间戳,秒)
231    pub fn expires(mut self, timestamp: i64) -> Self {
232        self.expires = Some(timestamp);
233        self
234    }
235
236    /// 设置最大存活时间(秒)
237    pub fn max_age(mut self, seconds: i64) -> Self {
238        self.max_age = Some(seconds);
239        self
240    }
241
242    /// 设置Secure标志(仅HTTPS传输)
243    pub fn secure(mut self, secure: bool) -> Self {
244        self.secure = secure;
245        self
246    }
247
248    /// 设置HttpOnly标志(禁止JavaScript访问)
249    pub fn http_only(mut self, http_only: bool) -> Self {
250        self.http_only = http_only;
251        self
252    }
253
254    /// 设置SameSite策略 ("Strict", "Lax", "None")
255    pub fn same_site(mut self, policy: &str) -> Self {
256        self.same_site = Some(policy.to_string());
257        self
258    }
259
260    /// 生成Set-Cookie header值
261    pub fn to_set_cookie_string(&self) -> String {
262        let mut parts = vec![format!("{}={}", self.name, self.value)];
263
264        if let Some(ref path) = self.path {
265            parts.push(format!("Path={}", path));
266        }
267
268        if let Some(ref domain) = self.domain {
269            parts.push(format!("Domain={}", domain));
270        }
271
272        if let Some(expires) = self.expires {
273            // 将Unix时间戳转换为HTTP日期格式
274            let datetime = chrono::DateTime::<chrono::Utc>::from_timestamp(expires, 0);
275            if let Some(dt) = datetime {
276                parts.push(format!(
277                    "Expires={}",
278                    dt.format("%a, %d %b %Y %H:%M:%S GMT")
279                ));
280            }
281        }
282
283        if let Some(max_age) = self.max_age {
284            parts.push(format!("Max-Age={}", max_age));
285        }
286
287        if self.secure {
288            parts.push("Secure".to_string());
289        }
290
291        if self.http_only {
292            parts.push("HttpOnly".to_string());
293        }
294
295        if let Some(ref same_site) = self.same_site {
296            parts.push(format!("SameSite={}", same_site));
297        }
298
299        parts.join("; ")
300    }
301
302    /// 生成删除cookie的Set-Cookie header值
303    pub fn to_delete_cookie_string(&self) -> String {
304        let mut parts = vec![
305            format!("{}=", self.name),
306            "Path=/".to_string(),
307            "Expires=Thu, 01 Jan 1970 00:00:00 GMT".to_string(),
308        ];
309
310        if let Some(ref domain) = self.domain {
311            parts.push(format!("Domain={}", domain));
312        }
313
314        parts.join("; ")
315    }
316}
317
318/// SessionCache数据类型别名
319type SessionCacheData = Arc<RwLock<HashMap<(String, TypeId), Box<dyn Any + Send + Sync>>>>;
320
321/// 会话级缓存,用于同一用户的不同请求间传递参数
322/// 基于Bearer token中的id区分不同Session
323#[derive(Debug, Clone)]
324pub struct SessionCache {
325    data: SessionCacheData,
326    /// 存储从请求中读取的cookies
327    request_cookies: Arc<RwLock<HashMap<String, String>>>,
328    /// 存储需要设置到响应的cookies(包含完整属性)
329    response_cookies: Arc<RwLock<Vec<CookieBuilder>>>,
330}
331
332impl Default for SessionCache {
333    fn default() -> Self {
334        Self::new()
335    }
336}
337
338impl SessionCache {
339    pub fn new() -> Self {
340        Self {
341            data: Arc::new(RwLock::new(HashMap::new())),
342            request_cookies: Arc::new(RwLock::new(HashMap::new())),
343            response_cookies: Arc::new(RwLock::new(Vec::new())),
344        }
345    }
346
347    /// 从请求的Cookie header中解析cookies
348    pub fn parse_request_cookies(&mut self, cookie_header: &str) {
349        if let Ok(mut cookies) = self.request_cookies.write() {
350            for pair in cookie_header.split(';') {
351                let pair = pair.trim();
352                if let Some((key, value)) = pair.split_once('=') {
353                    cookies.insert(key.trim().to_string(), value.trim().to_string());
354                }
355            }
356        }
357    }
358
359    /// 获取请求中的cookie值
360    pub fn get_cookie(&self, name: &str) -> Option<String> {
361        let cookies = self.request_cookies.read().ok()?;
362        cookies.get(name).cloned()
363    }
364
365    /// 设置响应cookie(简单版本,仅设置名称和值)
366    pub fn set_cookie(&self, name: &str, value: &str) {
367        if let Ok(mut cookies) = self.response_cookies.write() {
368            cookies.push(CookieBuilder::new(name, value));
369        }
370    }
371
372    /// 设置响应cookie(完整配置版本)
373    pub fn set_cookie_with_builder(&self, cookie: CookieBuilder) {
374        if let Ok(mut cookies) = self.response_cookies.write() {
375            cookies.push(cookie);
376        }
377    }
378
379    /// 移除响应cookie(设置过期时间为过去)
380    pub fn remove_cookie(&self, name: &str) {
381        if let Ok(mut cookies) = self.response_cookies.write() {
382            cookies.push(CookieBuilder::new(name, ""));
383        }
384    }
385
386    /// 移除响应cookie(带域名配置)
387    pub fn remove_cookie_with_domain(&self, name: &str, domain: &str) {
388        if let Ok(mut cookies) = self.response_cookies.write() {
389            cookies.push(CookieBuilder::new(name, "").domain(domain));
390        }
391    }
392
393    /// 将所有待设置的cookies应用到HttpResponse
394    pub fn apply_cookies(&self, response: &mut HttpResponse) {
395        if let Ok(cookies) = self.response_cookies.read() {
396            for cookie in cookies.iter() {
397                let cookie_str = if cookie.value.is_empty() {
398                    // 移除cookie
399                    cookie.to_delete_cookie_string()
400                } else {
401                    // 设置cookie
402                    cookie.to_set_cookie_string()
403                };
404                response.add_header(Cow::Borrowed("Set-Cookie"), Cow::Owned(cookie_str));
405            }
406        }
407    }
408
409    /// 获取值(需要类型实现Clone)
410    pub fn get<T: Any + Send + Sync + Clone + 'static>(&self, name: &str) -> Option<T> {
411        let key = (name.to_string(), TypeId::of::<T>());
412        let data = self.data.read().ok()?;
413        let value = data.get(&key).and_then(|boxed| boxed.downcast_ref::<T>())?;
414        Some(value.clone())
415    }
416
417    /// 获取值并应用函数
418    pub fn with_get<T: Any + Send + Sync + 'static, F, R>(&self, name: &str, f: F) -> Option<R>
419    where
420        F: FnOnce(&T) -> R,
421    {
422        let key = (name.to_string(), TypeId::of::<T>());
423        let data = self.data.read().ok()?;
424        let value = data.get(&key).and_then(|boxed| boxed.downcast_ref::<T>())?;
425        Some(f(value))
426    }
427
428    /// 获取可变引用并应用函数
429    pub fn with_mut<T: Any + Send + Sync + 'static, F, R>(&self, name: &str, f: F) -> Option<R>
430    where
431        F: FnOnce(&mut T) -> R,
432    {
433        let key = (name.to_string(), TypeId::of::<T>());
434        let mut data = self.data.write().ok()?;
435        let value = data
436            .get_mut(&key)
437            .and_then(|boxed| boxed.downcast_mut::<T>())?;
438        Some(f(value))
439    }
440
441    /// 插入或更新值
442    pub fn set<T: Any + Send + Sync + 'static>(&self, name: &str, value: T) {
443        let key = (name.to_string(), TypeId::of::<T>());
444        if let Ok(mut data) = self.data.write() {
445            data.insert(key, Box::new(value));
446        }
447    }
448
449    /// 移除并返回值
450    pub fn remove<T: Any + Send + Sync + 'static>(&self, name: &str) -> Option<T> {
451        let key = (name.to_string(), TypeId::of::<T>());
452        let mut data = self.data.write().ok()?;
453        data.remove(&key)
454            .and_then(|boxed| boxed.downcast::<T>().ok())
455            .map(|boxed| *boxed)
456    }
457
458    // ==================== 静态方法:Session管理 ====================
459
460    /// 设置JWT签发秘钥(与ServerConfig共享同一个密钥)
461    pub async fn set_jwt_secret(secret: &[u8]) {
462        let secret_str = String::from_utf8_lossy(secret);
463        crate::ServerConfig::set_jwt_secret(secret_str).await;
464    }
465
466    /// 获取JWT签发秘钥
467    async fn get_jwt_secret() -> Vec<u8> {
468        crate::ServerConfig::get_jwt_secret().await.into_bytes()
469    }
470
471    /// 签发JWT token
472    /// 参数:
473    /// - user_id: 用户ID
474    /// - ttl: token有效期
475    ///
476    /// 返回: JWT token字符串
477    pub async fn generate_token(
478        user_id: i64,
479        ttl: std::time::Duration,
480    ) -> Result<String, anyhow::Error> {
481        use jsonwebtoken::{EncodingKey, Header, encode};
482        use serde::{Deserialize, Serialize};
483
484        #[derive(Debug, Serialize, Deserialize)]
485        struct SessionClaims {
486            sub: i64,        // user_id
487            exp: usize,      // token过期时间戳
488            iat: usize,      // 签发时间戳
489            sess_exp: usize, // session过期时间戳(固定,不随访问更新)
490        }
491
492        let now = std::time::SystemTime::now()
493            .duration_since(std::time::UNIX_EPOCH)?
494            .as_secs() as usize;
495
496        let claims = SessionClaims {
497            sub: user_id,
498            exp: now + ttl.as_secs() as usize,
499            iat: now,
500            sess_exp: now + ttl.as_secs() as usize,
501        };
502
503        let secret = Self::get_jwt_secret().await;
504        let token = encode(
505            &Header::default(),
506            &claims,
507            &EncodingKey::from_secret(&secret),
508        )?;
509
510        Ok(token)
511    }
512
513    /// 解析JWT token
514    /// 返回: (user_id, session_exp_duration)
515    pub async fn parse_token(token: &str) -> Result<(i64, Duration), SessionCacheError> {
516        use jsonwebtoken::{DecodingKey, Validation, decode};
517        use serde::{Deserialize, Serialize};
518
519        #[derive(Debug, Serialize, Deserialize)]
520        struct SessionClaims {
521            sub: i64,
522            exp: usize,
523            iat: usize,
524            sess_exp: usize,
525        }
526
527        let secret = Self::get_jwt_secret().await;
528        let token_data = decode::<SessionClaims>(
529            token,
530            &DecodingKey::from_secret(&secret),
531            &Validation::default(),
532        )
533        .map_err(|e| SessionCacheError::InvalidToken(format!("Token decode failed: {e}")))?;
534
535        let now = std::time::SystemTime::now()
536            .duration_since(std::time::UNIX_EPOCH)
537            .ok()
538            .ok_or_else(|| {
539                SessionCacheError::InternalError("Failed to get system time".to_string())
540            })?
541            .as_secs() as usize;
542
543        if token_data.claims.exp < now {
544            return Err(SessionCacheError::TokenExpired);
545        }
546
547        // 使用固定的session过期时间,不随访问更新
548        let session_exp = token_data.claims.sess_exp;
549        if session_exp < now {
550            return Err(SessionCacheError::SessionExpired);
551        }
552
553        let remaining_secs = session_exp - now;
554        Ok((
555            token_data.claims.sub,
556            Duration::from_secs(remaining_secs as u64),
557        ))
558    }
559
560    // ==================== 内部缓存管理器 ====================
561
562    /// 获取或创建Session缓存
563    /// 参数:
564    /// - token: JWT token
565    ///
566    /// 返回: SessionCache实例
567    pub async fn from_token(token: &str) -> Result<Self, SessionCacheError> {
568        let (user_id, ttl) = Self::parse_token(token).await?;
569
570        let now = Instant::now();
571        let expires_at = now + ttl;
572
573        // 使用 DashMap 的 entry API 一次性完成检查和创建,避免竞态条件
574        let mut entry = SESSION_CACHE_MANAGER
575            .entry(user_id)
576            .or_insert_with(|| (SessionCache::new(), expires_at));
577
578        // 检查是否过期,如果过期则重新创建
579        if entry.value().1 < now {
580            // 已过期,重新创建
581            *entry.value_mut() = (SessionCache::new(), expires_at);
582        }
583
584        // 返回session的克隆
585        Ok(entry.value().0.clone())
586    }
587
588    /// 使指定用户的session失效(用于登出等场景)
589    pub fn invalidate(user_id: i64) {
590        SESSION_CACHE_MANAGER.remove(&user_id);
591    }
592
593    /// 内部方法:清理过期的session缓存(已废弃,使用后台清理任务替代)
594    #[allow(dead_code)]
595    fn cleanup_expired_internal() {
596        use std::sync::atomic::{AtomicU64, Ordering};
597        static CALL_COUNTER: AtomicU64 = AtomicU64::new(0);
598
599        // 每100次调用触发一次清理
600        if CALL_COUNTER.fetch_add(1, Ordering::Relaxed) % 100 == 0 {
601            let now = Instant::now();
602            SESSION_CACHE_MANAGER.retain(|_, (_, expires_at)| *expires_at > now);
603        }
604    }
605
606    /// 内部方法:提供对session manager的访问,用于后台清理任务
607    /// 注意:这是一个完全私有的内部方法,仅供服务器清理任务使用
608    pub(crate) fn cleanup_expired_sessions() {
609        let now = Instant::now();
610        SESSION_CACHE_MANAGER.retain(|_, (_, expires_at)| *expires_at > now);
611    }
612}
613
614/// HTTP conditional preflight result
615#[derive(Debug, PartialEq)]
616pub enum PreflightResult {
617    /// Pass preflight check, can continue processing
618    Proceed,
619    /// Return 304 Not Modified
620    NotModified,
621    /// Return 412 Precondition Failed
622    PreconditionFailed,
623}
624
625#[derive(Debug)]
626pub enum HttpRequestParseError {
627    BadRequest(String),
628    NotImplemented(String),
629    ExpectationFailed(String),
630    RequestHeaderFieldsTooLarge(String),
631    PayloadTooLarge(String),
632}
633
634impl fmt::Display for HttpRequestParseError {
635    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
636        match self {
637            HttpRequestParseError::BadRequest(msg) => write!(f, "{msg}"),
638            HttpRequestParseError::NotImplemented(msg) => write!(f, "{msg}"),
639            HttpRequestParseError::ExpectationFailed(msg) => write!(f, "{msg}"),
640            HttpRequestParseError::RequestHeaderFieldsTooLarge(msg) => write!(f, "{msg}"),
641            HttpRequestParseError::PayloadTooLarge(msg) => write!(f, "{msg}"),
642        }
643    }
644}
645
646impl std::error::Error for HttpRequestParseError {}
647
648fn parse_declared_trailer_names(raw: Option<&str>) -> HashSet<String> {
649    raw.map(|value| {
650        value
651            .split(',')
652            .map(|name| name.trim().to_ascii_lowercase())
653            .filter(|name| !name.is_empty())
654            .collect::<HashSet<_>>()
655    })
656    .unwrap_or_default()
657}
658
659fn is_forbidden_trailer_field(name: &str) -> bool {
660    // RFC 9110/9112: trailers must not carry framing or hop-by-hop control fields.
661    matches!(
662        name,
663        "transfer-encoding"
664            | "content-length"
665            | "trailer"
666            | "host"
667            | "connection"
668            | "keep-alive"
669            | "te"
670            | "upgrade"
671            | "proxy-authenticate"
672            | "proxy-authorization"
673    )
674}
675
676fn parse_trailer_line(line: &[u8]) -> anyhow::Result<(String, String)> {
677    let line = str::from_utf8(line)?.trim();
678    let (name, value) = line
679        .split_once(':')
680        .ok_or_else(|| anyhow!("invalid trailer field line"))?;
681    let name = name.trim();
682    if name.is_empty() {
683        Err(anyhow!("empty trailer field name"))?;
684    }
685    Ok((name.to_string(), value.trim().to_string()))
686}
687
688fn parse_transfer_encoding_tokens(raw: &str) -> anyhow::Result<Vec<String>> {
689    let codings = raw
690        .split(',')
691        .map(|part| part.trim().to_ascii_lowercase())
692        .filter(|part| !part.is_empty())
693        .collect::<Vec<_>>();
694    if codings.is_empty() {
695        Err(anyhow!("empty Transfer-Encoding header"))?;
696    }
697    Ok(codings)
698}
699
700/// HTTP date parsing error
701#[derive(Debug)]
702pub struct HttpDateParseError;
703
704impl std::fmt::Display for HttpDateParseError {
705    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
706        write!(f, "failed to parse HTTP date")
707    }
708}
709
710impl std::error::Error for HttpDateParseError {}
711
712/// Parse HTTP date format to Unix timestamp
713/// Supports RFC 7231 standard HTTP date formats:
714/// - RFC 1123: "Mon, 06 Nov 1994 08:49:37 GMT"
715/// - RFC 850: "Monday, 06-Nov-94 08:49:37 GMT"
716/// - ANSI C asctime(): "Mon Nov  6 08:49:37 1994"
717pub fn parse_http_date(date_str: &str) -> Result<u64, HttpDateParseError> {
718    // Use simple manual parsing method to handle RFC 1123 format
719    // Format: "Fri, 12 Sep 2025 00:00:00 GMT"
720    static DATE_REGEX: std::sync::LazyLock<Result<regex::Regex, regex::Error>> =
721        std::sync::LazyLock::new(|| {
722            regex::Regex::new(r"^\w+, (\d{1,2}) (\w+) (\d{4}) (\d{2}):(\d{2}):(\d{2}) GMT$")
723        });
724
725    if let Some(caps) = DATE_REGEX
726        .as_ref()
727        .ok()
728        .and_then(|re| re.captures(date_str))
729    {
730        let day: u32 = caps[1].parse().map_err(|_| HttpDateParseError)?;
731        let month_str = &caps[2];
732        let year: i32 = caps[3].parse().map_err(|_| HttpDateParseError)?;
733        let hour: u32 = caps[4].parse().map_err(|_| HttpDateParseError)?;
734        let minute: u32 = caps[5].parse().map_err(|_| HttpDateParseError)?;
735        let second: u32 = caps[6].parse().map_err(|_| HttpDateParseError)?;
736
737        let month = match month_str {
738            "Jan" => 1,
739            "Feb" => 2,
740            "Mar" => 3,
741            "Apr" => 4,
742            "May" => 5,
743            "Jun" => 6,
744            "Jul" => 7,
745            "Aug" => 8,
746            "Sep" => 9,
747            "Oct" => 10,
748            "Nov" => 11,
749            "Dec" => 12,
750            _ => return Err(HttpDateParseError),
751        };
752
753        if let Some(dt) = chrono::NaiveDate::from_ymd_opt(year, month, day)
754            .and_then(|d| d.and_hms_opt(hour, minute, second))
755        {
756            let timestamp = dt.and_utc().timestamp() as u64;
757            return Ok(timestamp);
758        }
759    }
760
761    // Try RFC 1123 format
762    if let Ok(dt) = chrono::DateTime::parse_from_str(date_str, "%a, %d %b %Y %H:%M:%S GMT") {
763        let timestamp = dt.timestamp() as u64;
764        return Ok(timestamp);
765    }
766
767    // Try RFC 850 format
768    if let Ok(dt) = chrono::DateTime::parse_from_str(date_str, "%A, %d-%b-%y %H:%M:%S GMT") {
769        let timestamp = dt.timestamp() as u64;
770        return Ok(timestamp);
771    }
772
773    // Try ANSI C asctime() format
774    if let Ok(dt) = chrono::NaiveDateTime::parse_from_str(date_str, "%a %b %e %H:%M:%S %Y") {
775        let timestamp = dt.and_utc().timestamp() as u64;
776        return Ok(timestamp);
777    }
778
779    Err(HttpDateParseError)
780}
781
782static SERVER_STR: LazyLock<String> =
783    LazyLock::new(|| format!("{} {}", env!("CARGO_PKG_NAME"), env!("CARGO_PKG_VERSION")));
784
785type AsyncHttpHandler =
786    fn(&mut HttpRequest) -> Pin<Box<dyn Future<Output = HttpResponse> + Send + '_>>;
787type AsyncNoSendHttpHandler =
788    fn(&mut HttpRequest) -> Pin<Box<dyn Future<Output = HttpResponse> + '_>>;
789type SyncHttpHandler = fn(&mut HttpRequest) -> HttpResponse;
790
791#[derive(Clone, Copy)]
792pub enum HttpHandler {
793    Async(AsyncHttpHandler),
794    AsyncNoSend(AsyncNoSendHttpHandler),
795    Sync(SyncHttpHandler),
796}
797
798pub struct RequestHandlerFlagDoc {
799    pub show: bool,
800    pub auth: bool,
801    pub summary: &'static str,
802    pub desp: &'static str,
803    pub args: &'static str,
804    pub tag: &'static str, // Controller 名称,用于 Swagger 分组
805}
806
807impl RequestHandlerFlagDoc {
808    pub const fn new(
809        show: bool,
810        auth: bool,
811        summary: &'static str,
812        desp: &'static str,
813        args: &'static str,
814        tag: &'static str,
815    ) -> Self {
816        RequestHandlerFlagDoc {
817            show,
818            auth,
819            summary,
820            desp,
821            args,
822            tag,
823        }
824    }
825}
826
827pub struct RequestHandlerFlag {
828    pub method: HttpMethod,
829    pub path: &'static str,
830    pub handler: HttpHandler,
831    pub doc: RequestHandlerFlagDoc,
832}
833
834impl RequestHandlerFlag {
835    pub const fn new(
836        method: HttpMethod,
837        path: &'static str,
838        handler: HttpHandler,
839        doc: RequestHandlerFlagDoc,
840    ) -> Self {
841        RequestHandlerFlag {
842            method,
843            path,
844            handler,
845            doc,
846        }
847    }
848}
849
850inventory::collect!(RequestHandlerFlag);
851
852/// 异步错误处理器类型
853pub type AsyncErrorHandler =
854    fn(&mut HttpRequest, anyhow::Error) -> Pin<Box<dyn Future<Output = HttpResponse> + Send + '_>>;
855
856/// 同步错误处理器类型
857pub type SyncErrorHandler = fn(&mut HttpRequest, anyhow::Error) -> HttpResponse;
858
859/// 错误处理器枚举,支持异步和同步
860#[derive(Clone)]
861pub enum ErrorHandler {
862    Async(AsyncErrorHandler),
863    Sync(SyncErrorHandler),
864}
865
866/// 错误处理器注册标志
867pub struct ErrorHandlerFlag {
868    pub handler: ErrorHandler,
869}
870
871impl ErrorHandlerFlag {
872    pub const fn new(handler: ErrorHandler) -> Self {
873        ErrorHandlerFlag { handler }
874    }
875}
876
877inventory::collect!(ErrorHandlerFlag);
878
879/// Controller 结构体字段信息
880pub struct ControllerStructFieldInfo {
881    pub has_once_cache: bool,
882    pub has_session_cache: bool,
883}
884
885/// Controller 结构体标志,用于在 impl 宏中传递字段信息
886pub struct ControllerStructFlag {
887    pub struct_name: &'static str,
888    pub field_info: ControllerStructFieldInfo,
889}
890
891impl ControllerStructFlag {
892    pub const fn new(struct_name: &'static str, field_info: ControllerStructFieldInfo) -> Self {
893        ControllerStructFlag {
894            struct_name,
895            field_info,
896        }
897    }
898}
899
900inventory::collect!(ControllerStructFlag);
901
902#[derive(Clone, Copy, Debug, Display, Eq, Hash, PartialEq)]
903pub enum HttpMethod {
904    GET,
905    PUT,
906    COPY,
907    HEAD,
908    LOCK,
909    MOVE,
910    POST,
911    MKCOL,
912    PATCH,
913    TRACE,
914    DELETE,
915    UNLOCK,
916    CONNECT,
917    OPTIONS,
918    PROPFIND,
919    PROPPATCH,
920}
921
922#[derive(Clone, Copy, Debug, Eq, PartialEq)]
923pub enum HttpRequestTargetForm {
924    Origin,
925    Absolute,
926    Authority,
927    Asterisk,
928}
929
930#[derive(Debug, Eq, PartialEq)]
931pub enum CompressMode {
932    None,
933    Gzip,
934}
935
936pub struct Websocket {
937    stream: Arc<Mutex<HttpStream>>,
938}
939
940impl Websocket {
941    pub async fn connect(url: &str, args: Vec<Headers>) -> anyhow::Result<Self> {
942        let mut sess = Session::new();
943        let mut req = sess.new_request(HttpMethod::GET, url).await?;
944        for arg in args.into_iter() {
945            req.apply_header(arg);
946        }
947        req.apply_header(Headers::Connection("Upgrade".to_string()));
948        req.apply_header(Headers::Upgrade("Websocket".to_string()));
949        req.apply_header(Headers::Sec_WebSocket_Version("13".to_string()));
950        req.apply_header(Headers::Sec_WebSocket_Key("VerySecurity".to_string()));
951        let res = sess.do_request(req).await?;
952        if res.http_code != 101 {
953            let body_str = match &res.body {
954                HttpResponseBody::Data(data) => str::from_utf8(&data[..])?,
955                HttpResponseBody::Stream(_) => "stream response",
956            };
957            Err(anyhow!("Server return code[{}]: {body_str}", res.http_code))?;
958        }
959        let stream = sess
960            .sess_impl
961            .ok_or_else(|| anyhow!("session impl is null"))?
962            .stream;
963        let stream = Arc::new(Mutex::new(stream));
964        Ok(Self { stream })
965    }
966
967    async fn recv_impl(&mut self) -> anyhow::Result<WsFrameImpl> {
968        let mut stream = self.stream.lock().await;
969        let buf = {
970            let mut buf = [0u8; 2];
971            stream.read_exact(&mut buf).await?;
972            buf
973        };
974        //let fin = buf[0] & 0b1000_0000 != 0;
975        let opcode = buf[0] & 0b0000_1111;
976        let payload_len = {
977            let payload_len = buf[1] & 0b0111_1111;
978            match payload_len {
979                126 => {
980                    let mut buf = [0u8; 2];
981                    stream.read_exact(&mut buf).await?;
982                    u16::from_be_bytes(buf) as usize
983                }
984                127 => {
985                    let mut buf = [0u8; 8];
986                    stream.read_exact(&mut buf).await?;
987                    u64::from_be_bytes(buf) as usize
988                }
989                _ => payload_len as usize,
990            }
991        };
992        let omask_key = match buf[1] & 0b1000_0000 != 0 {
993            true => {
994                let mut mask_key = [0u8; 4];
995                stream.read_exact(&mut mask_key).await?;
996                Some(mask_key)
997            }
998            false => None,
999        };
1000        let mut payload = vec![0u8; payload_len];
1001        stream.read_exact(&mut payload).await?;
1002        if let Some(mask_key) = omask_key {
1003            for i in 0..payload.len() {
1004                payload[i] ^= mask_key[i % 4];
1005            }
1006        }
1007        match opcode {
1008            0x0 => Ok(WsFrameImpl::PartData(payload)),
1009            0x1 => Ok(WsFrameImpl::Text(payload)),
1010            0x2 => Ok(WsFrameImpl::Binary(payload)),
1011            0x8 => Ok(WsFrameImpl::Close),
1012            0x9 => Ok(WsFrameImpl::Ping),
1013            0xA => Ok(WsFrameImpl::Pong),
1014            _ => Err(anyhow::Error::msg("unsupported opcode")),
1015        }
1016    }
1017
1018    pub async fn recv(&mut self) -> anyhow::Result<WsFrame> {
1019        let mut tmp = vec![];
1020        loop {
1021            let timeout = ServerConfig::get_ws_ping_duration().await;
1022            match tokio::time::timeout(timeout, self.recv_impl()).await {
1023                Ok(ws_frame) => match ws_frame? {
1024                    WsFrameImpl::Close => return Err(anyhow::Error::msg("close frame")),
1025                    WsFrameImpl::Ping => self.send_impl(WsFrameImpl::Pong).await?,
1026                    WsFrameImpl::Pong => (),
1027                    WsFrameImpl::Binary(data) => {
1028                        tmp.extend(data);
1029                        return Ok(WsFrame::Binary(tmp));
1030                    }
1031                    WsFrameImpl::Text(data) => {
1032                        tmp.extend(data);
1033                        let ret_str = String::from_utf8(tmp).unwrap_or("".to_string());
1034                        return Ok(WsFrame::Text(ret_str));
1035                    }
1036                    WsFrameImpl::PartData(data) => tmp.extend(data),
1037                },
1038                Err(_) => self.send_impl(WsFrameImpl::Ping).await?,
1039            }
1040        }
1041    }
1042
1043    async fn send_impl(&mut self, frame: WsFrameImpl) -> anyhow::Result<()> {
1044        let (fin, opcode, payload) = match frame {
1045            WsFrameImpl::Close => (true, 0x8, vec![]),
1046            WsFrameImpl::Ping => (true, 0x9, vec![]),
1047            WsFrameImpl::Pong => (true, 0xA, vec![]),
1048            WsFrameImpl::Binary(data) => (true, 0x2, data),
1049            WsFrameImpl::Text(data) => (true, 0x1, data),
1050            WsFrameImpl::PartData(data) => (false, 0x0, data),
1051        };
1052        let payload_len = payload.len();
1053        let mut buf = vec![];
1054        buf.push((fin as u8) << 7 | opcode);
1055        if payload_len < 126 {
1056            buf.push(payload_len as u8);
1057        } else if payload_len < 65536 {
1058            buf.push(126);
1059            buf.extend((payload_len as u16).to_be_bytes().iter());
1060        } else {
1061            buf.push(127);
1062            buf.extend((payload_len as u64).to_be_bytes().iter());
1063        }
1064        let mut stream = self.stream.lock().await;
1065        stream.write_all(&buf).await?;
1066        stream.write_all(&payload).await?;
1067        Ok(())
1068    }
1069
1070    pub async fn send_ping(&mut self) -> anyhow::Result<()> {
1071        self.send_impl(WsFrameImpl::Ping).await
1072    }
1073
1074    pub async fn send(&mut self, frame: WsFrame) -> anyhow::Result<()> {
1075        match frame {
1076            WsFrame::Binary(data) => self.send_impl(WsFrameImpl::Binary(data)).await,
1077            WsFrame::Text(text) => {
1078                self.send_impl(WsFrameImpl::Text(text.as_bytes().to_vec()))
1079                    .await
1080            }
1081        }
1082    }
1083
1084    pub async fn send_binary(&mut self, data: Vec<u8>) -> anyhow::Result<()> {
1085        self.send_impl(WsFrameImpl::Binary(data)).await
1086    }
1087
1088    pub async fn send_text(&mut self, data: &str) -> anyhow::Result<()> {
1089        self.send_impl(WsFrameImpl::Text(data.as_bytes().to_vec()))
1090            .await
1091    }
1092}
1093
1094#[derive(Debug)]
1095pub enum WsFrame {
1096    Binary(Vec<u8>),
1097    Text(String),
1098}
1099
1100pub enum WsFrameImpl {
1101    Close,
1102    Ping,
1103    Pong,
1104    Binary(Vec<u8>),
1105    Text(Vec<u8>),
1106    PartData(Vec<u8>),
1107}
1108
1109#[derive(Clone, Debug)]
1110pub struct PostFile {
1111    pub filename: LocalHipStr<'static>,
1112    pub data: LocalHipByt<'static>,
1113}
1114
1115unsafe impl Send for PostFile {}
1116
1117#[derive(Debug)]
1118pub struct HttpRequest {
1119    pub method: HttpMethod,
1120    pub target_form: HttpRequestTargetForm,
1121    pub url_path: LocalHipStr<'static>,
1122    pub url_query: HashMap<LocalHipStr<'static>, LocalHipStr<'static>>,
1123    pub version: u8,
1124    pub headers: HashMap<HeaderOrHipStr, LocalHipStr<'static>>,
1125    pub trailers: HashMap<HeaderOrHipStr, LocalHipStr<'static>>,
1126    pub body: LocalHipByt<'static>,
1127    pub body_pairs: HashMap<LocalHipStr<'static>, LocalHipStr<'static>>,
1128    pub body_files: HashMap<LocalHipStr<'static>, PostFile>,
1129    pub client_addr: Option<SocketAddr>,
1130    pub exts: HashMap<TypeId, Arc<dyn Any + Send + Sync + 'static>>,
1131}
1132
1133unsafe impl Send for HttpRequest {}
1134unsafe impl Sync for HttpRequest {}
1135
1136impl Default for HttpRequest {
1137    fn default() -> Self {
1138        Self::new()
1139    }
1140}
1141
1142impl HttpRequest {
1143    fn bad_request(msg: impl Into<String>) -> anyhow::Error {
1144        anyhow::Error::new(HttpRequestParseError::BadRequest(msg.into()))
1145    }
1146
1147    fn expectation_failed(msg: impl Into<String>) -> anyhow::Error {
1148        anyhow::Error::new(HttpRequestParseError::ExpectationFailed(msg.into()))
1149    }
1150
1151    fn not_implemented(msg: impl Into<String>) -> anyhow::Error {
1152        anyhow::Error::new(HttpRequestParseError::NotImplemented(msg.into()))
1153    }
1154
1155    fn request_header_fields_too_large(msg: impl Into<String>) -> anyhow::Error {
1156        anyhow::Error::new(HttpRequestParseError::RequestHeaderFieldsTooLarge(
1157            msg.into(),
1158        ))
1159    }
1160
1161    pub fn bad_request_message(err: &anyhow::Error) -> Option<&str> {
1162        err.downcast_ref::<HttpRequestParseError>()
1163            .and_then(|parse_err| match parse_err {
1164                HttpRequestParseError::BadRequest(msg) => Some(msg.as_str()),
1165                HttpRequestParseError::NotImplemented(_) => None,
1166                HttpRequestParseError::ExpectationFailed(_) => None,
1167                HttpRequestParseError::RequestHeaderFieldsTooLarge(_) => None,
1168                HttpRequestParseError::PayloadTooLarge(_) => None,
1169            })
1170    }
1171
1172    pub fn parse_error_response(err: &anyhow::Error) -> Option<HttpResponse> {
1173        err.downcast_ref::<HttpRequestParseError>()
1174            .map(|parse_err| {
1175                let (status, msg) = match parse_err {
1176                    HttpRequestParseError::BadRequest(msg) => (400, msg.as_str()),
1177                    HttpRequestParseError::NotImplemented(msg) => (501, msg.as_str()),
1178                    HttpRequestParseError::ExpectationFailed(msg) => (417, msg.as_str()),
1179                    HttpRequestParseError::RequestHeaderFieldsTooLarge(msg) => (431, msg.as_str()),
1180                    HttpRequestParseError::PayloadTooLarge(msg) => (413, msg.as_str()),
1181                };
1182                let mut res = HttpResponse::text(msg.to_string());
1183                res.http_code = status;
1184                res.add_header("Connection".into(), "close".into());
1185                res
1186            })
1187    }
1188
1189    fn ensure_header_section_size(buf: &[u8], max_header_bytes: usize) -> anyhow::Result<()> {
1190        let header_len = match buf.windows(4).position(|w| w == b"\r\n\r\n") {
1191            Some(pos) => pos + 4,
1192            None => buf.len(),
1193        };
1194        if header_len > max_header_bytes {
1195            Err(Self::request_header_fields_too_large(format!(
1196                "request headers too large: {header_len} bytes exceeds {max_header_bytes} bytes"
1197            )))?;
1198        }
1199        Ok(())
1200    }
1201
1202    async fn process_expect_header(
1203        &self,
1204        stream: &mut HttpStream,
1205        has_request_body: bool,
1206    ) -> anyhow::Result<()> {
1207        let Some(expect) = self.get_header_key(HeaderItem::Expect) else {
1208            return Ok(());
1209        };
1210        if self.version < 11 {
1211            Err(Self::expectation_failed(
1212                "Expect header is not supported for HTTP versions below 1.1",
1213            ))?;
1214        }
1215        let mut has_100_continue = false;
1216        for token in expect.split(',').map(|token| token.trim()) {
1217            if token.is_empty() {
1218                Err(Self::expectation_failed("empty Expect header"))?;
1219            }
1220            if token.eq_ignore_ascii_case("100-continue") {
1221                has_100_continue = true;
1222                continue;
1223            }
1224            Err(Self::expectation_failed(format!(
1225                "unsupported Expect header: {token}"
1226            )))?;
1227        }
1228        if has_100_continue && has_request_body {
1229            stream.write_all(b"HTTP/1.1 100 Continue\r\n\r\n").await?;
1230        }
1231        Ok(())
1232    }
1233
1234    pub fn new() -> Self {
1235        Self {
1236            method: HttpMethod::GET,
1237            target_form: HttpRequestTargetForm::Origin,
1238            url_path: LocalHipStr::from("/"),
1239            url_query: HashMap::with_capacity(16),
1240            version: 11,
1241            headers: HashMap::with_capacity(16),
1242            trailers: HashMap::with_capacity(4),
1243            body: LocalHipByt::new(),
1244            body_pairs: HashMap::with_capacity(16),
1245            body_files: HashMap::with_capacity(4),
1246            client_addr: None,
1247            exts: HashMap::with_capacity(2),
1248        }
1249    }
1250
1251    pub fn query_string(&self) -> String {
1252        let mut q = "?".to_string();
1253        for (k, v) in self.url_query.iter() {
1254            q.push_str(k);
1255            q.push('=');
1256            q.push_str(v);
1257            q.push('&');
1258        }
1259        q.pop();
1260        q
1261    }
1262
1263    fn add_ext<T: Any + Send + Sync + 'static>(&mut self, item: Arc<T>) {
1264        let type_id = TypeId::of::<T>();
1265        self.exts.insert(type_id, item);
1266    }
1267
1268    fn get_ext<T: Any + Send + Sync + 'static>(&self) -> Option<Arc<T>> {
1269        self.exts
1270            .get(&TypeId::of::<T>())
1271            .and_then(|any| any.clone().downcast().ok())
1272    }
1273
1274    fn remove_ext<T: Any + Send + Sync + 'static>(&mut self) -> Option<Arc<T>> {
1275        self.exts
1276            .remove(&TypeId::of::<T>())
1277            .and_then(|any| any.clone().downcast().ok())
1278    }
1279
1280    fn parse_path_and_query(&mut self, target: &str) {
1281        self.url_query.clear();
1282        match target.find('?') {
1283            Some(p) => {
1284                self.url_path = LocalHipStr::from(&target[..p]);
1285                self.url_query = target[p + 1..]
1286                    .split('&')
1287                    .map(|s| s.split_once('=').unwrap_or((s, "")))
1288                    .map(|(a, b)| (LocalHipStr::from(a), LocalHipStr::from(b)))
1289                    .collect();
1290            }
1291            None => {
1292                self.url_path = LocalHipStr::from(target);
1293            }
1294        }
1295    }
1296
1297    fn request_target(&self) -> String {
1298        if matches!(
1299            self.target_form,
1300            HttpRequestTargetForm::Asterisk | HttpRequestTargetForm::Authority
1301        ) {
1302            return self.url_path.to_string();
1303        }
1304        let mut target = self.url_path.to_string();
1305        if !self.url_query.is_empty() {
1306            target.push_str(&self.query_string());
1307        }
1308        target
1309    }
1310
1311    fn request_version_line(&self) -> String {
1312        match self.version {
1313            10 => "HTTP/1.0".to_string(),
1314            11 => "HTTP/1.1".to_string(),
1315            // Request-line serialization is HTTP/1.x only. Non-H1 internal markers
1316            // (e.g. 20/30 for H2/H3) must not leak into an HTTP/1 request line.
1317            _ => "HTTP/1.1".to_string(),
1318        }
1319    }
1320
1321    fn parse_request_target(&mut self, target: &str) -> anyhow::Result<()> {
1322        if target == "*" {
1323            if self.method != HttpMethod::OPTIONS {
1324                Err(Self::bad_request(
1325                    "asterisk-form request-target requires OPTIONS",
1326                ))?;
1327            }
1328            self.target_form = HttpRequestTargetForm::Asterisk;
1329            self.url_query.clear();
1330            self.url_path = LocalHipStr::from("*");
1331            return Ok(());
1332        }
1333
1334        if target.starts_with('/') {
1335            if self.method == HttpMethod::CONNECT {
1336                Err(Self::bad_request(
1337                    "CONNECT requires authority-form request-target",
1338                ))?;
1339            }
1340            self.target_form = HttpRequestTargetForm::Origin;
1341            self.parse_path_and_query(target);
1342            return Ok(());
1343        }
1344
1345        if target.contains("://") {
1346            if self.method == HttpMethod::CONNECT {
1347                Err(Self::bad_request(
1348                    "CONNECT requires authority-form, absolute-form is invalid",
1349                ))?;
1350            }
1351            let uri = target
1352                .parse::<Uri>()
1353                .map_err(|_| Self::bad_request("invalid absolute-form request-target"))?;
1354            if uri.scheme().is_none() || uri.authority().is_none() {
1355                Err(Self::bad_request(
1356                    "absolute-form request-target must include scheme and authority",
1357                ))?;
1358            }
1359            self.target_form = HttpRequestTargetForm::Absolute;
1360            let path_and_query = uri.path_and_query().map(|v| v.as_str()).unwrap_or("/");
1361            self.parse_path_and_query(path_and_query);
1362            return Ok(());
1363        }
1364
1365        if http::uri::Authority::from_str(target).is_ok() {
1366            if self.method != HttpMethod::CONNECT {
1367                Err(Self::bad_request(
1368                    "authority-form request-target is only valid for CONNECT",
1369                ))?;
1370            }
1371            self.target_form = HttpRequestTargetForm::Authority;
1372            self.url_query.clear();
1373            self.url_path = LocalHipStr::from(target);
1374            return Ok(());
1375        }
1376
1377        Err(Self::bad_request("unsupported request-target form"))
1378    }
1379
1380    pub fn get_uri(&self, is_https: bool) -> anyhow::Result<http::Uri> {
1381        let mut q = self.url_path.to_string();
1382        let mut is_first = true;
1383        for (k, v) in self.url_query.iter() {
1384            match is_first {
1385                true => {
1386                    is_first = false;
1387                    q.push('?');
1388                }
1389                false => q.push('&'),
1390            }
1391            q.push_str(k);
1392            q.push('=');
1393            q.push_str(v);
1394        }
1395        Ok(http::uri::Builder::new()
1396            .scheme(if is_https { "https" } else { "http" })
1397            .path_and_query(q)
1398            .build()?)
1399    }
1400
1401    pub fn from_url(url: &str, method: HttpMethod) -> anyhow::Result<(Self, bool, u16)> {
1402        let uri = url.parse::<Uri>()?;
1403        let mut req = Self::new();
1404        req.method = method;
1405        req.target_form = HttpRequestTargetForm::Origin;
1406        req.parse_path_and_query(uri.path_and_query().map(|v| v.as_str()).unwrap_or("/"));
1407        req.headers.insert(
1408            HeaderOrHipStr::from_str("Host"),
1409            uri.authority()
1410                .map(|authority| authority.as_str())
1411                .unwrap_or("localhost")
1412                .into(),
1413        );
1414        let use_ssl = uri.scheme() == Some(&Scheme::HTTPS);
1415        let port = uri.port_u16().unwrap_or(if use_ssl { 443 } else { 80 });
1416        Ok((req, use_ssl, port))
1417    }
1418
1419    pub fn set_header(
1420        &mut self,
1421        key: impl Into<HeaderOrHipStr>,
1422        value: impl Into<LocalHipStr<'static>>,
1423    ) {
1424        self.headers.insert(key.into(), value.into());
1425    }
1426
1427    pub fn get_header(&self, key: &str) -> Option<&str> {
1428        if let Some(header_item) = HeaderItem::try_from_str(key) {
1429            if let Some(value) = self.headers.get(&HeaderOrHipStr::HeaderItem(header_item)) {
1430                return Some(&value[..]);
1431            }
1432        }
1433        self.headers
1434            .get(&HeaderOrHipStr::HipStr(LocalHipStr::from(key)))
1435            .map(|a| &a[..])
1436    }
1437
1438    pub fn get_header_key(&self, key: HeaderItem) -> Option<&str> {
1439        self.headers.get(&key.into()).map(|a| &a[..])
1440    }
1441
1442    pub fn set_trailer(
1443        &mut self,
1444        key: impl Into<HeaderOrHipStr>,
1445        value: impl Into<LocalHipStr<'static>>,
1446    ) {
1447        self.trailers.insert(key.into(), value.into());
1448    }
1449
1450    pub fn get_trailer(&self, key: &str) -> Option<&str> {
1451        if let Some(header_item) = HeaderItem::try_from_str(key) {
1452            if let Some(value) = self.trailers.get(&HeaderOrHipStr::HeaderItem(header_item)) {
1453                return Some(&value[..]);
1454            }
1455        }
1456        self.trailers
1457            .get(&HeaderOrHipStr::HipStr(LocalHipStr::from(key)))
1458            .map(|a| &a[..])
1459    }
1460
1461    pub fn get_header_accept_encoding(&self) -> CompressMode {
1462        Self::negotiate_accept_encoding(
1463            self.get_header_key(HeaderItem::Accept_Encoding)
1464                .unwrap_or(""),
1465        )
1466    }
1467
1468    fn negotiate_accept_encoding(header: &str) -> CompressMode {
1469        let mut explicit_gzip_q: Option<u16> = None;
1470        let mut wildcard_q: Option<u16> = None;
1471
1472        for item in header.split(',') {
1473            let trimmed = item.trim();
1474            if trimmed.is_empty() {
1475                continue;
1476            }
1477
1478            let mut parts = trimmed.split(';');
1479            let coding = parts.next().unwrap_or("").trim().to_ascii_lowercase();
1480            if coding.is_empty() {
1481                continue;
1482            }
1483
1484            let mut quality = 1000u16;
1485            let mut malformed_q = false;
1486            for param in parts {
1487                let param = param.trim();
1488                if param.is_empty() {
1489                    continue;
1490                }
1491                let mut key_val = param.splitn(2, '=');
1492                let key = key_val.next().unwrap_or("").trim().to_ascii_lowercase();
1493                if key != "q" {
1494                    continue;
1495                }
1496                let val = key_val.next().unwrap_or("").trim();
1497                if let Some(parsed_q) = Self::parse_qvalue_thousandths(val) {
1498                    quality = parsed_q;
1499                } else {
1500                    malformed_q = true;
1501                }
1502                break;
1503            }
1504
1505            if malformed_q {
1506                continue;
1507            }
1508
1509            match coding.as_str() {
1510                "gzip" => {
1511                    explicit_gzip_q =
1512                        Some(explicit_gzip_q.map_or(quality, |prev| prev.max(quality)));
1513                }
1514                "*" => {
1515                    wildcard_q = Some(wildcard_q.map_or(quality, |prev| prev.max(quality)));
1516                }
1517                _ => {}
1518            }
1519        }
1520
1521        let selected_q = explicit_gzip_q.or(wildcard_q).unwrap_or(0);
1522        if selected_q > 0 {
1523            CompressMode::Gzip
1524        } else {
1525            CompressMode::None
1526        }
1527    }
1528
1529    fn parse_qvalue_thousandths(raw: &str) -> Option<u16> {
1530        let val = raw.trim();
1531        if val == "1" || val == "1.0" || val == "1.00" || val == "1.000" {
1532            return Some(1000);
1533        }
1534        if val == "0" {
1535            return Some(0);
1536        }
1537        let frac = val.strip_prefix("0.")?;
1538        if frac.is_empty() || frac.len() > 3 || !frac.chars().all(|ch| ch.is_ascii_digit()) {
1539            return None;
1540        }
1541        let mut digits = frac.to_string();
1542        while digits.len() < 3 {
1543            digits.push('0');
1544        }
1545        digits.parse::<u16>().ok()
1546    }
1547
1548    pub fn get_header_host(&self) -> Option<&str> {
1549        self.get_header_key(HeaderItem::Host)
1550    }
1551
1552    pub fn get_header_connection(&self) -> HttpConnection {
1553        if let Some(conn) = self.get_header_key(HeaderItem::Connection) {
1554            HttpConnection::from_str(conn).unwrap_or(HttpConnection::Close)
1555        } else if self.version >= 11 {
1556            HttpConnection::KeepAlive
1557        } else {
1558            HttpConnection::Close
1559        }
1560    }
1561
1562    pub fn get_header_content_length(&self) -> usize {
1563        self.get_header_key(HeaderItem::Content_Length)
1564            .map_or(0, |val| val.parse::<usize>().unwrap_or(0))
1565    }
1566
1567    fn parse_header_content_length(&self) -> anyhow::Result<Option<usize>> {
1568        let Some(raw_val) = self.get_header_key(HeaderItem::Content_Length) else {
1569            return Ok(None);
1570        };
1571        let value = raw_val.trim();
1572        if value.is_empty() {
1573            Err(anyhow!("empty Content-Length header"))?
1574        }
1575        Ok(Some(value.parse::<usize>()?))
1576    }
1577
1578    fn has_chunked_transfer_encoding(&self) -> anyhow::Result<bool> {
1579        let Some(raw_val) = self.get_header_key(HeaderItem::Transfer_Encoding) else {
1580            return Ok(false);
1581        };
1582        let codings: Vec<String> = raw_val
1583            .split(',')
1584            .map(|part| part.trim().to_ascii_lowercase())
1585            .filter(|part| !part.is_empty())
1586            .collect();
1587        if codings.is_empty() {
1588            Err(Self::bad_request("empty Transfer-Encoding header"))?
1589        }
1590        if codings.len() == 1 && codings[0] == "chunked" {
1591            return Ok(true);
1592        }
1593        Err(Self::not_implemented(format!(
1594            "unsupported Transfer-Encoding: {raw_val}"
1595        )))
1596    }
1597
1598    async fn read_chunked_body(
1599        buf: &mut Vec<u8>,
1600        stream: &mut HttpStream,
1601        hdr_len: usize,
1602        allowed_trailers: &HashSet<String>,
1603    ) -> anyhow::Result<(
1604        LocalHipByt<'static>,
1605        HashMap<HeaderOrHipStr, LocalHipStr<'static>>,
1606        usize,
1607    )> {
1608        let mut cursor = hdr_len;
1609        let mut body = Vec::new();
1610        let mut trailers = HashMap::with_capacity(4);
1611        let mut tmp_buf = [0u8; 4096];
1612
1613        loop {
1614            let line_end = loop {
1615                if let Some(pos) = buf[cursor..].windows(2).position(|part| part == b"\r\n") {
1616                    break cursor + pos;
1617                }
1618                let n = stream.read(&mut tmp_buf).await?;
1619                if n == 0 {
1620                    Err(anyhow::Error::msg("connection closed"))?;
1621                }
1622                buf.extend(&tmp_buf[..n]);
1623            };
1624
1625            let chunk_size = {
1626                let size_line = str::from_utf8(&buf[cursor..line_end])?;
1627                let size_token = size_line
1628                    .split_once(';')
1629                    .map_or(size_line, |(size, _)| size)
1630                    .trim();
1631                if size_token.is_empty() {
1632                    Err(anyhow!("invalid chunk size"))?;
1633                }
1634                usize::from_str_radix(size_token, 16)?
1635            };
1636            cursor = line_end + 2;
1637
1638            if chunk_size == 0 {
1639                let trailer_end = loop {
1640                    let line_start = cursor;
1641                    let line_end = loop {
1642                        if let Some(pos) = buf[cursor..].windows(2).position(|part| part == b"\r\n")
1643                        {
1644                            break cursor + pos;
1645                        }
1646                        let n = stream.read(&mut tmp_buf).await?;
1647                        if n == 0 {
1648                            Err(anyhow::Error::msg("connection closed"))?;
1649                        }
1650                        buf.extend(&tmp_buf[..n]);
1651                    };
1652                    cursor = line_end + 2;
1653                    if line_end == line_start {
1654                        break cursor;
1655                    }
1656
1657                    let (name, value) = parse_trailer_line(&buf[line_start..line_end])?;
1658                    let name_lower = name.to_ascii_lowercase();
1659                    if is_forbidden_trailer_field(&name_lower) {
1660                        Err(anyhow!("forbidden trailer field: {name}"))?;
1661                    }
1662                    if !allowed_trailers.contains(&name_lower) {
1663                        Err(anyhow!("unexpected trailer field: {name}"))?;
1664                    }
1665                    trailers.insert(HeaderOrHipStr::from_str(&name), value.into());
1666                };
1667                cursor = trailer_end;
1668                break;
1669            }
1670
1671            while buf.len() < cursor + chunk_size + 2 {
1672                let n = stream.read(&mut tmp_buf).await?;
1673                if n == 0 {
1674                    Err(anyhow::Error::msg("connection closed"))?;
1675                }
1676                buf.extend(&tmp_buf[..n]);
1677            }
1678            body.extend_from_slice(&buf[cursor..cursor + chunk_size]);
1679            if &buf[cursor + chunk_size..cursor + chunk_size + 2] != b"\r\n" {
1680                Err(anyhow!("invalid chunk terminator"))?;
1681            }
1682            cursor += chunk_size + 2;
1683        }
1684
1685        Ok((LocalHipByt::from(body), trailers, cursor - hdr_len))
1686    }
1687
1688    pub fn get_header_content_type<'a>(&'a self) -> Option<HttpContentType<'a>> {
1689        HttpContentType::from_str(self.get_header_key(HeaderItem::Content_Type).unwrap_or(""))
1690    }
1691
1692    pub fn is_websocket(&self) -> bool {
1693        if self.method != HttpMethod::GET {
1694            return false;
1695        }
1696        if self.get_header_connection() != HttpConnection::Upgrade {
1697            return false;
1698        }
1699        if self
1700            .get_header_key(HeaderItem::Upgrade)
1701            .is_some_and(|val| val.to_lowercase() != "websocket")
1702        {
1703            return false;
1704        }
1705        if self
1706            .get_header("Sec-WebSocket-Version")
1707            .is_some_and(|val| val != "13")
1708        {
1709            return false;
1710        }
1711        if self
1712            .get_header("Sec-WebSocket-Key")
1713            .is_some_and(|val| val.is_empty())
1714        {
1715            return false;
1716        }
1717        true
1718    }
1719
1720    pub async fn upgrade_websocket(&mut self) -> anyhow::Result<Websocket> {
1721        if !self.is_websocket() {
1722            Err(anyhow!("it is not a websocket request"))?;
1723        }
1724        let ws_key = self
1725            .get_header("Sec-WebSocket-Key")
1726            .unwrap_or("")
1727            .to_string();
1728        // let ws_ext = req.get_header("Sec-WebSocket-Extensions").unwrap_or("".to_string());
1729        let stream = match self.remove_ext::<Mutex<HttpStream>>() {
1730            Some(stream) => stream,
1731            None => Err(anyhow!("connot get stream"))?,
1732        };
1733        {
1734            let mut stream = stream.lock().await;
1735            let res = HttpResponse::from_websocket(&ws_key);
1736            stream.write_all(&res.as_bytes(CompressMode::None)).await?;
1737        }
1738        Ok(Websocket { stream })
1739    }
1740
1741    pub async fn get_client_addr(&self) -> anyhow::Result<SocketAddr> {
1742        if let Some(addr) = self.client_addr {
1743            return Ok(addr);
1744        }
1745        match self.get_ext::<SocketAddr>() {
1746            Some(addr) => Ok(*addr),
1747            None => Err(anyhow!("no addr info")),
1748        }
1749    }
1750
1751    async fn from_stream_impl(
1752        buf: &mut Vec<u8>,
1753        stream: &mut HttpStream,
1754    ) -> anyhow::Result<(Self, usize)> {
1755        let mut tmp_buf = [0u8; 4096];
1756        let (mut req, hdr_len) = loop {
1757            match HttpRequest::from_headers_part(&buf[..])? {
1758                Some((req, hdr_len)) => break (req, hdr_len),
1759                None => {
1760                    let n = stream.read(&mut tmp_buf).await?;
1761                    if n == 0 {
1762                        return Err(anyhow::Error::msg("connection closed"));
1763                    }
1764                    buf.extend(&tmp_buf[0..n]);
1765                }
1766            }
1767        };
1768        let has_chunked_transfer_encoding = req.has_chunked_transfer_encoding()?;
1769        let mut content_length = 0usize;
1770        if has_chunked_transfer_encoding {
1771            if req.get_header_key(HeaderItem::Content_Length).is_some() {
1772                Err(Self::bad_request(
1773                    "conflicting headers: Transfer-Encoding and Content-Length",
1774                ))?;
1775            }
1776        } else {
1777            content_length = req
1778                .parse_header_content_length()
1779                .map_err(|err| Self::bad_request(err.to_string()))?
1780                .unwrap_or(0);
1781        }
1782
1783        let has_request_body = has_chunked_transfer_encoding || content_length > 0;
1784        req.process_expect_header(stream, has_request_body).await?;
1785
1786        let bdy_len;
1787        if has_chunked_transfer_encoding {
1788            let allowed_trailers =
1789                parse_declared_trailer_names(req.get_header_key(HeaderItem::Trailer));
1790            let (body, trailers, consumed_len) =
1791                Self::read_chunked_body(buf, stream, hdr_len, &allowed_trailers)
1792                    .await
1793                    .map_err(|err| Self::bad_request(err.to_string()))?;
1794            req.body = body;
1795            req.trailers = trailers;
1796            bdy_len = consumed_len;
1797        } else {
1798            while hdr_len + content_length > buf.len() {
1799                let t = stream.read(&mut tmp_buf).await?;
1800                if t == 0 {
1801                    return Err(anyhow::Error::msg("connection closed"));
1802                }
1803                buf.extend(&tmp_buf[0..t]);
1804            }
1805            if content_length > 0 {
1806                req.body = LocalHipByt::from(&buf[hdr_len..hdr_len + content_length]);
1807            }
1808            bdy_len = content_length;
1809        }
1810
1811        // 先获取Content-Type的字符串值,避免借用冲突
1812        let content_type_str = {
1813            req.get_header_key(HeaderItem::Content_Type)
1814                .map(|s| s.to_string())
1815        };
1816
1817        // 根据内容类型字符串处理请求体
1818        if let Some(content_type_str) = content_type_str {
1819            // 解析内容类型
1820            let content_type_parsed = HttpContentType::from_str(&content_type_str);
1821
1822            match content_type_parsed {
1823                Some(HttpContentType::ApplicationJson) => {
1824                    if let Ok(body_str) = std::str::from_utf8(&req.body) {
1825                        if let Ok(serde_json::Value::Object(obj)) =
1826                            serde_json::from_str::<serde_json::Value>(body_str)
1827                        {
1828                            for (k, v) in obj {
1829                                req.body_pairs
1830                                    .insert(LocalHipStr::from(k), LocalHipStr::from(v.to_string()));
1831                            }
1832                        }
1833                    }
1834                }
1835                Some(HttpContentType::ApplicationXWwwFormUrlencoded) => {
1836                    if let Ok(body_str) = std::str::from_utf8(&req.body) {
1837                        body_str.split('&').for_each(|s| {
1838                            if let Some((a, b)) = s.split_once('=') {
1839                                req.body_pairs
1840                                    .insert(a.url_decode().into(), b.url_decode().into());
1841                            }
1842                        });
1843                    }
1844                }
1845                Some(HttpContentType::MultipartFormData(boundary)) => {
1846                    if let Ok(body_str) = std::str::from_utf8(&req.body) {
1847                        let split_str = ssformat!(64, "--{boundary}");
1848                        for mut s in body_str.split(split_str.as_str()) {
1849                            if s == "--" {
1850                                break;
1851                            }
1852                            if s.ends_with("\r\n") {
1853                                s = &s[..s.len() - 2];
1854                            }
1855                            if let Some((key_str, content)) = s.split_once("\r\n\r\n") {
1856                                let keys: Vec<&str> = key_str
1857                                    .split("\r\n")
1858                                    .map(|p| p.split(";").collect::<Vec<_>>())
1859                                    .collect::<Vec<_>>()
1860                                    .into_iter()
1861                                    .flatten()
1862                                    .collect();
1863                                let mut name = None;
1864                                let mut filename = None;
1865                                for key in keys.into_iter() {
1866                                    if let Some((k, v)) = key.trim().split_once('=') {
1867                                        if k == "name" {
1868                                            name = Some(LocalHipStr::from(&v[1..v.len() - 1]));
1869                                        } else if k == "filename" {
1870                                            filename = Some(LocalHipStr::from(&v[1..v.len() - 1]));
1871                                        }
1872                                    }
1873                                }
1874                                if let Some(name) = name {
1875                                    if let Some(filename) = filename {
1876                                        let data = LocalHipByt::from(content.as_bytes());
1877                                        req.body_files.insert(name, PostFile { filename, data });
1878                                    } else {
1879                                        req.body_pairs.insert(name, LocalHipStr::from(content));
1880                                    }
1881                                }
1882                            }
1883                        }
1884                    }
1885                }
1886                None => {}
1887            }
1888        }
1889        Ok((req, hdr_len + bdy_len))
1890    }
1891
1892    pub async fn from_stream(
1893        buf: &mut Vec<u8>,
1894        stream: Arc<Mutex<HttpStream>>,
1895    ) -> anyhow::Result<(Self, usize)> {
1896        let mut stream = stream.lock().await;
1897        Self::from_stream_impl(buf, &mut stream).await
1898    }
1899
1900    pub fn from_headers_part(buf: &[u8]) -> anyhow::Result<Option<(Self, usize)>> {
1901        let max_header_count = ServerConfig::get_max_header_count();
1902        let max_header_bytes = ServerConfig::get_max_header_bytes();
1903        Self::ensure_header_section_size(buf, max_header_bytes)?;
1904
1905        let mut headers = vec![httparse::EMPTY_HEADER; max_header_count];
1906        let (rreq, n) = {
1907            let mut req = httparse::Request::new(&mut headers);
1908            let n = match httparse::ParserConfig::default().parse_request(&mut req, buf) {
1909                Ok(httparse::Status::Complete(n)) => n,
1910                Ok(httparse::Status::Partial) => return Ok(None),
1911                Err(httparse::Error::TooManyHeaders) => {
1912                    Err(Self::request_header_fields_too_large(format!(
1913                        "too many request headers: exceeds configured limit {max_header_count}"
1914                    )))?
1915                }
1916                Err(err) => Err(anyhow!(err))?,
1917            };
1918            (req, n)
1919        };
1920        let parsed_header_count = rreq.headers.iter().filter(|h| !h.name.is_empty()).count();
1921        if parsed_header_count > max_header_count {
1922            Err(Self::request_header_fields_too_large(format!(
1923                "too many request headers: {parsed_header_count} exceeds {max_header_count}"
1924            )))?;
1925        }
1926        let mut req = HttpRequest::new();
1927        let mut content_length_seen: Option<String> = None;
1928        let mut host_header_count = 0usize;
1929        let mut has_valid_host = false;
1930        req.method = {
1931            let method = rreq
1932                .method
1933                .ok_or_else(|| anyhow!("Missing HTTP method in request"))?;
1934            match method.len() {
1935                3 if method == "GET" => HttpMethod::GET,
1936                3 if method == "PUT" => HttpMethod::PUT,
1937                4 if method == "COPY" => HttpMethod::COPY,
1938                4 if method == "HEAD" => HttpMethod::HEAD,
1939                4 if method == "LOCK" => HttpMethod::LOCK,
1940                4 if method == "MOVE" => HttpMethod::MOVE,
1941                4 if method == "POST" => HttpMethod::POST,
1942                5 if method == "MKCOL" => HttpMethod::MKCOL,
1943                5 if method == "PATCH" => HttpMethod::PATCH,
1944                5 if method == "TRACE" => HttpMethod::TRACE,
1945                6 if method == "DELETE" => HttpMethod::DELETE,
1946                6 if method == "UNLOCK" => HttpMethod::UNLOCK,
1947                7 if method == "OPTIONS" => HttpMethod::OPTIONS,
1948                7 if method == "CONNECT" => HttpMethod::CONNECT,
1949                8 if method == "PROPFIND" => HttpMethod::PROPFIND,
1950                9 if method == "PROPPATCH" => HttpMethod::PROPPATCH,
1951                _ => Err(Self::not_implemented(format!(
1952                    "unsupported method: {method}"
1953                )))?,
1954            }
1955        };
1956        let target = rreq
1957            .path
1958            .ok_or_else(|| anyhow!("Missing HTTP path in request"))?;
1959        req.parse_request_target(target)?;
1960        req.version = rreq.version.unwrap_or(1) + 10;
1961        for h in rreq.headers.iter() {
1962            if h.name.is_empty() {
1963                break;
1964            }
1965            let header_value = str::from_utf8(h.value)?;
1966            let normalized_header_value = header_value.trim();
1967            if h.name.eq_ignore_ascii_case("Content-Length") {
1968                let cl = normalized_header_value;
1969                if cl.is_empty() {
1970                    Err(anyhow!("empty Content-Length header"))?;
1971                }
1972                if let Some(prev) = &content_length_seen {
1973                    if prev != cl {
1974                        Err(anyhow!("conflicting duplicate Content-Length headers"))?;
1975                    }
1976                } else {
1977                    content_length_seen = Some(cl.to_string());
1978                }
1979            }
1980            if h.name.eq_ignore_ascii_case("Host") {
1981                host_header_count += 1;
1982                if host_header_count > 1 {
1983                    Err(Self::bad_request("multiple Host headers are not allowed"))?;
1984                }
1985                if normalized_header_value.is_empty() {
1986                    Err(Self::bad_request("empty Host header"))?;
1987                }
1988                if http::uri::Authority::from_str(normalized_header_value).is_err() {
1989                    Err(Self::bad_request("invalid Host header"))?;
1990                }
1991                has_valid_host = true;
1992            }
1993            if h.name.eq_ignore_ascii_case("Expect") {
1994                let expect_key: HeaderOrHipStr = HeaderItem::Expect.into();
1995                if let Some(existing) = req.headers.get(&expect_key) {
1996                    req.headers.insert(
1997                        expect_key,
1998                        LocalHipStr::from(format!(
1999                            "{}, {normalized_header_value}",
2000                            existing.as_str()
2001                        )),
2002                    );
2003                } else {
2004                    req.headers
2005                        .insert(expect_key, LocalHipStr::from(normalized_header_value));
2006                }
2007            } else {
2008                req.headers.insert(
2009                    HeaderOrHipStr::from_str(h.name),
2010                    LocalHipStr::from(normalized_header_value),
2011                );
2012            }
2013        }
2014        if req.version >= 11 && !has_valid_host {
2015            Err(Self::bad_request("missing required Host header"))?;
2016        }
2017        Ok(Some((req, n)))
2018    }
2019
2020    /// Check HTTP conditional preflight headers to determine if special status codes should be returned
2021    ///
2022    /// This method handles the following HTTP conditional headers:
2023    /// - If-Modified-Since: Return 304 if resource is not modified
2024    /// - If-None-Match: Return 304 if ETag matches
2025    /// - If-Match: Return 412 if ETag doesn't match
2026    /// - If-Unmodified-Since: Return 412 if resource is modified
2027    ///
2028    /// # Parameters
2029    /// - `meta`: File metadata (optional)
2030    /// - `etag`: Resource's ETag value (optional)
2031    ///
2032    /// # Return Values
2033    /// - `PreflightResult::Proceed`: Pass preflight check, can continue processing
2034    /// - `PreflightResult::NotModified`: Should return 304 status code
2035    /// - `PreflightResult::PreconditionFailed`: Should return 412 status code
2036    pub fn check_precondition_headers(
2037        &self,
2038        meta: Option<&Metadata>,
2039        etag: Option<&str>,
2040    ) -> PreflightResult {
2041        use crate::utils::refstr::HeaderItem;
2042        use std::time::UNIX_EPOCH;
2043
2044        let is_get_or_head = matches!(self.method, HttpMethod::GET | HttpMethod::HEAD);
2045
2046        // Get file's last modified time (Unix timestamp in seconds)
2047        let last_modified_timestamp = meta
2048            .and_then(|m| m.modified().ok())
2049            .and_then(|t| t.duration_since(UNIX_EPOCH).ok())
2050            .map(|d| d.as_secs());
2051
2052        // Check If-Match header (if exists and doesn't match, return 412)
2053        if let Some(if_match) = self.get_header_key(HeaderItem::If_Match) {
2054            if if_match != "*" {
2055                if let Some(current_etag) = etag {
2056                    // Parse ETag list in If-Match
2057                    let match_found = if_match
2058                        .split(',')
2059                        .map(|s| s.trim())
2060                        .any(|expected_etag| expected_etag == current_etag);
2061
2062                    if !match_found {
2063                        return PreflightResult::PreconditionFailed;
2064                    }
2065                } else {
2066                    // No ETag but client requires match, return 412
2067                    return PreflightResult::PreconditionFailed;
2068                }
2069            }
2070        }
2071
2072        // Check If-Unmodified-Since header (if resource is modified, return 412)
2073        if let Some(if_unmodified_since) = self.get_header_key(HeaderItem::If_Unmodified_Since) {
2074            if let Some(last_modified) = last_modified_timestamp {
2075                if let Ok(since_timestamp) = parse_http_date(if_unmodified_since) {
2076                    if last_modified > since_timestamp {
2077                        return PreflightResult::PreconditionFailed;
2078                    }
2079                }
2080            }
2081        }
2082
2083        // Check If-None-Match header
2084        if let Some(if_none_match) = self.get_header_key(HeaderItem::If_None_Match) {
2085            if if_none_match == "*" {
2086                // If resource exists, GET/HEAD -> 304, others -> 412.
2087                if etag.is_some() || meta.is_some() {
2088                    return if is_get_or_head {
2089                        PreflightResult::NotModified
2090                    } else {
2091                        PreflightResult::PreconditionFailed
2092                    };
2093                }
2094            } else if let Some(current_etag) = etag {
2095                // Check if ETag is in If-None-Match list.
2096                let match_found = if_none_match
2097                    .split(',')
2098                    .map(|s| s.trim())
2099                    .any(|expected_etag| expected_etag == current_etag);
2100
2101                if match_found {
2102                    return if is_get_or_head {
2103                        PreflightResult::NotModified
2104                    } else {
2105                        PreflightResult::PreconditionFailed
2106                    };
2107                }
2108            }
2109        }
2110
2111        // Check If-Modified-Since header (if resource is not modified, return 304)
2112        // Note: only applies to GET/HEAD and only when there's no If-None-Match header.
2113        if is_get_or_head && self.get_header_key(HeaderItem::If_None_Match).is_none() {
2114            if let Some(if_modified_since) = self.get_header_key(HeaderItem::If_Modified_Since) {
2115                if let Some(last_modified) = last_modified_timestamp {
2116                    if let Ok(since_timestamp) = parse_http_date(if_modified_since) {
2117                        if last_modified <= since_timestamp {
2118                            return PreflightResult::NotModified;
2119                        }
2120                    }
2121                }
2122            }
2123        }
2124
2125        PreflightResult::Proceed
2126    }
2127
2128    pub fn as_bytes(&self) -> Vec<u8> {
2129        let use_chunked = self
2130            .get_header_key(HeaderItem::Transfer_Encoding)
2131            .map(|encodings| {
2132                encodings
2133                    .split(',')
2134                    .map(|coding| coding.trim())
2135                    .any(|coding| coding.eq_ignore_ascii_case("chunked"))
2136            })
2137            .unwrap_or(false);
2138
2139        let declared_trailer_names =
2140            parse_declared_trailer_names(self.get_header_key(HeaderItem::Trailer));
2141        let mut outbound_trailers: Vec<(String, String)> = Vec::with_capacity(self.trailers.len());
2142        for (key, value) in self.trailers.iter() {
2143            let key_str = key.to_str();
2144            let lower = key_str.to_ascii_lowercase();
2145            if is_forbidden_trailer_field(&lower) {
2146                continue;
2147            }
2148            if !declared_trailer_names.is_empty() && !declared_trailer_names.contains(&lower) {
2149                continue;
2150            }
2151            outbound_trailers.push((key_str.to_string(), value.to_string()));
2152        }
2153
2154        let mut req_str = format!(
2155            "{} {} {}\r\n",
2156            self.method,
2157            self.request_target(),
2158            self.request_version_line()
2159        );
2160        for (k, v) in self.headers.iter() {
2161            if let HeaderOrHipStr::HeaderItem(HeaderItem::Content_Length) = k {
2162                continue;
2163            }
2164            req_str.push_str(&format!("{}: {v}\r\n", k.to_str()));
2165        }
2166        if use_chunked {
2167            if declared_trailer_names.is_empty() && !outbound_trailers.is_empty() {
2168                let trailer_names = outbound_trailers
2169                    .iter()
2170                    .map(|(name, _)| name.as_str())
2171                    .collect::<Vec<_>>()
2172                    .join(", ");
2173                req_str.push_str(&format!("Trailer: {trailer_names}\r\n"));
2174            }
2175            req_str.push_str("\r\n");
2176            let mut ret = req_str.as_bytes().to_vec();
2177            if self.body.is_empty() {
2178                if outbound_trailers.is_empty() {
2179                    ret.extend_from_slice(b"0\r\n\r\n");
2180                } else {
2181                    ret.extend_from_slice(b"0\r\n");
2182                    for (name, value) in outbound_trailers.iter() {
2183                        ret.extend_from_slice(format!("{name}: {value}\r\n").as_bytes());
2184                    }
2185                    ret.extend_from_slice(b"\r\n");
2186                }
2187            } else {
2188                ret.extend_from_slice(format!("{:x}\r\n", self.body.len()).as_bytes());
2189                ret.extend(&self.body[..]);
2190                if outbound_trailers.is_empty() {
2191                    ret.extend_from_slice(b"\r\n0\r\n\r\n");
2192                } else {
2193                    ret.extend_from_slice(b"\r\n0\r\n");
2194                    for (name, value) in outbound_trailers.iter() {
2195                        ret.extend_from_slice(format!("{name}: {value}\r\n").as_bytes());
2196                    }
2197                    ret.extend_from_slice(b"\r\n");
2198                }
2199            }
2200            ret
2201        } else {
2202            req_str.push_str(&format!(
2203                "{}: {}\r\n",
2204                HeaderItem::Content_Length.to_str(),
2205                self.body.len()
2206            ));
2207            req_str.push_str("\r\n");
2208            let mut ret = req_str.as_bytes().to_vec();
2209            ret.extend(&self.body[..]);
2210            ret
2211        }
2212    }
2213}
2214
2215#[derive(Debug)]
2216pub enum HttpResponseBody {
2217    Data(Vec<u8>),
2218    Stream(Receiver<Vec<u8>>),
2219}
2220
2221pub struct HttpResponseBodyStream<'a> {
2222    body: &'a mut HttpResponseBody,
2223    data_consumed: bool,
2224}
2225
2226impl HttpResponseBody {
2227    pub async fn data(&mut self) -> &[u8] {
2228        // 先检查是否是 Stream 类型
2229        let is_stream = matches!(self, HttpResponseBody::Stream(_));
2230
2231        if is_stream {
2232            // 从 Stream 中替换出 rx,接收所有数据
2233            if let HttpResponseBody::Stream(mut rx) =
2234                std::mem::replace(self, HttpResponseBody::Data(vec![]))
2235            {
2236                let mut data = Vec::with_capacity(1024);
2237                while let Some(chunk) = rx.recv().await {
2238                    data.extend_from_slice(&chunk);
2239                }
2240                *self = HttpResponseBody::Data(data);
2241            }
2242        }
2243
2244        match self {
2245            HttpResponseBody::Data(data) => data.as_slice(),
2246            HttpResponseBody::Stream(_) => &[], // Should not reach here
2247        }
2248    }
2249
2250    pub fn stream_data(&mut self) -> HttpResponseBodyStream<'_> {
2251        HttpResponseBodyStream {
2252            body: self,
2253            data_consumed: false,
2254        }
2255    }
2256}
2257
2258impl HttpResponseBodyStream<'_> {
2259    pub async fn next(&mut self) -> Option<Vec<u8>> {
2260        match self.body {
2261            HttpResponseBody::Data(data) => {
2262                if self.data_consumed {
2263                    None
2264                } else {
2265                    self.data_consumed = true;
2266                    Some(data.clone())
2267                }
2268            }
2269            HttpResponseBody::Stream(rx) => rx.recv().await,
2270        }
2271    }
2272}
2273
2274#[derive(Debug)]
2275pub struct HttpResponse {
2276    pub version: String,
2277    pub http_code: u16,
2278    pub headers: HashMap<Cow<'static, str>, Cow<'static, str>>,
2279    pub trailers: HashMap<Cow<'static, str>, Cow<'static, str>>,
2280    pub body: HttpResponseBody,
2281}
2282unsafe impl Send for HttpResponse {}
2283unsafe impl Sync for HttpResponse {}
2284impl Clone for HttpResponse {
2285    fn clone(&self) -> Self {
2286        Self {
2287            version: self.version.clone(),
2288            http_code: self.http_code,
2289            headers: self.headers.clone(),
2290            trailers: self.trailers.clone(),
2291            body: match &self.body {
2292                HttpResponseBody::Data(data) => HttpResponseBody::Data(data.clone()),
2293                HttpResponseBody::Stream(_) => HttpResponseBody::Data(vec![]), // Clone stream as empty data
2294            },
2295        }
2296    }
2297}
2298
2299macro_rules! make_resp_by_text {
2300    ($fn_name:ident, $cnt_type:expr) => {
2301        pub fn $fn_name(body: impl Into<String>) -> Self {
2302            let body = body.into();
2303            Self {
2304                version: "HTTP/1.1".into(),
2305                http_code: 200,
2306                headers: Self::default_headers($cnt_type),
2307                trailers: HashMap::with_capacity(4),
2308                body: HttpResponseBody::Data(body.as_bytes().to_vec()),
2309            }
2310        }
2311    };
2312}
2313
2314macro_rules! make_resp_by_binary {
2315    ($fn_name:ident, $cnt_type:expr) => {
2316        pub fn $fn_name(body: &[u8]) -> Self {
2317            Self {
2318                version: "HTTP/1.1".into(),
2319                http_code: 200,
2320                headers: Self::default_headers($cnt_type),
2321                trailers: HashMap::with_capacity(4),
2322                body: HttpResponseBody::Data(body.to_vec()),
2323            }
2324        }
2325    };
2326}
2327
2328impl Default for HttpResponse {
2329    fn default() -> Self {
2330        Self::new()
2331    }
2332}
2333
2334impl HttpResponse {
2335    make_resp_by_text!(html, "text/html");
2336    make_resp_by_text!(css, "text/css");
2337    make_resp_by_text!(csv, "text/csv");
2338    make_resp_by_text!(js, "text/javascript");
2339    make_resp_by_text!(text, "text/plain");
2340    make_resp_by_text!(json, "application/json");
2341    make_resp_by_text!(xml, "application/xml");
2342    make_resp_by_binary!(png, "image/png");
2343
2344    fn default_headers(
2345        cnt_type: impl Into<String>,
2346    ) -> HashMap<Cow<'static, str>, Cow<'static, str>> {
2347        let now = Utc::now();
2348        let current_ts = now.timestamp();
2349
2350        static TL_TIMESTAMP: ThreadLocal<RefCell<(i64, Cow<'static, str>)>> = ThreadLocal::new();
2351        let mut tl_timestamp = TL_TIMESTAMP.get_or_default().borrow_mut();
2352        let date_str = if current_ts == tl_timestamp.0 {
2353            tl_timestamp.1.clone()
2354        } else {
2355            let new_date: Cow<'_, str> = now.format("%a, %d %b %Y %H:%M:%S GMT").to_string().into();
2356            *tl_timestamp = (current_ts, new_date.clone());
2357            new_date
2358        };
2359
2360        [
2361            ("Date".into(), date_str),
2362            ("Server".into(), SERVER_STR.clone().into()),
2363            ("Connection".into(), "keep-alive".into()),
2364            ("Content-Type".into(), cnt_type.into().into()),
2365            ("Pragma".into(), "no-cache".into()),
2366            ("Cache-Control".into(), "no-cache".into()),
2367        ]
2368        .into()
2369    }
2370
2371    pub fn new() -> Self {
2372        Self {
2373            version: "".into(),
2374            http_code: 0,
2375            headers: HashMap::with_capacity(16),
2376            trailers: HashMap::with_capacity(4),
2377            body: HttpResponseBody::Data(vec![]),
2378        }
2379    }
2380
2381    pub fn add_header(&mut self, key: Cow<'static, str>, value: Cow<'static, str>) {
2382        self.headers.insert(key, value);
2383    }
2384
2385    pub fn not_found() -> Self {
2386        let mut ret = Self::html("404 not found");
2387        ret.http_code = 404;
2388        ret
2389    }
2390
2391    pub fn error(payload: impl Into<String>) -> Self {
2392        let mut ret = Self::html(payload);
2393        ret.http_code = 500;
2394        ret
2395    }
2396
2397    pub fn empty() -> Self {
2398        Self::html("")
2399    }
2400
2401    pub fn chunked(rx: Receiver<Vec<u8>>) -> Self {
2402        Self {
2403            version: "HTTP/1.1".into(),
2404            http_code: 200,
2405            headers: [
2406                ("Transfer-Encoding".into(), "chunked".into()),
2407                ("Content-Type".into(), "application/octet-stream".into()),
2408                ("Cache-Control".into(), "no-cache".into()),
2409                ("Connection".into(), "keep-alive".into()),
2410            ]
2411            .into(),
2412            trailers: HashMap::with_capacity(4),
2413            body: HttpResponseBody::Stream(rx),
2414        }
2415    }
2416
2417    /// Create a SSE response with chunked transfer encoding
2418    pub fn sse(rx: Receiver<Vec<u8>>) -> Self {
2419        let mut res = Self::chunked(rx);
2420        res.add_header("Content-Type".into(), "text/event-stream".into());
2421        res
2422    }
2423
2424    pub fn from_file(path: &str, download: bool, meta: Option<Metadata>) -> Self {
2425        let mut buffer = vec![];
2426        if let Ok(mut file) = File::open(path) {
2427            _ = file.read_to_end(&mut buffer);
2428        }
2429        Self::from_mem_file(path, buffer, download, meta)
2430    }
2431
2432    pub fn from_mem_file(
2433        path: &str,
2434        data: Vec<u8>,
2435        download: bool,
2436        meta: Option<Metadata>,
2437    ) -> Self {
2438        if let Some(meta) = meta {
2439            let mut ret = Self::from_mem_file(path, data, download, None);
2440            // Add Last-Modified header
2441            if let Ok(modified) = meta.modified() {
2442                if let Ok(duration) = modified.duration_since(UNIX_EPOCH) {
2443                    let modified_time =
2444                        chrono::DateTime::<chrono::Utc>::from(UNIX_EPOCH + duration);
2445                    ret.add_header(
2446                        "Last-Modified".into(),
2447                        modified_time
2448                            .format("%a, %d %b %Y %H:%M:%S GMT")
2449                            .to_string()
2450                            .into(),
2451                    );
2452                }
2453            }
2454
2455            // Add ETag header (format: "hex-file-modified-time-hex-file-size")
2456            if let Ok(modified) = meta.modified() {
2457                if let Ok(duration) = modified.duration_since(UNIX_EPOCH) {
2458                    let modified_secs = duration.as_secs();
2459                    let file_size = meta.len();
2460                    let etag = format!("\"{:x}-{:x}\"", modified_secs, file_size);
2461                    ret.add_header("ETag".into(), etag.into());
2462                }
2463            }
2464            ret
2465        } else {
2466            let mut ret = Self::empty();
2467            let mime_type = match path.split('.').next_back() {
2468                Some("css") => "text/css",
2469                Some("csv") => "text/csv",
2470                Some("htm") => "text/html",
2471                Some("html") => "text/html",
2472                Some("js") => "application/javascript",
2473                Some("json") => "application/json",
2474                Some("pdf") => "application/pdf",
2475                Some("xml") => "application/xml",
2476                _ if path.ends_with('/') => "text/html",
2477                _ => "application/octet-stream",
2478            };
2479            ret.add_header("Content-Type".into(), mime_type.into());
2480            if download {
2481                let file = match path.rfind('/') {
2482                    Some(p) => &path[p + 1..],
2483                    None => path,
2484                };
2485                if !file.is_empty() {
2486                    ret.add_header(
2487                        "Content-Disposition".into(),
2488                        format!("attachment; filename={file}").into(),
2489                    );
2490                }
2491            }
2492            ret.body = HttpResponseBody::Data(data);
2493            ret
2494        }
2495    }
2496
2497    pub fn from_websocket(ws_key: &str) -> Self {
2498        #[allow(deprecated)]
2499        let ws_accept = {
2500            let mut sha1 = Sha1::default();
2501            sha1.update(ws_key);
2502            sha1.update(&b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"[..]);
2503            base64::encode(sha1.finalize())
2504        };
2505        Self {
2506            version: "HTTP/1.1".into(),
2507            http_code: 101,
2508            headers: [
2509                (
2510                    "Date".into(),
2511                    Utc::now()
2512                        .format("%a, %d %b %Y %H:%M:%S GMT")
2513                        .to_string()
2514                        .into(),
2515                ),
2516                ("Server".into(), SERVER_STR.clone().into()),
2517                ("Connection".into(), "Upgrade".into()),
2518                ("Upgrade".into(), "websocket".into()),
2519                ("Sec-WebSocket-Accept".into(), ws_accept.into()),
2520            ]
2521            .into(),
2522            trailers: HashMap::with_capacity(4),
2523            body: HttpResponseBody::Data(vec![]),
2524        }
2525    }
2526
2527    pub fn add_trailer(&mut self, key: Cow<'static, str>, value: Cow<'static, str>) {
2528        self.trailers.insert(key, value);
2529    }
2530
2531    pub fn get_trailer(&self, key: &str) -> Option<&str> {
2532        self.trailers.get(key).map(|v| v.as_ref())
2533    }
2534
2535    fn status_disallows_response_body(status: u16) -> bool {
2536        (100..200).contains(&status) || status == 204 || status == 304
2537    }
2538
2539    fn method_disallows_response_body(request_method: Option<HttpMethod>) -> bool {
2540        request_method == Some(HttpMethod::HEAD)
2541    }
2542
2543    fn transfer_encoding_has_chunked(raw: &str) -> anyhow::Result<bool> {
2544        let codings = parse_transfer_encoding_tokens(raw)?;
2545        if codings.iter().any(|coding| coding == "chunked") {
2546            if codings.last().is_some_and(|coding| coding == "chunked") {
2547                return Ok(true);
2548            }
2549            Err(anyhow!(
2550                "invalid Transfer-Encoding order: chunked must be final coding"
2551            ))?;
2552        }
2553        Err(anyhow!("unsupported Transfer-Encoding: {raw}"))
2554    }
2555
2556    async fn read_chunked_body(
2557        buf: &mut Vec<u8>,
2558        stream: &mut HttpStream,
2559        hdr_len: usize,
2560        allowed_trailers: &HashSet<String>,
2561    ) -> anyhow::Result<(
2562        Vec<u8>,
2563        HashMap<Cow<'static, str>, Cow<'static, str>>,
2564        usize,
2565    )> {
2566        let mut cursor = hdr_len;
2567        let mut body = Vec::new();
2568        let mut trailers = HashMap::with_capacity(4);
2569        let mut tmp_buf = [0u8; 4096];
2570
2571        loop {
2572            let line_end = loop {
2573                if let Some(pos) = buf[cursor..].windows(2).position(|part| part == b"\r\n") {
2574                    break cursor + pos;
2575                }
2576                let n = stream.read(&mut tmp_buf).await?;
2577                if n == 0 {
2578                    Err(anyhow::Error::msg("connection closed"))?;
2579                }
2580                buf.extend(&tmp_buf[..n]);
2581            };
2582
2583            let chunk_size = {
2584                let size_line = str::from_utf8(&buf[cursor..line_end])?;
2585                let size_token = size_line
2586                    .split_once(';')
2587                    .map_or(size_line, |(size, _)| size)
2588                    .trim();
2589                if size_token.is_empty() {
2590                    Err(anyhow!("invalid chunk size"))?;
2591                }
2592                usize::from_str_radix(size_token, 16)?
2593            };
2594            cursor = line_end + 2;
2595
2596            if chunk_size == 0 {
2597                loop {
2598                    let line_start = cursor;
2599                    let line_end = loop {
2600                        if let Some(pos) = buf[cursor..].windows(2).position(|part| part == b"\r\n")
2601                        {
2602                            break cursor + pos;
2603                        }
2604                        let n = stream.read(&mut tmp_buf).await?;
2605                        if n == 0 {
2606                            Err(anyhow::Error::msg("connection closed"))?;
2607                        }
2608                        buf.extend(&tmp_buf[..n]);
2609                    };
2610                    cursor = line_end + 2;
2611                    if line_end == line_start {
2612                        break;
2613                    }
2614
2615                    let (name, value) = parse_trailer_line(&buf[line_start..line_end])?;
2616                    let name_lower = name.to_ascii_lowercase();
2617                    if is_forbidden_trailer_field(&name_lower) {
2618                        Err(anyhow!("forbidden trailer field: {name}"))?;
2619                    }
2620                    if !allowed_trailers.contains(&name_lower) {
2621                        Err(anyhow!("unexpected trailer field: {name}"))?;
2622                    }
2623                    trailers.insert(name.into(), value.into());
2624                }
2625                break;
2626            }
2627
2628            while buf.len() < cursor + chunk_size + 2 {
2629                let n = stream.read(&mut tmp_buf).await?;
2630                if n == 0 {
2631                    Err(anyhow::Error::msg("connection closed"))?;
2632                }
2633                buf.extend(&tmp_buf[..n]);
2634            }
2635            body.extend_from_slice(&buf[cursor..cursor + chunk_size]);
2636            if &buf[cursor + chunk_size..cursor + chunk_size + 2] != b"\r\n" {
2637                Err(anyhow!("invalid chunk terminator"))?;
2638            }
2639            cursor += chunk_size + 2;
2640        }
2641
2642        Ok((body, trailers, cursor - hdr_len))
2643    }
2644
2645    pub fn as_bytes(&self, mut cmode: CompressMode) -> Vec<u8> {
2646        match &self.body {
2647            HttpResponseBody::Data(data) => {
2648                let suppress_body = Self::status_disallows_response_body(self.http_code);
2649                let mut payload_tmp: Vec<u8> = vec![];
2650                if cmode == CompressMode::Gzip
2651                    && data.len() >= 32
2652                    && self.get_header("Content-Encoding").is_none()
2653                    && !suppress_body
2654                {
2655                    if let Ok(compressed_data) = data.compress() {
2656                        payload_tmp = compressed_data;
2657                    }
2658                }
2659                let mut payload_ref = if payload_tmp.is_empty() {
2660                    cmode = CompressMode::None;
2661                    data.as_slice()
2662                } else {
2663                    payload_tmp.as_slice()
2664                };
2665                if suppress_body {
2666                    cmode = CompressMode::None;
2667                    payload_ref = &[];
2668                }
2669                //
2670                let mut ret = smallstr::SmallString::<[u8; 4096]>::new();
2671                let status_str = self.http_code.http_code_to_desp();
2672                ret.push_str(&ssformat!(
2673                    64,
2674                    "{} {} {status_str}\r\n",
2675                    self.version,
2676                    self.http_code
2677                ));
2678                for (key, value) in self.headers.iter() {
2679                    if key == "Content-Length" || key.eq_ignore_ascii_case("Transfer-Encoding") {
2680                        continue;
2681                    }
2682                    ret.push_str(&ssformat!(512, "{key}: {value}\r\n"));
2683                }
2684                if !suppress_body {
2685                    ret.push_str(&ssformat!(64, "Content-Length: {}\r\n", payload_ref.len()));
2686                }
2687                if cmode == CompressMode::Gzip && !suppress_body {
2688                    ret.push_str("Content-Encoding: gzip\r\n");
2689                }
2690                ret.push_str("\r\n");
2691                let mut ret: Vec<u8> = ret.as_bytes().to_vec();
2692                ret.extend(payload_ref);
2693                ret
2694            }
2695            HttpResponseBody::Stream(_) => vec![], // Stream responses are handled separately
2696        }
2697    }
2698
2699    /// Write response to stream, handling both Data and Stream body types
2700    pub async fn write_to_stream(
2701        &mut self,
2702        stream: &mut crate::utils::tcp_stream::HttpStream,
2703        cmode: CompressMode,
2704        request_method: Option<HttpMethod>,
2705    ) -> anyhow::Result<()> {
2706        let suppress_body_by_status = Self::status_disallows_response_body(self.http_code);
2707        let suppress_body_by_method = Self::method_disallows_response_body(request_method);
2708        let suppress_body = suppress_body_by_status || suppress_body_by_method;
2709        let no_content_encoding = self.get_header("Content-Encoding").is_none();
2710        let declared_trailer_names = parse_declared_trailer_names(self.get_header("Trailer"));
2711        let mut outbound_stream_trailers: Vec<(String, String)> =
2712            Vec::with_capacity(self.trailers.len());
2713        for (key, value) in self.trailers.iter() {
2714            let lower = key.to_ascii_lowercase();
2715            if is_forbidden_trailer_field(&lower) {
2716                continue;
2717            }
2718            if !declared_trailer_names.is_empty() && !declared_trailer_names.contains(&lower) {
2719                continue;
2720            }
2721            outbound_stream_trailers.push((key.to_string(), value.to_string()));
2722        }
2723        match &mut self.body {
2724            HttpResponseBody::Data(data) => {
2725                let mut payload_tmp: Vec<u8> = vec![];
2726                let mut cmode = cmode;
2727                if cmode == CompressMode::Gzip
2728                    && data.len() >= 32
2729                    && no_content_encoding
2730                    && !suppress_body_by_status
2731                {
2732                    if let Ok(compressed_data) = data.compress() {
2733                        payload_tmp = compressed_data;
2734                    }
2735                }
2736                let mut payload_ref = if payload_tmp.is_empty() {
2737                    cmode = CompressMode::None;
2738                    data.as_slice()
2739                } else {
2740                    payload_tmp.as_slice()
2741                };
2742                if suppress_body_by_status {
2743                    cmode = CompressMode::None;
2744                    payload_ref = &[];
2745                }
2746
2747                let mut ret = smallstr::SmallString::<[u8; 4096]>::new();
2748                let status_str = self.http_code.http_code_to_desp();
2749                ret.push_str(&ssformat!(
2750                    64,
2751                    "{} {} {status_str}\r\n",
2752                    self.version,
2753                    self.http_code
2754                ));
2755                if self.headers.len() == 6 {
2756                    if let (
2757                        Some(date),
2758                        Some(server),
2759                        Some(connection),
2760                        Some(content_type),
2761                        Some(pragma),
2762                        Some(cache_control),
2763                    ) = (
2764                        self.headers.get("Date"),
2765                        self.headers.get("Server"),
2766                        self.headers.get("Connection"),
2767                        self.headers.get("Content-Type"),
2768                        self.headers.get("Pragma"),
2769                        self.headers.get("Cache-Control"),
2770                    ) {
2771                        ret.push_str(&ssformat!(512, "Date: {date}\r\n"));
2772                        ret.push_str(&ssformat!(512, "Server: {server}\r\n"));
2773                        ret.push_str(&ssformat!(512, "Connection: {connection}\r\n"));
2774                        ret.push_str(&ssformat!(512, "Content-Type: {content_type}\r\n"));
2775                        ret.push_str(&ssformat!(512, "Pragma: {pragma}\r\n"));
2776                        ret.push_str(&ssformat!(512, "Cache-Control: {cache_control}\r\n"));
2777                    } else {
2778                        for (key, value) in self.headers.iter() {
2779                            if key == "Content-Length"
2780                                || key.eq_ignore_ascii_case("Transfer-Encoding")
2781                            {
2782                                continue;
2783                            }
2784                            ret.push_str(&ssformat!(512, "{key}: {value}\r\n"));
2785                        }
2786                    }
2787                } else {
2788                    for (key, value) in self.headers.iter() {
2789                        if key == "Content-Length" || key.eq_ignore_ascii_case("Transfer-Encoding")
2790                        {
2791                            continue;
2792                        }
2793                        ret.push_str(&ssformat!(512, "{key}: {value}\r\n"));
2794                    }
2795                }
2796                if !suppress_body_by_status {
2797                    ret.push_str(&ssformat!(64, "Content-Length: {}\r\n", payload_ref.len()));
2798                }
2799                if cmode == CompressMode::Gzip && !suppress_body_by_status {
2800                    ret.push_str("Content-Encoding: gzip\r\n");
2801                }
2802                ret.push_str("\r\n");
2803
2804                if suppress_body || payload_ref.is_empty() {
2805                    stream.write_all(ret.as_bytes()).await?;
2806                } else {
2807                    stream
2808                        .write_all_vectored2(ret.as_bytes(), payload_ref)
2809                        .await?;
2810                }
2811            }
2812            HttpResponseBody::Stream(rx) => {
2813                // For Stream body, send headers first, then chunks
2814                let mut ret = smallstr::SmallString::<[u8; 4096]>::new();
2815                let status_str = self.http_code.http_code_to_desp();
2816                ret.push_str(&ssformat!(
2817                    64,
2818                    "{} {} {status_str}\r\n",
2819                    self.version,
2820                    self.http_code
2821                ));
2822                for (key, value) in self.headers.iter() {
2823                    if key == "Content-Length"
2824                        || (suppress_body && key.eq_ignore_ascii_case("Transfer-Encoding"))
2825                    {
2826                        continue;
2827                    }
2828                    ret.push_str(&ssformat!(512, "{key}: {value}\r\n"));
2829                }
2830                if !suppress_body {
2831                    ret.push_str("Transfer-Encoding: chunked\r\n");
2832                }
2833                if declared_trailer_names.is_empty() && !outbound_stream_trailers.is_empty() {
2834                    let trailer_names = outbound_stream_trailers
2835                        .iter()
2836                        .map(|(name, _)| name.as_str())
2837                        .collect::<Vec<_>>()
2838                        .join(", ");
2839                    ret.push_str(&ssformat!(512, "Trailer: {trailer_names}\r\n"));
2840                }
2841                ret.push_str("\r\n");
2842                let header_bytes: Vec<u8> = ret.as_bytes().to_vec();
2843                stream.write_all(&header_bytes).await?;
2844
2845                if suppress_body {
2846                    return Ok(());
2847                }
2848
2849                // Send chunks
2850                while let Some(chunk) = rx.recv().await {
2851                    if chunk.is_empty() {
2852                        break;
2853                    }
2854                    // Write chunked encoding: length in hex, \r\n, data, \r\n
2855                    let chunk_len_hex = format!("{:x}\r\n", chunk.len());
2856                    stream.write_all(chunk_len_hex.as_bytes()).await?;
2857                    stream.write_all(&chunk).await?;
2858                    stream.write_all(b"\r\n").await?;
2859                }
2860                // Send final chunk (length 0)
2861                if outbound_stream_trailers.is_empty() {
2862                    stream.write_all(b"0\r\n\r\n").await?;
2863                } else {
2864                    stream.write_all(b"0\r\n").await?;
2865                    for (name, value) in outbound_stream_trailers.iter() {
2866                        stream
2867                            .write_all(format!("{name}: {value}\r\n").as_bytes())
2868                            .await?;
2869                    }
2870                    stream.write_all(b"\r\n").await?;
2871                }
2872            }
2873        }
2874        Ok(())
2875    }
2876
2877    pub async fn from_stream(
2878        buf: &mut Vec<u8>,
2879        stream: &mut HttpStream,
2880    ) -> anyhow::Result<(Self, usize)> {
2881        Self::from_stream_with_request_method(buf, stream, None).await
2882    }
2883
2884    pub async fn from_stream_with_request_method(
2885        buf: &mut Vec<u8>,
2886        stream: &mut HttpStream,
2887        request_method: Option<HttpMethod>,
2888    ) -> anyhow::Result<(Self, usize)> {
2889        let (mut res, hdr_len) = loop {
2890            match HttpResponse::from_headers_part(&buf[..])? {
2891                Some((res, hdr_len)) => break (res, hdr_len),
2892                None => {
2893                    buf.extend_by_streams(stream).await?;
2894                }
2895            }
2896        };
2897        let mut bdy_len = 0;
2898        let skip_body = request_method == Some(HttpMethod::HEAD)
2899            || Self::status_disallows_response_body(res.http_code);
2900        if skip_body {
2901            return Ok((res, hdr_len));
2902        }
2903        let has_chunked_transfer_encoding =
2904            if let Some(raw_te) = res.headers.get("Transfer-Encoding") {
2905                Self::transfer_encoding_has_chunked(raw_te)?
2906            } else {
2907                false
2908            };
2909        if has_chunked_transfer_encoding && res.headers.contains_key("Content-Length") {
2910            Err(anyhow!(
2911                "conflicting headers: Transfer-Encoding and Content-Length"
2912            ))?;
2913        }
2914        if has_chunked_transfer_encoding {
2915            let allowed_trailers = parse_declared_trailer_names(res.get_header("Trailer"));
2916            let (chunked_body, trailers, consumed_len) =
2917                Self::read_chunked_body(buf, stream, hdr_len, &allowed_trailers).await?;
2918            bdy_len = consumed_len;
2919            res.trailers = trailers;
2920            res.body = HttpResponseBody::Data(chunked_body);
2921        } else if let Some(cnt_len) = res.headers.get("Content-Length") {
2922            bdy_len = cnt_len.parse::<usize>().unwrap_or(0);
2923            while hdr_len + bdy_len > buf.len() {
2924                buf.extend_by_streams(stream).await?;
2925            }
2926            res.body = HttpResponseBody::Data(buf[hdr_len..hdr_len + bdy_len].to_vec());
2927        }
2928
2929        Ok((res, hdr_len + bdy_len))
2930    }
2931
2932    pub fn from_headers_part(buf: &[u8]) -> anyhow::Result<Option<(Self, usize)>> {
2933        let mut headers = [httparse::EMPTY_HEADER; 96];
2934        let (rres, n) = {
2935            let mut res = httparse::Response::new(&mut headers);
2936            let n = match httparse::ParserConfig::default().parse_response(&mut res, buf)? {
2937                httparse::Status::Complete(n) => n,
2938                httparse::Status::Partial => return Ok(None),
2939            };
2940            (res, n)
2941        };
2942
2943        let mut req = HttpResponse::new();
2944        let mut content_length_seen: Option<String> = None;
2945        let mut transfer_encoding_seen = false;
2946        req.version = format!("HTTP/1.{}", rres.version.unwrap_or(0));
2947        req.http_code = rres.code.unwrap_or(0);
2948        for h in rres.headers.iter() {
2949            if h.name.is_empty() {
2950                break;
2951            }
2952            let header_value = str::from_utf8(h.value)?.trim();
2953            if h.name.eq_ignore_ascii_case("Content-Length") {
2954                if header_value.is_empty() {
2955                    Err(anyhow!("empty Content-Length header"))?;
2956                }
2957                if let Some(prev) = &content_length_seen {
2958                    if prev != header_value {
2959                        Err(anyhow!("conflicting duplicate Content-Length headers"))?;
2960                    }
2961                } else {
2962                    content_length_seen = Some(header_value.to_string());
2963                    req.headers
2964                        .insert("Content-Length".into(), header_value.to_string().into());
2965                }
2966                continue;
2967            }
2968            if h.name.eq_ignore_ascii_case("Transfer-Encoding") {
2969                if header_value.is_empty() {
2970                    Err(anyhow!("empty Transfer-Encoding header"))?;
2971                }
2972                transfer_encoding_seen = true;
2973                if let Some(existing) = req.headers.get_mut("Transfer-Encoding") {
2974                    *existing = format!("{}, {header_value}", existing.as_ref()).into();
2975                } else {
2976                    req.headers
2977                        .insert("Transfer-Encoding".into(), header_value.to_string().into());
2978                }
2979                continue;
2980            }
2981            req.headers.insert(
2982                h.name.http_std_case().into(),
2983                header_value.to_string().into(),
2984            );
2985        }
2986        if transfer_encoding_seen && content_length_seen.is_some() {
2987            Err(anyhow!(
2988                "conflicting headers: Transfer-Encoding and Content-Length"
2989            ))?;
2990        }
2991        Ok(Some((req, n)))
2992    }
2993
2994    pub fn get_header(&self, key: &str) -> Option<&str> {
2995        let header_key = key.http_std_case();
2996        self.headers.get(header_key.as_str()).map(|a| a.as_ref())
2997    }
2998}
2999
3000pub fn load_embed<T: Embed>() -> HashMap<String, Cow<'static, [u8]>> {
3001    let mut ret = HashMap::with_capacity(16);
3002    for name in T::iter() {
3003        if let Some(file) = T::get(&name) {
3004            if name.ends_with("index.htm") || name.ends_with("index.html") {
3005                if let Some(path) = Path::new(&name[..]).parent() {
3006                    if let Some(path) = path.to_str() {
3007                        ret.insert(path.to_string(), file.data.clone());
3008                    }
3009                }
3010            }
3011            ret.insert(name.to_string(), file.data);
3012        }
3013    }
3014    ret
3015}
3016
3017/// WebTransport 客户端连接宏
3018///
3019/// 用于快速连接到 WebTransport 服务器。
3020///
3021/// # 示例
3022///
3023/// ```rust,no_run
3024/// #[tokio::main]
3025/// async fn main() -> Result<(), Box<dyn std::error::Error>> {
3026///     // 基本连接
3027///     let mut wt = potato::webtransport!("https://server.com/wt").await?;
3028///
3029///     // 带自定义头连接
3030///     let mut wt = potato::webtransport!(
3031///         "https://server.com/wt",
3032///         Authorization = "Bearer token"
3033///     ).await?;
3034///     Ok(())
3035/// }
3036/// ```
3037#[cfg(feature = "http3")]
3038#[macro_export]
3039macro_rules! webtransport {
3040    ($url:expr) => {
3041        $crate::WebTransport::connect($url, vec![])
3042    };
3043    ($url:expr, $($key:ident = $value:expr),+ $(,)?) => {
3044        $crate::WebTransport::connect($url, vec![
3045            $(
3046                $crate::Headers::Custom((stringify!($key).to_string(), $value.to_string())),
3047            )+
3048        ])
3049    };
3050}
3051
3052#[cfg(test)]
3053mod tests {
3054    use super::{CompressMode, HttpMethod, HttpRequest, HttpResponse};
3055
3056    #[test]
3057    fn accept_encoding_supports_simple_gzip_token() {
3058        let mut req = HttpRequest::new();
3059        req.set_header("Accept-Encoding", "gzip");
3060        assert_eq!(req.get_header_accept_encoding(), CompressMode::Gzip);
3061    }
3062
3063    #[test]
3064    fn accept_encoding_supports_qvalue_for_gzip() {
3065        let mut req = HttpRequest::new();
3066        req.set_header("Accept-Encoding", "br;q=1, gzip;q=0.3");
3067        assert_eq!(req.get_header_accept_encoding(), CompressMode::Gzip);
3068    }
3069
3070    #[test]
3071    fn accept_encoding_uses_wildcard_when_gzip_not_listed() {
3072        let mut req = HttpRequest::new();
3073        req.set_header("Accept-Encoding", "br;q=1, *;q=0.8");
3074        assert_eq!(req.get_header_accept_encoding(), CompressMode::Gzip);
3075    }
3076
3077    #[test]
3078    fn accept_encoding_respects_explicit_gzip_zero_over_wildcard() {
3079        let mut req = HttpRequest::new();
3080        req.set_header("Accept-Encoding", "gzip;q=0, *;q=1");
3081        assert_eq!(req.get_header_accept_encoding(), CompressMode::None);
3082    }
3083
3084    #[test]
3085    fn accept_encoding_ignores_invalid_qvalue() {
3086        let mut req = HttpRequest::new();
3087        req.set_header("Accept-Encoding", "gzip;q=xyz");
3088        assert_eq!(req.get_header_accept_encoding(), CompressMode::None);
3089    }
3090
3091    #[test]
3092    fn request_parser_returns_431_for_too_many_headers() {
3093        let mut raw = String::from("GET / HTTP/1.1\r\nHost: example.com\r\n");
3094        for i in 0..64 {
3095            raw.push_str(&format!("X-Header-{i}: value\r\n"));
3096        }
3097        raw.push_str("\r\n");
3098
3099        let err = HttpRequest::from_headers_part(raw.as_bytes()).unwrap_err();
3100        let res = HttpRequest::parse_error_response(&err).unwrap();
3101        assert_eq!(res.http_code, 431);
3102    }
3103
3104    #[test]
3105    fn request_parser_returns_431_for_oversized_header_section() {
3106        let oversized = "a".repeat(20 * 1024);
3107        let raw = format!("GET / HTTP/1.1\r\nHost: example.com\r\nX-Large: {oversized}\r\n\r\n");
3108
3109        let err = HttpRequest::from_headers_part(raw.as_bytes()).unwrap_err();
3110        let res = HttpRequest::parse_error_response(&err).unwrap();
3111        assert_eq!(res.http_code, 431);
3112    }
3113
3114    #[test]
3115    fn request_serialization_includes_query_and_version() {
3116        let (mut req, _, _) =
3117            HttpRequest::from_url("http://example.com/search?q=rust", HttpMethod::GET).unwrap();
3118        req.version = 10;
3119        let serialized = String::from_utf8(req.as_bytes()).unwrap();
3120
3121        assert!(serialized.starts_with("GET /search?q=rust HTTP/1.0\r\n"));
3122    }
3123
3124    #[test]
3125    fn request_serialization_falls_back_to_http11_for_non_h1_versions() {
3126        let (mut req, _, _) =
3127            HttpRequest::from_url("http://example.com/", HttpMethod::GET).unwrap();
3128        req.version = 20;
3129        let serialized = String::from_utf8(req.as_bytes()).unwrap();
3130
3131        assert!(serialized.starts_with("GET / HTTP/1.1\r\n"));
3132    }
3133
3134    #[test]
3135    fn from_url_host_header_keeps_non_default_port() {
3136        let (req, _, _) =
3137            HttpRequest::from_url("http://example.com:8080/demo", HttpMethod::GET).unwrap();
3138        assert_eq!(req.get_header("Host"), Some("example.com:8080"));
3139    }
3140
3141    #[test]
3142    fn response_parser_rejects_conflicting_duplicate_content_length() {
3143        let raw = b"HTTP/1.1 200 OK\r\nContent-Length: 5\r\nContent-Length: 6\r\n\r\n";
3144        assert!(HttpResponse::from_headers_part(raw).is_err());
3145    }
3146
3147    #[test]
3148    fn response_parser_rejects_transfer_encoding_content_length_conflict() {
3149        let raw = b"HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\nContent-Length: 5\r\n\r\n";
3150        assert!(HttpResponse::from_headers_part(raw).is_err());
3151    }
3152
3153    #[test]
3154    fn response_transfer_encoding_supports_case_insensitive_lists() {
3155        assert!(HttpResponse::transfer_encoding_has_chunked("gzip, Chunked").unwrap());
3156    }
3157
3158    #[test]
3159    fn response_data_serialization_omits_transfer_encoding() {
3160        let mut res = HttpResponse::text("hello");
3161        res.add_header("Transfer-Encoding".into(), "chunked".into());
3162
3163        let serialized = String::from_utf8(res.as_bytes(CompressMode::None)).unwrap();
3164        assert!(serialized.contains("Content-Length: 5\r\n"));
3165        assert!(!serialized.contains("Transfer-Encoding:"));
3166    }
3167}