use std::sync::Arc;
use std::time::Duration;
use futures::future::join_all;
use futures::stream::futures_unordered::FuturesUnordered;
use futures::stream::StreamExt;
use futures_util::future::FutureExt;
use tokio::select;
use tokio::sync::watch;
use opentelemetry::KeyValue;
use opentelemetry::{
	trace::{FutureExt as OtelFutureExt, Span, TraceContextExt, Tracer},
	Context,
};
pub use netapp::endpoint::{Endpoint, EndpointHandler, StreamingEndpointHandler};
use netapp::message::IntoReq;
pub use netapp::message::{
	Message as Rpc, OrderTag, Req, RequestPriority, Resp, PRIO_BACKGROUND, PRIO_HIGH, PRIO_NORMAL,
	PRIO_SECONDARY,
};
use netapp::peering::fullmesh::FullMeshPeeringStrategy;
pub use netapp::{self, NetApp, NodeID};
use garage_util::background::BackgroundRunner;
use garage_util::data::*;
use garage_util::error::Error;
use garage_util::metrics::RecordDuration;
use crate::metrics::RpcMetrics;
use crate::ring::Ring;
const DEFAULT_TIMEOUT: Duration = Duration::from_secs(300);
#[derive(Copy, Clone)]
pub struct RequestStrategy {
		pub rs_quorum: Option<usize>,
		pub rs_interrupt_after_quorum: bool,
		pub rs_priority: RequestPriority,
		rs_timeout: Timeout,
}
#[derive(Copy, Clone)]
enum Timeout {
	None,
	Default,
	Custom(Duration),
}
impl RequestStrategy {
		pub fn with_priority(prio: RequestPriority) -> Self {
		RequestStrategy {
			rs_quorum: None,
			rs_interrupt_after_quorum: false,
			rs_priority: prio,
			rs_timeout: Timeout::Default,
		}
	}
		pub fn with_quorum(mut self, quorum: usize) -> Self {
		self.rs_quorum = Some(quorum);
		self
	}
			pub fn interrupt_after_quorum(mut self, interrupt: bool) -> Self {
		self.rs_interrupt_after_quorum = interrupt;
		self
	}
		pub fn without_timeout(mut self) -> Self {
		self.rs_timeout = Timeout::None;
		self
	}
		pub fn with_custom_timeout(mut self, timeout: Duration) -> Self {
		self.rs_timeout = Timeout::Custom(timeout);
		self
	}
}
#[derive(Clone)]
pub struct RpcHelper(Arc<RpcHelperInner>);
struct RpcHelperInner {
	our_node_id: Uuid,
	fullmesh: Arc<FullMeshPeeringStrategy>,
	background: Arc<BackgroundRunner>,
	ring: watch::Receiver<Arc<Ring>>,
	metrics: RpcMetrics,
	rpc_timeout: Duration,
}
impl RpcHelper {
	pub(crate) fn new(
		our_node_id: Uuid,
		fullmesh: Arc<FullMeshPeeringStrategy>,
		background: Arc<BackgroundRunner>,
		ring: watch::Receiver<Arc<Ring>>,
		rpc_timeout: Option<Duration>,
	) -> Self {
		let metrics = RpcMetrics::new();
		Self(Arc::new(RpcHelperInner {
			our_node_id,
			fullmesh,
			background,
			ring,
			metrics,
			rpc_timeout: rpc_timeout.unwrap_or(DEFAULT_TIMEOUT),
		}))
	}
	pub fn rpc_timeout(&self) -> Duration {
		self.0.rpc_timeout
	}
	pub async fn call<M, N, H, S>(
		&self,
		endpoint: &Endpoint<M, H>,
		to: Uuid,
		msg: N,
		strat: RequestStrategy,
	) -> Result<S, Error>
	where
		M: Rpc<Response = Result<S, Error>>,
		N: IntoReq<M> + Send,
		H: StreamingEndpointHandler<M>,
	{
		let metric_tags = [
			KeyValue::new("rpc_endpoint", endpoint.path().to_string()),
			KeyValue::new("from", format!("{:?}", self.0.our_node_id)),
			KeyValue::new("to", format!("{:?}", to)),
		];
		self.0.metrics.rpc_counter.add(1, &metric_tags);
		let node_id = to.into();
		let rpc_call = endpoint
			.call_streaming(&node_id, msg, strat.rs_priority)
			.record_duration(&self.0.metrics.rpc_duration, &metric_tags);
		let timeout = async {
			match strat.rs_timeout {
				Timeout::None => futures::future::pending().await,
				Timeout::Default => tokio::time::sleep(self.0.rpc_timeout).await,
				Timeout::Custom(t) => tokio::time::sleep(t).await,
			}
		};
		select! {
			res = rpc_call => {
				if res.is_err() {
					self.0.metrics.rpc_netapp_error_counter.add(1, &metric_tags);
				}
				let res = res?.into_msg();
				if res.is_err() {
					self.0.metrics.rpc_garage_error_counter.add(1, &metric_tags);
				}
				Ok(res?)
			}
			() = timeout => {
				self.0.metrics.rpc_timeout_counter.add(1, &metric_tags);
				Err(Error::Timeout)
			}
		}
	}
	pub async fn call_many<M, N, H, S>(
		&self,
		endpoint: &Endpoint<M, H>,
		to: &[Uuid],
		msg: N,
		strat: RequestStrategy,
	) -> Result<Vec<(Uuid, Result<S, Error>)>, Error>
	where
		M: Rpc<Response = Result<S, Error>>,
		N: IntoReq<M>,
		H: StreamingEndpointHandler<M>,
	{
		let msg = msg.into_req().map_err(netapp::error::Error::from)?;
		let resps = join_all(
			to.iter()
				.map(|to| self.call(endpoint, *to, msg.clone(), strat)),
		)
		.await;
		Ok(to
			.iter()
			.cloned()
			.zip(resps.into_iter())
			.collect::<Vec<_>>())
	}
	pub async fn broadcast<M, N, H, S>(
		&self,
		endpoint: &Endpoint<M, H>,
		msg: N,
		strat: RequestStrategy,
	) -> Result<Vec<(Uuid, Result<S, Error>)>, Error>
	where
		M: Rpc<Response = Result<S, Error>>,
		N: IntoReq<M>,
		H: StreamingEndpointHandler<M>,
	{
		let to = self
			.0
			.fullmesh
			.get_peer_list()
			.iter()
			.map(|p| p.id.into())
			.collect::<Vec<_>>();
		self.call_many(endpoint, &to[..], msg, strat).await
	}
			pub async fn try_call_many<M, N, H, S>(
		&self,
		endpoint: &Arc<Endpoint<M, H>>,
		to: &[Uuid],
		msg: N,
		strategy: RequestStrategy,
	) -> Result<Vec<S>, Error>
	where
		M: Rpc<Response = Result<S, Error>> + 'static,
		N: IntoReq<M>,
		H: StreamingEndpointHandler<M> + 'static,
		S: Send + 'static,
	{
		let quorum = strategy.rs_quorum.unwrap_or(to.len());
		let tracer = opentelemetry::global::tracer("garage");
		let span_name = if strategy.rs_interrupt_after_quorum {
			format!("RPC {} to {} of {}", endpoint.path(), quorum, to.len())
		} else {
			format!(
				"RPC {} to {} (quorum {})",
				endpoint.path(),
				to.len(),
				quorum
			)
		};
		let mut span = tracer.start(span_name);
		span.set_attribute(KeyValue::new("from", format!("{:?}", self.0.our_node_id)));
		span.set_attribute(KeyValue::new("to", format!("{:?}", to)));
		span.set_attribute(KeyValue::new("quorum", quorum as i64));
		span.set_attribute(KeyValue::new(
			"interrupt_after_quorum",
			strategy.rs_interrupt_after_quorum.to_string(),
		));
		self.try_call_many_internal(endpoint, to, msg, strategy, quorum)
			.with_context(Context::current_with_span(span))
			.await
	}
	async fn try_call_many_internal<M, N, H, S>(
		&self,
		endpoint: &Arc<Endpoint<M, H>>,
		to: &[Uuid],
		msg: N,
		strategy: RequestStrategy,
		quorum: usize,
	) -> Result<Vec<S>, Error>
	where
		M: Rpc<Response = Result<S, Error>> + 'static,
		N: IntoReq<M>,
		H: StreamingEndpointHandler<M> + 'static,
		S: Send + 'static,
	{
		let msg = msg.into_req().map_err(netapp::error::Error::from)?;
								let requests = to.iter().cloned().map(|to| {
			let self2 = self.clone();
			let msg = msg.clone();
			let endpoint2 = endpoint.clone();
			(to, async move {
				self2.call(&endpoint2, to, msg, strategy).await
			})
		});
				let mut successes = vec![];
		let mut errors = vec![];
		if strategy.rs_interrupt_after_quorum {
												
						let request_order = self.request_order(to);
			let mut ord_requests = vec![(); request_order.len()]
				.into_iter()
				.map(|_| None)
				.collect::<Vec<_>>();
			for (to, fut) in requests {
				let i = request_order.iter().position(|x| *x == to).unwrap();
				ord_requests[i] = Some((to, fut));
			}
						let mut requests = ord_requests.into_iter().map(Option::unwrap);
									let mut resp_stream = FuturesUnordered::new();
						'request_loop: while successes.len() < quorum {
												while successes.len() + resp_stream.len() < quorum {
					if let Some((req_to, fut)) = requests.next() {
						let tracer = opentelemetry::global::tracer("garage");
						let span = tracer.start(format!("RPC to {:?}", req_to));
						resp_stream.push(tokio::spawn(
							fut.with_context(Context::current_with_span(span)),
						));
					} else {
																		break 'request_loop;
					}
				}
				assert!(!resp_stream.is_empty()); 
								match resp_stream.next().await.unwrap().unwrap() {
					Ok(msg) => {
						successes.push(msg);
					}
					Err(e) => {
						errors.push(e);
					}
				}
			}
		} else {
																		let mut resp_stream = requests
				.map(|(_, fut)| fut)
				.collect::<FuturesUnordered<_>>();
			while let Some(resp) = resp_stream.next().await {
				match resp {
					Ok(msg) => {
						successes.push(msg);
						if successes.len() >= quorum {
							break;
						}
					}
					Err(e) => {
						errors.push(e);
					}
				}
			}
			if !resp_stream.is_empty() {
																																let wait_finished_fut = tokio::spawn(async move {
					resp_stream.collect::<Vec<Result<_, _>>>().await;
				});
				self.0.background.spawn(wait_finished_fut.map(|_| Ok(())));
			}
		}
		if successes.len() >= quorum {
			Ok(successes)
		} else {
			let errors = errors.iter().map(|e| format!("{}", e)).collect::<Vec<_>>();
			Err(Error::Quorum(quorum, successes.len(), to.len(), errors))
		}
	}
	pub fn request_order(&self, nodes: &[Uuid]) -> Vec<Uuid> {
				let peer_list = self.0.fullmesh.get_peer_list();
		let ring: Arc<Ring> = self.0.ring.borrow().clone();
		let our_zone = match ring.layout.node_role(&self.0.our_node_id) {
			Some(pc) => &pc.zone,
			None => "",
		};
														let mut nodes = nodes
			.iter()
			.map(|to| {
				let peer_zone = match ring.layout.node_role(to) {
					Some(pc) => &pc.zone,
					None => "",
				};
				let peer_avg_ping = peer_list
					.iter()
					.find(|x| x.id.as_ref() == to.as_slice())
					.and_then(|pi| pi.avg_ping)
					.unwrap_or_else(|| Duration::from_secs(10));
				(
					*to != self.0.our_node_id,
					peer_zone != our_zone,
					peer_avg_ping,
					*to,
				)
			})
			.collect::<Vec<_>>();
				nodes.sort_by_key(|(diffnode, diffzone, ping, _to)| (*diffnode, *diffzone, *ping));
		nodes
			.into_iter()
			.map(|(_, _, _, to)| to)
			.collect::<Vec<_>>()
	}
}