Documentation
use std::net::SocketAddr;
//use tokio::net;
use tracing::{
	Instrument,
	trace_span,
	info_span,
	Span
};
use chacha20poly1305::{
	ChaCha20Poly1305,
	KeyInit
};
use chacha20poly1305::aead::AeadInPlace;
use rand::Rng;

use crate::peer::Address;
use crate::message;
use super::InterfaceAddress;
use super::interface::Terminated;
use crate::util::*;

const BUFFER_SIZE: usize = 65536;

const NONCE_SIZE: usize = 12;
const TAG_SIZE: usize = 16;
const CIPHERTEXT_OFFSET: usize = NONCE_SIZE + TAG_SIZE;

pub struct Server {
	addr: InterfaceAddress,
	cipher: ChaCha20Poly1305,
	socket: tokio::net::UdpSocket,
	span: Span
}

impl Server {
	/*
	 */
	pub fn get_address(&self) -> InterfaceAddress {
		self.addr.clone()
	}

	pub async fn bind(
		key: &[u8; 32],
		addr: InterfaceAddress,
		parent_span: &Span)
		-> Option<Self>
	{
		let socket_addr = addr.get_socket_address()?;
		let span = parent_span.in_scope(|| info_span!("chacha20", %addr));
		let bind_span = info_span!("chacha_bind", %addr);
		let _g = bind_span.enter();

		let cipher = ChaCha20Poly1305::new(key.into());
		let socket = tokio::net::UdpSocket::bind(socket_addr)
			.in_current_span()
			.await
			.map_err(|e| error!("Unable to bind to address: {}", e))
			.ok()?;

		info!("started");
		Some(Self{addr, cipher, socket, span})
	}

	pub async fn launch(self, channel: ServerChannel) -> Terminated {

		info!(@self, "listening");
		let (tx, mut rx): (Sender<PktFrom>, Receiver<PktTo>)
			= channel.split();

		loop { tokio::select! {
			biased;
			pkt_to = rx.recv() => match pkt_to {
				Some((addr, pkt)) => {
					let addr = self.unwrap_socket_addr(addr);
					self.send_message(&pkt, &addr).await;
					continue
				},
				None => {
					info!(@self,"received termination signal, shutting down!");
					info!(@self, "stopped!");
					return Terminated::Shutdown
				}
			},
			_ = self.proxy_loop(&tx) => {
				info!(@self, "crashed!");
				return Terminated::Crashed("proxy_loop".to_string())
			}
		}}
	}

	fn unwrap_socket_addr(&self, address: Address) -> SocketAddr {
		match address {
			Address::V4UdpChaCha20(addr) => SocketAddr::V4(addr),
			Address::V6UdpChaCha20(addr) => SocketAddr::V6(addr),
			_ => {
				error!(@self, %address, "invalid outbound address, panicking!");
				panic!("ChaCha20 server received invalid outbound address");
			}
		}
	}

	fn wrap_socket_addr(socket_addr: SocketAddr) -> Address {
		match socket_addr {
			SocketAddr::V4(addr) => Address::V4UdpChaCha20(addr),
			SocketAddr::V6(addr) => Address::V6UdpChaCha20(addr)
		}
	}

