use std::borrow::Cow;
use std::cell::RefCell;
use std::collections::{HashMap, HashSet};
use std::iter;
use std::rc::Rc;
use std::sync::Arc;

use wgpu::{BindGroup, BindGroupLayout, BlendState, Buffer, BufferDescriptor, BufferSize, BufferUsages, ColorTargetState, ColorWrites, CompareFunction, DepthStencilState, Device, FragmentState, IndexFormat, MultisampleState, PipelineLayoutDescriptor, RenderPass, RenderPipeline, RenderPipelineDescriptor, ShaderModuleDescriptor, ShaderSource, VertexState};

use crate::asset::AssetManagerRc;
use crate::model::{InstGridBuf, InstPhongColorBuf, InstShaderImplType, InstShaderSize, InstShaderType, InstSimpleColorBuf, InstWindowBuf, Mesh};
use crate::output::OutputInfoRc;
use crate::ui::UIManagerRc;
use crate::util::StatsRc;

// Keep render->Uni and modelrender->UNI_TMPL in-sync.
// TODO: make it model-specific, like vertex/inst shader?
const UNI_TMPL: &str = "
struct Uni {
    light_pos: vec3<f32>,
    _unused1: f32,
    cam_pos: vec3<f32>,
    _unused2: f32,
    view_m: array<mat4x4<f32>, #VIEW_LEN#>,
}

@group(0) @binding(0) var<uniform> uni: Uni;
";

pub trait ModelFactory {
    type Model: Model + 'static;

    fn get_name() -> &'static str;
    fn get_mesh(asset_mgr: AssetManagerRc, device: &Device) -> Mesh;
    fn create(self, handle: ModelHandle, device: &Device, inst_sh_impls: &mut [InstShaderImplType], ui_manager: UIManagerRc) -> Self::Model;
}

pub trait Model {
    fn fill_simple_color(&self, _inst_index: u32) -> InstSimpleColorBuf {
        panic!("Method is not implemented");
    }

    fn fill_phong_color(&self, _inst_index: u32) -> InstPhongColorBuf {
        panic!("Method is not implemented");
    }

    fn fill_grid(&self, _inst_index: u32) -> InstGridBuf {
        panic!("Method is not implemented");
    }

    fn fill_window(&self, _inst_index: u32) -> InstWindowBuf {
        panic!("Method is not implemented");
    }
}

type ModelInfos = HashMap<String, ModelInfo>; 
type Visibles = Rc<RefCell<Box<[HashSet<u32>]>>>; // Box[inst_index]->HashSet[model_index]

pub struct ModelRegistry {
    asset_mgr: AssetManagerRc,
    output_info: OutputInfoRc,
    ui_manager: UIManagerRc,
    model_infos: ModelInfos,
}

struct ModelInfo {
    mesh: Mesh,
    inst_sh_impls: Box<[InstShaderImplType]>,
    models: Vec<Rc<dyn Model>>, // [model_index]
    visibles: Visibles,
}

impl ModelRegistry {
    pub fn new(asset_mgr: AssetManagerRc, output_info: OutputInfoRc, ui_manager: UIManagerRc) -> Self {
        Self {
            asset_mgr,
            output_info,
            ui_manager,
            model_infos: HashMap::new(),
        }
    }

    pub fn create<F: ModelFactory>(&mut self, factory: F) -> Rc<F::Model> {
        // Models are grouped by name, so we can do instanced rendering,
        // see ModelRenderer->render().

        let device = self.output_info.get_device();

        let model_info = self.model_infos.entry(F::get_name().to_string()).or_insert_with(|| {
            let mesh = F::get_mesh(Arc::clone(&self.asset_mgr), device);
            let submeshes = mesh.get_submeshes();

            let inst_sh_impls = submeshes.iter().map(|submesh| submesh.get_inst_sh_type().create_impl()).collect();
            let visibles = Rc::new(RefCell::new(iter::repeat_with(HashSet::new).take(submeshes.len()).collect()));

            ModelInfo {
                mesh,
                inst_sh_impls,
                models: Vec::new(),
                visibles,
            }
        });

        let model_index = model_info.models.len().try_into().unwrap();
        let handle = ModelHandle::new(Rc::clone(&model_info.visibles), model_index);
        let model = Rc::new(factory.create(handle, device, &mut model_info.inst_sh_impls, Rc::clone(&self.ui_manager)));

        model_info.models.push(Rc::clone(&model) as Rc<dyn Model>);

        model
    }

    pub fn build(self, stats: StatsRc, uni_bg_layout: &BindGroupLayout) -> ModelRenderer {
        ModelRenderer::new(Arc::clone(&self.asset_mgr), self.output_info, stats, self.model_infos, uni_bg_layout)
    }
}

pub struct ModelHandle {
    visibles: Visibles,
    model_index: u32,
}

impl ModelHandle {
    fn new(visibles: Visibles, model_index: u32) -> Self {
        Self {
            visibles,
            model_index,
        }
    }

