1use std::future::Future;
17use std::pin::Pin;
18
19use async_trait::async_trait;
20use schema_core::common::{ColumnName, IndexName};
21use schema_core::{
22 AggregateOp, Column, DatabaseSchema, Field, FieldSource, FlussoType, Geo, Relation, TableName,
23};
24
25use crate::{Result, SourceSpec};
26
27#[derive(Debug, Clone)]
31pub struct ColumnInfo {
32 pub sql_type: String,
33 pub nullable: bool,
34}
35
36#[async_trait]
40pub trait Catalog: Send + Sync {
41 async fn column(
44 &self,
45 schema: &DatabaseSchema,
46 table: &TableName,
47 column: &ColumnName,
48 ) -> Result<ColumnInfo>;
49}
50
51#[derive(Debug, Clone, Copy, PartialEq, Eq)]
53pub enum Severity {
54 Error,
57 Warning,
60}
61
62#[derive(Debug, Clone)]
64pub struct Diagnostic {
65 pub index: IndexName,
66 pub field: FieldName,
67 pub severity: Severity,
68 pub message: String,
69}
70
71type FieldName = schema_core::common::FieldName;
72
73pub async fn validate_indexes(spec: &SourceSpec, catalog: &dyn Catalog) -> Result<Vec<Diagnostic>> {
77 let mut diagnostics = Vec::new();
78 for (name, schema) in spec.indexes() {
79 validate_fields(
80 name,
81 &schema.db_schema,
82 &schema.table,
83 &schema.fields,
84 schema.primary_key.as_ref(),
85 catalog,
86 &mut diagnostics,
87 )
88 .await?;
89 }
90 Ok(diagnostics)
91}
92
93fn validate_fields<'a>(
96 index: &'a IndexName,
97 db_schema: &'a DatabaseSchema,
98 table: &'a TableName,
99 fields: &'a [Field],
100 primary_key: Option<&'a ColumnName>,
101 catalog: &'a dyn Catalog,
102 out: &'a mut Vec<Diagnostic>,
103) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>> {
104 Box::pin(async move {
105 for field in fields {
106 validate_field(index, db_schema, table, field, primary_key, catalog, out).await?;
107 }
108 Ok(())
109 })
110}
111
112async fn validate_field(
113 index: &IndexName,
114 db_schema: &DatabaseSchema,
115 table: &TableName,
116 field: &Field,
117 primary_key: Option<&ColumnName>,
118 catalog: &dyn Catalog,
119 out: &mut Vec<Diagnostic>,
120) -> Result<()> {
121 match &field.source {
122 FieldSource::Column(column) => {
123 validate_column(
124 index,
125 db_schema,
126 table,
127 &field.field,
128 column,
129 primary_key,
130 catalog,
131 out,
132 )
133 .await?;
134 }
135 FieldSource::Relation(Relation::Aggregate(aggregate)) => {
136 let column = match &aggregate.op {
137 AggregateOp::Sum(c) | AggregateOp::Min(c) | AggregateOp::Max(c) => Some(c),
138 AggregateOp::Count | AggregateOp::Avg(_) => None,
139 };
140 if let (Some(column), Some(value_type)) = (column, &aggregate.value_type) {
141 check_type(
142 index,
143 db_schema,
144 &aggregate.table,
145 &field.field,
146 column,
147 value_type,
148 catalog,
149 out,
150 )
151 .await?;
152 }
153 }
154 FieldSource::Group(fields) => {
155 validate_fields(index, db_schema, table, fields, primary_key, catalog, out).await?;
156 }
157 FieldSource::Relation(Relation::Join(join)) => {
158 validate_fields(
159 index,
160 db_schema,
161 &join.table,
162 &join.fields,
163 Some(&join.primary_key),
164 catalog,
165 out,
166 )
167 .await?;
168 }
169 FieldSource::Geo(geo) => {
170 validate_geo(index, db_schema, table, &field.field, geo, catalog, out).await?;
171 }
172 FieldSource::Constant(_) => {}
173 }
174 Ok(())
175}
176
177async fn validate_geo(
179 index: &IndexName,
180 db_schema: &DatabaseSchema,
181 table: &TableName,
182 field: &FieldName,
183 geo: &Geo,
184 catalog: &dyn Catalog,
185 out: &mut Vec<Diagnostic>,
186) -> Result<()> {
187 const NUMERIC: &[FlussoType] = &[
188 FlussoType::Double,
189 FlussoType::Float,
190 FlussoType::Decimal,
191 FlussoType::Integer,
192 FlussoType::Long,
193 FlussoType::Short,
194 ];
195 for column in [&geo.lat, &geo.lon] {
196 let info = catalog.column(db_schema, table, column).await?;
197 if !NUMERIC.iter().any(|ty| ty.accepts_pg(&info.sql_type)) {
198 out.push(Diagnostic {
199 index: index.clone(),
200 field: field.clone(),
201 severity: Severity::Error,
202 message: format!(
203 "geo_point coordinate column `{column}` must be numeric, found `{}`",
204 info.sql_type
205 ),
206 });
207 }
208 }
209 Ok(())
210}
211
212#[allow(clippy::too_many_arguments)]
213async fn validate_column(
214 index: &IndexName,
215 db_schema: &DatabaseSchema,
216 table: &TableName,
217 field: &FieldName,
218 column: &Column,
219 primary_key: Option<&ColumnName>,
220 catalog: &dyn Catalog,
221 out: &mut Vec<Diagnostic>,
222) -> Result<()> {
223 let info = catalog.column(db_schema, table, &column.column).await?;
224
225 if !column.ty.accepts_pg(&info.sql_type) {
226 out.push(Diagnostic {
227 index: index.clone(),
228 field: field.clone(),
229 severity: Severity::Error,
230 message: format!(
231 "declared type does not accept the column's database type `{}`",
232 info.sql_type
233 ),
234 });
235 }
236
237 let forced_non_null = primary_key == Some(&column.column) || column.default.is_some();
240 if !column.nullable && info.nullable && !forced_non_null {
241 out.push(Diagnostic {
242 index: index.clone(),
243 field: field.clone(),
244 severity: Severity::Warning,
245 message: "declared non-null (`required`) but the database column allows null"
246 .to_owned(),
247 });
248 }
249
250 Ok(())
251}
252
253#[allow(clippy::too_many_arguments)]
254async fn check_type(
255 index: &IndexName,
256 db_schema: &DatabaseSchema,
257 table: &TableName,
258 field: &FieldName,
259 column: &ColumnName,
260 declared: &FlussoType,
261 catalog: &dyn Catalog,
262 out: &mut Vec<Diagnostic>,
263) -> Result<()> {
264 let info = catalog.column(db_schema, table, column).await?;
265 if !declared.accepts_pg(&info.sql_type) {
266 out.push(Diagnostic {
267 index: index.clone(),
268 field: field.clone(),
269 severity: Severity::Error,
270 message: format!(
271 "declared aggregate type does not accept the column's database type `{}`",
272 info.sql_type
273 ),
274 });
275 }
276 Ok(())
277}