Documentation
use futures::StreamExt;
use tracing::{
	info_span,
	Span,
	Instrument
};
use tokio::sync::watch;
use tokio::task::JoinError;

use crate::control::{
	KeyArg,
	RouterReq
};
use crate::peer::Address;
use crate::util::*;

use super::{
	InterfaceAddress,
	ChaCha20Server
};
use super::interface::{
	Interface,
	InterfaceLauncher,
	Terminated
};

pub struct Router {
	inbound_pkt_tx: Sender<PktFrom>,
	monitor_tx: watch::Sender<bool>,
	outbound_pkt_rx: Receiver<PktTo>,
	query_rx: Receiver<RouterReq>,
	encryption_key: [u8; 32],
	interfaces: Registry<Terminated, Interface>,
	span: Span
}

pub type RouterMonitor = watch::Receiver<bool>;

pub struct RouterSetup {
	pub router: Router,
	pub inbound_packet_receiver: Receiver<PktFrom>,
	pub router_monitor: RouterMonitor,
	pub outbound_packet_sender: Sender<PktTo>,
	pub router_query_sender: Sender<RouterReq>
}

impl Router {

	/// Setup in-/outbound packet channels, query channel and initialize router
	pub fn setup(encryption_key: KeyArg, parent_span: &Span) -> RouterSetup {
		// Create inbound packet channel
		let (inbound_packet_sender, inbound_packet_receiver)
			= new_channel::<PktFrom>();

		// Create outbound packet channel
		let (outbound_packet_sender, outbound_packet_receiver)
			= new_channel::<PktTo>();

		// Create router control channel
		let (router_query_sender, query_receiver)
			= new_channel::<RouterReq>();

		// Create router dormant monitor
		let (monitor_tx, router_monitor)
			= watch::channel::<bool>(false);

		let interfaces = Registry::new();

		// The router manages all the interfaces and interface tasks
		let router = Self {
			inbound_pkt_tx: inbound_packet_sender,
			monitor_tx,
			outbound_pkt_rx: outbound_packet_receiver,
			query_rx: query_receiver,
			encryption_key: encryption_key.get(),
			interfaces,
			span: parent_span.in_scope(|| info_span!("router"))
		};

		RouterSetup {
			router,
			inbound_packet_receiver,
			router_monitor,
			outbound_packet_sender,
			router_query_sender
		}
	}

	#[allow(clippy::option_map_unit_fn)]
	fn send_packet(&self, pkt_to: PktTo) {
		self.interfaces
			.iter()
			.find(|&iface| iface.can_send(pkt_to.0))
			.map(|iface| iface.send(pkt_to));
	}


	fn terminate_all(&mut self) {
		// drop all interfaces
		for interface in self.interfaces.iter_mut() {
			trace!(:self.span, %interface, "terminating");
			interface.shutdown();
		}
	}

	/// Remove (all) interface(s) with the given address.
	/// Does not error if there were none to be removed.
	fn remove(&mut self, addr: Address) -> bool {
		self.interfaces
			.iter_mut()
			.filter(|interface| interface.get_address() == addr)
			.map(|interface| {
				trace!(:self.span, %interface, "terminating");
				interface.shutdown();
			})
			.count() > 0
	}

	async fn bind_interface(&mut self, address: InterfaceAddress) -> bool {
		let _g = self.span.enter();

		let already_bound = self.interfaces
			.iter_mut()
			.find(|iface| iface.get_address() == address.get_address());

		if let Some(interface) = already_bound {
			trace!(%interface, %address, "already bound, replacing routes");
			interface.replace_address(address);
			return true;
		}

		let Some(server) = ChaCha20Server::bind(
			&self.encryption_key,
			address.clone(),
			&self.span
		).in_current_span().await else {return false};

		// create new outbound packet channel
		let (outbound_packet_sender, outbound_packet_receiver)
			= new_channel::<PktTo>();

		let server_channel = ServerChannel::combine(
			self.inbound_pkt_tx.clone(),
			outbound_packet_receiver
		);

		let interface = Interface::new(address, outbound_packet_sender);

		trace!(@self, %interface, "spawned chacha20 task");
		self.interfaces.launch(interface, server.launch(server_channel));

		true
	}

