use super::atom_info::{AtomIdent, AtomInfo};
use super::read::{nested_atom, skip_unneeded, AtomReader};
use crate::config::ParsingMode;
use crate::error::{LoftyError, Result};
use crate::macros::{decode_err, err, try_vec};
use crate::properties::FileProperties;
use crate::util::math::RoundedDivision;
use std::io::{Cursor, Read, Seek, SeekFrom};
use std::time::Duration;
use byteorder::{BigEndian, ReadBytesExt};
#[allow(missing_docs)]
#[derive(Default, Debug, Copy, Clone, PartialEq, Eq)]
#[non_exhaustive]
pub enum Mp4Codec {
	#[default]
	Unknown,
	AAC,
	ALAC,
	MP3,
	FLAC,
}
#[allow(missing_docs)]
#[derive(Default, Debug, Copy, Clone, PartialEq, Eq)]
#[rustfmt::skip]
#[non_exhaustive]
pub enum AudioObjectType {
	#[default]
	NULL = 0,
	AacMain = 1,                                       AacLowComplexity = 2,                              AacScalableSampleRate = 3,                         AacLongTermPrediction = 4,                         SpectralBandReplication = 5,                       AACScalable = 6,                                   TwinVQ = 7,                                        CodeExcitedLinearPrediction = 8,                   HarmonicVectorExcitationCoding = 9,                TextToSpeechtInterface = 12,                       MainSynthetic = 13,                                WavetableSynthesis = 14,                           GeneralMIDI = 15,                                  AlgorithmicSynthesis = 16,                         ErrorResilientAacLowComplexity = 17,               ErrorResilientAacLongTermPrediction = 19,          ErrorResilientAacScalable = 20,                    ErrorResilientAacTwinVQ = 21,                      ErrorResilientAacBitSlicedArithmeticCoding = 22,   ErrorResilientAacLowDelay = 23,                    ErrorResilientCodeExcitedLinearPrediction = 24,    ErrorResilientHarmonicVectorExcitationCoding = 25, ErrorResilientHarmonicIndividualLinesNoise = 26,   ErrorResilientParametric = 27,                     SinuSoidalCoding = 28,                             ParametricStereo = 29,                             MpegSurround = 30,                                 MpegLayer1 = 32,                                   MpegLayer2 = 33,                                   MpegLayer3 = 34,                                   DirectStreamTransfer = 35,                         AudioLosslessCoding = 36,                          ScalableLosslessCoding = 37,                       ScalableLosslessCodingNoneCore = 38,               ErrorResilientAacEnhancedLowDelay = 39,            SymbolicMusicRepresentationSimple = 40,            SymbolicMusicRepresentationMain = 41,              UnifiedSpeechAudioCoding = 42,                     SpatialAudioObjectCoding = 43,                     LowDelayMpegSurround = 44,                         SpatialAudioObjectCodingDialogueEnhancement = 45,  AudioSync = 46,                                    }
