use std::collections::HashMap;
use std::collections::HashSet;
use std::fmt;
use std::sync::atomic::AtomicUsize;
use std::sync::atomic::Ordering;
use std::sync::Arc;
use opentelemetry::trace::TraceContextExt;
use opentelemetry::Context as otelContext;
use parking_lot::Mutex as PMutex;
use tokio::sync::mpsc;
use tokio::sync::oneshot;
use tokio::sync::Mutex;
use tokio::task::JoinHandle;
use tower::BoxError;
use tracing::Instrument;
use tracing::Span;
use crate::error::FetchError;
use crate::error::SubgraphBatchingError;
use crate::graphql;
use crate::plugins::telemetry::otel::span_ext::OpenTelemetrySpanExt;
use crate::query_planner::fetch::QueryHash;
use crate::services::http::HttpClientServiceFactory;
use crate::services::process_batches;
use crate::services::router::body::get_body_bytes;
use crate::services::router::body::RouterBody;
use crate::services::SubgraphRequest;
use crate::services::SubgraphResponse;
use crate::Context;
#[derive(Clone, Debug)]
pub(crate) struct BatchQuery {
    index: usize,
    sender: Arc<Mutex<Option<mpsc::Sender<BatchHandlerMessage>>>>,
    remaining: Arc<AtomicUsize>,
    batch: Arc<Batch>,
}
impl fmt::Display for BatchQuery {
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
        write!(f, "index: {}, ", self.index)?;
        write!(f, "remaining: {}, ", self.remaining.load(Ordering::Acquire))?;
        write!(f, "sender: {:?}, ", self.sender)?;
        write!(f, "batch: {:?}, ", self.batch)?;
        Ok(())
    }
}
impl BatchQuery {
    pub(crate) fn finished(&self) -> bool {
        self.remaining.load(Ordering::Acquire) == 0
    }
    pub(crate) async fn set_query_hashes(
        &self,
        query_hashes: Vec<Arc<QueryHash>>,
    ) -> Result<(), BoxError> {
        self.remaining.store(query_hashes.len(), Ordering::Release);
        self.sender
            .lock()
            .await
            .as_ref()
            .ok_or(SubgraphBatchingError::SenderUnavailable)?
            .send(BatchHandlerMessage::Begin {
                index: self.index,
                query_hashes,
            })
            .await?;
        Ok(())
    }
    pub(crate) async fn signal_progress(
        &self,
        client_factory: HttpClientServiceFactory,
        request: SubgraphRequest,
        gql_request: graphql::Request,
    ) -> Result<oneshot::Receiver<Result<SubgraphResponse, BoxError>>, BoxError> {
        let (tx, rx) = oneshot::channel();
        tracing::debug!(
            "index: {}, REMAINING: {}",
            self.index,
            self.remaining.load(Ordering::Acquire)
        );
        self.sender
            .lock()
            .await
            .as_ref()
            .ok_or(SubgraphBatchingError::SenderUnavailable)?
            .send(BatchHandlerMessage::Progress {
                index: self.index,
                client_factory,
                request,
                gql_request,
                response_sender: tx,
                span_context: Span::current().context(),
            })
            .await?;
        if !self.finished() {
            self.remaining.fetch_sub(1, Ordering::AcqRel);
        }
        if self.finished() {
            let mut sender = self.sender.lock().await;
            *sender = None;
        }
        Ok(rx)
    }
    pub(crate) async fn signal_cancelled(&self, reason: String) -> Result<(), BoxError> {
        self.sender
            .lock()
            .await
            .as_ref()
            .ok_or(SubgraphBatchingError::SenderUnavailable)?
            .send(BatchHandlerMessage::Cancel {
                index: self.index,
                reason,
            })
            .await?;
        if !self.finished() {
            self.remaining.fetch_sub(1, Ordering::AcqRel);
        }
        if self.finished() {
            let mut sender = self.sender.lock().await;
            *sender = None;
        }
        Ok(())
    }
}
enum BatchHandlerMessage {
    Cancel { index: usize, reason: String },
    Progress {
        index: usize,
        client_factory: HttpClientServiceFactory,
        request: SubgraphRequest,
        gql_request: graphql::Request,
        response_sender: oneshot::Sender<Result<SubgraphResponse, BoxError>>,
        span_context: otelContext,
    },
    Begin {
        index: usize,
        query_hashes: Vec<Arc<QueryHash>>,
    },
}
pub(crate) struct BatchQueryInfo {
    request: SubgraphRequest,
    gql_request: graphql::Request,
    sender: oneshot::Sender<Result<SubgraphResponse, BoxError>>,
}
#[derive(Debug)]
pub(crate) struct Batch {
    senders: PMutex<Vec<Option<mpsc::Sender<BatchHandlerMessage>>>>,
    spawn_handle: JoinHandle<Result<(), BoxError>>,
    #[allow(dead_code)]
    size: usize,
}
impl Batch {
    pub(crate) fn spawn_handler(size: usize) -> Self {
        tracing::debug!("New batch created with size {size}");
        let (spawn_tx, mut rx) = mpsc::channel(size);
        let mut senders = vec![];
        for _ in 0..size {
            senders.push(Some(spawn_tx.clone()));
        }
        let spawn_handle = tokio::spawn(async move {
            #[derive(Debug)]
            struct BatchQueryState {
                registered: HashSet<Arc<QueryHash>>,
                committed: HashSet<Arc<QueryHash>>,
                cancelled: HashSet<Arc<QueryHash>>,
            }
            impl BatchQueryState {
                fn is_ready(&self) -> bool {
                    self.registered.difference(&self.committed.union(&self.cancelled).cloned().collect()).collect::<Vec<_>>().is_empty()
                }
            }
            let mut batch_state: HashMap<usize, BatchQueryState> = HashMap::with_capacity(size);
            let mut requests: Vec<Vec<BatchQueryInfo>> =
                Vec::from_iter((0..size).map(|_| Vec::new()));
            let mut master_client_factory = None;
            tracing::debug!("Batch about to await messages...");
            while let Some(msg) = rx.recv().await {
                match msg {
                    BatchHandlerMessage::Cancel { index, reason } => {
                        tracing::debug!("Cancelling index: {index}, {reason}");
                        if let Some(state) = batch_state.get_mut(&index) {
                            let cancelled_requests = std::mem::take(&mut requests[index]);
                            for BatchQueryInfo {
                                request, sender, ..
                            } in cancelled_requests
                            {
                                let subgraph_name = request.subgraph_name.ok_or(SubgraphBatchingError::MissingSubgraphName)?;
                                if let Err(log_error) = sender.send(Err(Box::new(FetchError::SubrequestBatchingError {
                                        service: subgraph_name.clone(),
                                        reason: format!("request cancelled: {reason}"),
                                    }))) {
                                    tracing::error!(service=subgraph_name, error=?log_error, "failed to notify waiter that request is cancelled");
                                }
                            }
                            state.committed.clear();
                            state.cancelled = state.registered.clone();
                        }
                    }
                    BatchHandlerMessage::Begin {
                        index,
                        query_hashes,
                    } => {
                        tracing::debug!("Beginning batch for index {index} with {query_hashes:?}");
                        batch_state.insert(
                            index,
                            BatchQueryState {
                                cancelled: HashSet::with_capacity(query_hashes.len()),
                                committed: HashSet::with_capacity(query_hashes.len()),
                                registered: HashSet::from_iter(query_hashes),
                            },
                        );
                    }
                    BatchHandlerMessage::Progress {
                        index,
                        client_factory,
                        request,
                        gql_request,
                        response_sender,
                        span_context,
                    } => {
                        tracing::debug!("Progress index: {index}");
                        if let Some(state) = batch_state.get_mut(&index) {
                            state.committed.insert(request.query_hash.clone());
                        }
                        if master_client_factory.is_none() {
                            master_client_factory = Some(client_factory);
                        }
                        Span::current().add_link(span_context.span().span_context().clone());
                        requests[index].push(BatchQueryInfo {
                            request,
                            gql_request,
                            sender: response_sender,
                        })
                    }
                }
            }
            if batch_state.values().any(|f| !f.is_ready()) {
                tracing::error!("All senders for the batch have dropped before reaching the ready state: {batch_state:#?}");
                return Err(SubgraphBatchingError::ProcessingFailed("batch senders not ready when required".to_string()).into());
            }
            tracing::debug!("Assembling {size} requests into batches");
            let all_in_one: Vec<_> = requests.into_iter().flatten().collect();
            let mut svc_map: HashMap<String, Vec<BatchQueryInfo>> = HashMap::new();
            for BatchQueryInfo {
                request: sg_request,
                gql_request,
                sender: tx,
            } in all_in_one
            {
                let subgraph_name = sg_request.subgraph_name.clone().ok_or(SubgraphBatchingError::MissingSubgraphName)?;
                let value = svc_map
                    .entry(
                        subgraph_name,
                    )
                    .or_default();
                value.push(BatchQueryInfo {
                    request: sg_request,
                    gql_request,
                    sender: tx,
                });
            }
            if let Some(client_factory) = master_client_factory {
                process_batches(client_factory, svc_map).await?;
            }
            Ok(())
        }.instrument(tracing::info_span!("batch_request", size)));
        Self {
            senders: PMutex::new(senders),
            spawn_handle,
            size,
        }
    }
    pub(crate) fn query_for_index(
        batch: Arc<Batch>,
        index: usize,
    ) -> Result<BatchQuery, SubgraphBatchingError> {
        let mut guard = batch.senders.lock();
        if index >= guard.len() {
            return Err(SubgraphBatchingError::ProcessingFailed(format!(
                "tried to retriever sender for index: {index} which does not exist"
            )));
        }
        let opt_sender = std::mem::take(&mut guard[index]);
        if opt_sender.is_none() {
            return Err(SubgraphBatchingError::ProcessingFailed(format!(
                "tried to retriever sender for index: {index} which has already been taken"
            )));
        }
        drop(guard);
        Ok(BatchQuery {
            index,
            sender: Arc::new(Mutex::new(opt_sender)),
            remaining: Arc::new(AtomicUsize::new(0)),
            batch,
        })
    }
}
impl Drop for Batch {
    fn drop(&mut self) {
        self.spawn_handle.abort();
    }
}
pub(crate) async fn assemble_batch(
    requests: Vec<BatchQueryInfo>,
) -> Result<
    (
        String,
        Vec<Context>,
        http::Request<RouterBody>,
        Vec<oneshot::Sender<Result<SubgraphResponse, BoxError>>>,
    ),
    BoxError,
