use super::sources::{FileId, Source, SourceType};
use crate::hir::TypeSystem;
use ariadne::{Cache as AriadneCache, Source as AriadneSource};
use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::Arc;
#[derive(Debug, thiserror::Error)]
#[error("Unknown file ID")]
struct UnknownFileError;
#[derive(Clone, PartialEq, Eq)]
pub struct SourceCache {
    sources: HashMap<FileId, Arc<AriadneSource>>,
    paths: HashMap<FileId, PathBuf>,
}
impl AriadneCache<FileId> for &SourceCache {
    fn fetch(&mut self, id: &FileId) -> Result<&AriadneSource, Box<dyn std::fmt::Debug>> {
        let source = self.sources.get(id);
        source
            .map(|arc| &**arc)
            .ok_or_else(|| Box::new(UnknownFileError) as Box<dyn std::fmt::Debug>)
    }
    fn display<'a>(&self, id: &'a FileId) -> Option<Box<dyn std::fmt::Display + 'a>> {
        self.paths
            .get(id)
            .and_then(|path| path.to_str())
            .map(ToOwned::to_owned)
            .map(Box::new)
            .map(|bx| bx as Box<dyn std::fmt::Display + 'static>)
    }
}
impl std::fmt::Debug for SourceCache {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_map()
            .entries({
                let mut paths: Vec<_> = self.paths.iter().collect();
                paths.sort_by(|a, b| a.0.cmp(b.0));
                paths.into_iter().map(|(id, path)| (id.as_u64(), path))
            })
            .finish()
    }
}
#[salsa::query_group(InputStorage)]
pub trait InputDatabase {
    #[salsa::input]
    fn recursion_limit(&self) -> Option<usize>;
    #[salsa::input]
    fn token_limit(&self) -> Option<usize>;
    #[salsa::input]
    fn type_system_hir_input(&self) -> Option<Arc<TypeSystem>>;
    #[salsa::input]
    fn input(&self, file_id: FileId) -> Source;
    #[salsa::invoke(source_code)]
    fn source_code(&self, file_id: FileId) -> Arc<str>;
    #[salsa::invoke(source_type)]
    fn source_type(&self, file_id: FileId) -> SourceType;
    #[salsa::input]
    fn source_files(&self) -> Vec<FileId>;
    fn source_file(&self, path: PathBuf) -> Option<FileId>;
    #[salsa::invoke(source_with_lines)]
    fn source_with_lines(&self, file_id: FileId) -> Arc<AriadneSource>;
    #[salsa::invoke(source_cache)]
    fn source_cache(&self) -> Arc<SourceCache>;
    #[salsa::invoke(type_definition_files)]
    fn type_definition_files(&self) -> Vec<FileId>;
    #[salsa::invoke(executable_definition_files)]
    fn executable_definition_files(&self) -> Vec<FileId>;
}
fn source_code(db: &dyn InputDatabase, file_id: FileId) -> Arc<str> {
    if let Some(precomputed) = db.type_system_hir_input() {
        if let Some(source) = precomputed.inputs.get(&file_id) {
            return source.text();
        }
    }
    db.input(file_id).text()
}
fn source_file(db: &dyn InputDatabase, path: PathBuf) -> Option<FileId> {
    db.source_files()
        .iter()
        .find(|id| db.input(**id).filename() == path)
        .copied()
}
fn source_type(db: &dyn InputDatabase, file_id: FileId) -> SourceType {
    db.input(file_id).source_type()
}
fn source_with_lines(db: &dyn InputDatabase, file_id: FileId) -> Arc<AriadneSource> {
    let code = db.source_code(file_id);
    Arc::new(AriadneSource::from(code))
}
fn source_cache(db: &dyn InputDatabase) -> Arc<SourceCache> {
    let file_ids = db.source_files();
    let sources = file_ids
        .iter()
        .map(|&id| (id, db.source_with_lines(id)))
        .collect();
    let paths = file_ids
        .iter()
        .map(|&id| (id, db.input(id).filename().to_owned()))
        .collect();
    Arc::new(SourceCache { sources, paths })
}
fn type_definition_files(db: &dyn InputDatabase) -> Vec<FileId> {
    db.source_files()
        .into_iter()
        .filter(|source| {
            matches!(
                db.source_type(*source),
                SourceType::Schema | SourceType::Document | SourceType::BuiltIn
            )
        })
        .collect()
}
fn executable_definition_files(db: &dyn InputDatabase) -> Vec<FileId> {
    db.source_files()
        .into_iter()
        .filter(|source| {
            matches!(
                db.source_type(*source),
                SourceType::Executable | SourceType::Document
            )
        })
        .collect()
}