	fn launch_interface(&mut self, launcher: InterfaceLauncher) {
		let address = launcher.get_address();

		// create new outbound packet channel
		let (outbound_packet_sender, outbound_packet_receiver)
			= new_channel::<PktTo>();

		let interface = Interface::new(address, outbound_packet_sender);

		let channel = ServerChannel::combine(
			self.inbound_pkt_tx.clone(),
			outbound_packet_receiver
		);

		match launcher {
			InterfaceLauncher::Chacha20Udp(server) => {
				let address = server.get_address();
				self.interfaces.launch(interface, server.launch(channel));
				trace!(@self, %address, "spawned chacha20 task");
			},
			InterfaceLauncher::Dummy(server) => {
				let address = server.get_address();
				self.interfaces.launch(interface, server.run(channel));
				trace!(@self, %address, "spawned dummy server task");
			}
		}
	}

	fn log_termination(
		&self,
		interface: Interface,
		reason: Result<Terminated, JoinError>)
	{
		match reason {
			Ok(Terminated::Shutdown) =>
				info!(@self, %interface, "shut down"),
			Ok(Terminated::Crashed(error_msg)) =>
				error!(@self, %interface, %error_msg, "crashed"),
			Ok(Terminated::Panic) =>
				error!(@self, %interface, "task panicked"),
			Err(error) =>
				error!(@self, %interface, %error, "crashed"),
		}
	}


	pub async fn run(mut self) {
		info!(@self, "started");
		let mut is_active = false;

		loop {
			// is_active should always be the opposite of handles.is_empty()
			if is_active == self.interfaces.is_empty() {
				if self.interfaces.is_empty() {
					debug!(@self, "entering dormant state");
					is_active = false;
				}
				else {
					debug!(@self, "entering active state");
					is_active = true;
				}
				let _ = self.monitor_tx.send(is_active);
			}
			tokio::select! {
				// Query received
				q = self.query_rx.recv() => match q {
					Some(request) => self.handle_request(request)
						.in_current_span()
						.await,
					None => {
						info!(@self, "received shutdown signal");
						break
					}
				},
				// Interface task terminated
				r = self.interfaces.next(), if is_active => {
					if let Some((addr, reason)) = r {
						self.log_termination(addr, reason);
					}
				},
				// Outbound Packet received
				p = self.outbound_pkt_rx.recv() => match p {
					Some(pkt_to) => self.send_packet(pkt_to),
					// Sender hung up, terminate
					None => break
				}
			}
		}

		// flush out outbound packets
		while let Ok(pkt_to) = self.outbound_pkt_rx.try_recv() {
			self.send_packet(pkt_to);
		}

		// drop all interfaces
		self.terminate_all();

		// wait for all interface tasks to shut down cleanly
		while let Some((addr, reason)) = self.interfaces.next().await {
			self.log_termination(addr, reason);
		}

		info!(@self, "finished");
	}

	async fn handle_request(&mut self, req: RouterReq) {
		match req {
			RouterReq::GetInterfaces(req) => {
				let addrs = self.interfaces
					.iter()
					.map(Interface::get_interface_address)
					.collect::<Vec<_>>();
				req.reply(addrs);
			},
			RouterReq::CanSend(req) => {
				let can_send = self.interfaces
					.iter()
					.any(|iface| iface.can_send(req.0));
				req.reply(can_send);
			},
			RouterReq::BindInterface((address, replier)) => {
				let worked = self.bind_interface(address)
					.in_current_span()
					.await;
				replier.reply(worked);
			},
			RouterReq::LaunchInterface((launcher, replier)) => {
				self.launch_interface(launcher);
				replier.reply(true);
			},
			RouterReq::RemoveInterface((addr, replier)) => {
				replier.reply(self.remove(addr));
			}
		}

	}


}

/*

Listener
Socket
Interface
Proxy

*/