Documentation
/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
use std::{
    collections::VecDeque,
    pin::Pin,
    sync::{Arc, Mutex},
    task::{Context, Poll},
};

use bytes::{Buf, BufMut, Bytes, BytesMut};
use futures_core::ready;

use crate::StdError;
use http::HeaderMap;
use http_body::Body;
use pin_project::pin_project;
use thiserror::Error;

#[derive(Error, Debug)]
#[error("buffered body reach max capacity.")]
pub struct ReachMaxCapacityError;

pub struct BufferedBody {
    shared: Arc<Mutex<Option<OwnedBufferedBody>>>,
    owned: Option<OwnedBufferedBody>,
    replay_body: bool,
    replay_trailers: bool,
    is_empty: bool,
    size_hint: http_body::SizeHint,
}

pub struct OwnedBufferedBody {
    body: hyper::Body,
    trailers: Option<HeaderMap>,
    buf: InnerBuffer,
}

impl BufferedBody {
    pub fn new(body: hyper::Body, buf_size: usize) -> Self {
        let size_hint = body.size_hint();
        let is_empty = body.is_end_stream();
        BufferedBody {
            shared: Default::default(),
            owned: Some(OwnedBufferedBody {
                body,
                trailers: None,
                buf: InnerBuffer {
                    bufs: Default::default(),
                    capacity: buf_size,
                },
            }),
            replay_body: false,
            replay_trailers: false,
            is_empty,
            size_hint,
        }
    }
}

impl Clone for BufferedBody {
    fn clone(&self) -> Self {
        Self {
            shared: self.shared.clone(),
            owned: None,
            replay_body: true,
            replay_trailers: true,
            is_empty: self.is_empty,
            size_hint: self.size_hint.clone(),
        }
    }
}

impl Drop for BufferedBody {
    fn drop(&mut self) {
        if let Some(owned) = self.owned.take() {
            let lock = self.shared.lock();
            if let Ok(mut lock) = lock {
                *lock = Some(owned);
            }
        }
    }
}

impl Body for BufferedBody {
    type Data = BytesData;
    type Error = StdError;

    fn poll_data(
        self: Pin<&mut Self>,
        cx: &mut Context<'_>,
    ) -> Poll<Option<Result<Self::Data, Self::Error>>> {
        let mut_self = self.get_mut();

        let owned_body = mut_self.owned.get_or_insert_with(|| {
            let lock = mut_self.shared.lock();
            if let Err(e) = lock {
                panic!("buffered body get shared data lock failed. {}", e);
            }
            let mut data = lock.unwrap();

            data.take().expect("cannot get shared buffered body.")
        });

        if mut_self.replay_body {
            mut_self.replay_body = false;
            if owned_body.buf.has_remaining() {
                return Poll::Ready(Some(Ok(BytesData::BufferedBytes(owned_body.buf.clone()))));
            }

            if owned_body.buf.is_capped() {
                return Poll::Ready(Some(Err(ReachMaxCapacityError.into())));
            }
        }

        if mut_self.is_empty {
            return Poll::Ready(None);
        }

        let mut data = {
            let pin = Pin::new(&mut owned_body.body);
            let data = ready!(pin.poll_data(cx));
            match data {
                Some(Ok(data)) => data,
                Some(Err(e)) => return Poll::Ready(Some(Err(e.into()))),
                None => {
                    mut_self.is_empty = true;
                    return Poll::Ready(None);
                }
            }
        };

        let len = data.remaining();

        owned_body.buf.capacity = owned_body.buf.capacity.saturating_sub(len);

        let data = if owned_body.buf.is_capped() {
            if owned_body.buf.has_remaining() {
                owned_body.buf.bufs = VecDeque::default();
            }
            data.copy_to_bytes(len)
        } else {
            owned_body.buf.push_bytes(data.copy_to_bytes(len))
        };

        Poll::Ready(Some(Ok(BytesData::OriginBytes(data))))
    }

    fn poll_trailers(
        self: Pin<&mut Self>,
        cx: &mut Context<'_>,
    ) -> Poll<Result<Option<http::HeaderMap>, Self::Error>> {
        let mut_self = self.get_mut();
        let owned_body = mut_self.owned.get_or_insert_with(|| {
            let lock = mut_self.shared.lock();
            if let Err(e) = lock {
                panic!("buffered body get shared data lock failed. {}", e);
            }
            let mut data = lock.unwrap();

            data.take().expect("cannot get shared buffered body.")
        });

        if mut_self.replay_trailers {
            mut_self.replay_trailers = false;
            if let Some(ref trailers) = owned_body.trailers {
                return Poll::Ready(Ok(Some(trailers.clone())));
            }
        }

        let mut_body = &mut owned_body.body;
        if !mut_body.is_end_stream() {
            let trailers = ready!(Pin::new(mut_body).poll_trailers(cx)).map(|trailers| {
                owned_body.trailers = trailers.clone();
                trailers
            });
            return Poll::Ready(trailers.map_err(|e| e.into()));
        }

        Poll::Ready(Ok(None))
    }

