1use std::pin::Pin;
2use std::task::{Context, Poll};
3
4use bytes::{Buf, BufMut, Bytes, BytesMut};
5use tokio::io::{AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};
6
7use crate::fastcgi::{Header, NameValuePair, Record, FastCGIRole, RecordType};
8#[cfg(feature = "web_server")]
9use http_body::Body;
10use log::trace;
11use tokio::io::AsyncWriteExt;
12
13pub enum FCGIType {
14    BeginRequest {
15        request_id: u16,
16        role: FastCGIRole,
17        flags: u8,
18    },
19    AbortRequest {
20        request_id: u16,
21    },
22    EndRequest {
23        request_id: u16,
24        app_status: u32,
25        proto_status: u8,
26    },
27    Params {
28        request_id: u16,
29        p: Vec<NameValuePair>,
30    },
31    STDIN {
32        request_id: u16,
33        data: Bytes,
34    },
35    STDOUT {
36        request_id: u16,
37        data: Bytes,
38    },
39    STDERR {
40        request_id: u16,
41        data: Bytes,
42    },
43    DATA {
44        request_id: u16,
45        data: Bytes,
46    },
47    GetValues {
48        p: Vec<NameValuePair>,
49    },
50    GetValuesResult {
51        p: Vec<NameValuePair>,
52    },
53}
54
55pub struct FCGIWriter<RW> {
56    io: RW,
57    }
59impl<RW: AsyncRead + AsyncWrite + Unpin> FCGIWriter<RW> {
60    pub fn new(io: RW) -> FCGIWriter<RW> {
61        FCGIWriter { io }
62    }
63}
64
65const BUF_LEN: usize = 0xFF_FF - 7 + 8;
67
68impl<R: AsyncRead + Unpin> AsyncRead for FCGIWriter<R> {
69    fn poll_read(
70        self: Pin<&mut Self>,
71        cx: &mut Context<'_>,
72        buf: &mut ReadBuf<'_>,
73    ) -> Poll<std::io::Result<()>> {
74        Pin::new(&mut self.get_mut().io).poll_read(cx, buf)
75    }
76}
77impl<R: AsyncBufRead + Unpin> AsyncBufRead for FCGIWriter<R> {
78    fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<&[u8]>> {
79        Pin::new(&mut self.get_mut().io).poll_fill_buf(cx)
80    }
81
82    fn consume(self: Pin<&mut Self>, amt: usize) {
83        Pin::new(&mut self.get_mut().io).consume(amt)
84    }
85}
86impl<W: AsyncWrite + Unpin> FCGIWriter<W> {
87    pub async fn shutdown(&mut self) -> std::io::Result<()> {
88        self.io.shutdown().await
89    }
90    #[inline]
91    async fn write_whole_buf<B: Buf>(&mut self, buf: &mut B) -> std::io::Result<()> {
92        while buf.has_remaining() {
93            self.io.write_buf(buf).await?;
96        }
97        Ok(())
98    }
99
100    pub async fn encode(&mut self, item: FCGIType) -> std::io::Result<()> {
101        match item {
102            FCGIType::BeginRequest {
103                request_id,
104                role,
105                flags,
106            } => {
107                let mut buf = BytesMut::with_capacity(Header::HEADER_LEN + 8);
108                Header::new(RecordType::BeginRequest, request_id, 8).write_into(&mut buf);
109                buf.put_u16(role as u16);
110                buf.put_u8(flags);
111                buf.put_slice(&[0; 5]); self.write_whole_buf(&mut buf.freeze()).await?;
113            }
114            FCGIType::AbortRequest { request_id } => {
115                let mut buf = BytesMut::with_capacity(Header::HEADER_LEN);
116                Header::new(RecordType::AbortRequest, request_id, 0).write_into(&mut buf);
117                self.write_whole_buf(&mut buf.freeze()).await?;
118            }
119            FCGIType::EndRequest {
120                request_id,
121                app_status,
122                proto_status,
123            } => {
124                let mut buf = BytesMut::with_capacity(Header::HEADER_LEN + 8);
125                Header::new(RecordType::EndRequest, request_id, 8).write_into(&mut buf);
126
127                buf.put_u32(app_status);
128                buf.put_u8(proto_status);
129                buf.put_slice(&[0; 3]); self.write_whole_buf(&mut buf.freeze()).await?;
131            }
132            FCGIType::Params { request_id, p } => {
133                self.encode_kvp(request_id, RecordType::Params, p).await?;
134            }
135            FCGIType::STDIN { request_id, data } => {
136                self.encode_data(request_id, RecordType::StdIn, data).await?;
137            }
138            FCGIType::STDOUT { request_id, data } => {
139                self.encode_data(request_id, RecordType::StdOut, data).await?;
140            }
141            FCGIType::STDERR { request_id, data } => {
142                self.encode_data(request_id, RecordType::StdErr, data).await?;
143            }
144            FCGIType::DATA { request_id, data } => {
145                self.encode_data(request_id, RecordType::Data, data).await?;
146            }
147            FCGIType::GetValues { p } => {
148                self.encode_kvp(Record::MGMT_REQUEST_ID, RecordType::GetValues, p)
149                    .await?;
150            }
151            FCGIType::GetValuesResult { p } => {
152                self.encode_kvp(Record::MGMT_REQUEST_ID, RecordType::GetValuesResult, p)
153                    .await?;
154            }
155        }
156        Ok(())
157    }
158    async fn encode_kvp(
159        &mut self,
160        request_id: u16,
161        rtype: RecordType,
162        p: Vec<NameValuePair>,
163    ) -> std::io::Result<()> {
164        let mut kvps = self.kv_stream(request_id, rtype);
165
166        for pair in p {
167            kvps.add(pair).await?;
168        }
169        kvps.flush().await?;
171        Ok(())
172    }
173    #[inline]
178    async fn append_to_stream<B>(
179        &mut self,
180        mut val: B,
181        buf: &mut BytesMut,
182        request_id: u16,
183        rtype: RecordType,
184    ) -> std::io::Result<()>
185    where
186        B: Buf,
187    {
188        while buf.len() + val.remaining() > BUF_LEN {
189            let mut part = val.take(BUF_LEN - buf.len());
190            Header::new(rtype, request_id, (BUF_LEN - Header::HEADER_LEN) as u16)
191                .write_into(&mut &mut buf[0..Header::HEADER_LEN]);
192
193            self.write_whole_buf(buf).await?;
194            self.write_whole_buf(&mut part).await?;
195            val = part.into_inner();
196
197            unsafe {
198                buf.set_len(Header::HEADER_LEN);
199            } }
201        buf.put(val);
202        Ok(())
203    }
204    #[inline]
206    async fn end_stream(
207        &mut self,
208        mut buf: BytesMut,
209        request_id: u16,
210        rtype: RecordType,
211    ) -> std::io::Result<()> {
212        if buf.len() - Header::HEADER_LEN > 0 {
213            let last_head = Header::new(rtype, request_id, (buf.len() - Header::HEADER_LEN) as u16);
215            let pad = last_head.get_padding() as usize;
216            last_head.write_into(&mut &mut buf[0..Header::HEADER_LEN]);
217            let mut buf = buf.freeze();
218            trace!("..with header: {:?}", buf);
219            let mut pad = buf.slice(0..pad);
220            self.write_whole_buf(&mut buf).await?;
221            self.write_whole_buf(&mut pad).await?;
223        }
224        let mut end = BytesMut::with_capacity(Header::HEADER_LEN);
226        Header::new(rtype, request_id, 0).write_into(&mut end);
227        self.write_whole_buf(&mut end).await?;
228        Ok(())
229    }
230    async fn encode_data(
232        &mut self,
233        request_id: u16,
234        rtype: RecordType,
235        mut data: Bytes,
236    ) -> std::io::Result<()> {
237        let mut buf = BytesMut::with_capacity(BUF_LEN);
238        unsafe {
239            buf.set_len(Header::HEADER_LEN);
240        } self.append_to_stream(&mut data, &mut buf, request_id, rtype)
242            .await?;
243        self.end_stream(buf, request_id, rtype).await?;
244        Ok(())
245    }
246    #[cfg(feature = "web_server")]
247    pub async fn data_stream<B>(
249        &mut self,
250        mut body: B,
251        request_id: u16,
252        rtype: RecordType,
253        mut len: usize,
254    ) -> std::io::Result<()>
255    where
256        B: Body + Unpin,
257    {
258        let mut buf = BytesMut::with_capacity(BUF_LEN);
259        unsafe {
260            buf.set_len(Header::HEADER_LEN);
261        } while let Some(chunk) = crate::client::connection::BodyExt::data(&mut body).await {
263            if let Ok(mut b) = chunk {
264                let val = b.copy_to_bytes(b.remaining());
266                len = len.saturating_sub(val.len());
267                self.append_to_stream(val, &mut buf, request_id, rtype)
268                    .await?;
269            }
270        }
271
272        if len > 0 {
273            let a = FCGIType::AbortRequest { request_id };
275            self.encode(a).await?;
276            return Err(std::io::Error::new(
277                std::io::ErrorKind::ConnectionAborted,
278                "body too short",
279            ));
280        }
281        self.end_stream(buf, request_id, rtype).await?;
282        Ok(())
283    }
284    #[cfg(feature = "web_server")]
285    pub async fn flush_data_chunk<B>(
289        &mut self,
290        mut data: B,
291        request_id: u16,
292        rtype: RecordType,
293    ) -> std::io::Result<()>
294    where
295        B: Buf,
296    {
297        let mut header = BytesMut::with_capacity(Header::HEADER_LEN);
298        while data.remaining() > BUF_LEN - Header::HEADER_LEN {
299            let mut part = data.take(BUF_LEN - Header::HEADER_LEN);
300            Header::new(rtype, request_id, (BUF_LEN - Header::HEADER_LEN) as u16)
301                .write_into(&mut header);
302
303            self.write_whole_buf(&mut header).await?;
304            self.write_whole_buf(&mut part).await?;
305            data = part.into_inner();
306            header.clear();
307        }
308        let last_head = Header::new(rtype, request_id, data.remaining() as u16);
309        let pad = last_head.get_padding() as usize;
310        last_head.write_into(&mut header);
311        if data.remaining() == 0 {
312            self.write_whole_buf(&mut header).await?;
313            return Ok(());
314        }
315        let mut buf = header.freeze();
316        let mut pad = buf.slice(0..pad);
317        self.write_whole_buf(&mut buf).await?;
318        self.write_whole_buf(&mut data).await?;
319        self.write_whole_buf(&mut pad).await?;
321        Ok(())
322    }
323    pub fn kv_stream(&mut self, request_id: u16, rtype: RecordType) -> NameValuePairWriter<W> {
336        let mut buf = BytesMut::with_capacity(BUF_LEN);
337        unsafe {
338            buf.set_len(Header::HEADER_LEN);
339        } NameValuePairWriter {
341            w: self,
342            request_id,
343            rtype,
344            buf,
345        }
346    }
347}
348pub struct NameValuePairWriter<'a, R> {
372    w: &'a mut FCGIWriter<R>,
373    request_id: u16,
374    rtype: RecordType,
375    buf: BytesMut,
376}
377impl<R: AsyncWrite + Unpin> NameValuePairWriter<'_, R> {
378    pub async fn flush(self) -> std::io::Result<()> {
380        self.w
382            .end_stream(self.buf, self.request_id, self.rtype)
383            .await?;
384        Ok(())
385    }
386    pub async fn extend<T: IntoIterator<Item = (P1, P2)>, P1: Buf, P2: Buf>(
389        &mut self,
390        iter: T,
391    ) -> std::io::Result<()> {
392        for (k, v) in iter {
393            self.add_kv(k, v).await?;
394        }
395        Ok(())
396    }
397    #[inline]
398    pub async fn add(&mut self, mut pair: NameValuePair) -> std::io::Result<()> {
401        self.add_kv(&mut pair.name_data, &mut pair.value_data).await
402    }
403    #[inline]
404    pub async fn add_kv<B1, B2>(&mut self, mut name: B1, mut val: B2) -> std::io::Result<()>
407    where
408        B1: Buf,
409        B2: Buf,
410    {
411        let mut ln = name.remaining();
412        let mut lv = val.remaining();
413        let mut blen = BytesMut::with_capacity(8);
414        if ln > 0x7f {
415            if ln > 0x7fffffff {
416                panic!();
417            }
418            ln |= 0x80000000;
419            blen.put_u32(ln as u32);
420        } else {
421            blen.put_u8(ln as u8);
422        }
423        if lv > 0x7f {
424            if lv > 0x7fffffff {
425                panic!();
426            }
427            lv |= 0x80000000;
428            blen.put_u32(lv as u32);
429        } else {
430            blen.put_u8(lv as u8);
431        }
432        self.w
433            .append_to_stream(&mut blen, &mut self.buf, self.request_id, self.rtype)
434            .await?;
435
436        self.w
437            .append_to_stream(&mut name, &mut self.buf, self.request_id, self.rtype)
438            .await?;
439
440        self.w
441            .append_to_stream(&mut val, &mut self.buf, self.request_id, self.rtype)
442            .await?;
443        Ok(())
444    }
445}