kosame_dsl 0.3.0

Macro-based Rust ORM focused on developer ergonomics
Documentation
use std::{cell::Cell, collections::HashSet};

use proc_macro2::TokenStream;
use quote::{ToTokens, format_ident, quote};
use syn::Ident;

use crate::{
    clause::{FromChain, FromItem, SelectCore, With},
    command::{Command, SelectChain},
    correlations::CorrelationId,
    inferred_type::InferredType,
    part::TargetTable,
    query::{self, Query},
    visit::Visit,
};

thread_local! {
    static SCOPE_ID_AUTO_INCREMENT: Cell<u32> = const { Cell::new(0) };
    static SCOPE_ID_CONTEXT: Cell<Option<ScopeId>> = const { Cell::new(None) };
}

#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy)]
pub struct ScopeId(u32);

impl ScopeId {
    #[must_use]
    pub fn new() -> Self {
        let id = SCOPE_ID_AUTO_INCREMENT.get();
        SCOPE_ID_AUTO_INCREMENT.set(id + 1);
        Self(id)
    }

    pub fn scope(&self, f: impl FnOnce()) {
        let previous = SCOPE_ID_CONTEXT.with(|cell| cell.replace(Some(*self)));
        f();
        SCOPE_ID_CONTEXT.with(|cell| cell.replace(previous));
    }

    #[must_use]
    pub fn of_scope() -> ScopeId {
        SCOPE_ID_CONTEXT
            .get()
            .expect("`ScopeId::of_scope` was called outside of a ScopeId scope")
    }

    pub fn reset() {
        SCOPE_ID_AUTO_INCREMENT.set(0);
    }
}

impl Default for ScopeId {
    fn default() -> Self {
        Self::new()
    }
}

impl ToTokens for ScopeId {
    fn to_tokens(&self, tokens: &mut TokenStream) {
        format_ident!("scope_{}", self.0).to_tokens(tokens);
    }
}

pub trait Scoped {
    #[must_use]
    fn scope_id(&self) -> ScopeId;

    #[must_use]
    fn with(&self) -> Option<&With> {
        None
    }

    #[must_use]
    fn target_table(&self) -> Option<&TargetTable> {
        None
    }

    #[must_use]
    fn select_chain(&self) -> Option<&SelectChain> {
        None
    }

    #[must_use]
    #[allow(clippy::wrong_self_convention)]
    fn from_chain(&self) -> Option<&FromChain>;
}

pub struct Scopes<'a> {
    scopes: Vec<Scope<'a>>,
}

impl Scopes<'_> {
    #[must_use]
    pub fn infer_type<'a>(
        &self,
        scope_id: ScopeId,
        table: Option<&Ident>,
        column: &'a Ident,
    ) -> Option<InferredType<'a>> {
        let scope = self
            .scopes
            .iter()
            .find(|scope| scope.id == scope_id)
            .expect("scope ID must be valid");
        scope.infer_type(table, column)
    }
}

impl ToTokens for Scopes<'_> {
    fn to_tokens(&self, tokens: &mut TokenStream) {
        let scopes = &self.scopes;
        quote! {
            mod scopes {
                #(#scopes)*
            }
        }
        .to_tokens(tokens);
    }
}

struct Scope<'a> {
    id: ScopeId,
    items: Vec<ScopeItem<'a>>,
}

impl<'a> Scope<'a> {
    fn new(id: ScopeId, items: Vec<ScopeItem<'a>>) -> Self {
        Self { id, items }
    }

    pub fn infer_type<'b>(
        &self,
        table: Option<&Ident>,
        column: &'b Ident,
    ) -> Option<InferredType<'b>> {
        let table = table?;
        let item = self.items.iter().find(|item| item.name() == Some(table))?;
        Some(InferredType::Correlation {
            correlation_id: item.correlation_id(),
            column,
            nullable: item.nullable(),
        })
    }
}

impl ToTokens for Scope<'_> {
    fn to_tokens(&self, tokens: &mut TokenStream) {
        let name = &self.id;

        let correlations = self.items.iter().filter_map(|item| {
            item.name().map(|name| {
                let correlation_id = item.correlation_id();
                quote! {
                    pub use super::super::super::correlations::#correlation_id as #name;
                }
            })
        });

        let columns = self
            .items
            .iter()
            .filter(|item| !item.is_inherited())
            .filter_map(|item| item.name());

        quote! {
            pub mod #name {
                pub mod tables {
                    #(#correlations)*
                }
                pub mod columns {
                    #(pub use super::tables::#columns::columns::*;)*
                }
            }
        }
        .to_tokens(tokens);
    }
}

pub enum ScopeItem<'a> {
    TargetTable {
        target_table: &'a TargetTable,
    },
    FromItem {
        from_item: &'a FromItem,
        inherited_from: Option<ScopeId>,
        nullable: bool,
    },
    QueryNode {
        node: &'a query::Node,
        name: &'a Ident,
    },
}

