Skip to main content

cedros_data/store/
schema.rs

1use std::collections::{BTreeMap, BTreeSet};
2
3use sqlx::Row;
4use sqlx::{Postgres, Transaction};
5
6use crate::custom_schema::{
7    apply_additive_schema, diff_custom_schema, parse_ref_table, validate_custom_schema,
8};
9use crate::defaults::SITE_SCHEMA_NAME;
10use crate::error::{CedrosDataError, Result};
11use crate::models::{CustomSchemaApplyReport, CustomSchemaDefinition, RegisterCustomSchemaRequest};
12
13use super::CedrosData;
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16enum SchemaApplyDecision {
17    Noop { version: i32 },
18    Breaking { version: i32 },
19    Apply { version: i32 },
20}
21
22impl CedrosData {
23    pub async fn register_custom_schema(
24        &self,
25        request: RegisterCustomSchemaRequest,
26    ) -> Result<CustomSchemaApplyReport> {
27        self.ensure_site_exists().await?;
28        let mut tx = self.pool().begin().await?;
29        let report = self
30            .register_custom_schema_in_transaction(&mut tx, &request.definition)
31            .await?;
32        if report.applied {
33            tx.commit().await?;
34        }
35        Ok(report)
36    }
37
38    pub(super) async fn register_custom_schema_in_transaction(
39        &self,
40        tx: &mut Transaction<'_, Postgres>,
41        definition: &CustomSchemaDefinition,
42    ) -> Result<CustomSchemaApplyReport> {
43        validate_custom_schema(definition)?;
44
45        let (current_version, current_definition) = load_current_state(tx).await?;
46        validate_foreign_key_targets(
47            tx,
48            SITE_SCHEMA_NAME,
49            current_definition.as_ref(),
50            definition,
51        )
52        .await?;
53        let diff = diff_custom_schema(current_definition.as_ref(), definition);
54        let next_version = match schema_apply_decision(
55            current_version,
56            !diff.additive_changes.is_empty(),
57            !diff.breaking_changes.is_empty(),
58        ) {
59            SchemaApplyDecision::Noop { version } => {
60                return Ok(CustomSchemaApplyReport {
61                    applied: true,
62                    version,
63                    additive_changes: Vec::new(),
64                    breaking_changes: Vec::new(),
65                    generated_sql: Vec::new(),
66                });
67            }
68            SchemaApplyDecision::Breaking { version } => {
69                return Ok(CustomSchemaApplyReport {
70                    applied: false,
71                    version,
72                    additive_changes: diff.additive_changes,
73                    breaking_changes: diff.breaking_changes,
74                    generated_sql: Vec::new(),
75                });
76            }
77            SchemaApplyDecision::Apply { version } => version,
78        };
79
80        let generated_sql = apply_additive_schema(
81            tx,
82            SITE_SCHEMA_NAME,
83            current_definition.as_ref(),
84            definition,
85        )
86        .await?;
87        self.invalidate_typed_table_columns_cache_for_schema(SITE_SCHEMA_NAME)
88            .await;
89        let report = CustomSchemaApplyReport {
90            applied: true,
91            version: next_version,
92            additive_changes: diff.additive_changes,
93            breaking_changes: Vec::new(),
94            generated_sql: generated_sql.clone(),
95        };
96
97        upsert_schema_state(tx, next_version, definition).await?;
98        record_schema_migration(tx, next_version, &report, &generated_sql).await?;
99        Ok(report)
100    }
101}
102
103fn schema_apply_decision(
104    current_version: i32,
105    has_additive_changes: bool,
106    has_breaking_changes: bool,
107) -> SchemaApplyDecision {
108    if !has_additive_changes && !has_breaking_changes {
109        return SchemaApplyDecision::Noop {
110            version: current_version,
111        };
112    }
113
114    let next_version = current_version + 1;
115    if has_breaking_changes {
116        return SchemaApplyDecision::Breaking {
117            version: next_version,
118        };
119    }
120
121    SchemaApplyDecision::Apply {
122        version: next_version,
123    }
124}
125
126async fn load_current_state(
127    tx: &mut Transaction<'_, Postgres>,
128) -> Result<(i32, Option<CustomSchemaDefinition>)> {
129    let row = sqlx::query(
130        "SELECT version, definition
131         FROM custom_schema_state
132         WHERE id = 1",
133    )
134    .fetch_optional(&mut **tx)
135    .await?;
136
137    parse_current_state(row)
138}
139
140fn parse_current_state(
141    row: Option<sqlx::postgres::PgRow>,
142) -> Result<(i32, Option<CustomSchemaDefinition>)> {
143    let Some(row) = row else {
144        return Ok((0, None));
145    };
146
147    let definition_value = row.get::<serde_json::Value, _>("definition");
148    let definition = serde_json::from_value(definition_value)?;
149    Ok((row.get::<i32, _>("version"), Some(definition)))
150}
151
152async fn validate_foreign_key_targets(
153    tx: &mut Transaction<'_, Postgres>,
154    schema_name: &str,
155    current_definition: Option<&CustomSchemaDefinition>,
156    incoming_definition: &CustomSchemaDefinition,
157) -> Result<()> {
158    let incoming_tables = table_columns_by_name(incoming_definition);
159    let current_tables = current_definition
160        .map(table_columns_by_name)
161        .unwrap_or_default();
162    let mut db_tables = BTreeMap::new();
163
164    for table in &incoming_definition.tables {
165        for foreign_key in &table.foreign_keys {
166            let (ref_schema, ref_table) = parse_ref_table(schema_name, &foreign_key.ref_table);
167            let ref_columns = match local_ref_table_columns(
168                schema_name,
169                &ref_schema,
170                &ref_table,
171                &incoming_tables,
172                &current_tables,
173            ) {
174                Some(columns) => columns.clone(),
175                None => load_table_columns(tx, &ref_schema, &ref_table, &mut db_tables).await?,
176            };
177
178            if ref_columns.is_empty() {
179                return Err(CedrosDataError::InvalidRequest(format!(
180                    "foreign key references unknown table {}.{}",
181                    ref_schema, ref_table
182                )));
183            }
184
185            for ref_column in &foreign_key.ref_columns {
186                if ref_columns.contains(ref_column) {
187                    continue;
188                }
189                return Err(CedrosDataError::InvalidRequest(format!(
190                    "foreign key references unknown column {}.{}",
191                    ref_table, ref_column
192                )));
193            }
194        }
195    }
196
197    Ok(())
198}
199
200fn table_columns_by_name(
201    definition: &CustomSchemaDefinition,
202) -> BTreeMap<String, BTreeSet<String>> {
203    definition
204        .tables
205        .iter()
206        .map(|table| {
207            (
208                table.name.clone(),
209                table
210                    .columns
211                    .iter()
212                    .map(|column| column.name.clone())
213                    .collect::<BTreeSet<String>>(),
214            )
215        })
216        .collect()
217}
218
219fn local_ref_table_columns<'a>(
220    schema_name: &str,
221    ref_schema: &str,
222    ref_table: &str,
223    incoming_tables: &'a BTreeMap<String, BTreeSet<String>>,
224    current_tables: &'a BTreeMap<String, BTreeSet<String>>,
225) -> Option<&'a BTreeSet<String>> {
226    if ref_schema != schema_name {
227        return None;
228    }
229
230    incoming_tables
231        .get(ref_table)
232        .or_else(|| current_tables.get(ref_table))
233}
234
235async fn load_table_columns(
236    tx: &mut Transaction<'_, Postgres>,
237    table_schema: &str,
238    table_name: &str,
239    cache: &mut BTreeMap<(String, String), BTreeSet<String>>,
240) -> Result<BTreeSet<String>> {
241    let cache_key = (table_schema.to_string(), table_name.to_string());
242    if let Some(columns) = cache.get(&cache_key) {
243        return Ok(columns.clone());
244    }
245
246    let rows = sqlx::query(
247        "SELECT column_name
248         FROM information_schema.columns
249         WHERE table_schema = $1 AND table_name = $2",
250    )
251    .bind(table_schema)
252    .bind(table_name)
253    .fetch_all(&mut **tx)
254    .await?;
255
256    let columns = rows
257        .into_iter()
258        .map(|row| row.get::<String, _>("column_name"))
259        .collect::<BTreeSet<String>>();
260    cache.insert(cache_key, columns.clone());
261    Ok(columns)
262}
263
264async fn upsert_schema_state(
265    tx: &mut Transaction<'_, Postgres>,
266    version: i32,
267    definition: &CustomSchemaDefinition,
268) -> Result<()> {
269    sqlx::query(
270        "INSERT INTO custom_schema_state (id, version, definition, updated_at)
271         VALUES (1, $1, $2, NOW())
272         ON CONFLICT (id)
273         DO UPDATE SET
274             version = EXCLUDED.version,
275             definition = EXCLUDED.definition,
276             updated_at = NOW()",
277    )
278    .bind(version)
279    .bind(serde_json::to_value(definition)?)
280    .execute(&mut **tx)
281    .await?;
282    Ok(())
283}
284
285async fn record_schema_migration(
286    tx: &mut Transaction<'_, Postgres>,
287    version: i32,
288    report: &CustomSchemaApplyReport,
289    generated_sql: &[String],
290) -> Result<()> {
291    sqlx::query(
292        "INSERT INTO custom_schema_migrations (version, report, generated_sql)
293         VALUES ($1, $2, $3)
294         ON CONFLICT (version)
295         DO NOTHING",
296    )
297    .bind(version)
298    .bind(serde_json::to_value(report)?)
299    .bind(serde_json::to_value(generated_sql)?)
300    .execute(&mut **tx)
301    .await?;
302    Ok(())
303}
304
305#[cfg(test)]
306mod tests {
307    use super::{schema_apply_decision, SchemaApplyDecision};
308
309    #[test]
310    fn schema_apply_decision_keeps_version_for_noop_diff() {
311        assert_eq!(
312            schema_apply_decision(3, false, false),
313            SchemaApplyDecision::Noop { version: 3 }
314        );
315    }
316
317    #[test]
318    fn schema_apply_decision_advances_version_for_breaking_diff() {
319        assert_eq!(
320            schema_apply_decision(3, false, true),
321            SchemaApplyDecision::Breaking { version: 4 }
322        );
323    }
324
325    #[test]
326    fn schema_apply_decision_advances_version_for_additive_diff() {
327        assert_eq!(
328            schema_apply_decision(3, true, false),
329            SchemaApplyDecision::Apply { version: 4 }
330        );
331    }
332}