Skip to main content

cedros_data/store/
mod.rs

1#[cfg(feature = "storage")]
2mod assets;
3mod contracts;
4mod defaults;
5mod entries;
6mod imex;
7pub(crate) mod metered_reads;
8mod migrate;
9mod schema;
10mod site;
11
12use std::{collections::BTreeMap, sync::Arc, time::Instant};
13
14use sqlx::{postgres::PgPoolOptions, Executor, PgPool, Postgres, Row, Transaction};
15use tokio::sync::RwLock;
16
17use crate::config::CedrosDataConfig;
18use crate::error::{CedrosDataError, Result};
19use crate::models::{Collection, CollectionMode, ContractSchema, Site};
20
21use self::entries::typed_table_columns_with;
22
23type TypedTableColumns = BTreeMap<String, bool>;
24type TypedTableCacheKey = (String, String);
25type TypedTableColumnsCache = BTreeMap<TypedTableCacheKey, TypedTableColumns>;
26
27#[derive(Clone)]
28pub struct CedrosData {
29    pool: PgPool,
30    typed_table_columns_cache: Arc<RwLock<TypedTableColumnsCache>>,
31    /// Cached (timestamp, value) for freeReadsPerMonth with 5-min TTL.
32    free_reads_cache: Arc<RwLock<Option<(Instant, i64)>>>,
33}
34
35impl CedrosData {
36    pub async fn from_env() -> Result<Self> {
37        let config = CedrosDataConfig::from_env()?;
38        Self::connect(&config.postgres_uri).await
39    }
40
41    pub async fn connect(postgres_uri: &str) -> Result<Self> {
42        let pool = PgPoolOptions::new()
43            .max_connections(10)
44            .connect(postgres_uri)
45            .await?;
46        Ok(Self {
47            pool,
48            typed_table_columns_cache: Arc::new(RwLock::new(BTreeMap::new())),
49            free_reads_cache: Arc::new(RwLock::new(None)),
50        })
51    }
52
53    pub fn pool(&self) -> &PgPool {
54        &self.pool
55    }
56
57    async fn current_site(&self) -> Result<Site> {
58        self.ensure_site_exists().await?;
59        load_current_site_with(&self.pool).await
60    }
61
62    async fn find_collection(&self, collection_name: &str) -> Result<Collection> {
63        find_collection_with(&self.pool, collection_name).await
64    }
65
66    async fn latest_contract(&self, collection_name: &str) -> Result<Option<ContractSchema>> {
67        latest_contract_with(&self.pool, collection_name).await
68    }
69
70    pub(super) async fn typed_table_columns_in_transaction(
71        &self,
72        tx: &mut Transaction<'_, Postgres>,
73        schema_name: &str,
74        table_name: &str,
75    ) -> Result<TypedTableColumns> {
76        let cache_key = typed_table_cache_key(schema_name, table_name);
77        if let Some(columns) = self
78            .typed_table_columns_cache
79            .read()
80            .await
81            .get(&cache_key)
82            .cloned()
83        {
84            return Ok(columns);
85        }
86
87        let columns = typed_table_columns_with(&mut **tx, schema_name, table_name).await?;
88        self.typed_table_columns_cache
89            .write()
90            .await
91            .insert(cache_key, columns.clone());
92        Ok(columns)
93    }
94
95    pub(super) async fn validate_typed_table_compatibility_in_transaction(
96        &self,
97        tx: &mut Transaction<'_, Postgres>,
98        schema_name: &str,
99        table_name: &str,
100    ) -> Result<()> {
101        self.typed_table_columns_in_transaction(tx, schema_name, table_name)
102            .await
103            .map(|_| ())
104    }
105
106    pub(super) async fn invalidate_typed_table_columns_cache(
107        &self,
108        schema_name: &str,
109        table_name: &str,
110    ) {
111        let mut cache = self.typed_table_columns_cache.write().await;
112        invalidate_typed_table_columns_cache_entry(&mut cache, schema_name, table_name);
113    }
114
115    pub(super) async fn invalidate_typed_table_columns_cache_for_schema(&self, schema_name: &str) {
116        let mut cache = self.typed_table_columns_cache.write().await;
117        invalidate_typed_table_columns_cache_entries_for_schema(&mut cache, schema_name);
118    }
119}
120
121fn typed_table_cache_key(schema_name: &str, table_name: &str) -> TypedTableCacheKey {
122    (schema_name.to_string(), table_name.to_string())
123}
124
125fn invalidate_typed_table_columns_cache_entry(
126    cache: &mut TypedTableColumnsCache,
127    schema_name: &str,
128    table_name: &str,
129) {
130    cache.remove(&typed_table_cache_key(schema_name, table_name));
131}
132
133fn invalidate_typed_table_columns_cache_entries_for_schema(
134    cache: &mut TypedTableColumnsCache,
135    schema_name: &str,
136) {
137    cache.retain(|(cached_schema, _), _| cached_schema != schema_name);
138}
139
140pub(super) async fn load_site_if_configured_with<'e, E>(executor: E) -> Result<Option<Site>>
141where
142    E: Executor<'e, Database = Postgres>,
143{
144    let row = sqlx::query(
145        "SELECT display_name, metadata
146         FROM site
147         WHERE id = 1",
148    )
149    .fetch_optional(executor)
150    .await?;
151
152    row.map(map_site_row).transpose()
153}
154
155pub(super) async fn load_current_site_with<'e, E>(executor: E) -> Result<Site>
156where
157    E: Executor<'e, Database = Postgres>,
158{
159    load_site_if_configured_with(executor)
160        .await?
161        .ok_or(CedrosDataError::SiteNotConfigured)
162}
163
164pub(super) async fn find_collection_with<'e, E>(
165    executor: E,
166    collection_name: &str,
167) -> Result<Collection>
168where
169    E: Executor<'e, Database = Postgres>,
170{
171    let row = sqlx::query(
172        "SELECT collection_name, mode, table_name, strict_contract
173         FROM collections
174         WHERE collection_name = $1",
175    )
176    .bind(collection_name)
177    .fetch_optional(executor)
178    .await?;
179
180    row.map(map_collection_row)
181        .transpose()?
182        .ok_or_else(|| CedrosDataError::CollectionNotFound(collection_name.to_string()))
183}
184
185pub(super) async fn latest_contract_with<'e, E>(
186    executor: E,
187    collection_name: &str,
188) -> Result<Option<ContractSchema>>
189where
190    E: Executor<'e, Database = Postgres>,
191{
192    let row = sqlx::query(
193        "SELECT contract
194         FROM collection_contracts
195         WHERE collection_name = $1
196         ORDER BY version DESC
197         LIMIT 1",
198    )
199    .bind(collection_name)
200    .fetch_optional(executor)
201    .await?;
202
203    let Some(row) = row else {
204        return Ok(None);
205    };
206
207    let contract_value = row.get::<serde_json::Value, _>("contract");
208    Ok(Some(serde_json::from_value(contract_value)?))
209}
210
211pub(super) async fn next_contract_version_with<'e, E>(
212    executor: E,
213    collection_name: &str,
214) -> Result<i32>
215where
216    E: Executor<'e, Database = Postgres>,
217{
218    let row = sqlx::query(
219        "SELECT COALESCE(MAX(version), 0) AS version
220         FROM collection_contracts
221         WHERE collection_name = $1",
222    )
223    .bind(collection_name)
224    .fetch_one(executor)
225    .await?;
226
227    Ok(row.get::<i32, _>("version") + 1)
228}
229
230pub(super) async fn insert_contract_in_transaction(
231    tx: &mut Transaction<'_, Postgres>,
232    collection_name: &str,
233    contract: &ContractSchema,
234) -> Result<i32> {
235    lock_collection_for_contracts(tx, collection_name).await?;
236    let version = next_contract_version_with(&mut **tx, collection_name).await?;
237    insert_contract_version_in_transaction(tx, collection_name, version, contract).await?;
238    Ok(version)
239}
240
241pub(super) async fn insert_contract_version_in_transaction(
242    tx: &mut Transaction<'_, Postgres>,
243    collection_name: &str,
244    version: i32,
245    contract: &ContractSchema,
246) -> Result<()> {
247    let existing = sqlx::query(
248        "SELECT contract
249         FROM collection_contracts
250         WHERE collection_name = $1 AND version = $2",
251    )
252    .bind(collection_name)
253    .bind(version)
254    .fetch_optional(&mut **tx)
255    .await?;
256
257    if let Some(existing_row) = existing {
258        let existing_contract =
259            serde_json::from_value(existing_row.get::<serde_json::Value, _>("contract"))?;
260        ensure_matching_contract_version(collection_name, version, &existing_contract, contract)?;
261        return Ok(());
262    }
263
264    sqlx::query(
265        "INSERT INTO collection_contracts (collection_name, version, contract)
266         VALUES ($1, $2, $3)",
267    )
268    .bind(collection_name)
269    .bind(version)
270    .bind(serde_json::to_value(contract)?)
271    .execute(&mut **tx)
272    .await?;
273    Ok(())
274}
275
276async fn lock_collection_for_contracts(
277    tx: &mut Transaction<'_, Postgres>,
278    collection_name: &str,
279) -> Result<()> {
280    let row = sqlx::query(
281        "SELECT 1
282         FROM collections
283         WHERE collection_name = $1
284         FOR UPDATE",
285    )
286    .bind(collection_name)
287    .fetch_optional(&mut **tx)
288    .await?;
289
290    if row.is_some() {
291        return Ok(());
292    }
293
294    Err(CedrosDataError::InvalidRequest(format!(
295        "collection {} is missing during contract version allocation",
296        collection_name
297    )))
298}
299
300fn ensure_matching_contract_version(
301    collection_name: &str,
302    version: i32,
303    existing: &ContractSchema,
304    incoming: &ContractSchema,
305) -> Result<()> {
306    if serde_json::to_value(existing)? == serde_json::to_value(incoming)? {
307        return Ok(());
308    }
309
310    Err(CedrosDataError::InvalidRequest(format!(
311        "import contract version conflict for {collection_name} v{version}"
312    )))
313}
314
315pub(super) fn map_site_row(row: sqlx::postgres::PgRow) -> Result<Site> {
316    Ok(Site {
317        display_name: row.get("display_name"),
318        metadata: row.get("metadata"),
319    })
320}
321
322pub(super) fn map_collection_row(row: sqlx::postgres::PgRow) -> Result<Collection> {
323    let mode = match row.get::<String, _>("mode").as_str() {
324        "jsonb" => CollectionMode::Jsonb,
325        "typed" => CollectionMode::Typed,
326        other => {
327            return Err(CedrosDataError::InvalidRequest(format!(
328                "unknown collection mode: {other}"
329            )));
330        }
331    };
332
333    let strict_contract = row
334        .get::<Option<serde_json::Value>, _>("strict_contract")
335        .map(serde_json::from_value)
336        .transpose()?;
337
338    Ok(Collection {
339        collection_name: row.get("collection_name"),
340        mode,
341        table_name: row.get("table_name"),
342        strict_contract,
343    })
344}
345
346#[cfg(test)]
347mod tests {
348    use std::collections::BTreeMap;
349
350    use crate::models::{ContractField, ContractSchema, ValueType};
351
352    use super::{
353        ensure_matching_contract_version, invalidate_typed_table_columns_cache_entries_for_schema,
354        invalidate_typed_table_columns_cache_entry, typed_table_cache_key,
355    };
356
357    #[test]
358    fn matching_contract_versions_are_idempotent() {
359        let contract = sample_contract("title");
360        assert!(ensure_matching_contract_version("articles", 2, &contract, &contract).is_ok());
361    }
362
363    #[test]
364    fn conflicting_contract_versions_are_rejected() {
365        let existing = sample_contract("title");
366        let incoming = sample_contract("summary");
367
368        assert!(ensure_matching_contract_version("articles", 2, &existing, &incoming).is_err());
369    }
370
371    #[test]
372    fn invalidating_one_typed_table_cache_entry_preserves_other_entries() {
373        let mut cache = sample_typed_table_cache();
374
375        invalidate_typed_table_columns_cache_entry(&mut cache, "site_data", "pages");
376
377        assert!(!cache.contains_key(&typed_table_cache_key("site_data", "pages")));
378        assert!(cache.contains_key(&typed_table_cache_key("site_data", "articles")));
379        assert!(cache.contains_key(&typed_table_cache_key("site_blog", "pages")));
380    }
381
382    #[test]
383    fn invalidating_typed_table_cache_by_schema_removes_matching_entries() {
384        let mut cache = sample_typed_table_cache();
385
386        invalidate_typed_table_columns_cache_entries_for_schema(&mut cache, "site_data");
387
388        assert!(!cache.contains_key(&typed_table_cache_key("site_data", "pages")));
389        assert!(!cache.contains_key(&typed_table_cache_key("site_data", "articles")));
390        assert!(cache.contains_key(&typed_table_cache_key("site_blog", "pages")));
391    }
392
393    fn sample_contract(path: &str) -> ContractSchema {
394        ContractSchema {
395            fields: vec![ContractField {
396                path: path.to_string(),
397                required: true,
398                types: vec![ValueType::String],
399            }],
400        }
401    }
402
403    fn sample_typed_table_cache() -> BTreeMap<(String, String), BTreeMap<String, bool>> {
404        BTreeMap::from([
405            (
406                typed_table_cache_key("site_data", "pages"),
407                BTreeMap::from([(String::from("entry_key"), false)]),
408            ),
409            (
410                typed_table_cache_key("site_data", "articles"),
411                BTreeMap::from([(String::from("entry_key"), false)]),
412            ),
413            (
414                typed_table_cache_key("site_blog", "pages"),
415                BTreeMap::from([(String::from("entry_key"), false)]),
416            ),
417        ])
418    }
419}