use std::rc::Rc;
use heck::ToUpperCamelCase;
use indexmap::{map::Entry, IndexMap};
use postgres::Client;
use postgres_types::{Kind, Type};
use crate::{
parser::{Module, NullableIdent, Query, Span, TypeAnnotation},
read_queries::ModuleInfo,
type_registrar::CornucopiaType,
type_registrar::TypeRegistrar,
utils::escape_keyword,
validation,
};
use self::error::Error;
#[derive(Debug, Clone)]
pub(crate) struct PreparedQuery {
pub(crate) name: String,
pub(crate) param: Option<(usize, Vec<usize>)>,
pub(crate) row: Option<(usize, Vec<usize>)>,
pub(crate) sql: String,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct PreparedField {
pub(crate) name: String,
pub(crate) ty: Rc<CornucopiaType>,
pub(crate) is_nullable: bool,
pub(crate) is_inner_nullable: bool, }
impl PreparedField {
pub(crate) fn new(
name: String,
ty: Rc<CornucopiaType>,
nullity: Option<&NullableIdent>,
) -> Self {
Self {
name: escape_keyword(name),
ty,
is_nullable: nullity.map_or(false, |it| it.nullable),
is_inner_nullable: nullity.map_or(false, |it| it.inner_nullable),
}
}
}
impl PreparedField {
pub fn unwrapped_name(&self) -> String {
self.own_struct()
.replace(['<', '>', '_'], "")
.to_upper_camel_case()
}
}
#[derive(Debug, Clone)]
pub(crate) struct PreparedItem {
pub(crate) name: Span<String>,
pub(crate) fields: Vec<PreparedField>,
pub(crate) is_copy: bool,
pub(crate) is_named: bool,
pub(crate) is_ref: bool,
}
impl PreparedItem {
pub fn new(name: Span<String>, fields: Vec<PreparedField>, is_implicit: bool) -> Self {
Self {
name,
is_copy: fields.iter().all(|f| f.ty.is_copy()),
is_ref: fields.iter().any(|f| f.ty.is_ref()),
is_named: !is_implicit || fields.len() > 1,
fields,
}
}
}
#[derive(PartialEq, Eq, Debug, Clone)]
pub(crate) struct PreparedType {
pub(crate) name: String,
pub(crate) struct_name: String,
pub(crate) content: PreparedContent,
pub(crate) is_copy: bool,
pub(crate) is_params: bool,
}
#[derive(PartialEq, Eq, Debug, Clone)]
pub(crate) enum PreparedContent {
Enum(Vec<String>),
Composite(Vec<PreparedField>),
}
#[derive(Debug, Clone)]
pub(crate) struct PreparedModule {
pub(crate) info: ModuleInfo,
pub(crate) queries: IndexMap<Span<String>, PreparedQuery>,
pub(crate) params: IndexMap<Span<String>, PreparedItem>,
pub(crate) rows: IndexMap<Span<String>, PreparedItem>,
}
#[derive(Debug, Clone)]
pub(crate) struct Preparation {
pub(crate) modules: Vec<PreparedModule>,
pub(crate) types: IndexMap<String, Vec<PreparedType>>,
}
impl PreparedModule {
fn add(
info: &ModuleInfo,
map: &mut IndexMap<Span<String>, PreparedItem>,
name: Span<String>,
fields: Vec<PreparedField>,
is_implicit: bool,
) -> Result<(usize, Vec<usize>), Error> {
assert!(!fields.is_empty());
match map.entry(name.clone()) {
Entry::Occupied(o) => {
let prev = &o.get();
let indexes: Vec<_> = if prev.is_named {
validation::named_struct_field(info, &prev.name, &prev.fields, &name, &fields)?;
prev.fields
.iter()
.map(|f| fields.iter().position(|it| it == f).unwrap())
.collect()
} else {
vec![0]
};
Ok((o.index(), indexes))
}
Entry::Vacant(v) => {
v.insert(PreparedItem::new(name.clone(), fields.clone(), is_implicit));
Self::add(info, map, name, fields, is_implicit)
}
}
}
fn add_row(
&mut self,
name: Span<String>,
fields: Vec<PreparedField>,
is_implicit: bool,
) -> Result<(usize, Vec<usize>), Error> {
let fuck = if fields.len() == 1 && is_implicit {
name.map(|_| fields[0].unwrapped_name())
} else {
name
};
Self::add(&self.info, &mut self.rows, fuck, fields, is_implicit)
}
fn add_param(
&mut self,
name: Span<String>,
fields: Vec<PreparedField>,
is_implicit: bool,
) -> Result<(usize, Vec<usize>), Error> {
Self::add(&self.info, &mut self.params, name, fields, is_implicit)
}
fn add_query(
&mut self,
name: Span<String>,
param_idx: Option<(usize, Vec<usize>)>,
row_idx: Option<(usize, Vec<usize>)>,
sql: String,
) {
self.queries.insert(
name.clone(),
PreparedQuery {
name: name.value,
row: row_idx,
sql,
param: param_idx,
},
);
}
}
pub(crate) fn prepare(client: &mut Client, modules: Vec<Module>) -> Result<Preparation, Error> {
let mut registrar = TypeRegistrar::default();
let mut tmp = Preparation {
modules: Vec::new(),
types: IndexMap::new(),
};
let declared: Vec<_> = modules
.iter()
.flat_map(|it| &it.types)
.map(|ty| (*ty).clone())
.collect();
for module in modules {
tmp.modules
.push(prepare_module(client, module, &mut registrar)?);
}
for ((schema, name), ty) in ®istrar.types {
if let Some(ty) = prepare_type(®istrar, name, ty, &declared) {
match tmp.types.entry(schema.clone()) {
Entry::Occupied(mut entry) => {
entry.get_mut().push(ty);
}
Entry::Vacant(entry) => {
entry.insert(vec![ty]);
}
}
}
}
Ok(tmp)
}
fn normalize_rust_name(name: &str) -> String {
name.replace(':', "_")
}
fn prepare_type(
registrar: &TypeRegistrar,
name: &str,
ty: &CornucopiaType,
types: &[TypeAnnotation],
) -> Option<PreparedType> {
if let CornucopiaType::Custom {
pg_ty,
struct_name,
is_copy,
is_params,
..
} = ty
{
let declared = types
.iter()
.find(|it| it.name.value == pg_ty.name())
.map_or(&[] as &[NullableIdent], |it| it.fields.as_slice());
let content = match pg_ty.kind() {
Kind::Enum(variants) => {
PreparedContent::Enum(variants.clone().into_iter().map(escape_keyword).collect())
}
Kind::Domain(_) => return None,
Kind::Composite(fields) => PreparedContent::Composite(
fields
.iter()
.map(|field| {
let nullity = declared.iter().find(|it| it.name.value == field.name());
PreparedField::new(
field.name().to_string(),
registrar.ref_of(field.type_()),
nullity,
)
})
.collect(),
),
_ => unreachable!(),
};
Some(PreparedType {
name: name.to_string(),
struct_name: struct_name.clone(),
content,
is_copy: *is_copy,
is_params: *is_params,
})
} else {
None
}
}
fn prepare_module(
client: &mut Client,
module: Module,
registrar: &mut TypeRegistrar,
) -> Result<PreparedModule, Error> {
validation::validate_module(&module)?;
let mut tmp_prepared_module = PreparedModule {
info: module.info.clone(),
queries: IndexMap::new(),
params: IndexMap::new(),
rows: IndexMap::new(),
};
for query in module.queries {
prepare_query(
client,
&mut tmp_prepared_module,
registrar,
&module.types,
query,
&module.info,
)?;
}
validation::validate_preparation(&tmp_prepared_module)?;
Ok(tmp_prepared_module)
}
fn prepare_query(
client: &mut Client,
module: &mut PreparedModule,
registrar: &mut TypeRegistrar,
types: &[TypeAnnotation],
Query {
name,
param,
bind_params,
row,
sql_str,
sql_span,
}: Query,
module_info: &ModuleInfo,
) -> Result<(), Error> {
let stmt = client
.prepare(&sql_str)
.map_err(|e| Error::new_db_err(&e, module_info, &sql_span, &name))?;
let (nullable_params_fields, params_name) = param.name_and_fields(types, &name, Some("Params"));
let (nullable_row_fields, row_name) = row.name_and_fields(types, &name, None);
let params_fields = {
let stmt_params = stmt.params();
let params = bind_params
.iter()
.zip(stmt_params)
.map(|(a, b)| (a.clone(), b.clone()))
.collect::<Vec<(Span<String>, Type)>>();
validation::param_on_simple_query(&module.info, &name, &sql_span, ¶m, ¶ms)?;
for nullable_col in nullable_params_fields {
validation::nullable_param_name(&module.info, nullable_col, ¶ms)
.map_err(Error::from)?;
}
let mut param_fields = Vec::new();
for (col_name, col_ty) in params {
let nullity = nullable_params_fields
.iter()
.find(|x| x.name.value == col_name.value);
param_fields.push(PreparedField::new(
col_name.value.clone(),
registrar
.register(&col_name.value, &col_ty, &name, module_info)?
.clone(),
nullity,
));
}
param_fields
};
let row_fields = {
let stmt_cols = stmt.columns();
validation::row_on_execute(&module.info, &name, &sql_span, &row, stmt_cols)?;
validation::duplicate_sql_col_name(&module.info, &name, stmt_cols).map_err(Error::from)?;
for nullable_col in nullable_row_fields {
validation::nullable_column_name(&module.info, nullable_col, stmt_cols)
.map_err(Error::from)?;
}
let mut row_fields = Vec::new();
for (col_name, col_ty) in stmt_cols.iter().map(|c| (c.name().to_owned(), c.type_())) {
let nullity = nullable_row_fields
.iter()
.find(|x| x.name.value == col_name);
let ty = registrar
.register(&col_name, col_ty, &name, module_info)?
.clone();
row_fields.push(PreparedField::new(
normalize_rust_name(&col_name),
ty,
nullity,
));
}
row_fields
};
let row_idx = if row_fields.is_empty() {
None
} else {
Some(module.add_row(row_name, row_fields, row.is_implicit())?)
};
let param_idx = if params_fields.is_empty() {
None
} else {
Some(module.add_param(params_name, params_fields, param.is_implicit())?)
};
module.add_query(name.clone(), param_idx, row_idx, sql_str);
Ok(())
}
pub(crate) mod error {
use miette::{Diagnostic, NamedSource, SourceSpan};
use thiserror::Error as ThisError;
use crate::{
parser::Span, read_queries::ModuleInfo, type_registrar::error::Error as PostgresTypeError,
utils::db_err, validation::error::Error as ValidationError,
};
#[derive(Debug, ThisError, Diagnostic)]
pub enum Error {
#[error("Couldn't prepare query: {msg}")]
Db {
msg: String,
#[help]
help: Option<String>,
#[source_code]
src: NamedSource,
#[label("error occurs near this location")]
err_span: Option<SourceSpan>,
},
#[error(transparent)]
#[diagnostic(transparent)]
PostgresType(#[from] PostgresTypeError),
#[error(transparent)]
#[diagnostic(transparent)]
Validation(#[from] Box<ValidationError>),
}
impl Error {
pub(crate) fn new_db_err(
err: &postgres::Error,
module_info: &ModuleInfo,
query_span: &SourceSpan,
query_name: &Span<String>,
) -> Self {
let msg = format!("{:#}", err);
if let Some((position, msg, help)) = db_err(err) {
Self::Db {
msg,
help,
src: module_info.into(),
err_span: Some((query_span.offset() + position as usize - 1).into()),
}
} else {
Self::Db {
msg,
help: None,
src: module_info.into(),
err_span: Some(query_name.span),
}
}
}
}
}