1use std::collections::{BTreeMap, BTreeSet};
2
3use sqlx::Row;
4use sqlx::{Postgres, Transaction};
5
6use crate::custom_schema::{
7 apply_additive_schema, diff_custom_schema, parse_ref_table, validate_custom_schema,
8};
9use crate::defaults::SITE_SCHEMA_NAME;
10use crate::error::{CedrosDataError, Result};
11use crate::models::{CustomSchemaApplyReport, CustomSchemaDefinition, RegisterCustomSchemaRequest};
12
13use super::CedrosData;
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16enum SchemaApplyDecision {
17 Noop { version: i32 },
18 Breaking { version: i32 },
19 Apply { version: i32 },
20}
21
22impl CedrosData {
23 pub async fn register_custom_schema(
24 &self,
25 request: RegisterCustomSchemaRequest,
26 ) -> Result<CustomSchemaApplyReport> {
27 self.ensure_site_exists().await?;
28 let mut tx = self.pool().begin().await?;
29 let report = self
30 .register_custom_schema_in_transaction(&mut tx, &request.definition)
31 .await?;
32 if report.applied {
33 tx.commit().await?;
34 }
35 Ok(report)
36 }
37
38 pub(super) async fn register_custom_schema_in_transaction(
39 &self,
40 tx: &mut Transaction<'_, Postgres>,
41 definition: &CustomSchemaDefinition,
42 ) -> Result<CustomSchemaApplyReport> {
43 validate_custom_schema(definition)?;
44
45 let (current_version, current_definition) = load_current_state(tx).await?;
46 validate_foreign_key_targets(
47 tx,
48 SITE_SCHEMA_NAME,
49 current_definition.as_ref(),
50 definition,
51 )
52 .await?;
53 let diff = diff_custom_schema(current_definition.as_ref(), definition);
54 let next_version = match schema_apply_decision(
55 current_version,
56 !diff.additive_changes.is_empty(),
57 !diff.breaking_changes.is_empty(),
58 ) {
59 SchemaApplyDecision::Noop { version } => {
60 return Ok(CustomSchemaApplyReport {
61 applied: true,
62 version,
63 additive_changes: Vec::new(),
64 breaking_changes: Vec::new(),
65 generated_sql: Vec::new(),
66 });
67 }
68 SchemaApplyDecision::Breaking { version } => {
69 return Ok(CustomSchemaApplyReport {
70 applied: false,
71 version,
72 additive_changes: diff.additive_changes,
73 breaking_changes: diff.breaking_changes,
74 generated_sql: Vec::new(),
75 });
76 }
77 SchemaApplyDecision::Apply { version } => version,
78 };
79
80 let generated_sql = apply_additive_schema(
81 tx,
82 SITE_SCHEMA_NAME,
83 current_definition.as_ref(),
84 definition,
85 )
86 .await?;
87 self.invalidate_typed_table_columns_cache_for_schema(SITE_SCHEMA_NAME)
88 .await;
89 let report = CustomSchemaApplyReport {
90 applied: true,
91 version: next_version,
92 additive_changes: diff.additive_changes,
93 breaking_changes: Vec::new(),
94 generated_sql: generated_sql.clone(),
95 };
96
97 upsert_schema_state(tx, next_version, definition).await?;
98 record_schema_migration(tx, next_version, &report, &generated_sql).await?;
99 Ok(report)
100 }
101}
102
103fn schema_apply_decision(
104 current_version: i32,
105 has_additive_changes: bool,
106 has_breaking_changes: bool,
107) -> SchemaApplyDecision {
108 if !has_additive_changes && !has_breaking_changes {
109 return SchemaApplyDecision::Noop {
110 version: current_version,
111 };
112 }
113
114 let next_version = current_version + 1;
115 if has_breaking_changes {
116 return SchemaApplyDecision::Breaking {
117 version: next_version,
118 };
119 }
120
121 SchemaApplyDecision::Apply {
122 version: next_version,
123 }
124}
125
126async fn load_current_state(
127 tx: &mut Transaction<'_, Postgres>,
128) -> Result<(i32, Option<CustomSchemaDefinition>)> {
129 let row = sqlx::query(
130 "SELECT version, definition
131 FROM cedros_data.custom_schema_state
132 WHERE id = 1",
133 )
134 .fetch_optional(&mut **tx)
135 .await?;
136
137 parse_current_state(row)
138}
139
140fn parse_current_state(
141 row: Option<sqlx::postgres::PgRow>,
142) -> Result<(i32, Option<CustomSchemaDefinition>)> {
143 let Some(row) = row else {
144 return Ok((0, None));
145 };
146
147 let definition_value = row.get::<serde_json::Value, _>("definition");
148 let definition = serde_json::from_value(definition_value)?;
149 Ok((row.get::<i32, _>("version"), Some(definition)))
150}
151
152async fn validate_foreign_key_targets(
153 tx: &mut Transaction<'_, Postgres>,
154 schema_name: &str,
155 current_definition: Option<&CustomSchemaDefinition>,
156 incoming_definition: &CustomSchemaDefinition,
157) -> Result<()> {
158 let incoming_tables = table_columns_by_name(incoming_definition);
159 let current_tables = current_definition
160 .map(table_columns_by_name)
161 .unwrap_or_default();
162 let mut db_tables = BTreeMap::new();
163
164 for table in &incoming_definition.tables {
165 for foreign_key in &table.foreign_keys {
166 let (ref_schema, ref_table) = parse_ref_table(schema_name, &foreign_key.ref_table);
167 let ref_columns = match local_ref_table_columns(
168 schema_name,
169 &ref_schema,
170 &ref_table,
171 &incoming_tables,
172 ¤t_tables,
173 ) {
174 Some(columns) => columns.clone(),
175 None => load_table_columns(tx, &ref_schema, &ref_table, &mut db_tables).await?,
176 };
177
178 if ref_columns.is_empty() {
179 return Err(CedrosDataError::InvalidRequest(format!(
180 "foreign key references unknown table {}.{}",
181 ref_schema, ref_table
182 )));
183 }
184
185 for ref_column in &foreign_key.ref_columns {
186 if ref_columns.contains(ref_column) {
187 continue;
188 }
189 return Err(CedrosDataError::InvalidRequest(format!(
190 "foreign key references unknown column {}.{}",
191 ref_table, ref_column
192 )));
193 }
194 }
195 }
196
197 Ok(())
198}
199
200fn table_columns_by_name(
201 definition: &CustomSchemaDefinition,
202) -> BTreeMap<String, BTreeSet<String>> {
203 definition
204 .tables
205 .iter()
206 .map(|table| {
207 (
208 table.name.clone(),
209 table
210 .columns
211 .iter()
212 .map(|column| column.name.clone())
213 .collect::<BTreeSet<String>>(),
214 )
215 })
216 .collect()
217}
218
219fn local_ref_table_columns<'a>(
220 schema_name: &str,
221 ref_schema: &str,
222 ref_table: &str,
223 incoming_tables: &'a BTreeMap<String, BTreeSet<String>>,
224 current_tables: &'a BTreeMap<String, BTreeSet<String>>,
225) -> Option<&'a BTreeSet<String>> {
226 if ref_schema != schema_name {
227 return None;
228 }
229
230 incoming_tables
231 .get(ref_table)
232 .or_else(|| current_tables.get(ref_table))
233}
234
235async fn load_table_columns(
236 tx: &mut Transaction<'_, Postgres>,
237 table_schema: &str,
238 table_name: &str,
239 cache: &mut BTreeMap<(String, String), BTreeSet<String>>,
240) -> Result<BTreeSet<String>> {
241 let cache_key = (table_schema.to_string(), table_name.to_string());
242 if let Some(columns) = cache.get(&cache_key) {
243 return Ok(columns.clone());
244 }
245
246 let rows = sqlx::query(
247 "SELECT column_name
248 FROM information_schema.columns
249 WHERE table_schema = $1 AND table_name = $2",
250 )
251 .bind(table_schema)
252 .bind(table_name)
253 .fetch_all(&mut **tx)
254 .await?;
255
256 let columns = rows
257 .into_iter()
258 .map(|row| row.get::<String, _>("column_name"))
259 .collect::<BTreeSet<String>>();
260 cache.insert(cache_key, columns.clone());
261 Ok(columns)
262}
263
264async fn upsert_schema_state(
265 tx: &mut Transaction<'_, Postgres>,
266 version: i32,
267 definition: &CustomSchemaDefinition,
268) -> Result<()> {
269 sqlx::query(
270 "INSERT INTO cedros_data.custom_schema_state (id, version, definition, updated_at)
271 VALUES (1, $1, $2, NOW())
272 ON CONFLICT (id)
273 DO UPDATE SET
274 version = EXCLUDED.version,
275 definition = EXCLUDED.definition,
276 updated_at = NOW()",
277 )
278 .bind(version)
279 .bind(serde_json::to_value(definition)?)
280 .execute(&mut **tx)
281 .await?;
282 Ok(())
283}
284
285async fn record_schema_migration(
286 tx: &mut Transaction<'_, Postgres>,
287 version: i32,
288 report: &CustomSchemaApplyReport,
289 generated_sql: &[String],
290) -> Result<()> {
291 sqlx::query(
292 "INSERT INTO cedros_data.custom_schema_migrations (version, report, generated_sql)
293 VALUES ($1, $2, $3)
294 ON CONFLICT (version)
295 DO NOTHING",
296 )
297 .bind(version)
298 .bind(serde_json::to_value(report)?)
299 .bind(serde_json::to_value(generated_sql)?)
300 .execute(&mut **tx)
301 .await?;
302 Ok(())
303}
304
305#[cfg(test)]
306mod tests {
307 use super::{schema_apply_decision, SchemaApplyDecision};
308
309 #[test]
310 fn schema_apply_decision_keeps_version_for_noop_diff() {
311 assert_eq!(
312 schema_apply_decision(3, false, false),
313 SchemaApplyDecision::Noop { version: 3 }
314 );
315 }
316
317 #[test]
318 fn schema_apply_decision_advances_version_for_breaking_diff() {
319 assert_eq!(
320 schema_apply_decision(3, false, true),
321 SchemaApplyDecision::Breaking { version: 4 }
322 );
323 }
324
325 #[test]
326 fn schema_apply_decision_advances_version_for_additive_diff() {
327 assert_eq!(
328 schema_apply_decision(3, true, false),
329 SchemaApplyDecision::Apply { version: 4 }
330 );
331 }
332}