    fn is_end_stream(&self) -> bool {
        if self.is_empty {
            return true;
        }

        let is_end = self
            .owned
            .as_ref()
            .map(|owned| owned.body.is_end_stream())
            .unwrap_or(false);

        !self.replay_body && !self.replay_trailers && is_end
    }

    fn size_hint(&self) -> http_body::SizeHint {
        self.size_hint.clone()
    }
}

#[derive(Clone)]
pub struct InnerBuffer {
    bufs: VecDeque<Bytes>,
    capacity: usize,
}

impl InnerBuffer {
    pub fn push_bytes(&mut self, bytes: Bytes) -> Bytes {
        self.bufs.push_back(bytes.clone());
        bytes
    }

    pub fn is_capped(&self) -> bool {
        self.capacity == 0
    }
}

impl Buf for InnerBuffer {
    fn remaining(&self) -> usize {
        self.bufs.iter().map(|bytes| bytes.remaining()).sum()
    }

    fn chunk(&self) -> &[u8] {
        self.bufs.front().map(Buf::chunk).unwrap_or(&[])
    }

    fn chunks_vectored<'a>(&'a self, dst: &mut [std::io::IoSlice<'a>]) -> usize {
        if dst.is_empty() {
            return 0;
        }

        let mut filled = 0;

        for bytes in self.bufs.iter() {
            filled += bytes.chunks_vectored(&mut dst[filled..])
        }

        filled
    }

    fn advance(&mut self, mut cnt: usize) {
        while cnt > 0 {
            let first = self.bufs.front_mut();
            if first.is_none() {
                break;
            }
            let first = first.unwrap();
            let first_remaining = first.remaining();
            if first_remaining > cnt {
                first.advance(cnt);
                break;
            }

            first.advance(first_remaining);
            cnt = cnt - first_remaining;
            self.bufs.pop_front();
        }
    }

    fn copy_to_bytes(&mut self, len: usize) -> bytes::Bytes {
        match self.bufs.front_mut() {
            Some(buf) if len <= buf.remaining() => {
                let bytes = buf.copy_to_bytes(len);
                if buf.remaining() == 0 {
                    self.bufs.pop_front();
                }
                bytes
            }
            _ => {
                let mut bytes = BytesMut::with_capacity(len);
                bytes.put(self.take(len));
                bytes.freeze()
            }
        }
    }
}

pub enum BytesData {
    BufferedBytes(InnerBuffer),
    OriginBytes(Bytes),
}

impl Buf for BytesData {
    fn remaining(&self) -> usize {
        match self {
            BytesData::BufferedBytes(bytes) => bytes.remaining(),
            BytesData::OriginBytes(bytes) => bytes.remaining(),
        }
    }

    fn chunk(&self) -> &[u8] {
        match self {
            BytesData::BufferedBytes(bytes) => bytes.chunk(),
            BytesData::OriginBytes(bytes) => bytes.chunk(),
        }
    }

    fn advance(&mut self, cnt: usize) {
        match self {
            BytesData::BufferedBytes(bytes) => bytes.advance(cnt),
            BytesData::OriginBytes(bytes) => bytes.advance(cnt),
        }
    }

    fn copy_to_bytes(&mut self, len: usize) -> bytes::Bytes {
        match self {
            BytesData::BufferedBytes(bytes) => bytes.copy_to_bytes(len),
            BytesData::OriginBytes(bytes) => bytes.copy_to_bytes(len),
        }
    }
}

#[pin_project]
pub struct CloneBody(#[pin] BufferedBody);

impl CloneBody {
    pub fn new(inner_body: hyper::Body) -> Self {
        let inner_body = BufferedBody::new(inner_body, 1024 * 64);
        CloneBody(inner_body)
    }
}

impl Body for CloneBody {
    type Data = BytesData;

    type Error = StdError;

    fn poll_data(
        self: Pin<&mut Self>,
        cx: &mut Context<'_>,
    ) -> Poll<Option<Result<Self::Data, Self::Error>>> {
        self.project().0.poll_data(cx)
    }

    fn poll_trailers(
        self: Pin<&mut Self>,
        cx: &mut Context<'_>,
    ) -> Poll<Result<Option<http::HeaderMap>, Self::Error>> {
        self.project().0.poll_trailers(cx)
    }

    fn size_hint(&self) -> http_body::SizeHint {
        self.0.size_hint()
    }
}

impl Clone for CloneBody {
    fn clone(&self) -> Self {
        Self(self.0.clone())
    }
}