kosame_dsl 0.3.0

Macro-based Rust ORM focused on developer ergonomics
Documentation
use std::cell::Cell;

use proc_macro2::TokenStream;
use quote::{ToTokens, format_ident, quote};
use syn::Ident;

use crate::{
    clause::{FromItem, WithItem},
    command::Command,
    inferred_type::InferredType,
    part::TablePath,
    path_ext::PathExt,
    query::{self, Query, QueryNodePath},
    scopes::Scoped,
    visit::Visit,
};

thread_local! {
    static CORRELATION_ID_AUTO_INCREMENT: Cell<u32> = const { Cell::new(0) };
    static CORRELATION_ID_CONTEXT: Cell<Option<CorrelationId>> = const { Cell::new(None) };
}

#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy)]
pub struct CorrelationId(u32);

impl CorrelationId {
    #[must_use]
    pub fn new() -> Self {
        let id = CORRELATION_ID_AUTO_INCREMENT.get();
        CORRELATION_ID_AUTO_INCREMENT.set(id + 1);
        Self(id)
    }

    pub fn scope(&self, f: impl FnOnce()) {
        let previous = CORRELATION_ID_CONTEXT.with(|cell| cell.replace(Some(*self)));
        f();
        CORRELATION_ID_CONTEXT.with(|cell| cell.replace(previous));
    }

    #[must_use]
    pub fn of_scope() -> CorrelationId {
        CORRELATION_ID_CONTEXT
            .get()
            .expect("`ScopeId::of_scope` was called outside of a ScopeId scope")
    }

    pub fn reset() {
        CORRELATION_ID_AUTO_INCREMENT.set(0);
    }
}

impl Default for CorrelationId {
    fn default() -> Self {
        Self::new()
    }
}

impl ToTokens for CorrelationId {
    fn to_tokens(&self, tokens: &mut TokenStream) {
        format_ident!("correlation_{}", self.0).to_tokens(tokens);
    }
}

pub struct Correlations<'a> {
    correlations: Vec<Correlation<'a>>,
}

impl<'a> Correlations<'a> {
    #[must_use]
    pub fn infer_type(
        &'a self,
        correlation_id: CorrelationId,
        column: &'a Ident,
    ) -> Option<InferredType<'a>> {
        let correlation = self
            .correlations
            .iter()
            .find(|correlation| correlation.id() == correlation_id)
            .expect("scope ID must be valid");
        correlation.infer_type(column)
    }
}

impl ToTokens for Correlations<'_> {
    fn to_tokens(&self, tokens: &mut TokenStream) {
        let correlations = &self.correlations;
        quote! {
            mod correlations {
                #(#correlations)*
            }
        }
        .to_tokens(tokens);
    }
}

enum Correlation<'a> {
    Table(&'a TablePath, Option<&'a WithItem>),
    Command(&'a Command),
    WithItem(&'a WithItem),
    FromItem(&'a FromItem),
    QueryNodePath {
        node: &'a query::Node,
        table_path: &'a TablePath,
        node_path: QueryNodePath,
    },
}

impl<'a> Correlation<'a> {
    fn id(&self) -> CorrelationId {
        match self {
            Self::Table(inner, _) => inner.correlation_id,
            Self::Command(inner) => inner.correlation_id,
            Self::WithItem(inner) => inner.correlation_id,
            Self::FromItem(inner) => inner.correlation_id(),
            Self::QueryNodePath { node, .. } => node.correlation_id,
        }
    }

    fn _source_id(&self) -> Option<CorrelationId> {
        match self {
            Self::Table(_, with_item) => {
                with_item.as_ref().map(|with_item| with_item.correlation_id)
            }
            Self::Command(_) => None,
            Self::WithItem(inner) => Some(inner.command.correlation_id),
            Self::FromItem(inner) => match inner {
                FromItem::Table { table_path, .. } => Some(table_path.correlation_id),
                FromItem::Subquery { command, .. } => Some(command.correlation_id),
            },
            Self::QueryNodePath { .. } => None,
        }
    }

    pub fn infer_type(&'a self, column: &'a Ident) -> Option<InferredType<'a>> {
        match self {
            Self::Table(table_path, with_item) => match with_item {
                Some(with_item) => Some(InferredType::Correlation {
                    correlation_id: with_item.correlation_id,
                    column,
                    nullable: false,
                }),
                None => Some(InferredType::TableColumn { table_path, column }),
            },
            Self::Command(command) => {
                let field = command
                    .fields()?
                    .iter()
                    .find(|field| field.infer_name() == Some(column))?;
                match command.select_chain() {
                    Some(select_chain) => field.infer_type(select_chain.start.scope_id()),
                    None => field.infer_type(command.scope_id),
                }
            }
            Self::WithItem(with_item) => Some(InferredType::Correlation {
                correlation_id: with_item.command.correlation_id,
                column,
                nullable: false,
            }),
            Self::FromItem(from_item) => match from_item {
                FromItem::Table { table_path, .. } => Some(InferredType::Correlation {
                    correlation_id: table_path.correlation_id,
                    column,
                    nullable: false,
                }),
                FromItem::Subquery { command, .. } => Some(InferredType::Correlation {
                    correlation_id: command.correlation_id,
                    column,
                    nullable: false,
                }),
            },
            Self::QueryNodePath { table_path, .. } => {
                Some(InferredType::TableColumn { table_path, column })
            }
        }
    }
}

impl ToTokens for Correlation<'_> {
    fn to_tokens(&self, tokens: &mut TokenStream) {
        let id = &self.id();

        match self {
            Self::Table(_, Some(with_item)) => {
                let source_id = with_item.correlation_id;
                quote! {
                    pub use #source_id as #id;
                }
            }
            Self::Table(table_path, None) => {
                let table_path = table_path.as_path().to_call_site(2);
                quote! {
                    pub use #table_path as #id;
                }
            }
            Self::Command(command) => {
                if let Some(fields) = command.fields() {
                    let fields = fields.columns();
                    let field_strings = fields.iter().map(std::string::ToString::to_string);
                    quote! {
                        pub mod #id {
                            pub mod columns {
                                #(
                                    pub mod #fields {
                                        pub const COLUMN_NAME: &str = #field_strings;
                                    }
                                )*
                            }
                        }
                    }
                } else {
                    quote! { pub mod #id {} }
                }
            }
            Self::WithItem(with_item) => {
                if with_item.alias.columns.is_some() {
                    unimplemented!();
                } else {
                    let source_id = with_item.command.correlation_id;
                    let alias = with_item.alias.name.to_string();
                    quote! {
                        pub mod #id {
                            pub const TABLE_NAME: &str = #alias;
                            pub use super::#source_id::columns;
                        }
                    }
                }
            }
            Self::FromItem(from_item) => match from_item {
                FromItem::Table {
                    table_path, alias, ..
                } => {
                    let source_id = table_path.correlation_id;
                    if let Some(alias) = alias {
                        let alias = alias.name.to_string();
                        quote! {
                            pub mod #id {
                                pub const TABLE_NAME: &str = #alias;
                                pub use super::#source_id::columns;
                            }
                        }
                    } else {
                        quote! { pub use #source_id as #id; }
                    }
                }
                FromItem::Subquery { command, alias, .. } => {
                    let source_id = command.correlation_id;
                    if let Some(alias) = alias {
                        let alias = alias.name.to_string();
                        quote! {
                            pub mod #id {
                                pub const TABLE_NAME: &str = #alias;
                                pub use super::#source_id::columns;
                            }
                        }
                    } else {
                        quote! { pub use #source_id as #id; }
                    }
                }
            },
            Self::QueryNodePath {
                table_path,
                node_path,
                ..
            } => {
                let table_path = node_path.resolve(&table_path.as_path().to_call_site(2));
                quote! {
                    pub use #table_path as #id;
                }
            }
        }
        .to_tokens(tokens);
    }
}