impl ScopeItem<'_> {
    #[must_use]
    pub fn correlation_id(&self) -> CorrelationId {
        match self {
            Self::TargetTable { target_table, .. } => target_table.table.correlation_id,
            Self::FromItem { from_item, .. } => from_item.correlation_id(),
            Self::QueryNode { node, .. } => node.correlation_id,
        }
    }

    #[must_use]
    pub fn name(&self) -> Option<&Ident> {
        match self {
            Self::TargetTable { target_table, .. } => Some(target_table.name()),
            Self::FromItem { from_item, .. } => from_item.name(),
            Self::QueryNode { name, .. } => Some(name),
        }
    }

    #[must_use]
    pub fn nullable(&self) -> bool {
        match self {
            Self::FromItem { nullable, .. } => *nullable,
            Self::TargetTable { .. } | Self::QueryNode { .. } => false,
        }
    }

    #[must_use]
    pub fn is_inherited(&self) -> bool {
        match self {
            Self::FromItem { inherited_from, .. } => inherited_from.is_some(),
            Self::TargetTable { .. } | Self::QueryNode { .. } => false,
        }
    }
}

impl<'a> From<&'a Command> for Scopes<'a> {
    fn from(value: &'a Command) -> Self {
        #[derive(Default)]
        struct Visitor<'a> {
            scopes: Vec<Scope<'a>>,
            inherited_from_items: Vec<(ScopeId, &'a FromItem)>,
        }

        impl<'a> Visit<'a> for Visitor<'a> {
            fn visit_command(&mut self, command: &'a Command) {
                self.visit_scoped(command);
            }

            fn visit_select_core(&mut self, select_core: &'a SelectCore) {
                self.visit_scoped(select_core);
            }
        }

        impl<'a> Visitor<'a> {
            fn visit_scoped(&mut self, scoped: &'a dyn Scoped) {
                let scope_id = scoped.scope_id();
                let from_items_truncate = self.inherited_from_items.len();

                let mut items = Vec::new();
                let mut shadow = HashSet::new();

                if let Some(with) = scoped.with() {
                    self.visit_with(with);
                }

                if let Some(select_chain) = scoped.select_chain() {
                    self.visit_select_chain(select_chain);
                }

                if let Some(target_table) = scoped.target_table() {
                    shadow.insert(target_table.name());
                    items.push(ScopeItem::TargetTable { target_table });
                }

                if let Some(from_chain) = scoped.from_chain() {
                    let nullables = from_chain.nullables();

                    for (from_item, nullable) in from_chain.into_iter().zip(nullables.into_iter()) {
                        self.inherited_from_items.push((scope_id, from_item));

                        if let Some(name) = from_item.name() {
                            shadow.insert(name);
                        }

                        // Avoid processing the subquery twice.
                        if scoped.select_chain().is_none()
                            && let FromItem::Subquery { command, .. } = from_item
                        {
                            self.visit_command(command);
                        }

                        items.push(ScopeItem::FromItem {
                            from_item,
                            inherited_from: None,
                            nullable,
                        });
                    }
                }

                self.inherited_from_items.truncate(from_items_truncate);

                for (inherited_from, from_item) in &self.inherited_from_items {
                    if let Some(name) = from_item.name()
                        && !shadow.contains(name)
                    {
                        items.push(ScopeItem::FromItem {
                            from_item,
                            inherited_from: Some(*inherited_from),
                            nullable: false,
                        });
                    }
                }

                self.scopes.push(Scope::new(scope_id, items));
            }
        }

        let mut visitor = Visitor::default();
        visitor.visit_command(value);
        Scopes {
            scopes: visitor.scopes,
        }
    }
}

impl<'a> From<&'a Query> for Scopes<'a> {
    fn from(value: &'a Query) -> Self {
        struct Visitor<'a> {
            scopes: Vec<Scope<'a>>,
            name: &'a Ident,
        }

        impl<'a> Visit<'a> for Visitor<'a> {
            fn visit_node(&mut self, node: &'a query::Node) {
                let scope_id = node.scope_id;
                let items = vec![ScopeItem::QueryNode {
                    node,
                    name: self.name,
                }];

                for field in &node.fields {
                    if let query::Field::Relation { node, name, .. } = field {
                        self.name = name;
                        self.visit_node(node);
                    }
                }

                self.scopes.push(Scope::new(scope_id, items));
            }
        }

        let mut visitor = Visitor {
            scopes: Vec::new(),
            name: &value.table.as_path().segments.last().unwrap().ident,
        };
        visitor.visit_node(&value.body);
        Scopes {
            scopes: visitor.scopes,
        }
    }
}