use std::convert::{TryFrom, TryInto};
use std::sync::Arc;
use bytes::{BytesMut, BufMut};
use snafu::OptionExt;
use uuid::Uuid;
use edgedb_errors::{Error, ErrorKind};
use edgedb_errors::{ClientEncodingError, ProtocolError, DescriptorMismatch};
use crate::codec::{self, Codec, build_codec};
use crate::descriptors::Descriptor;
use crate::descriptors::TypePos;
use crate::errors;
use crate::features::ProtocolVersion;
use crate::value::Value;
use crate::model::range;
pub struct Encoder<'a> {
    pub ctx: &'a DescriptorContext<'a>,
    pub buf: &'a mut BytesMut,
}
pub trait QueryArg: Send + Sync + Sized {
    fn encode_slot(&self, encoder: &mut Encoder)
        -> Result<(), Error>;
    fn check_descriptor(&self, ctx: &DescriptorContext, pos: TypePos)
        -> Result<(), Error>;
    fn to_value(&self) -> Result<Value, Error>;
}
pub trait ScalarArg: Send + Sync + Sized {
    fn encode(&self, encoder: &mut Encoder)
        -> Result<(), Error>;
    fn check_descriptor(ctx: &DescriptorContext, pos: TypePos)
        -> Result<(), Error>;
    fn to_value(&self) -> Result<Value, Error>;
}
pub trait QueryArgs: Send + Sync {
    fn encode(&self, encoder: &mut Encoder)
        -> Result<(), Error>;
}
pub struct DescriptorContext<'a> {
    #[allow(dead_code)]
    pub(crate) proto: &'a ProtocolVersion,
    pub(crate) root_pos: Option<TypePos>,
    pub(crate) descriptors: &'a [Descriptor],
}
impl<'a> Encoder<'a> {
    pub fn new(ctx: &'a DescriptorContext<'a>, buf: &'a mut BytesMut)
        -> Encoder<'a>
    {
        Encoder { ctx, buf }
    }
    pub fn length_prefixed(&mut self,
        f: impl FnOnce(&mut Encoder) -> Result<(), Error>)
        -> Result<(), Error>
    {
        self.buf.reserve(4);
        let pos = self.buf.len();
        self.buf.put_u32(0);  f(self)?;
        let len = self.buf.len()-pos-4;
        self.buf[pos..pos+4].copy_from_slice(&u32::try_from(len)
                .map_err(|_| ClientEncodingError::with_message(
                        "alias is too long"))?
                .to_be_bytes());
        Ok(())
    }
}
impl DescriptorContext<'_> {
    pub fn get(&self, type_pos: TypePos)
        -> Result<&Descriptor, Error>
    {
        self.descriptors.get(type_pos.0 as usize)
            .ok_or_else(|| ProtocolError::with_message(
                "invalid type descriptor"))
    }
    pub fn build_codec(&self) -> Result<Arc<dyn Codec>, Error> {
        build_codec(self.root_pos, self.descriptors)
        .map_err(|e| ProtocolError::with_source(e)
            .context("error decoding input codec"))
    }
    pub fn wrong_type(&self, descriptor: &Descriptor, expected: &str) -> Error
    {
        DescriptorMismatch::with_message(format!("\nEdgeDB returned unexpected type {descriptor:?}\nClient expected {expected}"))
    }
    pub fn field_number(&self, expected: usize, unexpected: usize)
        -> Error
    {
        DescriptorMismatch::with_message(format!(
            "expected {} fields, got {}",
            expected, unexpected))
    }
}
impl<T: ScalarArg> ScalarArg for &T {
    fn encode(&self, encoder: &mut Encoder)
        -> Result<(), Error>
    {
        (*self).encode(encoder)
    }
    fn check_descriptor(ctx: &DescriptorContext, pos: TypePos)
        -> Result<(), Error>
    {
        T::check_descriptor(ctx, pos)
    }
    fn to_value(&self) -> Result<Value, Error> {
        (*self).to_value()
    }
}
impl QueryArgs for () {
    fn encode(&self, enc: &mut Encoder)
        -> Result<(), Error>
    {
        if enc.ctx.root_pos.is_some() {
            if enc.ctx.proto.is_at_most(0, 11) {
                let root = enc.ctx.root_pos.and_then(|p| enc.ctx.get(p).ok());
                match root {
                    Some(Descriptor::Tuple(t))
                    if t.id == Uuid::from_u128(0xFF)
                    && t.element_types.is_empty()
                    => {}
                    _ => return Err(DescriptorMismatch::with_message(
                            "query arguments expected")),
                };
            } else {
                return Err(DescriptorMismatch::with_message(
                    "query arguments expected"));
            }
        }
        if enc.ctx.proto.is_at_most(0, 11) {
            enc.buf.reserve(4);
            enc.buf.put_u32(0);
        }
        Ok(())
    }
}
impl QueryArg for Value {
    fn encode_slot(&self, enc: &mut Encoder)
        -> Result<(), Error>
    {
        use Value::*;
        match self {
            Nothing => {
                enc.buf.reserve(4);
                enc.buf.put_i32(-1);
            }
            Uuid(v) => v.encode_slot(enc)?,
            Str(v) => v.encode_slot(enc)?,
            Bytes(v) => v.encode_slot(enc)?,
            Int16(v) => v.encode_slot(enc)?,
            Int32(v) => v.encode_slot(enc)?,
            Int64(v) => v.encode_slot(enc)?,
            Float32(v) => v.encode_slot(enc)?,
            Float64(v) => v.encode_slot(enc)?,
            BigInt(v) => v.encode_slot(enc)?,
            ConfigMemory(v) => v.encode_slot(enc)?,
            Decimal(v) => v.encode_slot(enc)?,
            Bool(v) => v.encode_slot(enc)?,
            Datetime(v) => v.encode_slot(enc)?,
            LocalDatetime(v) => v.encode_slot(enc)?,
            LocalDate(v) => v.encode_slot(enc)?,
            LocalTime(v) => v.encode_slot(enc)?,
            Duration(v) => v.encode_slot(enc)?,
            RelativeDuration(v) => v.encode_slot(enc)?,
            DateDuration(v) => v.encode_slot(enc)?,
            Json(v) => v.encode_slot(enc)?,
            Set(_) => return Err(ClientEncodingError::with_message(
                    "set cannot be query argument")),
            Object {..} => return Err(ClientEncodingError::with_message(
                    "object cannot be query argument")),
            SparseObject(_) => return Err(ClientEncodingError::with_message(
                    "sparse object cannot be query argument")),
            Tuple(_) => return Err(ClientEncodingError::with_message(
                    "tuple object cannot be query argument")),
            NamedTuple {..} => return Err(ClientEncodingError::with_message(
                    "named tuple object cannot be query argument")),
            Array(v) => v.encode_slot(enc)?,
            Enum(v) => v.encode_slot(enc)?,
            Range(v) => v.encode_slot(enc)?,
        }
        Ok(())
    }
    fn check_descriptor(&self, ctx: &DescriptorContext, pos: TypePos)
        -> Result<(), Error>
    {
        use Descriptor::*;
        use Value::*;
        let mut desc = ctx.get(pos)?;
        if let Scalar(d) = desc {
            desc = ctx.get(d.base_type_pos)?;
        }
        match (desc, self) {
            (_, Nothing) => Ok(()),  (Scalar(_), _) => {
                unreachable!("scalar dereference to a non-base type");
            }
            (BaseScalar(d), Uuid(_)) if d.id == codec::STD_UUID => Ok(()),
            (BaseScalar(d), Str(_)) if d.id == codec::STD_STR => Ok(()),
            (BaseScalar(d), Bytes(_)) if d.id == codec::STD_BYTES => Ok(()),
            (BaseScalar(d), Int16(_)) if d.id == codec::STD_INT16 => Ok(()),
            (BaseScalar(d), Int32(_)) if d.id == codec::STD_INT32 => Ok(()),
            (BaseScalar(d), Int64(_)) if d.id == codec::STD_INT64 => Ok(()),
            (BaseScalar(d), Float32(_)) if d.id == codec::STD_FLOAT32 => Ok(()),
            (BaseScalar(d), Float64(_)) if d.id == codec::STD_FLOAT64 => Ok(()),
            (BaseScalar(d), BigInt(_)) if d.id == codec::STD_BIGINT => Ok(()),
            (BaseScalar(d), ConfigMemory(_))
                if d.id == codec::CFG_MEMORY => Ok(()),
            (BaseScalar(d), Decimal(_))
                if d.id == codec::STD_DECIMAL => Ok(()),
            (BaseScalar(d), Bool(_)) if d.id == codec::STD_BOOL => Ok(()),
            (BaseScalar(d), Datetime(_))
                if d.id == codec::STD_DATETIME => Ok(()),
            (BaseScalar(d), LocalDatetime(_))
                if d.id == codec::CAL_LOCAL_DATETIME => Ok(()),
            (BaseScalar(d), LocalDate(_))
                if d.id == codec::CAL_LOCAL_DATE => Ok(()),
            (BaseScalar(d), LocalTime(_))
                if d.id == codec::CAL_LOCAL_TIME => Ok(()),
            (BaseScalar(d), Duration(_))
                if d.id == codec::STD_DURATION => Ok(()),
            (BaseScalar(d), RelativeDuration(_))
                if d.id == codec::CAL_RELATIVE_DURATION => Ok(()),
            (BaseScalar(d), DateDuration(_))
                if d.id == codec::CAL_DATE_DURATION => Ok(()),
            (BaseScalar(d), Json(_))
                if d.id == codec::STD_JSON => Ok(()),
            (desc, _) => Err(ctx.wrong_type(desc, self.kind())),
        }
    }
    fn to_value(&self) -> Result<Value, Error>
    {
        Ok(self.clone())
    }
}
impl QueryArgs for Value {
    fn encode(&self, enc: &mut Encoder)
        -> Result<(), Error>
    {
        let codec = enc.ctx.build_codec()?;
        codec.encode(&mut enc.buf, self)
            .map_err(ClientEncodingError::with_source)
    }
}
impl<T: ScalarArg> QueryArg for T {
    fn encode_slot(&self, enc: &mut Encoder) -> Result<(), Error> {
        enc.buf.reserve(4);
        let pos = enc.buf.len();
        enc.buf.put_u32(0); ScalarArg::encode(self, enc)?;
        let len = enc.buf.len()-pos-4;
        enc.buf[pos..pos+4].copy_from_slice(&i32::try_from(len)
                .ok().context(errors::ElementTooLong)
                .map_err(ClientEncodingError::with_source)?
                .to_be_bytes());
        Ok(())
    }
    fn check_descriptor(&self, ctx: &DescriptorContext, pos: TypePos)
        -> Result<(), Error>
    {
        T::check_descriptor(ctx, pos)
    }
    fn to_value(&self) -> Result<Value, Error> {
        ScalarArg::to_value(self)
    }
}
impl<T: ScalarArg> QueryArg for Option<T> {
    fn encode_slot(&self, enc: &mut Encoder) -> Result<(), Error> {
        if let Some(val) = self {
            QueryArg::encode_slot(val, enc)
        } else {
            enc.buf.reserve(4);
            enc.buf.put_i32(-1);
            Ok(())
        }
    }
    fn check_descriptor(&self, ctx: &DescriptorContext, pos: TypePos)
        -> Result<(), Error>
    {
        T::check_descriptor(ctx, pos)
    }
    fn to_value(&self) -> Result<Value, Error> {
        match self.as_ref() {
            Some(v) => ScalarArg::to_value(v),
            None => Ok(Value::Nothing),
        }
    }
}
impl<T: ScalarArg> QueryArg for Vec<T> {
    fn encode_slot(&self, enc: &mut Encoder) -> Result<(), Error> {
        enc.buf.reserve(8);
        enc.length_prefixed(|enc| {
            if self.is_empty() {
                enc.buf.reserve(12);
                enc.buf.put_u32(0);  enc.buf.put_u32(0);  enc.buf.put_u32(0);  return Ok(());
            }
            enc.buf.reserve(20);
            enc.buf.put_u32(1);  enc.buf.put_u32(0);  enc.buf.put_u32(0);  enc.buf.put_u32(self.len().try_into()
                .map_err(|_| ClientEncodingError::with_message(
                        "array is too long"))?);
            enc.buf.put_u32(1);  for item in self {
                enc.length_prefixed(|enc| item.encode(enc))?;
            }
            Ok(())
        })
    }
    fn check_descriptor(&self, ctx: &DescriptorContext, pos: TypePos)
        -> Result<(), Error>
    {
        let desc = ctx.get(pos)?;
        if let Descriptor::Array(arr) = desc {
            T::check_descriptor(ctx, arr.type_pos)
        } else {
            Err(ctx.wrong_type(desc, "array"))
        }
    }
    fn to_value(&self) -> Result<Value, Error> {
        Ok(Value::Array(self.iter().map(|v| v.to_value())
                        .collect::<Result<_, _>>()?))
    }
}
impl QueryArg for Vec<Value> {
    fn encode_slot(&self, enc: &mut Encoder) -> Result<(), Error> {
        enc.buf.reserve(8);
        enc.length_prefixed(|enc| {
            if self.is_empty() {
                enc.buf.reserve(12);
                enc.buf.put_u32(0);  enc.buf.put_u32(0);  enc.buf.put_u32(0);  return Ok(());
            }
            enc.buf.reserve(20);
            enc.buf.put_u32(1);  enc.buf.put_u32(0);  enc.buf.put_u32(0);  enc.buf.put_u32(self.len().try_into()
                .map_err(|_| ClientEncodingError::with_message(
                        "array is too long"))?);
            enc.buf.put_u32(1);  for item in self {
                enc.length_prefixed(|enc| item.encode(enc))?;
            }
            Ok(())
        })
    }
    fn check_descriptor(&self, ctx: &DescriptorContext, pos: TypePos)
        -> Result<(), Error>
    {
        let desc = ctx.get(pos)?;
        if let Descriptor::Array(arr) = desc {
            for val in self {
                val.check_descriptor(ctx, arr.type_pos)?
            }
            Ok(())
        } else {
            Err(ctx.wrong_type(desc, "array"))
        }
    }
    fn to_value(&self) -> Result<Value, Error> {
        Ok(Value::Array(self.iter().map(|v| v.to_value())
                        .collect::<Result<_, _>>()?))
    }
}
impl QueryArg for range::Range<Box<Value>> {
    fn encode_slot(&self, encoder: &mut Encoder)
        -> Result<(), Error>
    {
        encoder.length_prefixed(|encoder| {
            let flags =
                if self.empty { range::EMPTY } else {
                    (if self.inc_lower { range::LB_INC } else { 0 }) |
                    (if self.inc_upper { range::UB_INC } else { 0 }) |
                    (if self.lower.is_none() { range::LB_INF } else { 0 }) |
                    (if self.upper.is_none() { range::UB_INF } else { 0 })
                };
            encoder.buf.reserve(1);
            encoder.buf.put_u8(flags as u8);
            if let Some(lower) = &self.lower {
                encoder.length_prefixed(|encoder| {
                    lower.encode(encoder)
                })?
            }
            if let Some(upper) = &self.upper {
                encoder.length_prefixed(|encoder| {
                    upper.encode(encoder)
                })?;
            }
            Ok(())
        })
    }
    fn check_descriptor(&self, ctx: &DescriptorContext, pos: TypePos)
        -> Result<(), Error>
    {
        let desc = ctx.get(pos)?;
        if let Descriptor::Range(rng) = desc {
            self.lower.as_ref()
                .map(|v| v.check_descriptor(ctx, rng.type_pos))
                .transpose()?;
            self.upper.as_ref()
                .map(|v| v.check_descriptor(ctx, rng.type_pos))
                .transpose()?;
            Ok(())
        } else {
            Err(ctx.wrong_type(desc, "range"))
        }
    }
    fn to_value(&self) -> Result<Value, Error> {
        Ok(Value::Range(self.clone()))
    }
}
macro_rules! implement_tuple {
    ( $count:expr, $($name:ident,)+ ) => {
        impl<$($name:QueryArg),+> QueryArgs for ($($name,)+) {
            fn encode(&self, enc: &mut Encoder)
                -> Result<(), Error>
            {
                #![allow(non_snake_case)]
                let root_pos = enc.ctx.root_pos
                    .ok_or_else(|| DescriptorMismatch::with_message(
                        format!(
                            "provided {} positional arguments, \
                             but no arguments expected by the server",
                             $count)))?;
                let desc = enc.ctx.get(root_pos)?;
                match desc {
                    Descriptor::ObjectShape(desc)
                    if enc.ctx.proto.is_at_least(0, 12)
                    => {
                        if desc.elements.len() != $count {
                            return Err(enc.ctx.field_number(
                                desc.elements.len(), $count));
                        }
                        let mut els = desc.elements.iter().enumerate();
                        let ($(ref $name,)+) = self;
                        $(
                            let (idx, el) = els.next().unwrap();
                            if el.name.parse() != Ok(idx) {
                                return Err(DescriptorMismatch::with_message(
                                    format!("expected positional arguments, \
                                             got {} instead of {}",
                                             el.name, idx)));
                            }
                            $name.check_descriptor(enc.ctx, el.type_pos)?;
                        )+
                    }
                    Descriptor::Tuple(desc) if enc.ctx.proto.is_at_most(0, 11)
                    => {
                        if desc.element_types.len() != $count {
                            return Err(enc.ctx.field_number(
                                desc.element_types.len(), $count));
                        }
                        let mut els = desc.element_types.iter();
                        let ($(ref $name,)+) = self;
                        $(
                            let type_pos = els.next().unwrap();
                            $name.check_descriptor(enc.ctx, *type_pos)?;
                        )+
                    }
                    _ => return Err(enc.ctx.wrong_type(desc,
                        if enc.ctx.proto.is_at_least(0, 12) { "object" }
                        else { "tuple" }))
                }
                enc.buf.reserve(4 + 8*$count);
                enc.buf.put_u32($count);
                let ($(ref $name,)+) = self;
                $(
                    enc.buf.reserve(8);
                    enc.buf.put_u32(0);
                    QueryArg::encode_slot($name, enc)?;
                )*
                Ok(())
            }
        }
    }
}
implement_tuple!{1, T0, }
implement_tuple!{2, T0, T1, }
implement_tuple!{3, T0, T1, T2, }
implement_tuple!{4, T0, T1, T2, T3, }
implement_tuple!{5, T0, T1, T2, T3, T4, }
implement_tuple!{6, T0, T1, T2, T3, T4, T5, }
implement_tuple!{7, T0, T1, T2, T3, T4, T5, T6, }
implement_tuple!{8, T0, T1, T2, T3, T4, T5, T6, T7, }
implement_tuple!{9, T0, T1, T2, T3, T4, T5, T6, T7, T8, }
implement_tuple!{10, T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, }
implement_tuple!{11, T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, }
implement_tuple!{12, T0, T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, }