use crate::core::*;
use crate::renderer::*;
use std::sync::Arc;
pub enum Lod {
    High,
    Medium,
    Low,
}
const VERTICES_PER_SIDE: usize = 33;
pub struct Terrain<M: Material> {
    context: Context,
    center: (i32, i32),
    patches: Vec<Gm<TerrainPatch, M>>,
    index_buffer1: Arc<ElementBuffer>,
    index_buffer4: Arc<ElementBuffer>,
    index_buffer16: Arc<ElementBuffer>,
    material: M,
    lod: Arc<dyn Fn(f32) -> Lod + Send + Sync>,
    height_map: Arc<dyn Fn(f32, f32) -> f32 + Send + Sync>,
    side_length: f32,
    vertex_distance: f32,
}
impl<M: Material + Clone> Terrain<M> {
    pub fn new(
        context: &Context,
        material: M,
        height_map: Arc<dyn Fn(f32, f32) -> f32 + Send + Sync>,
        side_length: f32,
        vertex_distance: f32,
        center: Vec2,
    ) -> Self {
        let index_buffer1 = Self::indices(context, 1);
        let mut patches = Vec::new();
        let (x0, y0) = pos2patch(vertex_distance, center);
        let half_patches_per_side = half_patches_per_side(vertex_distance, side_length);
        for ix in x0 - half_patches_per_side..x0 + half_patches_per_side + 1 {
            for iy in y0 - half_patches_per_side..y0 + half_patches_per_side + 1 {
                let patch = TerrainPatch::new(
                    context,
                    &*height_map.clone(),
                    (ix, iy),
                    index_buffer1.clone(),
                    vertex_distance,
                );
                patches.push(Gm::new(patch, material.clone()));
            }
        }
        Self {
            context: context.clone(),
            center: (x0, y0),
            patches,
            index_buffer1,
            index_buffer4: Self::indices(context, 4),
            index_buffer16: Self::indices(context, 16),
            lod: Arc::new(|_| Lod::High),
            material,
            height_map,
            side_length,
            vertex_distance,
        }
    }
    pub fn height_at(&self, position: Vec2) -> f32 {
        (*self.height_map)(position.x, position.y)
    }
    pub fn set_lod(&mut self, lod: Arc<dyn Fn(f32) -> Lod + Send + Sync>) {
        self.lod = lod;
    }
    pub fn set_center(&mut self, center: Vec2) {
        let (x0, y0) = pos2patch(self.vertex_distance, center);
        let half_patches_per_side = half_patches_per_side(self.vertex_distance, self.side_length);
        while x0 > self.center.0 {
            self.center.0 += 1;
            for iy in
                self.center.1 - half_patches_per_side..self.center.1 + half_patches_per_side + 1
            {
                self.patches.push(Gm::new(
                    TerrainPatch::new(
                        &self.context,
                        &*self.height_map.clone(),
                        (self.center.0 + half_patches_per_side, iy),
                        self.index_buffer1.clone(),
                        self.vertex_distance,
                    ),
                    self.material.clone(),
                ));
            }
        }
        while x0 < self.center.0 {
            self.center.0 -= 1;
            for iy in
                self.center.1 - half_patches_per_side..self.center.1 + half_patches_per_side + 1
            {
                self.patches.push(Gm::new(
                    TerrainPatch::new(
                        &self.context,
                        &*self.height_map.clone(),
                        (self.center.0 - half_patches_per_side, iy),
                        self.index_buffer1.clone(),
                        self.vertex_distance,
                    ),
                    self.material.clone(),
                ));
            }
        }
        while y0 > self.center.1 {
            self.center.1 += 1;
            for ix in
                self.center.0 - half_patches_per_side..self.center.0 + half_patches_per_side + 1
            {
                self.patches.push(Gm::new(
                    TerrainPatch::new(
                        &self.context,
                        &*self.height_map.clone(),
                        (ix, self.center.1 + half_patches_per_side),
                        self.index_buffer1.clone(),
                        self.vertex_distance,
                    ),
                    self.material.clone(),
                ));
            }
        }
        while y0 < self.center.1 {
            self.center.1 -= 1;
            for ix in
                self.center.0 - half_patches_per_side..self.center.0 + half_patches_per_side + 1
            {
                self.patches.push(Gm::new(
                    TerrainPatch::new(
                        &self.context,
                        &*self.height_map.clone(),
                        (ix, self.center.1 - half_patches_per_side),
                        self.index_buffer1.clone(),
                        self.vertex_distance,
                    ),
                    self.material.clone(),
                ));
            }
        }
        self.patches.retain(|p| {
            let (ix, iy) = p.index();
            (x0 - ix).abs() <= half_patches_per_side && (y0 - iy).abs() <= half_patches_per_side
        });
        self.patches.iter_mut().for_each(|p| {
            let distance = p.center().distance(center);
            p.index_buffer = match (*self.lod)(distance) {
                Lod::Low => self.index_buffer16.clone(),
                Lod::Medium => self.index_buffer4.clone(),
                Lod::High => self.index_buffer1.clone(),
            };
        })
    }
    fn indices(context: &Context, resolution: u32) -> Arc<ElementBuffer> {
        let mut indices: Vec<u32> = Vec::new();
        let stride = VERTICES_PER_SIDE as u32;
        let max = (stride - 1) / resolution;
        for r in 0..max {
            for c in 0..max {
                indices.push(r * resolution + c * resolution * stride);
                indices.push(r * resolution + resolution + c * resolution * stride);
                indices.push(r * resolution + (c * resolution + resolution) * stride);
                indices.push(r * resolution + (c * resolution + resolution) * stride);
                indices.push(r * resolution + resolution + c * resolution * stride);
                indices.push(r * resolution + resolution + (c * resolution + resolution) * stride);
            }
        }
        Arc::new(ElementBuffer::new_with_data(context, &indices))
    }
}
impl<'a, M: Material> IntoIterator for &'a Terrain<M> {
    type Item = &'a dyn Object;
    type IntoIter = std::vec::IntoIter<&'a dyn Object>;
    fn into_iter(self) -> Self::IntoIter {
        self.patches
            .iter()
            .map(|m| m as &dyn Object)
            .collect::<Vec<_>>()
            .into_iter()
    }
}
fn patch_size(vertex_distance: f32) -> f32 {
    vertex_distance * (VERTICES_PER_SIDE - 1) as f32
}
fn half_patches_per_side(vertex_distance: f32, side_length: f32) -> i32 {
    let patch_size = patch_size(vertex_distance);
    let patches_per_side = (side_length / patch_size).ceil() as u32;
    (patches_per_side as i32 - 1) / 2
}
fn pos2patch(vertex_distance: f32, position: Vec2) -> (i32, i32) {
    let patch_size = vertex_distance * (VERTICES_PER_SIDE - 1) as f32;
    (
        (position.x / patch_size).floor() as i32,
        (position.y / patch_size).floor() as i32,
    )
}
struct TerrainPatch {
    context: Context,
    index: (i32, i32),
    positions_buffer: VertexBuffer,
    normals_buffer: VertexBuffer,
    center: Vec2,
    aabb: AxisAlignedBoundingBox,
    pub index_buffer: Arc<ElementBuffer>,
}
impl TerrainPatch {
    pub fn new(
        context: &Context,
        height_map: impl Fn(f32, f32) -> f32 + Clone,
        index: (i32, i32),
        index_buffer: Arc<ElementBuffer>,
        vertex_distance: f32,
    ) -> Self {
        let patch_size = patch_size(vertex_distance);
        let offset = vec2(index.0 as f32 * patch_size, index.1 as f32 * patch_size);
        let positions = Self::positions(height_map.clone(), offset, vertex_distance);
        let aabb = AxisAlignedBoundingBox::new_with_positions(&positions);
        let normals = Self::normals(height_map, offset, &positions, vertex_distance);
        let positions_buffer = VertexBuffer::new_with_data(context, &positions);
        let normals_buffer = VertexBuffer::new_with_data(context, &normals);
        Self {
            context: context.clone(),
            index,
            index_buffer,
            positions_buffer,
            normals_buffer,
            aabb,
            center: offset + vec2(0.5 * patch_size, 0.5 * patch_size),
        }
    }
    pub fn center(&self) -> Vec2 {
        self.center
    }
    pub fn index(&self) -> (i32, i32) {
        self.index
    }
    fn positions(
        height_map: impl Fn(f32, f32) -> f32,
        offset: Vec2,
        vertex_distance: f32,
    ) -> Vec<Vec3> {
        let mut data = vec![vec3(0.0, 0.0, 0.0); VERTICES_PER_SIDE * VERTICES_PER_SIDE];
        for r in 0..VERTICES_PER_SIDE {
            for c in 0..VERTICES_PER_SIDE {
                let vertex_id = r * VERTICES_PER_SIDE + c;
                let x = offset.x + r as f32 * vertex_distance;
                let z = offset.y + c as f32 * vertex_distance;
                data[vertex_id] = vec3(x, height_map(x, z), z);
            }
        }
        data
    }
    fn normals(
        height_map: impl Fn(f32, f32) -> f32,
        offset: Vec2,
        positions: &[Vec3],
        vertex_distance: f32,
    ) -> Vec<Vec3> {
        let mut data = vec![vec3(0.0, 0.0, 0.0); VERTICES_PER_SIDE * VERTICES_PER_SIDE];
        let h = vertex_distance;
        for r in 0..VERTICES_PER_SIDE {
            for c in 0..VERTICES_PER_SIDE {
                let vertex_id = r * VERTICES_PER_SIDE + c;
                let x = offset.x + r as f32 * vertex_distance;
                let z = offset.y + c as f32 * vertex_distance;
                let xp = if r == VERTICES_PER_SIDE - 1 {
                    height_map(x + h, z)
                } else {
                    positions[vertex_id + VERTICES_PER_SIDE][1]
                };
                let xm = if r == 0 {
                    height_map(x - h, z)
                } else {
                    positions[vertex_id - VERTICES_PER_SIDE][1]
                };
                let zp = if c == VERTICES_PER_SIDE - 1 {
                    height_map(x, z + h)
                } else {
                    positions[vertex_id + 1][1]
                };
                let zm = if c == 0 {
                    height_map(x, z - h)
                } else {
                    positions[vertex_id - 1][1]
                };
                let dx = xp - xm;
                let dz = zp - zm;
                data[vertex_id] = vec3(-dx, 2.0 * h, -dz).normalize();
            }
        }
        data
    }
}
impl Geometry for TerrainPatch {
    fn vertex_shader_source(&self, required_attributes: FragmentAttributes) -> String {
        if required_attributes.normal || required_attributes.tangents {
            format!(
                "#define USE_NORMALS\n{}",
                include_str!("shaders/terrain.vert")
            )
        } else {
            include_str!("shaders/terrain.vert").to_owned()
        }
    }
    fn draw(
        &self,
        camera: &Camera,
        program: &Program,
        render_states: RenderStates,
        attributes: FragmentAttributes,
    ) {
        program.use_uniform("viewProjectionMatrix", camera.projection() * camera.view());
        program.use_vertex_attribute("position", &self.positions_buffer);
        if attributes.normal || attributes.tangents {
            program.use_vertex_attribute("normal", &self.normals_buffer);
        }
        program.draw_elements(render_states, camera.viewport(), &self.index_buffer);
    }
    fn id(&self, required_attributes: FragmentAttributes) -> u16 {
        if required_attributes.normal || required_attributes.tangents {
            0b1u16 << 15 | 0b10u16
        } else {
            0b1u16 << 15 | 0b11u16
        }
    }
    fn render_with_material(
        &self,
        material: &dyn Material,
        camera: &Camera,
        lights: &[&dyn Light],
    ) {
        render_with_material(&self.context, camera, &self, material, lights);
    }
    fn render_with_effect(
        &self,
        material: &dyn Effect,
        camera: &Camera,
        lights: &[&dyn Light],
        color_texture: Option<ColorTexture>,
        depth_texture: Option<DepthTexture>,
    ) {
        render_with_effect(
            &self.context,
            camera,
            self,
            material,
            lights,
            color_texture,
            depth_texture,
        )
    }
    fn aabb(&self) -> AxisAlignedBoundingBox {
        self.aabb
    }
}