use std::{num::NonZeroUsize, sync::Arc};
use async_graphql_parser::types::ExecutableDocument;
use futures_util::lock::Mutex;
use serde::Deserialize;
use sha2::{Digest, Sha256};
use crate::{
    extensions::{Extension, ExtensionContext, ExtensionFactory, NextPrepareRequest},
    from_value, Request, ServerError, ServerResult,
};
#[derive(Deserialize)]
struct PersistedQuery {
    version: i32,
    #[serde(rename = "sha256Hash")]
    sha256_hash: String,
}
#[async_trait::async_trait]
pub trait CacheStorage: Send + Sync + Clone + 'static {
    async fn get(&self, key: String) -> Option<ExecutableDocument>;
    async fn set(&self, key: String, query: ExecutableDocument);
}
#[derive(Clone)]
pub struct LruCacheStorage(Arc<Mutex<lru::LruCache<String, ExecutableDocument>>>);
impl LruCacheStorage {
    pub fn new(cap: usize) -> Self {
        Self(Arc::new(Mutex::new(lru::LruCache::new(
            NonZeroUsize::new(cap).unwrap(),
        ))))
    }
}
#[async_trait::async_trait]
impl CacheStorage for LruCacheStorage {
    async fn get(&self, key: String) -> Option<ExecutableDocument> {
        let mut cache = self.0.lock().await;
        cache.get(&key).cloned()
    }
    async fn set(&self, key: String, query: ExecutableDocument) {
        let mut cache = self.0.lock().await;
        cache.put(key, query);
    }
}
#[cfg_attr(docsrs, doc(cfg(feature = "apollo_persisted_queries")))]
pub struct ApolloPersistedQueries<T>(T);
impl<T: CacheStorage> ApolloPersistedQueries<T> {
    pub fn new(cache_storage: T) -> ApolloPersistedQueries<T> {
        Self(cache_storage)
    }
}
impl<T: CacheStorage> ExtensionFactory for ApolloPersistedQueries<T> {
    fn create(&self) -> Arc<dyn Extension> {
        Arc::new(ApolloPersistedQueriesExtension {
            storage: self.0.clone(),
        })
    }
}
struct ApolloPersistedQueriesExtension<T> {
    storage: T,
}
#[async_trait::async_trait]
impl<T: CacheStorage> Extension for ApolloPersistedQueriesExtension<T> {
    async fn prepare_request(
        &self,
        ctx: &ExtensionContext<'_>,
        mut request: Request,
        next: NextPrepareRequest<'_>,
    ) -> ServerResult<Request> {
        let res = if let Some(value) = request.extensions.remove("persistedQuery") {
            let persisted_query: PersistedQuery = from_value(value).map_err(|_| {
                ServerError::new("Invalid \"PersistedQuery\" extension configuration.", None)
            })?;
            if persisted_query.version != 1 {
                return Err(ServerError::new(
                    format!("Only the \"PersistedQuery\" extension of version \"1\" is supported, and the current version is \"{}\".", persisted_query.version), None
                ));
            }
            if request.query.is_empty() {
                if let Some(doc) = self.storage.get(persisted_query.sha256_hash).await {
                    Ok(Request {
                        parsed_query: Some(doc),
                        ..request
                    })
                } else {
                    Err(ServerError::new("PersistedQueryNotFound", None))
                }
            } else {
                let sha256_hash = format!("{:x}", Sha256::digest(request.query.as_bytes()));
                if persisted_query.sha256_hash != sha256_hash {
                    Err(ServerError::new("provided sha does not match query", None))
                } else {
                    let doc = async_graphql_parser::parse_query(&request.query)?;
                    self.storage.set(sha256_hash, doc.clone()).await;
                    Ok(Request {
                        query: String::new(),
                        parsed_query: Some(doc),
                        ..request
                    })
                }
            }
        } else {
            Ok(request)
        };
        next.run(ctx, res?).await
    }
}
#[cfg(test)]
mod tests {
    #[tokio::test]
    async fn test() {
        use super::*;
        use crate::*;
        struct Query;
        #[Object(internal)]
        impl Query {
            async fn value(&self) -> i32 {
                100
            }
        }
        let schema = Schema::build(Query, EmptyMutation, EmptySubscription)
            .extension(ApolloPersistedQueries::new(LruCacheStorage::new(256)))
            .finish();
        let mut request = Request::new("{ value }");
        request.extensions.insert(
            "persistedQuery".to_string(),
            value!({
                "version": 1,
                "sha256Hash": "854174ebed716fe24fd6659c30290aecd9bc1d17dc4f47939a1848a1b8ed3c6b",
            }),
        );
        assert_eq!(
            schema.execute(request).await.into_result().unwrap().data,
            value!({
                "value": 100
            })
        );
        let mut request = Request::new("");
        request.extensions.insert(
            "persistedQuery".to_string(),
            value!({
                "version": 1,
                "sha256Hash": "854174ebed716fe24fd6659c30290aecd9bc1d17dc4f47939a1848a1b8ed3c6b",
            }),
        );
        assert_eq!(
            schema.execute(request).await.into_result().unwrap().data,
            value!({
                "value": 100
            })
        );
        let mut request = Request::new("");
        request.extensions.insert(
            "persistedQuery".to_string(),
            value!({
                "version": 1,
                "sha256Hash": "def",
            }),
        );
        assert_eq!(
            schema.execute(request).await.into_result().unwrap_err(),
            vec![ServerError::new("PersistedQueryNotFound", None)]
        );
    }
}