	fn encrypt(&self, pkt: &message::Packet) -> Option<Vec<u8>> {

		let _g = self.span.in_scope(|| trace_span!("encrypt", ?pkt)).entered();

		let mut buf = Vec::<u8>::with_capacity(128);
		/* Packet layout:
		 *	- [ 0..12]	NONCE
		 *	- [12..28]	TAG
		 *	- [28..]	CIPHERTEXT
		 */
		// generate random bytes for the nonce
		buf.extend_from_slice(&rand::rng().random::<[u8; NONCE_SIZE]>());
		// reserve space for the tag
		buf.resize(NONCE_SIZE + TAG_SIZE, 0);

		// write the plaintext into the buffer
		if let Err(e) = bincode_serialize_into(&mut buf, pkt) {
			warn!("Failed to serialize packet: {}!", e);
			return None;
		}

		// check length, UDP packet maximum is u16::MAX so that kinda works out
		if buf.len() > BUFFER_SIZE {
			warn!("Packet is too large to serialize!");
			return None;
		}

		// need to split because nonce is read while ciphertext is written…
		let (nonce, rest_buf) = buf.split_at_mut(NONCE_SIZE);
		let (tag_buf, data_buf) = rest_buf.split_at_mut(TAG_SIZE);
		let nonce = nonce as &[_];

		let tag = self.cipher
			.encrypt_in_place_detached(nonce.into(), &[], data_buf)
			.map_err(|e| warn!("Failed to encrypt packet: {}!", e))
			.ok()?;
		tag_buf.copy_from_slice(&tag);

		#[cfg(debug_assertions)]
		trace!(
			nonce = ?fmt_bytes(nonce),
			tag = ?fmt_bytes(tag_buf),
			ciphertext = ?fmt_bytes(data_buf)
		);

		Some(buf)
	}

	fn decrypt(&self, buf: &mut [u8; BUFFER_SIZE], len: usize)
		-> Option<message::Packet>
	{
		let _g = self.span.in_scope(|| trace_span!("decrypt")).entered();

		if len <= CIPHERTEXT_OFFSET {return None;}
		let (nonce, rest) = buf.split_at_mut(NONCE_SIZE);
		let (tag, data) = rest.split_at_mut(TAG_SIZE);
		let nonce = nonce as &[_];
		let tag = tag as &[_];
		let data_len = len - CIPHERTEXT_OFFSET;
		let data = &mut data[..data_len];

		#[cfg(debug_assertions)]
		trace!(
			length = len,
			tag = ?fmt_bytes(tag),
			nonce = ?fmt_bytes(nonce),
			ciphertext = ?fmt_bytes(data)
		);

		match self.cipher.decrypt_in_place_detached(
			nonce.into(),
			&[], // Associated data: unused
			data,
			tag.into())
		{
			Ok(()) => match bincode_deserialize::<message::Packet>(data) {
				Ok(decoded) => Some(decoded),
				Err(error) => {
					error!("Unable to decode datagram: {}", error);
					None
				}
			},
			Err(e) => {
				warn!("Failed to decrypt a packet: {}!", e);
				None
			}
		}
	}

	async fn proxy_loop(&self, sender: &Sender<PktFrom>) -> ! {
		let mut buf: [u8; BUFFER_SIZE] = [0; BUFFER_SIZE];
		loop {
			match self.socket.recv_from(&mut buf).await {
				Ok((amt, src)) => {
					self.decrypt_and_push(sender, src, amt, &mut buf);
				},
				// TODO: Handle errors instead of looping endlessly
				Err(e) => warn!(@self, "UDP error: {}", e)
			}
		}
	}

	fn decrypt_and_push(
		&self,
		sender: &Sender::<PktFrom>,
		src: SocketAddr,
		amt: usize,
		buf: &mut [u8; BUFFER_SIZE])
	{
		match self.decrypt(buf, amt) {
			Some(pkt) => {
				trace!(@self, %src, amt, "datagram received");
				debug!(@self, ?pkt, "packet received");
				let _ = sender.send((Self::wrap_socket_addr(src), pkt));
			},
			None => warn!(@self, %src, amt, "Failed to parse datagram")
		}
	}

	async fn send_message(
		&self,
		pkt: &message::Packet,
		to: &SocketAddr)
		-> Option<usize>
	{
		debug!(@self, ?pkt, "sending packet");
		match self.socket.send_to(&self.encrypt(pkt)?, to).await {
			Ok(amt) => {
				trace!(@self, ?to, amt, "datagram sent");
				Some(amt)
			},
			Err(e) => {
				error!(@self, "Failed to send message: {}", e);
				None
			}
		}
	}
}