Skip to main content

ceres_core/
config.rs

1//! Configuration types for Ceres components.
2//!
3//! # Configuration Improvements
4//!
5//! TODO(config): Make all configuration values environment-configurable
6//! Currently all defaults are hardcoded. Should support:
7//! - `DB_MAX_CONNECTIONS` for database pool size
8//! - `SYNC_CONCURRENCY` for parallel dataset processing
9//! - `HTTP_TIMEOUT` for API request timeout
10//! - `HTTP_MAX_RETRIES` for retry attempts
11//!
12//! Consider using the `config` crate for layered configuration:
13//! defaults -> config file -> environment variables -> CLI args
14
15use serde::{Deserialize, Serialize};
16use std::fmt;
17use std::path::{Path, PathBuf};
18use std::str::FromStr;
19use std::time::Duration;
20
21use crate::circuit_breaker::CircuitBreakerConfig;
22use crate::error::AppError;
23
24// =============================================================================
25// Embedding Provider Configuration
26// =============================================================================
27
28/// Embedding provider type.
29///
30/// Determines which embedding API to use for generating text embeddings.
31#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
32#[serde(rename_all = "lowercase")]
33pub enum EmbeddingProviderType {
34    /// Google Gemini text-embedding-004 (768 dimensions).
35    #[default]
36    Gemini,
37    /// OpenAI text-embedding-3-small (1536d) or text-embedding-3-large (3072d).
38    OpenAI,
39}
40
41impl fmt::Display for EmbeddingProviderType {
42    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
43        match self {
44            Self::Gemini => write!(f, "gemini"),
45            Self::OpenAI => write!(f, "openai"),
46        }
47    }
48}
49
50impl FromStr for EmbeddingProviderType {
51    type Err = AppError;
52
53    fn from_str(s: &str) -> Result<Self, Self::Err> {
54        match s.to_lowercase().as_str() {
55            "gemini" => Ok(Self::Gemini),
56            "openai" => Ok(Self::OpenAI),
57            _ => Err(AppError::ConfigError(format!(
58                "Unknown embedding provider: '{}'. Valid options: gemini, openai",
59                s
60            ))),
61        }
62    }
63}
64
65/// Gemini embedding provider configuration.
66#[derive(Debug, Clone, Serialize, Deserialize)]
67pub struct GeminiEmbeddingConfig {
68    /// Gemini model name.
69    #[serde(default = "default_gemini_model")]
70    pub model: String,
71}
72
73fn default_gemini_model() -> String {
74    "text-embedding-004".to_string()
75}
76
77impl Default for GeminiEmbeddingConfig {
78    fn default() -> Self {
79        Self {
80            model: default_gemini_model(),
81        }
82    }
83}
84
85/// OpenAI embedding provider configuration.
86#[derive(Debug, Clone, Serialize, Deserialize)]
87pub struct OpenAIEmbeddingConfig {
88    /// OpenAI model name.
89    #[serde(default = "default_openai_model")]
90    pub model: String,
91    /// Custom API endpoint (for Azure OpenAI or proxies).
92    pub endpoint: Option<String>,
93}
94
95fn default_openai_model() -> String {
96    "text-embedding-3-small".to_string()
97}
98
99impl Default for OpenAIEmbeddingConfig {
100    fn default() -> Self {
101        Self {
102            model: default_openai_model(),
103            endpoint: None,
104        }
105    }
106}
107
108/// Returns the embedding dimension for a given provider and model.
109///
110/// # Arguments
111///
112/// * `provider` - The embedding provider type
113/// * `model` - The model name (optional, uses default if None)
114pub fn embedding_dimension(provider: EmbeddingProviderType, model: Option<&str>) -> usize {
115    match provider {
116        EmbeddingProviderType::Gemini => 768, // text-embedding-004 is always 768
117        EmbeddingProviderType::OpenAI => match model.unwrap_or("text-embedding-3-small") {
118            "text-embedding-3-large" => 3072,
119            _ => 1536, // text-embedding-3-small and ada-002
120        },
121    }
122}
123
124/// Database connection pool configuration.
125///
126/// TODO(config): Support environment variable `DB_MAX_CONNECTIONS`
127/// Default of 5 may be insufficient for high-concurrency scenarios.
128pub struct DbConfig {
129    pub max_connections: u32,
130}
131
132impl Default for DbConfig {
133    fn default() -> Self {
134        // TODO(config): Read from DB_MAX_CONNECTIONS env var
135        Self { max_connections: 5 }
136    }
137}
138
139/// HTTP client configuration for external API calls.
140pub struct HttpConfig {
141    pub timeout: Duration,
142    pub max_retries: u32,
143    pub retry_base_delay: Duration,
144}
145
146impl Default for HttpConfig {
147    fn default() -> Self {
148        Self {
149            timeout: Duration::from_secs(30),
150            max_retries: 3,
151            retry_base_delay: Duration::from_millis(500),
152        }
153    }
154}
155
156/// Portal synchronization configuration.
157///
158/// TODO(config): Support CLI arg `--concurrency` and env var `SYNC_CONCURRENCY`
159/// Optimal value depends on portal rate limits and system resources.
160/// Consider auto-tuning based on API response times.
161#[derive(Clone)]
162pub struct SyncConfig {
163    /// Number of concurrent dataset processing tasks.
164    pub concurrency: usize,
165    /// Force full sync even if incremental sync is available.
166    pub force_full_sync: bool,
167    /// Circuit breaker configuration for API resilience.
168    pub circuit_breaker: CircuitBreakerConfig,
169}
170
171impl Default for SyncConfig {
172    fn default() -> Self {
173        // TODO(config): Read from SYNC_CONCURRENCY env var
174        Self {
175            concurrency: 10,
176            force_full_sync: false,
177            circuit_breaker: CircuitBreakerConfig::default(),
178        }
179    }
180}
181
182impl SyncConfig {
183    /// Creates a new SyncConfig with force_full_sync enabled.
184    pub fn with_full_sync(mut self) -> Self {
185        self.force_full_sync = true;
186        self
187    }
188
189    /// Creates a new SyncConfig with custom circuit breaker configuration.
190    pub fn with_circuit_breaker(mut self, config: CircuitBreakerConfig) -> Self {
191        self.circuit_breaker = config;
192        self
193    }
194}
195
196// =============================================================================
197// Portal Configuration (portals.toml)
198// =============================================================================
199
200/// Default portal type when not specified in configuration.
201fn default_portal_type() -> String {
202    "ckan".to_string()
203}
204
205/// Default enabled status when not specified in configuration.
206fn default_enabled() -> bool {
207    true
208}
209
210/// Root configuration structure for portals.toml.
211///
212/// This structure represents the entire configuration file containing
213/// an array of portal definitions.
214///
215/// # Example
216///
217/// ```toml
218/// [[portals]]
219/// name = "dati-gov-it"
220/// url = "https://dati.gov.it"
221/// type = "ckan"
222/// description = "Italian national open data portal"
223///
224/// [[portals]]
225/// name = "milano"
226/// url = "https://dati.comune.milano.it"
227/// enabled = true
228/// ```
229#[derive(Debug, Clone, Serialize, Deserialize)]
230pub struct PortalsConfig {
231    /// Array of portal configurations.
232    pub portals: Vec<PortalEntry>,
233}
234
235impl PortalsConfig {
236    /// Returns only enabled portals.
237    ///
238    /// Portals with `enabled = false` are excluded from batch harvesting.
239    pub fn enabled_portals(&self) -> Vec<&PortalEntry> {
240        self.portals.iter().filter(|p| p.enabled).collect()
241    }
242
243    /// Find a portal by name (case-insensitive).
244    ///
245    /// # Arguments
246    /// * `name` - The portal name to search for.
247    ///
248    /// # Returns
249    /// The matching portal entry, or None if not found.
250    pub fn find_by_name(&self, name: &str) -> Option<&PortalEntry> {
251        self.portals
252            .iter()
253            .find(|p| p.name.eq_ignore_ascii_case(name))
254    }
255}
256
257/// A single portal entry in the configuration file.
258///
259/// Each portal entry defines a CKAN portal to harvest, including
260/// its URL, type, and whether it's enabled for batch harvesting.
261#[derive(Debug, Clone, Serialize, Deserialize)]
262pub struct PortalEntry {
263    /// Human-readable portal name.
264    ///
265    /// Used for `--portal <name>` lookup and logging.
266    pub name: String,
267
268    /// Base URL of the CKAN portal.
269    ///
270    /// Example: "<https://dati.comune.milano.it>"
271    pub url: String,
272
273    /// Portal type: "ckan", "socrata", or "dcat".
274    ///
275    /// Defaults to "ckan" if not specified.
276    #[serde(rename = "type", default = "default_portal_type")]
277    pub portal_type: String,
278
279    /// Whether this portal is enabled for batch harvesting.
280    ///
281    /// Defaults to `true` if not specified.
282    #[serde(default = "default_enabled")]
283    pub enabled: bool,
284
285    /// Optional description of the portal.
286    pub description: Option<String>,
287}
288
289/// Default configuration file name.
290pub const CONFIG_FILE_NAME: &str = "portals.toml";
291
292/// Returns the default configuration directory path.
293///
294/// Uses XDG Base Directory specification: `~/.config/ceres/`
295pub fn default_config_dir() -> Option<PathBuf> {
296    dirs::config_dir().map(|p| p.join("ceres"))
297}
298
299/// Returns the default configuration file path.
300///
301/// Path: `~/.config/ceres/portals.toml`
302pub fn default_config_path() -> Option<PathBuf> {
303    default_config_dir().map(|p| p.join(CONFIG_FILE_NAME))
304}
305
306/// Default template content for a new portals.toml file.
307///
308/// Includes pre-configured Italian open data portals so users can
309/// immediately run `ceres harvest` without manual configuration.
310const DEFAULT_CONFIG_TEMPLATE: &str = r#"# Ceres Portal Configuration
311#
312# Usage:
313#   ceres harvest                 # Harvest all enabled portals
314#   ceres harvest --portal milano # Harvest specific portal by name
315#   ceres harvest https://...     # Harvest single URL (ignores this file)
316#
317# Set enabled = false to skip a portal during batch harvest.
318
319# City of Milan open data
320[[portals]]
321name = "milano"
322url = "https://dati.comune.milano.it"
323type = "ckan"
324description = "Open data del Comune di Milano"
325
326# Sicily Region open data
327[[portals]]
328name = "sicilia"
329url = "https://dati.regione.sicilia.it"
330type = "ckan"
331description = "Open data della Regione Siciliana"
332"#;
333
334/// Load portal configuration from a TOML file.
335///
336/// # Arguments
337/// * `path` - Optional custom path. If `None`, uses default XDG path.
338///
339/// # Returns
340/// * `Ok(Some(config))` - Configuration loaded successfully
341/// * `Ok(None)` - No configuration file found (not an error for backward compatibility)
342/// * `Err(e)` - Configuration file exists but is invalid
343///
344/// # Behavior
345/// If no configuration file exists at the default path, a template file
346/// is automatically created to help users get started.
347pub fn load_portals_config(path: Option<PathBuf>) -> Result<Option<PortalsConfig>, AppError> {
348    let using_default_path = path.is_none();
349    let config_path = match path {
350        Some(p) => p,
351        None => match default_config_path() {
352            Some(p) => p,
353            None => return Ok(None),
354        },
355    };
356
357    if !config_path.exists() {
358        // Auto-create template if using default path
359        if using_default_path {
360            match create_default_config(&config_path) {
361                Ok(()) => {
362                    // Template created successfully - read it and return the config
363                    // This allows the user to immediately harvest without re-running
364                    tracing::info!(
365                        "Config file created at {}. Starting harvest with default portals...",
366                        config_path.display()
367                    );
368                    // Continue to read the newly created file below
369                }
370                Err(e) => {
371                    // Log warning but don't fail - user might not have write permissions
372                    tracing::warn!("Could not create default config template: {}", e);
373                    return Ok(None);
374                }
375            }
376        } else {
377            // Custom path specified but doesn't exist - that's an error
378            return Err(AppError::ConfigError(format!(
379                "Config file not found: {}",
380                config_path.display()
381            )));
382        }
383    }
384
385    let content = std::fs::read_to_string(&config_path).map_err(|e| {
386        AppError::ConfigError(format!(
387            "Failed to read config file '{}': {}",
388            config_path.display(),
389            e
390        ))
391    })?;
392
393    let config: PortalsConfig = toml::from_str(&content).map_err(|e| {
394        AppError::ConfigError(format!(
395            "Invalid TOML in '{}': {}",
396            config_path.display(),
397            e
398        ))
399    })?;
400
401    Ok(Some(config))
402}
403
404/// Create a default configuration file with a template.
405///
406/// Creates the parent directory if it doesn't exist.
407///
408/// # Arguments
409/// * `path` - The path where the config file should be created.
410fn create_default_config(path: &Path) -> std::io::Result<()> {
411    // Create parent directory if needed
412    if let Some(parent) = path.parent() {
413        std::fs::create_dir_all(parent)?;
414    }
415
416    std::fs::write(path, DEFAULT_CONFIG_TEMPLATE)?;
417    tracing::info!("Created default config template at: {}", path.display());
418
419    Ok(())
420}
421
422#[cfg(test)]
423mod tests {
424    use super::*;
425
426    #[test]
427    fn test_db_config_defaults() {
428        let config = DbConfig::default();
429        assert_eq!(config.max_connections, 5);
430    }
431
432    #[test]
433    fn test_http_config_defaults() {
434        let config = HttpConfig::default();
435        assert_eq!(config.timeout, Duration::from_secs(30));
436        assert_eq!(config.max_retries, 3);
437        assert_eq!(config.retry_base_delay, Duration::from_millis(500));
438    }
439
440    #[test]
441    fn test_sync_config_defaults() {
442        let config = SyncConfig::default();
443        assert_eq!(config.concurrency, 10);
444    }
445
446    // =========================================================================
447    // Portal Configuration Tests
448    // =========================================================================
449
450    #[test]
451    fn test_portals_config_deserialize() {
452        let toml = r#"
453[[portals]]
454name = "test-portal"
455url = "https://example.com"
456type = "ckan"
457"#;
458        let config: PortalsConfig = toml::from_str(toml).unwrap();
459        assert_eq!(config.portals.len(), 1);
460        assert_eq!(config.portals[0].name, "test-portal");
461        assert_eq!(config.portals[0].url, "https://example.com");
462        assert_eq!(config.portals[0].portal_type, "ckan");
463        assert!(config.portals[0].enabled); // default
464        assert!(config.portals[0].description.is_none());
465    }
466
467    #[test]
468    fn test_portals_config_defaults() {
469        let toml = r#"
470[[portals]]
471name = "minimal"
472url = "https://example.com"
473"#;
474        let config: PortalsConfig = toml::from_str(toml).unwrap();
475        assert_eq!(config.portals[0].portal_type, "ckan"); // default type
476        assert!(config.portals[0].enabled); // default enabled
477    }
478
479    #[test]
480    fn test_portals_config_enabled_filter() {
481        let toml = r#"
482[[portals]]
483name = "enabled-portal"
484url = "https://a.com"
485
486[[portals]]
487name = "disabled-portal"
488url = "https://b.com"
489enabled = false
490"#;
491        let config: PortalsConfig = toml::from_str(toml).unwrap();
492        let enabled = config.enabled_portals();
493        assert_eq!(enabled.len(), 1);
494        assert_eq!(enabled[0].name, "enabled-portal");
495    }
496
497    #[test]
498    fn test_portals_config_find_by_name() {
499        let toml = r#"
500[[portals]]
501name = "Milano"
502url = "https://dati.comune.milano.it"
503"#;
504        let config: PortalsConfig = toml::from_str(toml).unwrap();
505
506        // Case-insensitive search
507        assert!(config.find_by_name("milano").is_some());
508        assert!(config.find_by_name("MILANO").is_some());
509        assert!(config.find_by_name("Milano").is_some());
510
511        // Not found
512        assert!(config.find_by_name("roma").is_none());
513    }
514
515    #[test]
516    fn test_portals_config_with_description() {
517        let toml = r#"
518[[portals]]
519name = "test"
520url = "https://example.com"
521description = "A test portal"
522"#;
523        let config: PortalsConfig = toml::from_str(toml).unwrap();
524        assert_eq!(
525            config.portals[0].description,
526            Some("A test portal".to_string())
527        );
528    }
529
530    #[test]
531    fn test_portals_config_multiple_portals() {
532        let toml = r#"
533[[portals]]
534name = "portal-1"
535url = "https://a.com"
536
537[[portals]]
538name = "portal-2"
539url = "https://b.com"
540
541[[portals]]
542name = "portal-3"
543url = "https://c.com"
544enabled = false
545"#;
546        let config: PortalsConfig = toml::from_str(toml).unwrap();
547        assert_eq!(config.portals.len(), 3);
548        assert_eq!(config.enabled_portals().len(), 2);
549    }
550
551    #[test]
552    fn test_default_config_path() {
553        // This test just verifies the function doesn't panic
554        // Actual path depends on the platform
555        let path = default_config_path();
556        if let Some(p) = path {
557            assert!(p.ends_with("portals.toml"));
558        }
559    }
560
561    // =========================================================================
562    // load_portals_config() tests with real files
563    // =========================================================================
564
565    use std::io::Write;
566    use tempfile::NamedTempFile;
567
568    #[test]
569    fn test_load_portals_config_valid_file() {
570        let mut file = NamedTempFile::new().unwrap();
571        writeln!(
572            file,
573            r#"
574[[portals]]
575name = "test"
576url = "https://test.com"
577"#
578        )
579        .unwrap();
580
581        let config = load_portals_config(Some(file.path().to_path_buf()))
582            .unwrap()
583            .unwrap();
584
585        assert_eq!(config.portals.len(), 1);
586        assert_eq!(config.portals[0].name, "test");
587        assert_eq!(config.portals[0].url, "https://test.com");
588    }
589
590    #[test]
591    fn test_load_portals_config_custom_path_not_found() {
592        let result = load_portals_config(Some("/nonexistent/path/to/config.toml".into()));
593        assert!(result.is_err());
594        let err = result.unwrap_err();
595        assert!(matches!(err, AppError::ConfigError(_)));
596    }
597
598    #[test]
599    fn test_load_portals_config_invalid_toml() {
600        let mut file = NamedTempFile::new().unwrap();
601        writeln!(file, "this is not valid toml {{{{").unwrap();
602
603        let result = load_portals_config(Some(file.path().to_path_buf()));
604        assert!(result.is_err());
605        let err = result.unwrap_err();
606        assert!(matches!(err, AppError::ConfigError(_)));
607    }
608
609    #[test]
610    fn test_load_portals_config_multiple_portals_with_enabled_filter() {
611        let mut file = NamedTempFile::new().unwrap();
612        writeln!(
613            file,
614            r#"
615[[portals]]
616name = "enabled-portal"
617url = "https://a.com"
618
619[[portals]]
620name = "disabled-portal"
621url = "https://b.com"
622enabled = false
623
624[[portals]]
625name = "another-enabled"
626url = "https://c.com"
627enabled = true
628"#
629        )
630        .unwrap();
631
632        let config = load_portals_config(Some(file.path().to_path_buf()))
633            .unwrap()
634            .unwrap();
635
636        assert_eq!(config.portals.len(), 3);
637        assert_eq!(config.enabled_portals().len(), 2);
638    }
639
640    #[test]
641    fn test_load_portals_config_with_all_fields() {
642        let mut file = NamedTempFile::new().unwrap();
643        writeln!(
644            file,
645            r#"
646[[portals]]
647name = "full-config"
648url = "https://example.com"
649type = "ckan"
650enabled = true
651description = "A fully configured portal"
652"#
653        )
654        .unwrap();
655
656        let config = load_portals_config(Some(file.path().to_path_buf()))
657            .unwrap()
658            .unwrap();
659
660        let portal = &config.portals[0];
661        assert_eq!(portal.name, "full-config");
662        assert_eq!(portal.url, "https://example.com");
663        assert_eq!(portal.portal_type, "ckan");
664        assert!(portal.enabled);
665        assert_eq!(
666            portal.description,
667            Some("A fully configured portal".to_string())
668        );
669    }
670
671    #[test]
672    fn test_load_portals_config_empty_portals_array() {
673        let mut file = NamedTempFile::new().unwrap();
674        writeln!(file, "portals = []").unwrap();
675
676        let config = load_portals_config(Some(file.path().to_path_buf()))
677            .unwrap()
678            .unwrap();
679
680        assert!(config.portals.is_empty());
681        assert!(config.enabled_portals().is_empty());
682    }
683
684    // =========================================================================
685    // Embedding Provider Configuration Tests
686    // =========================================================================
687
688    #[test]
689    fn test_embedding_provider_type_from_str() {
690        assert_eq!(
691            "gemini".parse::<EmbeddingProviderType>().unwrap(),
692            EmbeddingProviderType::Gemini
693        );
694        assert_eq!(
695            "openai".parse::<EmbeddingProviderType>().unwrap(),
696            EmbeddingProviderType::OpenAI
697        );
698        assert_eq!(
699            "GEMINI".parse::<EmbeddingProviderType>().unwrap(),
700            EmbeddingProviderType::Gemini
701        );
702        assert_eq!(
703            "OpenAI".parse::<EmbeddingProviderType>().unwrap(),
704            EmbeddingProviderType::OpenAI
705        );
706    }
707
708    #[test]
709    fn test_embedding_provider_type_invalid() {
710        let result = "invalid".parse::<EmbeddingProviderType>();
711        assert!(result.is_err());
712    }
713
714    #[test]
715    fn test_embedding_provider_type_display() {
716        assert_eq!(EmbeddingProviderType::Gemini.to_string(), "gemini");
717        assert_eq!(EmbeddingProviderType::OpenAI.to_string(), "openai");
718    }
719
720    #[test]
721    fn test_embedding_dimension() {
722        // Gemini is always 768
723        assert_eq!(
724            embedding_dimension(EmbeddingProviderType::Gemini, None),
725            768
726        );
727        assert_eq!(
728            embedding_dimension(EmbeddingProviderType::Gemini, Some("text-embedding-004")),
729            768
730        );
731
732        // OpenAI defaults to 1536
733        assert_eq!(
734            embedding_dimension(EmbeddingProviderType::OpenAI, None),
735            1536
736        );
737        assert_eq!(
738            embedding_dimension(
739                EmbeddingProviderType::OpenAI,
740                Some("text-embedding-3-small")
741            ),
742            1536
743        );
744        assert_eq!(
745            embedding_dimension(
746                EmbeddingProviderType::OpenAI,
747                Some("text-embedding-3-large")
748            ),
749            3072
750        );
751    }
752
753    #[test]
754    fn test_gemini_embedding_config_default() {
755        let config = GeminiEmbeddingConfig::default();
756        assert_eq!(config.model, "text-embedding-004");
757    }
758
759    #[test]
760    fn test_openai_embedding_config_default() {
761        let config = OpenAIEmbeddingConfig::default();
762        assert_eq!(config.model, "text-embedding-3-small");
763        assert!(config.endpoint.is_none());
764    }
765}