Skip to main content

pgorm_check/
schema_cache.rs

1use crate::client::CheckClient;
2use crate::error::{CheckError, CheckResult};
3use crate::schema_introspect::{DbSchema, load_schema_from_db, schema_fingerprint};
4use chrono::{DateTime, Utc};
5use serde::{Deserialize, Serialize};
6use std::path::{Path, PathBuf};
7
8#[derive(Debug, Clone)]
9pub struct SchemaCacheConfig {
10    /// Directory to store cache files (default: `./.pgorm`).
11    pub cache_dir: PathBuf,
12    /// Cache file name inside `cache_dir` (default: `schema.json`).
13    pub cache_file_name: String,
14    /// Which PostgreSQL schemas to introspect (default: `["public"]`).
15    pub schemas: Vec<String>,
16}
17
18impl Default for SchemaCacheConfig {
19    fn default() -> Self {
20        let cache_dir = std::env::current_dir()
21            .unwrap_or_else(|_| PathBuf::from("."))
22            .join(".pgorm");
23
24        Self {
25            cache_dir,
26            cache_file_name: "schema.json".to_string(),
27            schemas: vec!["public".to_string()],
28        }
29    }
30}
31
32#[derive(Debug, Clone, Copy, PartialEq, Eq)]
33pub enum SchemaCacheLoad {
34    /// Loaded from local cache (fingerprint unchanged).
35    CacheHit,
36    /// Loaded from database (cache missing/invalid or fingerprint changed).
37    Refreshed,
38}
39
40#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct SchemaCache {
42    pub version: u32,
43    pub retrieved_at: DateTime<Utc>,
44    pub schemas: Vec<String>,
45    pub fingerprint: String,
46    pub schema: DbSchema,
47}
48
49impl SchemaCache {
50    pub fn cache_path(config: &SchemaCacheConfig) -> PathBuf {
51        config.cache_dir.join(&config.cache_file_name)
52    }
53
54    pub async fn load_or_refresh<C: CheckClient>(
55        client: &C,
56        config: &SchemaCacheConfig,
57    ) -> CheckResult<(Self, SchemaCacheLoad)> {
58        let cache_path = Self::cache_path(config);
59
60        if let Ok(cached) = read_cache_file(&cache_path) {
61            if cached.schemas == config.schemas && cached.version == 1 {
62                let current_fp = schema_fingerprint(client, &config.schemas).await?;
63                if current_fp == cached.fingerprint {
64                    return Ok((cached, SchemaCacheLoad::CacheHit));
65                }
66            }
67        }
68
69        let (schema, fingerprint) = load_schema_from_db(client, &config.schemas).await?;
70        let refreshed = SchemaCache {
71            version: 1,
72            retrieved_at: Utc::now(),
73            schemas: config.schemas.clone(),
74            fingerprint,
75            schema,
76        };
77
78        write_cache_file(&cache_path, &refreshed)?;
79        Ok((refreshed, SchemaCacheLoad::Refreshed))
80    }
81}
82
83fn read_cache_file(path: &Path) -> CheckResult<SchemaCache> {
84    let data = match std::fs::read(path) {
85        Ok(d) => d,
86        Err(e) if e.kind() == std::io::ErrorKind::NotFound => {
87            return Err(CheckError::Other(e.to_string()));
88        }
89        Err(e) => return Err(CheckError::Other(e.to_string())),
90    };
91
92    serde_json::from_slice::<SchemaCache>(&data)
93        .map_err(|e| CheckError::Serialization(format!("Failed to parse schema cache: {e}")))
94}
95
96fn write_cache_file(path: &Path, cache: &SchemaCache) -> CheckResult<()> {
97    if let Some(parent) = path.parent() {
98        std::fs::create_dir_all(parent).map_err(|e| CheckError::Other(e.to_string()))?;
99    }
100
101    let tmp_path = path.with_extension("json.tmp");
102    let data = serde_json::to_vec_pretty(cache)
103        .map_err(|e| CheckError::Serialization(format!("Failed to serialize schema cache: {e}")))?;
104
105    std::fs::write(&tmp_path, data).map_err(|e| CheckError::Other(e.to_string()))?;
106    std::fs::rename(&tmp_path, path).map_err(|e| CheckError::Other(e.to_string()))?;
107    Ok(())
108}