use std::future::Future;
use std::pin::Pin;
use async_trait::async_trait;
use schema_core::common::{ColumnName, IndexName};
use schema_core::{
AggregateOp, Column, DatabaseSchema, Field, FieldSource, FlussoType, Geo, Relation, TableName,
};
use crate::{Result, SourceSpec};
#[derive(Debug, Clone)]
pub struct ColumnInfo {
pub sql_type: String,
pub nullable: bool,
}
#[async_trait]
pub trait Catalog: Send + Sync {
async fn column(
&self,
schema: &DatabaseSchema,
table: &TableName,
column: &ColumnName,
) -> Result<ColumnInfo>;
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Severity {
Error,
Warning,
}
#[derive(Debug, Clone)]
pub struct Diagnostic {
pub index: IndexName,
pub field: FieldName,
pub severity: Severity,
pub message: String,
}
type FieldName = schema_core::common::FieldName;
pub async fn validate_indexes(spec: &SourceSpec, catalog: &dyn Catalog) -> Result<Vec<Diagnostic>> {
let mut diagnostics = Vec::new();
for (name, schema) in spec.indexes() {
validate_fields(
name,
&schema.db_schema,
&schema.table,
&schema.fields,
schema.primary_key.as_ref(),
catalog,
&mut diagnostics,
)
.await?;
}
Ok(diagnostics)
}
fn validate_fields<'a>(
index: &'a IndexName,
db_schema: &'a DatabaseSchema,
table: &'a TableName,
fields: &'a [Field],
primary_key: Option<&'a ColumnName>,
catalog: &'a dyn Catalog,
out: &'a mut Vec<Diagnostic>,
) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>> {
Box::pin(async move {
for field in fields {
validate_field(index, db_schema, table, field, primary_key, catalog, out).await?;
}
Ok(())
})
}
async fn validate_field(
index: &IndexName,
db_schema: &DatabaseSchema,
table: &TableName,
field: &Field,
primary_key: Option<&ColumnName>,
catalog: &dyn Catalog,
out: &mut Vec<Diagnostic>,
) -> Result<()> {
match &field.source {
FieldSource::Column(column) => {
validate_column(
index,
db_schema,
table,
&field.field,
column,
primary_key,
catalog,
out,
)
.await?;
}
FieldSource::Relation(Relation::Aggregate(aggregate)) => {
let column = match &aggregate.op {
AggregateOp::Sum(c) | AggregateOp::Min(c) | AggregateOp::Max(c) => Some(c),
AggregateOp::Count | AggregateOp::Avg(_) => None,
};
if let (Some(column), Some(value_type)) = (column, &aggregate.value_type) {
check_type(
index,
db_schema,
&aggregate.table,
&field.field,
column,
value_type,
catalog,
out,
)
.await?;
}
}
FieldSource::Group(fields) => {
validate_fields(index, db_schema, table, fields, primary_key, catalog, out).await?;
}
FieldSource::Relation(Relation::Join(join)) => {
validate_fields(
index,
db_schema,
&join.table,
&join.fields,
Some(&join.primary_key),
catalog,
out,
)
.await?;
}
FieldSource::Geo(geo) => {
validate_geo(index, db_schema, table, &field.field, geo, catalog, out).await?;
}
FieldSource::Constant(_) => {}
}
Ok(())
}
async fn validate_geo(
index: &IndexName,
db_schema: &DatabaseSchema,
table: &TableName,
field: &FieldName,
geo: &Geo,
catalog: &dyn Catalog,
out: &mut Vec<Diagnostic>,
) -> Result<()> {
const NUMERIC: &[FlussoType] = &[
FlussoType::Double,
FlussoType::Float,
FlussoType::Decimal,
FlussoType::Integer,
FlussoType::Long,
FlussoType::Short,
];
for column in [&geo.lat, &geo.lon] {
let info = catalog.column(db_schema, table, column).await?;
if !NUMERIC.iter().any(|ty| ty.accepts_pg(&info.sql_type)) {
out.push(Diagnostic {
index: index.clone(),
field: field.clone(),
severity: Severity::Error,
message: format!(
"geo_point coordinate column `{column}` must be numeric, found `{}`",
info.sql_type
),
});
}
}
Ok(())
}
#[allow(clippy::too_many_arguments)]
async fn validate_column(
index: &IndexName,
db_schema: &DatabaseSchema,
table: &TableName,
field: &FieldName,
column: &Column,
primary_key: Option<&ColumnName>,
catalog: &dyn Catalog,
out: &mut Vec<Diagnostic>,
) -> Result<()> {
let info = catalog.column(db_schema, table, &column.column).await?;
if !column.ty.accepts_pg(&info.sql_type) {
out.push(Diagnostic {
index: index.clone(),
field: field.clone(),
severity: Severity::Error,
message: format!(
"declared type does not accept the column's database type `{}`",
info.sql_type
),
});
}
let forced_non_null = primary_key == Some(&column.column) || column.default.is_some();
if !column.nullable && info.nullable && !forced_non_null {
out.push(Diagnostic {
index: index.clone(),
field: field.clone(),
severity: Severity::Warning,
message: "declared non-null (`required`) but the database column allows null"
.to_owned(),
});
}
Ok(())
}
#[allow(clippy::too_many_arguments)]
async fn check_type(
index: &IndexName,
db_schema: &DatabaseSchema,
table: &TableName,
field: &FieldName,
column: &ColumnName,
declared: &FlussoType,
catalog: &dyn Catalog,
out: &mut Vec<Diagnostic>,
) -> Result<()> {
let info = catalog.column(db_schema, table, column).await?;
if !declared.accepts_pg(&info.sql_type) {
out.push(Diagnostic {
index: index.clone(),
field: field.clone(),
severity: Severity::Error,
message: format!(
"declared aggregate type does not accept the column's database type `{}`",
info.sql_type
),
});
}
Ok(())
}