use floating_ui_utils::{
    get_alignment, get_alignment_sides, get_expanded_placements, get_opposite_axis_placements,
    get_opposite_placement, get_side, Alignment, Placement,
};
use serde::{Deserialize, Serialize};
use crate::{
    detect_overflow::{detect_overflow, DetectOverflowOptions},
    middleware::arrow::{ArrowData, ARROW_NAME},
    types::{
        Derivable, DerivableFn, Middleware, MiddlewareReturn, MiddlewareState,
        MiddlewareWithOptions, Reset, ResetValue,
    },
};
pub const FLIP_NAME: &str = "flip";
#[derive(Copy, Clone, Debug, Default, PartialEq)]
pub enum FallbackStrategy {
    #[default]
    BestFit,
    InitialPlacement,
}
#[derive(Clone, Debug)]
pub struct FlipOptions<Element: Clone> {
    pub detect_overflow: Option<DetectOverflowOptions<Element>>,
    pub main_axis: Option<bool>,
    pub cross_axis: Option<bool>,
    pub fallback_placements: Option<Vec<Placement>>,
    pub fallback_strategy: Option<FallbackStrategy>,
    pub fallback_axis_side_direction: Option<Alignment>,
    pub flip_alignment: Option<bool>,
}
impl<Element: Clone> FlipOptions<Element> {
    pub fn detect_overflow(mut self, value: DetectOverflowOptions<Element>) -> Self {
        self.detect_overflow = Some(value);
        self
    }
    pub fn main_axis(mut self, value: bool) -> Self {
        self.main_axis = Some(value);
        self
    }
    pub fn cross_axis(mut self, value: bool) -> Self {
        self.cross_axis = Some(value);
        self
    }
    pub fn fallback_placements(mut self, value: Vec<Placement>) -> Self {
        self.fallback_placements = Some(value);
        self
    }
    pub fn fallback_strategy(mut self, value: FallbackStrategy) -> Self {
        self.fallback_strategy = Some(value);
        self
    }
    pub fn fallback_axis_side_direction(mut self, value: Alignment) -> Self {
        self.fallback_axis_side_direction = Some(value);
        self
    }
    pub fn flip_alignment(mut self, value: bool) -> Self {
        self.flip_alignment = Some(value);
        self
    }
}
impl<Element: Clone> Default for FlipOptions<Element> {
    fn default() -> Self {
        Self {
            detect_overflow: Default::default(),
            main_axis: Default::default(),
            cross_axis: Default::default(),
            fallback_placements: Default::default(),
            fallback_strategy: Default::default(),
            fallback_axis_side_direction: Default::default(),
            flip_alignment: Default::default(),
        }
    }
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct FlipDataOverflow {
    pub placement: Placement,
    pub overflows: Vec<f64>,
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct FlipData {
    pub index: usize,
    pub overflows: Vec<FlipDataOverflow>,
}
pub struct Flip<'a, Element: Clone, Window: Clone> {
    options: Derivable<'a, Element, Window, FlipOptions<Element>>,
}
impl<'a, Element: Clone, Window: Clone> Flip<'a, Element, Window> {
    pub fn new(options: FlipOptions<Element>) -> Self {
        Flip {
            options: options.into(),
        }
    }
    pub fn new_derivable(options: Derivable<'a, Element, Window, FlipOptions<Element>>) -> Self {
        Flip { options }
    }
    pub fn new_derivable_fn(
        options: DerivableFn<'a, Element, Window, FlipOptions<Element>>,
    ) -> Self {
        Flip {
            options: options.into(),
        }
    }
}
impl<'a, Element: Clone, Window: Clone> Clone for Flip<'a, Element, Window> {
    fn clone(&self) -> Self {
        Self {
            options: self.options.clone(),
        }
    }
}
impl<'a, Element: Clone, Window: Clone> Middleware<Element, Window> for Flip<'a, Element, Window> {
    fn name(&self) -> &'static str {
        FLIP_NAME
    }
    fn compute(&self, state: MiddlewareState<Element, Window>) -> MiddlewareReturn {
        let options = self.options.evaluate(state.clone());
        let MiddlewareState {
            placement,
            initial_placement,
            middleware_data,
            elements,
            rects,
            platform,
            ..
        } = state;
        let data: FlipData = middleware_data.get_as(self.name()).unwrap_or(FlipData {
            index: 0,
            overflows: vec![],
        });
        let check_main_axis = options.main_axis.unwrap_or(true);
        let check_cross_axis = options.cross_axis.unwrap_or(true);
        let specified_fallback_placements = options.fallback_placements.clone();
        let fallback_strategy = options.fallback_strategy.unwrap_or_default();
        let fallback_axis_side_direction = options.fallback_axis_side_direction;
        let flip_alignment = options.flip_alignment.unwrap_or(true);
        let arrow_data: Option<ArrowData> = middleware_data.get_as(ARROW_NAME);
        if arrow_data.map_or(false, |arrow_data| arrow_data.alignment_offset.is_some()) {
            return MiddlewareReturn {
                x: None,
                y: None,
                data: None,
                reset: None,
            };
        }
        let side = get_side(placement);
        let is_base_placement = get_alignment(initial_placement).is_none();
        let rtl = platform.is_rtl(elements.floating);
        let has_specified_fallback_placements = specified_fallback_placements.is_some();
        let mut placements =
            specified_fallback_placements.unwrap_or(match is_base_placement || !flip_alignment {
                true => vec![get_opposite_placement(initial_placement)],
                false => get_expanded_placements(initial_placement),
            });
        if !has_specified_fallback_placements && fallback_axis_side_direction.is_some() {
            placements.append(&mut get_opposite_axis_placements(
                initial_placement,
                flip_alignment,
                fallback_axis_side_direction,
                rtl,
            ));
        }
        placements.insert(0, initial_placement);
        let overflow = detect_overflow(
            MiddlewareState {
                elements: elements.clone(),
                ..state
            },
            options.detect_overflow.unwrap_or_default(),
        );
        let mut overflows: Vec<f64> = Vec::new();
        let mut overflows_data = data.overflows;
        if check_main_axis {
            overflows.push(overflow.side(side));
        }
        if check_cross_axis {
            let sides = get_alignment_sides(placement, rects, rtl);
            overflows.push(overflow.side(sides.0));
            overflows.push(overflow.side(sides.1));
        }
        overflows_data.push(FlipDataOverflow {
            placement,
            overflows: overflows.clone(),
        });
        if !overflows.into_iter().all(|side| side <= 0.0) {
            let next_index = data.index + 1;
            let next_placement = placements.get(next_index);
            if let Some(next_placement) = next_placement {
                return MiddlewareReturn {
                    x: None,
                    y: None,
                    data: Some(
                        serde_json::to_value(FlipData {
                            index: next_index,
                            overflows: overflows_data,
                        })
                        .expect("Data should be valid JSON."),
                    ),
                    reset: Some(Reset::Value(ResetValue {
                        placement: Some(*next_placement),
                        rects: None,
                    })),
                };
            }
            let mut reset_placement: Vec<&FlipDataOverflow> = overflows_data
                .iter()
                .filter(|overflow| overflow.overflows[0] <= 0.0)
                .collect();
            reset_placement.sort_by(|a, b| a.overflows[1].total_cmp(&b.overflows[1]));
            let mut reset_placement = reset_placement.first().map(|overflow| overflow.placement);
            if reset_placement.is_none() {
                match fallback_strategy {
                    FallbackStrategy::BestFit => {
                        let mut placement: Vec<(Placement, f64)> = overflows_data
                            .into_iter()
                            .map(|overflow| {
                                (
                                    overflow.placement,
                                    overflow
                                        .overflows
                                        .into_iter()
                                        .filter(|overflow| *overflow > 0.0)
                                        .sum::<f64>(),
                                )
                            })
                            .collect();
                        placement.sort_by(|a, b| a.1.total_cmp(&b.1));
                        let placement = placement.first().map(|v| v.0);
                        if placement.is_some() {
                            reset_placement = placement;
                        }
                    }
                    FallbackStrategy::InitialPlacement => {
                        reset_placement = Some(initial_placement);
                    }
                }
            }
            if placement != reset_placement.expect("Reset placement is not none.") {
                return MiddlewareReturn {
                    x: None,
                    y: None,
                    data: None,
                    reset: Some(Reset::Value(ResetValue {
                        placement: reset_placement,
                        rects: None,
                    })),
                };
            }
        }
        MiddlewareReturn {
            x: None,
            y: None,
            data: None,
            reset: None,
        }
    }
}
impl<'a, Element: Clone, Window: Clone> MiddlewareWithOptions<Element, Window, FlipOptions<Element>>
    for Flip<'a, Element, Window>
{
    fn options(&self) -> &Derivable<Element, Window, FlipOptions<Element>> {
        &self.options
    }
}