> {
    let (txs, request_pairs): (Vec<_>, Vec<_>) = requests
        .into_iter()
        .map(|r| (r.sender, (r.request, r.gql_request)))
        .unzip();
    let (requests, gql_requests): (Vec<_>, Vec<_>) = request_pairs.into_iter().unzip();
    let bytes = get_body_bytes(serde_json::to_string(&gql_requests)?).await?;
    let contexts = requests
        .iter()
        .map(|x| x.context.clone())
        .collect::<Vec<Context>>();
    let first_request = requests
        .into_iter()
        .next()
        .ok_or(SubgraphBatchingError::RequestsIsEmpty)?
        .subgraph_request;
    let operation_name = first_request
        .body()
        .operation_name
        .clone()
        .unwrap_or_default();
    let (parts, _) = first_request.into_parts();
    let request = http::Request::from_parts(parts, RouterBody::from(bytes));
    Ok((operation_name, contexts, request, txs))
}
#[cfg(test)]
mod tests {
    use std::sync::Arc;
    use std::time::Duration;
    use tokio::sync::oneshot;
    use super::assemble_batch;
    use super::Batch;
    use super::BatchQueryInfo;
    use crate::graphql;
    use crate::plugins::traffic_shaping::Http2Config;
    use crate::query_planner::fetch::QueryHash;
    use crate::services::http::HttpClientServiceFactory;
    use crate::services::SubgraphRequest;
    use crate::services::SubgraphResponse;
    use crate::Configuration;
    use crate::Context;
    #[tokio::test(flavor = "multi_thread")]
    async fn it_assembles_batch() {
        let (receivers, requests): (Vec<_>, Vec<_>) = (0..2)
            .map(|index| {
                let (tx, rx) = oneshot::channel();
                let gql_request = graphql::Request::fake_builder()
                    .operation_name(format!("batch_test_{index}"))
                    .query(format!("query batch_test {{ slot{index} }}"))
                    .build();
                (
                    rx,
                    BatchQueryInfo {
                        request: SubgraphRequest::fake_builder()
                            .subgraph_request(
                                http::Request::builder().body(gql_request.clone()).unwrap(),
                            )
                            .subgraph_name(format!("slot{index}"))
                            .build(),
                        gql_request,
                        sender: tx,
                    },
                )
            })
            .unzip();
        let input_context_ids = requests
            .iter()
            .map(|r| r.request.context.id.clone())
            .collect::<Vec<String>>();
        let (op_name, contexts, request, txs) = assemble_batch(requests)
            .await
            .expect("it can assemble a batch");
        let output_context_ids = contexts
            .iter()
            .map(|r| r.id.clone())
            .collect::<Vec<String>>();
        assert_eq!(input_context_ids, output_context_ids);
        assert_eq!(op_name, "batch_test_0");
        let actual: Vec<graphql::Request> = serde_json::from_str(
            std::str::from_utf8(&request.into_body().to_bytes().await.unwrap()).unwrap(),
        )
        .unwrap();
        let expected: Vec<_> = (0..2)
            .map(|index| {
                graphql::Request::fake_builder()
                    .operation_name(format!("batch_test_{index}"))
                    .query(format!("query batch_test {{ slot{index} }}"))
                    .build()
            })
            .collect();
        assert_eq!(actual, expected);
        assert_eq!(txs.len(), receivers.len());
        for (index, (tx, rx)) in Iterator::zip(txs.into_iter(), receivers).enumerate() {
            let data = serde_json_bytes::json!({
                "data": {
                    format!("slot{index}"): "valid"
                }
            });
            let response = SubgraphResponse {
                response: http::Response::builder()
                    .body(graphql::Response::builder().data(data.clone()).build())
                    .unwrap(),
                context: Context::new(),
                subgraph_name: None,
            };
            tx.send(Ok(response)).unwrap();
            let received = tokio::time::timeout(Duration::from_millis(10), rx)
                .await
                .unwrap()
                .unwrap()
                .unwrap();
            assert_eq!(received.response.into_body().data, Some(data));
        }
    }
    #[tokio::test(flavor = "multi_thread")]
    async fn it_rejects_index_out_of_bounds() {
        let batch = Arc::new(Batch::spawn_handler(2));
        assert!(Batch::query_for_index(batch.clone(), 2).is_err());
    }
    #[tokio::test(flavor = "multi_thread")]
    async fn it_rejects_duplicated_index_get() {
        let batch = Arc::new(Batch::spawn_handler(2));
        assert!(Batch::query_for_index(batch.clone(), 0).is_ok());
        assert!(Batch::query_for_index(batch.clone(), 0).is_err());
    }
    #[tokio::test(flavor = "multi_thread")]
    async fn it_limits_the_number_of_cancelled_sends() {
        let batch = Arc::new(Batch::spawn_handler(2));
        let bq = Batch::query_for_index(batch.clone(), 0).expect("its a valid index");
        assert!(bq
            .set_query_hashes(vec![Arc::new(QueryHash::default())])
            .await
            .is_ok());
        assert!(!bq.finished());
        assert!(bq.signal_cancelled("why not?".to_string()).await.is_ok());
        assert!(bq.finished());
        assert!(bq
            .signal_cancelled("only once though".to_string())
            .await
            .is_err());
    }
    #[tokio::test(flavor = "multi_thread")]
    async fn it_limits_the_number_of_progressed_sends() {
        let batch = Arc::new(Batch::spawn_handler(2));
        let bq = Batch::query_for_index(batch.clone(), 0).expect("its a valid index");
        let factory = HttpClientServiceFactory::from_config(
            "testbatch",
            &Configuration::default(),
            Http2Config::Disable,
        );
        let request = SubgraphRequest::fake_builder()
            .subgraph_request(
                http::Request::builder()
                    .body(graphql::Request::default())
                    .unwrap(),
            )
            .subgraph_name("whatever".to_string())
            .build();
        assert!(bq
            .set_query_hashes(vec![Arc::new(QueryHash::default())])
            .await
            .is_ok());
        assert!(!bq.finished());
        assert!(bq
            .signal_progress(
                factory.clone(),
                request.clone(),
                graphql::Request::default()
            )
            .await
            .is_ok());
        assert!(bq.finished());
        assert!(bq
            .signal_progress(factory, request, graphql::Request::default())
            .await
            .is_err());
    }
    #[tokio::test(flavor = "multi_thread")]
    async fn it_limits_the_number_of_mixed_sends() {
        let batch = Arc::new(Batch::spawn_handler(2));
        let bq = Batch::query_for_index(batch.clone(), 0).expect("its a valid index");
        let factory = HttpClientServiceFactory::from_config(
            "testbatch",
            &Configuration::default(),
            Http2Config::Disable,
        );
        let request = SubgraphRequest::fake_builder()
            .subgraph_request(
                http::Request::builder()
                    .body(graphql::Request::default())
                    .unwrap(),
            )
            .subgraph_name("whatever".to_string())
            .build();
        assert!(bq
            .set_query_hashes(vec![Arc::new(QueryHash::default())])
            .await
            .is_ok());
        assert!(!bq.finished());
        assert!(bq
            .signal_progress(factory, request, graphql::Request::default())
            .await
            .is_ok());
        assert!(bq.finished());
        assert!(bq
            .signal_cancelled("only once though".to_string())
            .await
            .is_err());
    }
    #[tokio::test(flavor = "multi_thread")]
    async fn it_limits_the_number_of_mixed_sends_two_query_hashes() {
        let batch = Arc::new(Batch::spawn_handler(2));
        let bq = Batch::query_for_index(batch.clone(), 0).expect("its a valid index");
        let factory = HttpClientServiceFactory::from_config(
            "testbatch",
            &Configuration::default(),
            Http2Config::Disable,
        );
        let request = SubgraphRequest::fake_builder()
            .subgraph_request(
                http::Request::builder()
                    .body(graphql::Request::default())
                    .unwrap(),
            )
            .subgraph_name("whatever".to_string())
            .build();
        let qh = Arc::new(QueryHash::default());
        assert!(bq.set_query_hashes(vec![qh.clone(), qh]).await.is_ok());
        assert!(!bq.finished());
        assert!(bq
            .signal_progress(factory, request, graphql::Request::default())
            .await
            .is_ok());
        assert!(!bq.finished());
        assert!(bq
            .signal_cancelled("only twice though".to_string())
            .await
            .is_ok());
        assert!(bq.finished());
        assert!(bq
            .signal_cancelled("only twice though".to_string())
            .await
            .is_err());
    }
}