Documentation
use alloc::boxed::Box;
use alloc::sync::Arc;
use alloc::vec::Vec;
use core::hash::{Hash, Hasher};
use core::mem;

use super::Waiter;
use super::lock::Mutex;

/// Default number of shards for concurrent access
const DEFAULT_SHARDS: usize = 16;

/// A single entry in the HashMap
struct Entry<K, V> {
    key: K,
    value: V,
    next: Option<Box<Entry<K, V>>>,
}

/// A shard containing a portion of the HashMap
struct Shard<K, V> {
    buckets: Mutex<Vec<Option<Box<Entry<K, V>>>>>,
}

/// An async concurrent HashMap implementation with configurable sharding
///
/// The number of shards can be configured via the const generic parameter N.
/// More shards = better concurrency but more memory overhead.
/// Default is 16 shards which works well for most use cases.
///
/// # Examples
/// ```
/// // Use default 16 shards
/// let map: FastMap<String, i32> = FastMap::new();
///
/// // Use custom number of shards for high concurrency
/// let map: FastMap<String, i32, 64> = FastMap::new();
/// ```
pub struct FastMap<K, V, const N: usize = DEFAULT_SHARDS> {
    shards: [Arc<Shard<K, V>>; N],
}

impl<K, V, const N: usize> FastMap<K, V, N>
where
    K: Hash + Eq + Clone + Send + Sync + 'static,
    V: Send + Sync + 'static,
{
    /// Create a new FastMap with default capacity
    pub fn new() -> Self {
        Self::with_capacity(64)
    }

    /// Create a new FastMap with specified initial capacity
    pub fn with_capacity(capacity: usize) -> Self {
        let capacity_per_shard = (capacity + N - 1) / N;

        // Create array of shards using array::from_fn
        let shards = core::array::from_fn(|_| {
            let mut buckets = Vec::with_capacity(capacity_per_shard);
            buckets.resize_with(capacity_per_shard, || None);

            Arc::new(Shard {
                buckets: Mutex::new(buckets),
            })
        });

        Self { shards }
    }

    /// Get the shard index for a given key
    fn shard_index(&self, key: &K) -> usize {
        let mut hasher = DefaultHasher::new();
        key.hash(&mut hasher);
        hasher.finish() as usize % N
    }

    /// Get the bucket index within a shard
    fn bucket_index(&self, key: &K, _shard_idx: usize) -> usize {
        let mut hasher = DefaultHasher::new();
        key.hash(&mut hasher);
        let hash = hasher.finish();

        // Use a different part of the hash for bucket selection
        let bucket_hash = hash.wrapping_mul(0x9e3779b9);

        // This will be calculated when we have the lock
        bucket_hash as usize
    }

    /// Insert a key-value pair
    pub async fn insert(&self, key: K, value: V) -> Option<V> {
        let shard_idx = self.shard_index(&key);
        let shard = &self.shards[shard_idx];

        let mut buckets = shard.buckets.lock(Waiter::default()).await;
        let bucket_idx = {
            let mut hasher = DefaultHasher::new();
            key.hash(&mut hasher);
            let hash = hasher.finish();
            let bucket_hash = hash.wrapping_mul(0x9e3779b9);
            bucket_hash as usize % buckets.len()
        };

        // Search for existing entry in the chain
        {
            let mut current = &mut buckets[bucket_idx];

            // Check if key already exists
            while let Some(entry) = current {
                if entry.key == key {
                    // Replace existing value
                    let old_value = mem::replace(&mut entry.value, value);
                    return Some(old_value);
                }
                current = &mut entry.next;
            }
        }

        // Key doesn't exist, insert new entry at the head of the chain
        let old_head = buckets[bucket_idx].take();
        let new_entry = Box::new(Entry {
            key,
            value,
            next: old_head,
        });
        buckets[bucket_idx] = Some(new_entry);

        None
    }

    /// Get a value by key
    pub async fn get(&self, key: &K) -> Option<V>
    where
        V: Clone,
    {
        let shard_idx = self.shard_index(key);
        let shard = &self.shards[shard_idx];

        let buckets = shard.buckets.lock(Waiter::default()).await;
        let bucket_idx = {
            let mut hasher = DefaultHasher::new();
            key.hash(&mut hasher);
            let hash = hasher.finish();
            let bucket_hash = hash.wrapping_mul(0x9e3779b9);
            bucket_hash as usize % buckets.len()
        };

        let mut current = &buckets[bucket_idx];
        while let Some(entry) = current {
            if entry.key == *key {
                return Some(entry.value.clone());
            }
            current = &entry.next;
        }

        None
    }

    /// Get a reference to a value by key
    pub async fn get_ref<F, R>(&self, key: &K, f: F) -> Option<R>
    where
        F: FnOnce(&V) -> R,
    {
        let shard_idx = self.shard_index(key);
        let shard = &self.shards[shard_idx];

        let buckets = shard.buckets.lock(Waiter::default()).await;
        let bucket_idx = {
            let mut hasher = DefaultHasher::new();
            key.hash(&mut hasher);
            let hash = hasher.finish();
            let bucket_hash = hash.wrapping_mul(0x9e3779b9);
            bucket_hash as usize % buckets.len()
        };

        let mut current = &buckets[bucket_idx];
        while let Some(entry) = current {
            if entry.key == *key {
                return Some(f(&entry.value));
            }
            current = &entry.next;
        }

        None
    }

    /// Remove a key-value pair
    pub async fn remove(&self, key: &K) -> Option<V> {
        let shard_idx = self.shard_index(key);
        let shard = &self.shards[shard_idx];

        let mut buckets = shard.buckets.lock(Waiter::default()).await;
        let bucket_idx = {
            let mut hasher = DefaultHasher::new();
            key.hash(&mut hasher);
            let hash = hasher.finish();
            let bucket_hash = hash.wrapping_mul(0x9e3779b9);
            bucket_hash as usize % buckets.len()
        };

        let bucket = &mut buckets[bucket_idx];

        // Special case: first entry matches
        if let Some(entry) = bucket {
            if entry.key == *key {
                let mut removed = bucket.take().unwrap();
                *bucket = removed.next.take();
                return Some(removed.value);
            }
        }

        // Search in the chain
        let mut current = bucket;
        while let Some(entry) = current {
            if let Some(next) = &mut entry.next {
                if next.key == *key {
                    let mut removed = entry.next.take().unwrap();
                    entry.next = removed.next.take();
                    return Some(removed.value);
                }
            }
            current = &mut entry.next;
        }

        None
    }

    /// Check if a key exists
    pub async fn contains_key(&self, key: &K) -> bool {
        let shard_idx = self.shard_index(key);
        let shard = &self.shards[shard_idx];

        let buckets = shard.buckets.lock(Waiter::default()).await;
        let bucket_idx = {
            let mut hasher = DefaultHasher::new();
            key.hash(&mut hasher);
            let hash = hasher.finish();
            let bucket_hash = hash.wrapping_mul(0x9e3779b9);
            bucket_hash as usize % buckets.len()
        };

        let mut current = &buckets[bucket_idx];
        while let Some(entry) = current {
            if entry.key == *key {
                return true;
            }
            current = &entry.next;
        }

        false
    }

    /// Get the number of entries (expensive operation)
    pub async fn len(&self) -> usize {
        let mut total = 0;
        for shard in &self.shards {
            let buckets = shard.buckets.lock(Waiter::default()).await;
            for bucket in buckets.iter() {
                let mut current = bucket;
                while let Some(entry) = current {
                    total += 1;
                    current = &entry.next;
                }
            }
        }
        total
    }

    /// Check if the map is empty
    pub async fn is_empty(&self) -> bool {
        self.len().await == 0
    }

    /// Clear all entries
    pub async fn clear(&self) {
        for shard in &self.shards {
            let mut buckets = shard.buckets.lock(Waiter::default()).await;
            for bucket in buckets.iter_mut() {
                *bucket = None;
            }
        }
    }
}

// Default hasher implementation
struct DefaultHasher {
    state: u64,
}

impl DefaultHasher {
    fn new() -> Self {
        Self {
            state: 0xcafe_babe_dead_beef,
        }
    }

    fn finish(&self) -> u64 {
        self.state
    }
}

impl Hasher for DefaultHasher {
    fn write(&mut self, bytes: &[u8]) {
        for &byte in bytes {
            self.state = self
                .state
                .wrapping_mul(0x100000001b3)
                .wrapping_add(byte as u64);
        }
    }

    fn finish(&self) -> u64 {
        self.state
    }
}

impl<K, V, const N: usize> Default for FastMap<K, V, N>
where
    K: Hash + Eq + Clone + Send + Sync + 'static,
    V: Send + Sync + 'static,
{
    fn default() -> Self {
        Self::new()
    }
}

// Implement Send and Sync for FastMap
unsafe impl<K: Send, V: Send, const N: usize> Send for FastMap<K, V, N> {}
unsafe impl<K: Send + Sync, V: Send + Sync, const N: usize> Sync for FastMap<K, V, N> {}