Skip to main content

rivet_cli/config/
mod.rs

1mod models;
2pub mod resolve;
3
4pub use models::*;
5#[allow(unused_imports)]
6pub(crate) use resolve::resolve_env_vars;
7pub use resolve::{parse_file_size, resolve_vars};
8
9use serde::Deserialize;
10
11#[derive(Debug, Deserialize, Clone)]
12pub struct Config {
13    pub source: SourceConfig,
14    pub exports: Vec<ExportConfig>,
15    #[serde(default)]
16    pub notifications: Option<NotificationsConfig>,
17    #[serde(default)]
18    pub parallel_exports: bool,
19    #[serde(default)]
20    pub parallel_export_processes: bool,
21}
22
23impl Config {
24    pub fn load(path: &str) -> crate::error::Result<Self> {
25        Self::load_with_params(path, None)
26    }
27
28    pub fn load_with_params(
29        path: &str,
30        params: Option<&std::collections::HashMap<String, String>>,
31    ) -> crate::error::Result<Self> {
32        let contents = std::fs::read_to_string(path)?;
33        let resolved = resolve_vars(&contents, params);
34        Self::from_yaml(&resolved)
35    }
36
37    pub fn from_yaml(yaml: &str) -> crate::error::Result<Self> {
38        Self::check_misplaced_tuning_fields(yaml)?;
39        let config: Config = serde_yaml_ng::from_str(yaml)?;
40        config.validate()?;
41        Ok(config)
42    }
43
44    /// Detect tuning-related fields placed directly under `source:` or an
45    /// `exports[]` entry instead of inside the `tuning:` sub-key. Without this
46    /// check serde silently ignores unknown keys and the user gets unexpected
47    /// defaults (e.g. batch_size=10 000 instead of the intended 1 000).
48    fn check_misplaced_tuning_fields(yaml: &str) -> crate::error::Result<()> {
49        const TUNING_FIELDS: &[&str] = &[
50            "batch_size",
51            "batch_size_memory_mb",
52            "throttle_ms",
53            "statement_timeout_s",
54            "max_retries",
55            "retry_backoff_ms",
56            "lock_timeout_s",
57            "memory_threshold_mb",
58            "profile",
59        ];
60
61        let root: serde_yaml_ng::Value = serde_yaml_ng::from_str(yaml)?;
62
63        if let Some(source) = root.get("source") {
64            let misplaced: Vec<&str> = TUNING_FIELDS
65                .iter()
66                .copied()
67                .filter(|&f| source.get(f).is_some())
68                .collect();
69            if !misplaced.is_empty() {
70                anyhow::bail!(
71                    "source: field(s) [{}] belong under 'source.tuning:', not directly under 'source:'. \
72                     Example:\n  source:\n    tuning:\n      {}: <value>",
73                    misplaced.join(", "),
74                    misplaced[0],
75                );
76            }
77        }
78
79        if let Some(exports) = root.get("exports").and_then(|e| e.as_sequence()) {
80            for (i, export) in exports.iter().enumerate() {
81                let name = export
82                    .get("name")
83                    .and_then(|n| n.as_str())
84                    .unwrap_or("<unnamed>");
85                let misplaced: Vec<&str> = TUNING_FIELDS
86                    .iter()
87                    .copied()
88                    .filter(|&f| export.get(f).is_some())
89                    .collect();
90                if !misplaced.is_empty() {
91                    anyhow::bail!(
92                        "export '{}' (index {}): field(s) [{}] belong under 'exports[].tuning:', \
93                         not directly in the export. Example:\n  exports:\n    - name: {}\n      tuning:\n        {}: <value>",
94                        name,
95                        i,
96                        misplaced.join(", "),
97                        name,
98                        misplaced[0],
99                    );
100                }
101            }
102        }
103
104        Ok(())
105    }
106
107    fn validate(&self) -> crate::error::Result<()> {
108        if let Some(t) = &self.source.tuning
109            && t.batch_size.is_some()
110            && t.batch_size_memory_mb.is_some()
111        {
112            anyhow::bail!("tuning: batch_size and batch_size_memory_mb are mutually exclusive");
113        }
114
115        for export in &self.exports {
116            let merged = crate::tuning::merge_tuning_config(
117                self.source.tuning.as_ref(),
118                export.tuning.as_ref(),
119            );
120            if let Some(t) = merged
121                && t.batch_size.is_some()
122                && t.batch_size_memory_mb.is_some()
123            {
124                anyhow::bail!(
125                    "export '{}': effective tuning has both batch_size and batch_size_memory_mb (mutually exclusive)",
126                    export.name
127                );
128            }
129            if let Some(et) = &export.tuning
130                && et.batch_size.is_some()
131                && et.batch_size_memory_mb.is_some()
132            {
133                anyhow::bail!(
134                    "export '{}': tuning.batch_size and tuning.batch_size_memory_mb are mutually exclusive",
135                    export.name
136                );
137            }
138        }
139
140        if !self.source.has_url_fields() && !self.source.has_structured_fields() {
141            anyhow::bail!(
142                "source: must specify url, url_env, url_file, or structured fields (host/user/database)"
143            );
144        }
145
146        if self.source.has_url_fields() {
147            let url_count = [
148                &self.source.url,
149                &self.source.url_env,
150                &self.source.url_file,
151            ]
152            .iter()
153            .filter(|u| u.is_some())
154            .count();
155            if url_count > 1 {
156                anyhow::bail!("source: specify exactly one of 'url', 'url_env', or 'url_file'");
157            }
158        }
159
160        if self.source.has_url_fields() && self.source.has_structured_fields() {
161            anyhow::bail!(
162                "source: use either URL-based config (url/url_env/url_file) or structured fields (host/user/database/...), not both"
163            );
164        }
165
166        if self.source.has_structured_fields() {
167            if self.source.host.is_none() {
168                anyhow::bail!("source: structured config requires 'host'");
169            }
170            if self.source.user.is_none() {
171                anyhow::bail!("source: structured config requires 'user'");
172            }
173            if self.source.database.is_none() {
174                anyhow::bail!("source: structured config requires 'database'");
175            }
176            if self.source.password.is_some() && self.source.password_env.is_some() {
177                anyhow::bail!("source: specify 'password' or 'password_env', not both");
178            }
179        }
180
181        for export in &self.exports {
182            if export.query.is_none() && export.query_file.is_none() {
183                anyhow::bail!(
184                    "export '{}': must specify 'query' or 'query_file'",
185                    export.name
186                );
187            }
188            if export.query.is_some() && export.query_file.is_some() {
189                anyhow::bail!(
190                    "export '{}': specify either 'query' or 'query_file', not both",
191                    export.name
192                );
193            }
194            if export.destination.destination_type == DestinationType::S3 {
195                let ak = export.destination.access_key_env.is_some();
196                let sk = export.destination.secret_key_env.is_some();
197                if ak != sk {
198                    anyhow::bail!(
199                        "export '{}': S3 requires both access_key_env and secret_key_env, or neither (use default AWS credential chain)",
200                        export.name
201                    );
202                }
203            }
204
205            if export.destination.destination_type == DestinationType::Gcs
206                && export.destination.allow_anonymous
207                && export.destination.credentials_file.is_some()
208            {
209                anyhow::bail!(
210                    "export '{}': GCS allow_anonymous cannot be used together with credentials_file",
211                    export.name
212                );
213            }
214
215            if let Some(cred_path) = &export.destination.credentials_file
216                && !std::path::Path::new(cred_path).exists()
217            {
218                anyhow::bail!(
219                    "export '{}': credentials_file '{}' does not exist",
220                    export.name,
221                    cred_path
222                );
223            }
224
225            if let Some(ref size_str) = export.max_file_size {
226                parse_file_size(size_str).map_err(|_| {
227                    anyhow::anyhow!(
228                        "export '{}': invalid max_file_size '{}'",
229                        export.name,
230                        size_str
231                    )
232                })?;
233            }
234
235            if let Some(level) = export.compression_level {
236                match export.compression {
237                    CompressionType::Zstd => {
238                        if !(1..=22).contains(&level) {
239                            anyhow::bail!(
240                                "export '{}': zstd compression_level must be 1..22, got {}",
241                                export.name,
242                                level
243                            );
244                        }
245                    }
246                    CompressionType::Gzip => {
247                        if level > 10 {
248                            anyhow::bail!(
249                                "export '{}': gzip compression_level must be 0..10, got {}",
250                                export.name,
251                                level
252                            );
253                        }
254                    }
255                    _ => {
256                        anyhow::bail!(
257                            "export '{}': compression_level is only supported for zstd and gzip",
258                            export.name
259                        );
260                    }
261                }
262            }
263
264            match export.mode {
265                ExportMode::Incremental => {
266                    if export.cursor_column.is_none() {
267                        anyhow::bail!(
268                            "export '{}': incremental mode requires cursor_column",
269                            export.name
270                        );
271                    }
272                }
273                ExportMode::Chunked => {
274                    if export.chunk_column.is_none() {
275                        anyhow::bail!(
276                            "export '{}': chunked mode requires chunk_column",
277                            export.name
278                        );
279                    }
280                }
281                ExportMode::TimeWindow => {
282                    if export.time_column.is_none() {
283                        anyhow::bail!(
284                            "export '{}': time_window mode requires time_column",
285                            export.name
286                        );
287                    }
288                    if export.days_window.is_none() {
289                        anyhow::bail!(
290                            "export '{}': time_window mode requires days_window",
291                            export.name
292                        );
293                    }
294                }
295                ExportMode::Full => {}
296            }
297
298            if export.chunk_dense && export.mode != ExportMode::Chunked {
299                anyhow::bail!(
300                    "export '{}': chunk_dense is only valid with mode: chunked",
301                    export.name
302                );
303            }
304        }
305        Ok(())
306    }
307}
308
309#[cfg(test)]
310#[path = "tests.rs"]
311mod tests;