impl TryFrom<u8> for AudioObjectType {
	type Error = LoftyError;
	#[rustfmt::skip]
	fn try_from(value: u8) -> std::result::Result<Self, Self::Error> {
		match value {
			1  => Ok(Self::AacMain),
			2  => Ok(Self::AacLowComplexity),
			3  => Ok(Self::AacScalableSampleRate),
			4  => Ok(Self::AacLongTermPrediction),
			5  => Ok(Self::SpectralBandReplication),
			6  => Ok(Self::AACScalable),
			7  => Ok(Self::TwinVQ),
			8  => Ok(Self::CodeExcitedLinearPrediction),
			9  => Ok(Self::HarmonicVectorExcitationCoding),
			12 => Ok(Self::TextToSpeechtInterface),
			13 => Ok(Self::MainSynthetic),
			14 => Ok(Self::WavetableSynthesis),
			15 => Ok(Self::GeneralMIDI),
			16 => Ok(Self::AlgorithmicSynthesis),
			17 => Ok(Self::ErrorResilientAacLowComplexity),
			19 => Ok(Self::ErrorResilientAacLongTermPrediction),
			20 => Ok(Self::ErrorResilientAacScalable),
			21 => Ok(Self::ErrorResilientAacTwinVQ),
			22 => Ok(Self::ErrorResilientAacBitSlicedArithmeticCoding),
			23 => Ok(Self::ErrorResilientAacLowDelay),
			24 => Ok(Self::ErrorResilientCodeExcitedLinearPrediction),
			25 => Ok(Self::ErrorResilientHarmonicVectorExcitationCoding),
			26 => Ok(Self::ErrorResilientHarmonicIndividualLinesNoise),
			27 => Ok(Self::ErrorResilientParametric),
			28 => Ok(Self::SinuSoidalCoding),
			29 => Ok(Self::ParametricStereo),
			30 => Ok(Self::MpegSurround),
			32 => Ok(Self::MpegLayer1),
			33 => Ok(Self::MpegLayer2),
			34 => Ok(Self::MpegLayer3),
			35 => Ok(Self::DirectStreamTransfer),
			36 => Ok(Self::AudioLosslessCoding),
			37 => Ok(Self::ScalableLosslessCoding),
			38 => Ok(Self::ScalableLosslessCodingNoneCore),
			39 => Ok(Self::ErrorResilientAacEnhancedLowDelay),
			40 => Ok(Self::SymbolicMusicRepresentationSimple),
			41 => Ok(Self::SymbolicMusicRepresentationMain),
			42 => Ok(Self::UnifiedSpeechAudioCoding),
			43 => Ok(Self::SpatialAudioObjectCoding),
			44 => Ok(Self::LowDelayMpegSurround),
			45 => Ok(Self::SpatialAudioObjectCodingDialogueEnhancement),
			46 => Ok(Self::AudioSync),
			_ => decode_err!(@BAIL Mp4, "Encountered an invalid audio object type"),
		}
	}
}
#[derive(Debug, Clone, PartialEq, Eq, Default)]
#[non_exhaustive]
pub struct Mp4Properties {
	pub(crate) codec: Mp4Codec,
	pub(crate) extended_audio_object_type: Option<AudioObjectType>,
	pub(crate) duration: Duration,
	pub(crate) overall_bitrate: u32,
	pub(crate) audio_bitrate: u32,
	pub(crate) sample_rate: u32,
	pub(crate) bit_depth: Option<u8>,
	pub(crate) channels: u8,
	pub(crate) drm_protected: bool,
}
impl From<Mp4Properties> for FileProperties {
	fn from(input: Mp4Properties) -> Self {
		Self {
			duration: input.duration,
			overall_bitrate: Some(input.overall_bitrate),
			audio_bitrate: Some(input.audio_bitrate),
			sample_rate: Some(input.sample_rate),
			bit_depth: input.bit_depth,
			channels: Some(input.channels),
			channel_mask: None,
		}
	}
}
impl Mp4Properties {
	pub fn duration(&self) -> Duration {
		self.duration
	}
	pub fn overall_bitrate(&self) -> u32 {
		self.overall_bitrate
	}
	pub fn audio_bitrate(&self) -> u32 {
		self.audio_bitrate
	}
	pub fn sample_rate(&self) -> u32 {
		self.sample_rate
	}
	pub fn bit_depth(&self) -> Option<u8> {
		self.bit_depth
	}
	pub fn channels(&self) -> u8 {
		self.channels
	}
	pub fn codec(&self) -> &Mp4Codec {
		&self.codec
	}
	pub fn audio_object_type(&self) -> Option<AudioObjectType> {
		self.extended_audio_object_type
	}
	pub fn is_drm_protected(&self) -> bool {
		self.drm_protected
	}
}
struct TrakChildren {
	mdhd: AtomInfo,
	minf: Option<AtomInfo>,
}
fn get_trak_children<R>(reader: &mut AtomReader<R>, traks: &[AtomInfo]) -> Result<TrakChildren>
where
	R: Read + Seek,
{
	let mut audio_track = false;
	let mut mdhd = None;
	let mut minf = None;
	for mdia in traks {
		if audio_track {
			break;
		}
		reader.seek(SeekFrom::Start(mdia.start + 8))?;
		let mut read = 8;
		while read < mdia.len {
			let Some(atom) = reader.next()? else { break };
			read += atom.len;
			if let AtomIdent::Fourcc(fourcc) = atom.ident {
				match &fourcc {
					b"mdhd" => {
						skip_unneeded(reader, atom.extended, atom.len)?;
						mdhd = Some(atom)
					},
					b"hdlr" => {
						reader.seek(SeekFrom::Current(8))?;
						let mut handler_type = [0; 4];
						reader.read_exact(&mut handler_type)?;
						if &handler_type == b"soun" {
							audio_track = true
						}
						skip_unneeded(reader, atom.extended, atom.len - 12)?;
					},
					b"minf" => minf = Some(atom),
					_ => {
						skip_unneeded(reader, atom.extended, atom.len)?;
					},
				}
				continue;
			}
			skip_unneeded(reader, atom.extended, atom.len)?;
		}
	}
	if !audio_track {
		decode_err!(@BAIL Mp4, "File contains no audio tracks");
	}
	let Some(mdhd) = mdhd else {
		err!(BadAtom("Expected atom \"trak.mdia.mdhd\""));
	};
	Ok(TrakChildren { mdhd, minf })
}
struct Mdhd {
	timescale: u32,
	duration: u64,
}
fn read_mdhd<R>(reader: &mut AtomReader<R>) -> Result<Mdhd>
where
	R: Read + Seek,
{
	let version = reader.read_u8()?;
	let _flags = reader.read_uint(3)?;
	let (timescale, duration) = if version == 1 {
		let _creation_time = reader.read_u64()?;
		let _modification_time = reader.read_u64()?;
		let timescale = reader.read_u32()?;
		let duration = reader.read_u64()?;
		(timescale, duration)
	} else {
		let _creation_time = reader.read_u32()?;
		let _modification_time = reader.read_u32()?;
		let timescale = reader.read_u32()?;
		let duration = reader.read_u32()?;
		(timescale, u64::from(duration))
	};
	Ok(Mdhd {
		timescale,
		duration,
	})
}
#[derive(Debug)]
struct SttsEntry {
	_sample_count: u32,
	sample_duration: u32,
}
fn read_stts<R>(reader: &mut R) -> Result<Vec<SttsEntry>>
where
	R: Read,
{
	let _version_and_flags = reader.read_uint::<BigEndian>(4)?;
	let entry_count = reader.read_u32::<BigEndian>()?;
	let mut entries = Vec::with_capacity(entry_count as usize);
	for _ in 0..entry_count {
		let sample_count = reader.read_u32::<BigEndian>()?;
		let sample_duration = reader.read_u32::<BigEndian>()?;
		entries.push(SttsEntry {
			_sample_count: sample_count,
			sample_duration,
		});
	}
	Ok(entries)
}
struct Minf {
	stsd_data: Vec<u8>,
	stts: Option<Vec<SttsEntry>>,
}
fn read_minf<R>(
	reader: &mut AtomReader<R>,
	len: u64,
	parse_mode: ParsingMode,
) -> Result<Option<Minf>>
where
	R: Read + Seek,
{
	let Some(stbl) = nested_atom(reader, len, b"stbl", parse_mode)? else {
		return Ok(None);
	};
	let mut stsd_data = None;
	let mut stts = None;
	let mut read = 8;
	while read < stbl.len {
		let Some(atom) = reader.next()? else { break };
		read += atom.len;
		if let AtomIdent::Fourcc(fourcc) = atom.ident {
			match &fourcc {
				b"stsd" => {
					let mut stsd = try_vec![0; (atom.len - 8) as usize];
					reader.read_exact(&mut stsd)?;
					stsd_data = Some(stsd);
				},
				b"stts" => stts = Some(read_stts(reader)?),
				_ => {
					skip_unneeded(reader, atom.extended, atom.len)?;
				},
			}
			continue;
		}
	}
	let Some(stsd_data) = stsd_data else {
		return Ok(None);
	};
	Ok(Some(Minf { stsd_data, stts }))
}
fn read_stsd<R>(reader: &mut AtomReader<R>, properties: &mut Mp4Properties) -> Result<()>
where
	R: Read + Seek,
{
	reader.seek(SeekFrom::Current(4))?;
	let num_sample_entries = reader.read_u32()?;
	for _ in 0..num_sample_entries {
		let Some(atom) = reader.next()? else {
			err!(BadAtom("Expected sample entry atom in `stsd` atom"))
		};
		let AtomIdent::Fourcc(ref fourcc) = atom.ident else {
			err!(BadAtom("Expected fourcc atom in `stsd` atom"))
		};
		match fourcc {
			b"mp4a" => mp4a_properties(reader, properties)?,
			b"alac" => alac_properties(reader, properties)?,
			b"fLaC" => flac_properties(reader, properties)?,
			b"drms" => {
				properties.drm_protected = true;
				skip_unneeded(reader, atom.extended, atom.len)?;
				continue;
			},
			_ => {
				log::warn!(
					"Found unsupported sample entry: {:?}",
					fourcc.escape_ascii().to_string()
				);
				skip_unneeded(reader, atom.extended, atom.len)?;
				continue;
			},
		}
		break;
	}
	Ok(())
}
pub(super) fn read_properties<R>(
	reader: &mut AtomReader<R>,
	traks: &[AtomInfo],
	file_length: u64,
	parse_mode: ParsingMode,
) -> Result<Mp4Properties>
where
	R: Read + Seek,
{
	let TrakChildren { mdhd, minf } = get_trak_children(reader, traks)?;
	reader.seek(SeekFrom::Start(mdhd.start + 8))?;
	let Mdhd {
		timescale,
		duration,
	} = read_mdhd(reader)?;
	let mut properties = Mp4Properties::default();
	if timescale > 0 {
		let duration_millis = (duration * 1000).div_round(u64::from(timescale));
		properties.duration = Duration::from_millis(duration_millis);
	}
	let Some(minf_info) = minf else {
		return Ok(properties);
	};
	reader.seek(SeekFrom::Start(minf_info.start + 8))?;
	let Some(Minf { stsd_data, stts }) = read_minf(reader, minf_info.len, parse_mode)? else {
		return Ok(properties);
	};
	let mut cursor = Cursor::new(&*stsd_data);
	let mut stsd_reader = AtomReader::new(&mut cursor, parse_mode)?;
	read_stsd(&mut stsd_reader, &mut properties)?;
	if duration > 0 {
		let mdat_len = mdat_length(reader)?;
		if let Some(stts) = stts {
			let stts_specifies_duration = !(stts.len() == 1 && stts[0].sample_duration == 1);
			if stts_specifies_duration {
				let audio_bitrate_bps = (((u128::from(mdat_len) * 8) * u128::from(timescale))
					/ u128::from(duration)) as u32;
				properties.audio_bitrate = audio_bitrate_bps / 1000;
			}
		}
		let duration_millis = properties.duration.as_millis();
		let overall_bitrate = u128::from(file_length * 8) / duration_millis;
		properties.overall_bitrate = overall_bitrate as u32;
		if properties.audio_bitrate == 0 {
			log::warn!("Estimating audio bitrate from 'mdat' size");
			properties.audio_bitrate =
				(u128::from(mdat_length(reader)? * 8) / duration_millis) as u32;
		}
	}
	Ok(properties)
}
pub(crate) const SAMPLE_RATES: [u32; 15] = [
	96000, 88200, 64000, 48000, 44100, 32000, 24000, 22050, 16000, 12000, 11025, 8000, 7350, 0, 0,
];
fn mp4a_properties<R>(stsd: &mut AtomReader<R>, properties: &mut Mp4Properties) -> Result<()>
where
	R: Read + Seek,
{
	const ELEMENTARY_DESCRIPTOR_TAG: u8 = 0x03;
	const DECODER_CONFIG_TAG: u8 = 0x04;
	const DECODER_SPECIFIC_DESCRIPTOR_TAG: u8 = 0x05;
	properties.codec = Mp4Codec::AAC;
	stsd.seek(SeekFrom::Current(16))?;
	properties.channels = stsd.read_u16()? as u8;
	stsd.seek(SeekFrom::Current(4))?;
	properties.sample_rate = stsd.read_u32()?;
	stsd.seek(SeekFrom::Current(2))?;
	let Ok(Some(esds)) = stsd.next() else {
		return Ok(());
	};
	if esds.ident != AtomIdent::Fourcc(*b"esds") {
		return Ok(());
	}
	if stsd.read_u32()? != 0 {
		return Ok(());
	}
	let descriptor = Descriptor::read(stsd)?;
	if descriptor.tag == ELEMENTARY_DESCRIPTOR_TAG {
		stsd.seek(SeekFrom::Current(3))?;
		let descriptor = Descriptor::read(stsd)?;
		if descriptor.tag == DECODER_CONFIG_TAG {
			let codec = stsd.read_u8()?;
			properties.codec = match codec {
				0x40 | 0x41 | 0x66 | 0x67 | 0x68 => Mp4Codec::AAC,
				0x69 | 0x6B => Mp4Codec::MP3,
				_ => Mp4Codec::Unknown,
			};
			stsd.seek(SeekFrom::Current(8))?;
			let average_bitrate = stsd.read_u32()?;
			let descriptor = Descriptor::read(stsd)?;
			if descriptor.tag == DECODER_SPECIFIC_DESCRIPTOR_TAG {
				let byte_a = stsd.read_u8()?;
				let byte_b = stsd.read_u8()?;
				let mut object_type = byte_a >> 3;
				let mut frequency_index = ((byte_a & 0x07) << 1) | (byte_b >> 7);
				let mut channel_conf = (byte_b >> 3) & 0x0F;
				let mut extended_object_type = false;
				if object_type == 31 {
					extended_object_type = true;
					object_type = 32 + ((byte_a & 7) | (byte_b >> 5));
					frequency_index = (byte_b >> 1) & 0x0F;
				}
				properties.extended_audio_object_type =
					Some(AudioObjectType::try_from(object_type)?);
				match frequency_index {
					0x0F => {
						let sample_rate;
						let explicit_sample_rate = stsd.read_u24::<BigEndian>()?;
						if extended_object_type {
							sample_rate = explicit_sample_rate >> 1;
							channel_conf = ((explicit_sample_rate >> 4) & 0x0F) as u8;
						} else {
							sample_rate = explicit_sample_rate << 1;
							let byte_c = stsd.read_u8()?;
							channel_conf =
								((explicit_sample_rate & 0x80) as u8 | (byte_c >> 1)) & 0x0F;
						}
						if sample_rate > 0 {
							properties.sample_rate = sample_rate;
						}
					},
					i if i < SAMPLE_RATES.len() as u8 => {
						properties.sample_rate = SAMPLE_RATES[i as usize];
						if extended_object_type {
							let byte_c = stsd.read_u8()?;
							channel_conf = (byte_b & 1) | (byte_c & 0xE0);
						} else {
							channel_conf = (byte_b >> 3) & 0x0F;
						}
					},
					_ => {},
				}
				if channel_conf > 0 {
					properties.channels = channel_conf;
				}
				if object_type == 36 {
					let mut ident = [0; 5];
					stsd.read_exact(&mut ident)?;
					if &ident == b"\0ALS\0" {
						properties.sample_rate = stsd.read_u32()?;
						stsd.seek(SeekFrom::Current(4))?;
						properties.channels = stsd.read_u16()? as u8 + 1;
					}
				}
			}
			if average_bitrate > 0 || properties.duration.is_zero() {
				properties.audio_bitrate = average_bitrate / 1000;
			}
		}
	}
	Ok(())
}
fn alac_properties<R>(stsd: &mut AtomReader<R>, properties: &mut Mp4Properties) -> Result<()>
where
	R: Read + Seek,
{
	if stsd.seek(SeekFrom::End(0))? != 80 {
		return Ok(());
	}
	stsd.seek(SeekFrom::Start(44))?;
	let Ok(Some(alac)) = stsd.next() else {
		return Ok(());
	};
	if alac.ident != AtomIdent::Fourcc(*b"alac") {
		return Ok(());
	}
	properties.codec = Mp4Codec::ALAC;
	stsd.seek(SeekFrom::Current(9))?;
	let sample_size = stsd.read_u8()?;
	properties.bit_depth = Some(sample_size);
	stsd.seek(SeekFrom::Current(3))?;
	properties.channels = stsd.read_u8()?;
	stsd.seek(SeekFrom::Current(6))?;
	properties.audio_bitrate = stsd.read_u32()? / 1000;
	properties.sample_rate = stsd.read_u32()?;
	Ok(())
}
fn flac_properties<R>(stsd: &mut AtomReader<R>, properties: &mut Mp4Properties) -> Result<()>
where
	R: Read + Seek,
{
	properties.codec = Mp4Codec::FLAC;
	stsd.seek(SeekFrom::Current(16))?;
	properties.channels = stsd.read_u16()? as u8;
	properties.bit_depth = Some(stsd.read_u16()? as u8);
	stsd.seek(SeekFrom::Current(4))?;
	properties.sample_rate = u32::from(stsd.read_u16()?);
	let _reserved = stsd.read_u16()?;
	let Some(dfla) = stsd.next()? else {
		return Ok(());
	};
	if dfla.ident != AtomIdent::Fourcc(*b"dfLa") {
		return Ok(());
	}
	stsd.seek(SeekFrom::Current(4))?;
	if dfla.len - 12 < 18 {
		return Ok(());
	}
	let stream_info_block = crate::flac::block::Block::read(stsd)?;
	let flac_properties =
		crate::flac::properties::read_properties(&mut &stream_info_block.content[..], 0, 0)?;
	properties.sample_rate = flac_properties.sample_rate;
	properties.bit_depth = Some(flac_properties.bit_depth);
	properties.channels = flac_properties.channels;
	Ok(())
}
fn mdat_length<R>(reader: &mut AtomReader<R>) -> Result<u64>
where
	R: Read + Seek,
{
	reader.rewind()?;
	while let Ok(Some(atom)) = reader.next() {
		if atom.ident == AtomIdent::Fourcc(*b"mdat") {
			return Ok(atom.len - 8);
		}
		skip_unneeded(reader, atom.extended, atom.len)?;
	}
	decode_err!(@BAIL Mp4, "Failed to find \"mdat\" atom");
}
struct Descriptor {
	tag: u8,
	_size: u32,
}
impl Descriptor {
	fn read<R: Read>(reader: &mut R) -> Result<Descriptor> {
		let tag = reader.read_u8()?;
		let mut size: u32 = 0;
		for _ in 0..4 {
			let b = reader.read_u8()?;
			size = (size << 7) | u32::from(b & 0x7F);
			if b & 0x80 == 0 {
				break;
			}
		}
		Ok(Descriptor { tag, _size: size })
	}
}