impl<'a> From<&'a Command> for Correlations<'a> {
    fn from(value: &'a Command) -> Self {
        #[derive(Default)]
        struct Visitor<'a> {
            correlations: Vec<Correlation<'a>>,
            inherited_with_items: Vec<&'a WithItem>,
        }

        impl<'a> Visit<'a> for Visitor<'a> {
            fn visit_with_item(&mut self, with_item: &'a WithItem) {
                self.correlations.push(Correlation::WithItem(with_item));
                self.visit_command(&with_item.command);
                self.inherited_with_items.push(with_item);
            }

            fn visit_target_table(&mut self, target_table: &'a crate::part::TargetTable) {
                self.correlations
                    .push(Correlation::Table(&target_table.table, None));
            }

            fn visit_from_item(&mut self, from_item: &'a FromItem) {
                self.correlations.push(Correlation::FromItem(from_item));

                match from_item {
                    FromItem::Table { table_path, .. } => {
                        let with_item = match from_item {
                            FromItem::Table { table_path, .. } => match table_path.get_ident() {
                                Some(table) => self
                                    .inherited_with_items
                                    .iter()
                                    .rev()
                                    .find(|with_item| with_item.alias.name == *table),
                                None => None,
                            },
                            FromItem::Subquery { .. } => None,
                        };
                        self.correlations
                            .push(Correlation::Table(table_path, with_item.copied()));
                    }
                    FromItem::Subquery { command, .. } => {
                        self.visit_command(command);
                    }
                }
            }

            fn visit_command(&mut self, command: &'a Command) {
                self.correlations.push(Correlation::Command(command));
                let with_items_truncate = self.inherited_with_items.len();

                if let Some(with) = &command.with {
                    self.visit_with(with);
                }

                if let Some(target_table) = command.target_table() {
                    self.visit_target_table(target_table);
                }

                if let Some(select_chain) = command.select_chain() {
                    self.visit_select_chain(select_chain);
                } else if let Some(from_chain) = command.from_chain() {
                    // We would risk processing the from chain twice if we collected both the
                    // select chain and the from chain of a select.
                    self.visit_from_chain(from_chain);
                }

                self.inherited_with_items.truncate(with_items_truncate);
            }
        }

        let mut visitor = Visitor::default();
        visitor.visit_command(value);
        Correlations {
            correlations: visitor.correlations,
        }
    }
}

impl<'a> From<&'a Query> for Correlations<'a> {
    fn from(value: &'a Query) -> Self {
        struct Visitor<'a> {
            correlations: Vec<Correlation<'a>>,
            query: &'a Query,
            node_path: QueryNodePath,
        }

        impl<'a> Visit<'a> for Visitor<'a> {
            fn visit_node(&mut self, node: &'a query::Node) {
                for field in &node.fields {
                    if let query::Field::Relation { node, name, .. } = field {
                        self.node_path.append(name.clone());
                        self.visit_node(node);
                        self.node_path.pop();
                    }
                }

                self.correlations.push(Correlation::QueryNodePath {
                    node,
                    table_path: &self.query.table,
                    node_path: self.node_path.clone(),
                });
            }
        }

        let mut visitor = Visitor {
            correlations: Vec::new(),
            query: value,
            node_path: QueryNodePath::new(),
        };
        visitor.visit_node(&value.body);
        Correlations {
            correlations: visitor.correlations,
        }
    }
}