    pub fn set_visible(&self, inst_index: u32, visible: bool) {
        let model_indexes = &mut self.visibles.borrow_mut()[inst_index as usize]; // Use HashSet to make visibility changes fast.

        if visible {
            model_indexes.insert(self.model_index);
        } else {
            model_indexes.remove(&self.model_index);
        }
    }

    pub fn get_visible(&self, inst_index: u32) -> bool {
        let model_indexes = &self.visibles.borrow()[inst_index as usize];

        model_indexes.contains(&self.model_index)
    }
}

pub struct ModelRenderer {
    output_info: OutputInfoRc,
    stats: StatsRc,
    render_infos: Box<[RenderInfo]>,
}

struct RenderInfo {
    mesh: Mesh,
    models: Box<[Rc<dyn Model>]>, // [model_index]
    visibles: Visibles,
    inst_sh_infos: Box<[InstShaderInfo]>, // [inst_index]
}

struct InstShaderInfo {
    inst_size: u64,
    inst_buf: Buffer,
    pipeline: Rc<RenderPipeline>,
    bg_opt: Option<BindGroup>,
}

impl ModelRenderer {
    fn new(asset_mgr: AssetManagerRc, output_info: OutputInfoRc, stats: StatsRc, model_infos: ModelInfos, uni_bg_layout: &BindGroupLayout) -> Self {
        let device = output_info.get_device();

        let mut pipelines = HashMap::new();
        let mut pipeline_layouts = HashMap::new();
        let mut shaders = HashMap::new();

        let view_len = output_info.get_view_len();
        let uni = UNI_TMPL.replace("#VIEW_LEN#", &format!("{view_len}"));

        let render_infos = model_infos.into_values().map(|model_info| {
            let mesh = model_info.mesh;
            let vertex_sh_type = mesh.get_vertex_sh_type();

            let models = model_info.models.into_boxed_slice();
            let models_len = models.len() as u64;

            let inst_sh_infos = mesh.get_submeshes().iter().zip(model_info.inst_sh_impls.iter()).map(|(submesh, inst_sh_impl)| {
                let inst_sh_type = submesh.get_inst_sh_type();
                let inst_sh_layout = inst_sh_type.get_layout();                
                let inst_size = inst_sh_layout.array_stride;

                // Allocate buffer.

                let inst_buf = device.create_buffer(&BufferDescriptor {
                    label: None,
                    size: inst_size * models_len,
                    usage: BufferUsages::VERTEX | BufferUsages::COPY_DST,
                    mapped_at_creation: false,
                });

                // Create pipeline layout and bind groups.
                // TODO: Use fix number for inst_sh_impl->resources->count?, so we don't have to create different pipeline layouts because of different counts.
                // TODO: Once wgpu has bindless descriptors, then we can simplify this logic.

                let pipeline_layout_key = inst_sh_impl.get_key();

                let (bg_layout_opt, pipeline_layout) = pipeline_layouts.entry(pipeline_layout_key.clone()).or_insert_with(|| {
                    let mut bg_layouts = Vec::new();
                    bg_layouts.push(Some(uni_bg_layout));

                    let bg_layout_opt = inst_sh_impl.create_bind_group_layout(device);
                    if let Some(bg_layout) = &bg_layout_opt {
                        bg_layouts.push(Some(bg_layout));
                    }

                    let pipeline_layout = Rc::new(device.create_pipeline_layout(&PipelineLayoutDescriptor {
                        label: None,
                        bind_group_layouts: &bg_layouts, // See vertex/fragment shader->@group().
                        immediate_size: 0,
                    }));

                    (bg_layout_opt, pipeline_layout)
                });

                let bg_opt = if let Some(bg_layout) = bg_layout_opt {
                    let bg = inst_sh_impl.create_bind_group(device, bg_layout);
                    Some(bg)
                } else {
                    None
                };

                // Lookup pipeline, make sure we don't have duplicated pipelines/shaders.

                let primitive_st_type = submesh.get_primitive_state_type();

                let pipeline = pipelines.entry((pipeline_layout_key, vertex_sh_type.clone(), inst_sh_type.clone(), primitive_st_type.clone())).or_insert_with(|| {
                    let shader = shaders.entry((vertex_sh_type.clone(), inst_sh_type.clone())).or_insert_with(|| {
                        let name = format!("/shader/{}_{}.wgsl", vertex_sh_type.get_name(), inst_sh_type.get_name());

                        // Pre-process shaders, so we don't need to duplicate them for single
                        // or multiview (stereo) rendering.
                        // TODO: Embed compiled shaders into binary (window, xr)?

                        let asset_file = asset_mgr.open_or_err(&name);
                        let source = asset_file.read_str_or_err()
                            .replace("#UNI#", &uni)
                            .replace("#VIEW_INDEX_DEF#", output_info.get_view_index_def())
                            .replace("#VIEW_INDEX_VAL#", output_info.get_view_index_val());

                        Rc::new(device.create_shader_module(ShaderModuleDescriptor {
                            label: None,
                            source: ShaderSource::Wgsl(Cow::Owned(source)),
                        }))
                    });

                    Rc::new(device.create_render_pipeline(&RenderPipelineDescriptor {
                        label: None,
                        layout: Some(pipeline_layout),
                        vertex: VertexState {
                            module: shader,
                            entry_point: Some("vs_main"),
                            compilation_options: Default::default(),
                            buffers: &[
                                vertex_sh_type.get_layout(),
                                inst_sh_layout,
                            ],
                        },
                        fragment: Some(FragmentState {
                            module: shader,
                            entry_point: Some("fs_main"),
                            compilation_options: Default::default(),
                            targets: &[Some(ColorTargetState { // See fragment shader->@location().
                                format: output_info.get_color_format(),
                                blend: Some(BlendState::REPLACE),
                                write_mask: ColorWrites::ALL,
                            })],
                        }),
                        primitive: primitive_st_type.get_primitive(),
                        depth_stencil: Some(DepthStencilState {
                            format: output_info.get_depth_format(),
                            depth_write_enabled: Some(true),
                            depth_compare: Some(CompareFunction::Less),
                            stencil: Default::default(),
                            bias: Default::default(),
                        }),
                        multisample: MultisampleState {
                            count: output_info.get_sample_count(),
                            mask: !0,
                            alpha_to_coverage_enabled: false,
                        },
                        multiview_mask: output_info.get_view_mask(),
                        cache: None
                    }))
                });

                InstShaderInfo {
                    inst_size,
                    inst_buf,
                    pipeline: Rc::clone(pipeline),
                    bg_opt,
                }
            }).collect();

            RenderInfo {
                mesh,
                models,
                visibles: model_info.visibles,
                inst_sh_infos,
            }
        }).collect();

        Self {
            output_info,
            stats,
            render_infos,
        }
    }

