1#[cfg(feature = "storage")]
2mod assets;
3mod contracts;
4mod defaults;
5mod entries;
6mod imex;
7pub(crate) mod metered_reads;
8mod migrate;
9mod schema;
10mod site;
11
12use std::{collections::BTreeMap, sync::Arc, time::Instant};
13
14use sqlx::{postgres::PgPoolOptions, Executor, PgPool, Postgres, Row, Transaction};
15use tokio::sync::RwLock;
16
17use crate::config::CedrosDataConfig;
18use crate::error::{CedrosDataError, Result};
19use crate::models::{Collection, CollectionMode, ContractSchema, Site};
20
21use self::entries::typed_table_columns_with;
22
23type TypedTableColumns = BTreeMap<String, bool>;
24type TypedTableCacheKey = (String, String);
25type TypedTableColumnsCache = BTreeMap<TypedTableCacheKey, TypedTableColumns>;
26
27#[derive(Clone)]
28pub struct CedrosData {
29 pool: PgPool,
30 typed_table_columns_cache: Arc<RwLock<TypedTableColumnsCache>>,
31 free_reads_cache: Arc<RwLock<Option<(Instant, i64)>>>,
33}
34
35impl CedrosData {
36 pub async fn from_env() -> Result<Self> {
37 let config = CedrosDataConfig::from_env()?;
38 Self::connect(&config.postgres_uri).await
39 }
40
41 pub async fn connect(postgres_uri: &str) -> Result<Self> {
42 let pool = PgPoolOptions::new()
43 .max_connections(10)
44 .connect(postgres_uri)
45 .await?;
46 Ok(Self {
47 pool,
48 typed_table_columns_cache: Arc::new(RwLock::new(BTreeMap::new())),
49 free_reads_cache: Arc::new(RwLock::new(None)),
50 })
51 }
52
53 pub fn pool(&self) -> &PgPool {
54 &self.pool
55 }
56
57 async fn current_site(&self) -> Result<Site> {
58 self.ensure_site_exists().await?;
59 load_current_site_with(&self.pool).await
60 }
61
62 async fn find_collection(&self, collection_name: &str) -> Result<Collection> {
63 find_collection_with(&self.pool, collection_name).await
64 }
65
66 async fn latest_contract(&self, collection_name: &str) -> Result<Option<ContractSchema>> {
67 latest_contract_with(&self.pool, collection_name).await
68 }
69
70 pub(super) async fn typed_table_columns_in_transaction(
71 &self,
72 tx: &mut Transaction<'_, Postgres>,
73 schema_name: &str,
74 table_name: &str,
75 ) -> Result<TypedTableColumns> {
76 let cache_key = typed_table_cache_key(schema_name, table_name);
77 if let Some(columns) = self
78 .typed_table_columns_cache
79 .read()
80 .await
81 .get(&cache_key)
82 .cloned()
83 {
84 return Ok(columns);
85 }
86
87 let columns = typed_table_columns_with(&mut **tx, schema_name, table_name).await?;
88 self.typed_table_columns_cache
89 .write()
90 .await
91 .insert(cache_key, columns.clone());
92 Ok(columns)
93 }
94
95 pub(super) async fn validate_typed_table_compatibility_in_transaction(
96 &self,
97 tx: &mut Transaction<'_, Postgres>,
98 schema_name: &str,
99 table_name: &str,
100 ) -> Result<()> {
101 self.typed_table_columns_in_transaction(tx, schema_name, table_name)
102 .await
103 .map(|_| ())
104 }
105
106 pub(super) async fn invalidate_typed_table_columns_cache(
107 &self,
108 schema_name: &str,
109 table_name: &str,
110 ) {
111 let mut cache = self.typed_table_columns_cache.write().await;
112 invalidate_typed_table_columns_cache_entry(&mut cache, schema_name, table_name);
113 }
114
115 pub(super) async fn invalidate_typed_table_columns_cache_for_schema(&self, schema_name: &str) {
116 let mut cache = self.typed_table_columns_cache.write().await;
117 invalidate_typed_table_columns_cache_entries_for_schema(&mut cache, schema_name);
118 }
119}
120
121fn typed_table_cache_key(schema_name: &str, table_name: &str) -> TypedTableCacheKey {
122 (schema_name.to_string(), table_name.to_string())
123}
124
125fn invalidate_typed_table_columns_cache_entry(
126 cache: &mut TypedTableColumnsCache,
127 schema_name: &str,
128 table_name: &str,
129) {
130 cache.remove(&typed_table_cache_key(schema_name, table_name));
131}
132
133fn invalidate_typed_table_columns_cache_entries_for_schema(
134 cache: &mut TypedTableColumnsCache,
135 schema_name: &str,
136) {
137 cache.retain(|(cached_schema, _), _| cached_schema != schema_name);
138}
139
140pub(super) async fn load_site_if_configured_with<'e, E>(executor: E) -> Result<Option<Site>>
141where
142 E: Executor<'e, Database = Postgres>,
143{
144 let row = sqlx::query(
145 "SELECT display_name, metadata
146 FROM cedros_data.site
147 WHERE id = 1",
148 )
149 .fetch_optional(executor)
150 .await?;
151
152 row.map(map_site_row).transpose()
153}
154
155pub(super) async fn load_current_site_with<'e, E>(executor: E) -> Result<Site>
156where
157 E: Executor<'e, Database = Postgres>,
158{
159 load_site_if_configured_with(executor)
160 .await?
161 .ok_or(CedrosDataError::SiteNotConfigured)
162}
163
164pub(super) async fn find_collection_with<'e, E>(
165 executor: E,
166 collection_name: &str,
167) -> Result<Collection>
168where
169 E: Executor<'e, Database = Postgres>,
170{
171 let row = sqlx::query(
172 "SELECT collection_name, mode, table_name, strict_contract
173 FROM cedros_data.collections
174 WHERE collection_name = $1",
175 )
176 .bind(collection_name)
177 .fetch_optional(executor)
178 .await?;
179
180 row.map(map_collection_row)
181 .transpose()?
182 .ok_or_else(|| CedrosDataError::CollectionNotFound(collection_name.to_string()))
183}
184
185pub(super) async fn latest_contract_with<'e, E>(
186 executor: E,
187 collection_name: &str,
188) -> Result<Option<ContractSchema>>
189where
190 E: Executor<'e, Database = Postgres>,
191{
192 let row = sqlx::query(
193 "SELECT contract
194 FROM cedros_data.collection_contracts
195 WHERE collection_name = $1
196 ORDER BY version DESC
197 LIMIT 1",
198 )
199 .bind(collection_name)
200 .fetch_optional(executor)
201 .await?;
202
203 let Some(row) = row else {
204 return Ok(None);
205 };
206
207 let contract_value = row.get::<serde_json::Value, _>("contract");
208 Ok(Some(serde_json::from_value(contract_value)?))
209}
210
211pub(super) async fn next_contract_version_with<'e, E>(
212 executor: E,
213 collection_name: &str,
214) -> Result<i32>
215where
216 E: Executor<'e, Database = Postgres>,
217{
218 let row = sqlx::query(
219 "SELECT COALESCE(MAX(version), 0) AS version
220 FROM cedros_data.collection_contracts
221 WHERE collection_name = $1",
222 )
223 .bind(collection_name)
224 .fetch_one(executor)
225 .await?;
226
227 Ok(row.get::<i32, _>("version") + 1)
228}
229
230pub(super) async fn insert_contract_in_transaction(
231 tx: &mut Transaction<'_, Postgres>,
232 collection_name: &str,
233 contract: &ContractSchema,
234) -> Result<i32> {
235 lock_collection_for_contracts(tx, collection_name).await?;
236 let version = next_contract_version_with(&mut **tx, collection_name).await?;
237 insert_contract_version_in_transaction(tx, collection_name, version, contract).await?;
238 Ok(version)
239}
240
241pub(super) async fn insert_contract_version_in_transaction(
242 tx: &mut Transaction<'_, Postgres>,
243 collection_name: &str,
244 version: i32,
245 contract: &ContractSchema,
246) -> Result<()> {
247 let existing = sqlx::query(
248 "SELECT contract
249 FROM cedros_data.collection_contracts
250 WHERE collection_name = $1 AND version = $2",
251 )
252 .bind(collection_name)
253 .bind(version)
254 .fetch_optional(&mut **tx)
255 .await?;
256
257 if let Some(existing_row) = existing {
258 let existing_contract =
259 serde_json::from_value(existing_row.get::<serde_json::Value, _>("contract"))?;
260 ensure_matching_contract_version(collection_name, version, &existing_contract, contract)?;
261 return Ok(());
262 }
263
264 sqlx::query(
265 "INSERT INTO cedros_data.collection_contracts (collection_name, version, contract)
266 VALUES ($1, $2, $3)",
267 )
268 .bind(collection_name)
269 .bind(version)
270 .bind(serde_json::to_value(contract)?)
271 .execute(&mut **tx)
272 .await?;
273 Ok(())
274}
275
276async fn lock_collection_for_contracts(
277 tx: &mut Transaction<'_, Postgres>,
278 collection_name: &str,
279) -> Result<()> {
280 let row = sqlx::query(
281 "SELECT 1
282 FROM cedros_data.collections
283 WHERE collection_name = $1
284 FOR UPDATE",
285 )
286 .bind(collection_name)
287 .fetch_optional(&mut **tx)
288 .await?;
289
290 if row.is_some() {
291 return Ok(());
292 }
293
294 Err(CedrosDataError::InvalidRequest(format!(
295 "collection {} is missing during contract version allocation",
296 collection_name
297 )))
298}
299
300fn ensure_matching_contract_version(
301 collection_name: &str,
302 version: i32,
303 existing: &ContractSchema,
304 incoming: &ContractSchema,
305) -> Result<()> {
306 if serde_json::to_value(existing)? == serde_json::to_value(incoming)? {
307 return Ok(());
308 }
309
310 Err(CedrosDataError::InvalidRequest(format!(
311 "import contract version conflict for {collection_name} v{version}"
312 )))
313}
314
315pub(super) fn map_site_row(row: sqlx::postgres::PgRow) -> Result<Site> {
316 Ok(Site {
317 display_name: row.get("display_name"),
318 metadata: row.get("metadata"),
319 })
320}
321
322pub(super) fn map_collection_row(row: sqlx::postgres::PgRow) -> Result<Collection> {
323 let mode = match row.get::<String, _>("mode").as_str() {
324 "jsonb" => CollectionMode::Jsonb,
325 "typed" => CollectionMode::Typed,
326 other => {
327 return Err(CedrosDataError::InvalidRequest(format!(
328 "unknown collection mode: {other}"
329 )));
330 }
331 };
332
333 let strict_contract = row
334 .get::<Option<serde_json::Value>, _>("strict_contract")
335 .map(serde_json::from_value)
336 .transpose()?;
337
338 Ok(Collection {
339 collection_name: row.get("collection_name"),
340 mode,
341 table_name: row.get("table_name"),
342 strict_contract,
343 })
344}
345
346#[cfg(test)]
347mod tests {
348 use std::collections::BTreeMap;
349
350 use crate::models::{ContractField, ContractSchema, ValueType};
351
352 use super::{
353 ensure_matching_contract_version, invalidate_typed_table_columns_cache_entries_for_schema,
354 invalidate_typed_table_columns_cache_entry, typed_table_cache_key,
355 };
356
357 #[test]
358 fn matching_contract_versions_are_idempotent() {
359 let contract = sample_contract("title");
360 assert!(ensure_matching_contract_version("articles", 2, &contract, &contract).is_ok());
361 }
362
363 #[test]
364 fn conflicting_contract_versions_are_rejected() {
365 let existing = sample_contract("title");
366 let incoming = sample_contract("summary");
367
368 assert!(ensure_matching_contract_version("articles", 2, &existing, &incoming).is_err());
369 }
370
371 #[test]
372 fn invalidating_one_typed_table_cache_entry_preserves_other_entries() {
373 let mut cache = sample_typed_table_cache();
374
375 invalidate_typed_table_columns_cache_entry(&mut cache, "site_data", "pages");
376
377 assert!(!cache.contains_key(&typed_table_cache_key("site_data", "pages")));
378 assert!(cache.contains_key(&typed_table_cache_key("site_data", "articles")));
379 assert!(cache.contains_key(&typed_table_cache_key("site_blog", "pages")));
380 }
381
382 #[test]
383 fn invalidating_typed_table_cache_by_schema_removes_matching_entries() {
384 let mut cache = sample_typed_table_cache();
385
386 invalidate_typed_table_columns_cache_entries_for_schema(&mut cache, "site_data");
387
388 assert!(!cache.contains_key(&typed_table_cache_key("site_data", "pages")));
389 assert!(!cache.contains_key(&typed_table_cache_key("site_data", "articles")));
390 assert!(cache.contains_key(&typed_table_cache_key("site_blog", "pages")));
391 }
392
393 fn sample_contract(path: &str) -> ContractSchema {
394 ContractSchema {
395 fields: vec![ContractField {
396 path: path.to_string(),
397 required: true,
398 types: vec![ValueType::String],
399 }],
400 }
401 }
402
403 fn sample_typed_table_cache() -> BTreeMap<(String, String), BTreeMap<String, bool>> {
404 BTreeMap::from([
405 (
406 typed_table_cache_key("site_data", "pages"),
407 BTreeMap::from([(String::from("entry_key"), false)]),
408 ),
409 (
410 typed_table_cache_key("site_data", "articles"),
411 BTreeMap::from([(String::from("entry_key"), false)]),
412 ),
413 (
414 typed_table_cache_key("site_blog", "pages"),
415 BTreeMap::from([(String::from("entry_key"), false)]),
416 ),
417 ])
418 }
419}