squawk-ide 2.50.0

Linter for Postgres migrations & SQL
Documentation
use smol_str::SmolStr;
use squawk_syntax::ast::{self, AstNode};
use std::fmt;

use squawk_syntax::quote::normalize_identifier;

#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub(crate) struct Name(pub(crate) SmolStr);

#[derive(Clone, Debug, PartialEq, Eq)]
pub(crate) struct Schema(pub(crate) Name);

impl Schema {
    pub(crate) fn new(name: impl Into<SmolStr>) -> Self {
        Schema(Name::from_string(name))
    }
}

impl fmt::Display for Schema {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        write!(f, "{}", self.0.0)
    }
}

impl Name {
    pub(crate) fn from_string(text: impl Into<SmolStr>) -> Self {
        let text = text.into();
        let normalized = normalize_identifier(&text);
        Name(normalized.into())
    }
    pub(crate) fn from_node(node: &impl ast::NameLike) -> Self {
        let text = node.syntax().text().to_string();
        let normalized = normalize_identifier(&text);
        Name(normalized.into())
    }
}

impl fmt::Display for Name {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        write!(f, "{}", self.0)
    }
}

pub(crate) fn schema_and_name_path(path: &ast::Path) -> Option<(Option<Schema>, Name)> {
    Some((schema_name(path), table_name(path)?))
}

pub(crate) fn schema_and_table_name(name_ref: &ast::NameRef) -> Option<(Option<Schema>, Name)> {
    if let Some(path) = name_ref.syntax().ancestors().find_map(ast::Path::cast) {
        return schema_and_name_path(&path);
    }

    Some((None, Name::from_node(name_ref)))
}

pub(crate) fn schema_and_name(name_ref: &ast::NameRef) -> (Option<Schema>, Name) {
    let table_name = Name::from_node(name_ref);
    let schema = if let Some(parent) = name_ref.syntax().parent()
        && let Some(base) = ast::FieldExpr::cast(parent).and_then(|x| x.base())
        && let Some(schema_name_ref) = ast::NameRef::cast(base.syntax().clone())
    {
        Some(Schema(Name::from_node(&schema_name_ref)))
    } else {
        None
    };

    (schema, table_name)
}

pub(crate) fn schema_and_func_name(call_expr: &ast::CallExpr) -> Option<(Option<Schema>, Name)> {
    match call_expr.expr()? {
        ast::Expr::NameRef(name_ref) => Some((None, Name::from_node(&name_ref))),
        ast::Expr::FieldExpr(field_expr) => {
            let function_name = Name::from_node(&field_expr.field()?);
            let ast::Expr::NameRef(schema_name_ref) = field_expr.base()? else {
                return None;
            };
            let schema = Schema(Name::from_node(&schema_name_ref));
            Some((Some(schema), function_name))
        }
        _ => None,
    }
}

pub(crate) fn table_name(path: &ast::Path) -> Option<Name> {
    let segment = path.segment()?;
    if let Some(name_ref) = segment.name_ref() {
        return Some(Name::from_node(&name_ref));
    }
    if let Some(name) = segment.name() {
        return Some(Name::from_node(&name));
    }
    None
}

pub(crate) fn schema_name(path: &ast::Path) -> Option<Schema> {
    path.qualifier()
        .and_then(|q| q.segment())
        .and_then(|s| s.name_ref())
        .map(|name_ref| Schema(Name::from_node(&name_ref)))
}

pub(crate) fn schema_and_table_from_from_item(
    from_item: &ast::FromItem,
) -> Option<(Option<Schema>, Name)> {
    if let Some(name_ref_node) = from_item.name_ref() {
        Some((None, Name::from_node(&name_ref_node)))
    } else if let Some(from_field_expr) = from_item.field_expr() {
        let table_name = Name::from_node(&from_field_expr.field()?);
        let ast::Expr::NameRef(schema_name_ref) = from_field_expr.base()? else {
            return None;
        };
        let schema = Schema(Name::from_node(&schema_name_ref));
        Some((Some(schema), table_name))
    } else {
        None
    }
}

pub(crate) fn schema_and_table_from_field_expr(
    field_expr: &ast::FieldExpr,
) -> Option<(Option<Schema>, Name)> {
    match field_expr.base()? {
        ast::Expr::NameRef(name_ref) => Some((None, Name::from_node(&name_ref))),
        ast::Expr::FieldExpr(field_expr) => {
            let field = field_expr.field()?;
            let ast::Expr::NameRef(schema) = field_expr.base()? else {
                return None;
            };
            Some((
                Some(Schema(Name::from_node(&schema))),
                Name::from_node(&field),
            ))
        }
        _ => None,
    }
}

pub(crate) fn schema_and_type_name(ty: &ast::Type) -> Option<(Option<Schema>, Name)> {
    match ty {
        ast::Type::ArrayType(array_type) => {
            let inner = array_type.ty()?;
            schema_and_type_name(&inner)
        }
        ast::Type::BitType(bit_type) => {
            let name = if bit_type.varying_token().is_some() {
                "varbit"
            } else {
                "bit"
            };
            Some((None, Name::from_string(name)))
        }
        ast::Type::IntervalType(_) => Some((None, Name::from_string("interval"))),
        ast::Type::PathType(path_type) => {
            let path = path_type.path()?;
            schema_and_name_path(&path)
        }
        ast::Type::ExprType(expr_type) => {
            if let ast::Expr::FieldExpr(field_expr) = expr_type.expr()?
                && let Some(field) = field_expr.field()
                && let Some(ast::Expr::NameRef(schema_name_ref)) = field_expr.base()
            {
                let type_name = Name::from_node(&field);
                let schema = Some(Schema(Name::from_node(&schema_name_ref)));
                Some((schema, type_name))
            } else {
                None
            }
        }
        ast::Type::CharType(char_type) => {
            let name = if char_type.varchar_token().is_some() || char_type.varying_token().is_some()
            {
                "varchar"
            } else {
                "bpchar"
            };
            Some((None, Name::from_string(name)))
        }
        ast::Type::DoubleType(_) => Some((None, Name::from_string("float8"))),
        ast::Type::TimeType(time_type) => {
            let mut name = if time_type.timestamp_token().is_some() {
                "timestamp".to_string()
            } else {
                "time".to_string()
            };
            if let Some(ast::Timezone::WithTimezone(_)) = time_type.timezone() {
                name.push_str("tz");
            }
            Some((None, Name::from_string(name)))
        }
        ast::Type::PercentType(_) => None,
    }
}

#[cfg(test)]
mod test {
    use super::*;
    #[test]
    fn name_case_insensitive_compare() {
        assert_eq!(Name::from_string("foo"), Name::from_string("FOO"));
    }

    #[test]
    fn name_quote_comparing() {
        assert_eq!(Name::from_string(r#""foo""#), Name::from_string("foo"));
    }
}