    pub fn render(&self, render_pass: &mut RenderPass) {
        let queue = self.output_info.get_queue();
        
        let mut stat_draw_calls = 0;
        let mut stat_inst_num = 0;
        let mut stat_inst_buf = 0;

        for render_info in &self.render_infos {
            let mesh = &render_info.mesh;
            let mut mesh_bound = false;

            let visibles = render_info.visibles.borrow();

            for (inst_index, ((visible_model_indexes, inst_sh_info), submesh)) in visibles.iter().zip(render_info.inst_sh_infos.iter()).zip(mesh.get_submeshes().iter()).enumerate() {
                let visible_model_indexes_len: u32 = visible_model_indexes.len().try_into().unwrap();

                if visible_model_indexes_len > 0 {
                    let inst_index: u32 = inst_index.try_into().unwrap();
                    let inst_size = inst_sh_info.inst_size;
                    let inst_buf = &inst_sh_info.inst_buf;
                    let total_size = visible_model_indexes_len as u64 * inst_size;

                    // Don't map full buffer, only the part which is large enough to hold the visible models.

                    let mut inst_buf_view = queue.write_buffer_with(inst_buf, 0, BufferSize::new(total_size).unwrap()).unwrap();

                    macro_rules! fill_inst_buf {
                        ($(($inst_sh_type:ident, $method:ident)),*) => {
                            match submesh.get_inst_sh_type() {
                                $(
                                    InstShaderType::$inst_sh_type => {
                                        let (inst_buf_arr, _) = inst_buf_view.slice(..).into_chunks::<{InstShaderSize::$inst_sh_type}>();

                                        inst_buf_arr.write_iter(visible_model_indexes.iter().map(|model_index| {
                                            let model = &render_info.models[*model_index as usize];
                                            let inst_buf_single = model.$method(inst_index);
                                            bytemuck::cast(inst_buf_single)
                                        }));
                                    },
                                )*
                            }
                        }
                    }

                    fill_inst_buf!(
                        (SimpleColor, fill_simple_color),
                        (PhongColor, fill_phong_color),
                        (Grid, fill_grid),
                        (Window, fill_window)
                    );

                    if !mesh_bound {
                        render_pass.set_vertex_buffer(0, mesh.get_vertex_buf().slice(..)); // See VertexState->buffers[0].
                        render_pass.set_index_buffer(mesh.get_index_buf().slice(..), IndexFormat::Uint16);
                        mesh_bound = true;
                    }

                    render_pass.set_pipeline(&inst_sh_info.pipeline);
                    render_pass.set_vertex_buffer(1, inst_buf.slice(..total_size)); // See VertexState->buffers[1].

                    if let Some(bg) = &inst_sh_info.bg_opt {
                        render_pass.set_bind_group(1, bg, &[]); // See PipelineLayoutDescriptor->bind_group_layouts.
                    }

                    render_pass.draw_indexed(submesh.get_indices(), submesh.get_base_vertex(), 0..visible_model_indexes_len);

                    stat_draw_calls += 1;
                    stat_inst_num += visible_model_indexes_len;
                    stat_inst_buf += total_size as u32;
                }
            }
        }

        // Update stats.

        self.stats.set_draw_calls(stat_draw_calls);
        self.stats.set_inst_num(stat_inst_num);
        self.stats.set_inst_buf(stat_inst_buf);
    }
}