Skip to main content

ceres_core/
config.rs

1//! Configuration types for Ceres components.
2//!
3//! # Configuration Improvements
4//!
5//! TODO(config): Make all configuration values environment-configurable
6//! Currently all defaults are hardcoded. Should support:
7//! - `DB_MAX_CONNECTIONS` for database pool size
8//! - `SYNC_CONCURRENCY` for parallel dataset processing
9//! - `HTTP_TIMEOUT` for API request timeout
10//! - `HTTP_MAX_RETRIES` for retry attempts
11//!
12//! Consider using the `config` crate for layered configuration:
13//! defaults -> config file -> environment variables -> CLI args
14
15use serde::{Deserialize, Serialize};
16use std::fmt;
17use std::path::{Path, PathBuf};
18use std::str::FromStr;
19use std::time::Duration;
20
21use crate::circuit_breaker::CircuitBreakerConfig;
22use crate::error::AppError;
23
24// =============================================================================
25// Embedding Provider Configuration
26// =============================================================================
27
28/// Embedding provider type.
29///
30/// Determines which embedding API to use for generating text embeddings.
31#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
32#[serde(rename_all = "lowercase")]
33pub enum EmbeddingProviderType {
34    /// Google Gemini gemini-embedding-001 (768 dimensions).
35    #[default]
36    Gemini,
37    /// OpenAI text-embedding-3-small (1536d) or text-embedding-3-large (3072d).
38    OpenAI,
39}
40
41impl fmt::Display for EmbeddingProviderType {
42    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
43        match self {
44            Self::Gemini => write!(f, "gemini"),
45            Self::OpenAI => write!(f, "openai"),
46        }
47    }
48}
49
50impl FromStr for EmbeddingProviderType {
51    type Err = AppError;
52
53    fn from_str(s: &str) -> Result<Self, Self::Err> {
54        match s.to_lowercase().as_str() {
55            "gemini" => Ok(Self::Gemini),
56            "openai" => Ok(Self::OpenAI),
57            _ => Err(AppError::ConfigError(format!(
58                "Unknown embedding provider: '{}'. Valid options: gemini, openai",
59                s
60            ))),
61        }
62    }
63}
64
65/// Gemini embedding provider configuration.
66#[derive(Debug, Clone, Serialize, Deserialize)]
67pub struct GeminiEmbeddingConfig {
68    /// Gemini model name.
69    #[serde(default = "default_gemini_model")]
70    pub model: String,
71}
72
73fn default_gemini_model() -> String {
74    "gemini-embedding-001".to_string()
75}
76
77impl Default for GeminiEmbeddingConfig {
78    fn default() -> Self {
79        Self {
80            model: default_gemini_model(),
81        }
82    }
83}
84
85/// OpenAI embedding provider configuration.
86#[derive(Debug, Clone, Serialize, Deserialize)]
87pub struct OpenAIEmbeddingConfig {
88    /// OpenAI model name.
89    #[serde(default = "default_openai_model")]
90    pub model: String,
91    /// Custom API endpoint (for Azure OpenAI or proxies).
92    pub endpoint: Option<String>,
93}
94
95fn default_openai_model() -> String {
96    "text-embedding-3-small".to_string()
97}
98
99impl Default for OpenAIEmbeddingConfig {
100    fn default() -> Self {
101        Self {
102            model: default_openai_model(),
103            endpoint: None,
104        }
105    }
106}
107
108/// Returns the embedding dimension for a given provider and model.
109///
110/// # Arguments
111///
112/// * `provider` - The embedding provider type
113/// * `model` - The model name (optional, uses default if None)
114pub fn embedding_dimension(provider: EmbeddingProviderType, model: Option<&str>) -> usize {
115    match provider {
116        EmbeddingProviderType::Gemini => 768, // gemini-embedding-001 with output_dimensionality=768
117        EmbeddingProviderType::OpenAI => match model.unwrap_or("text-embedding-3-small") {
118            "text-embedding-3-large" => 3072,
119            _ => 1536, // text-embedding-3-small and ada-002
120        },
121    }
122}
123
124/// Database connection pool configuration.
125pub struct DbConfig {
126    pub max_connections: u32,
127}
128
129impl Default for DbConfig {
130    fn default() -> Self {
131        Self { max_connections: 5 }
132    }
133}
134
135/// HTTP client configuration for external API calls.
136pub struct HttpConfig {
137    pub timeout: Duration,
138    pub max_retries: u32,
139    pub retry_base_delay: Duration,
140}
141
142impl Default for HttpConfig {
143    fn default() -> Self {
144        Self {
145            timeout: Duration::from_secs(30),
146            max_retries: 3,
147            retry_base_delay: Duration::from_millis(500),
148        }
149    }
150}
151
152/// Portal synchronization configuration.
153///
154/// TODO(config): Support CLI arg `--concurrency` and env var `SYNC_CONCURRENCY`
155/// Optimal value depends on portal rate limits and system resources.
156/// Consider auto-tuning based on API response times.
157#[derive(Clone)]
158pub struct SyncConfig {
159    /// Number of concurrent dataset processing tasks.
160    pub concurrency: usize,
161    /// Maximum number of texts per embedding API batch call.
162    /// The actual batch size is `min(this, provider.max_batch_size())`.
163    pub embedding_batch_size: usize,
164    /// Force full sync even if incremental sync is available.
165    pub force_full_sync: bool,
166    /// Preview mode: fetch and compare datasets without writing to DB or calling embedding API.
167    pub dry_run: bool,
168    /// Circuit breaker configuration for API resilience.
169    pub circuit_breaker: CircuitBreakerConfig,
170}
171
172impl Default for SyncConfig {
173    fn default() -> Self {
174        // TODO(config): Read from SYNC_CONCURRENCY env var
175        Self {
176            concurrency: 10,
177            embedding_batch_size: 64,
178            force_full_sync: false,
179            dry_run: false,
180            circuit_breaker: CircuitBreakerConfig::default(),
181        }
182    }
183}
184
185impl SyncConfig {
186    /// Creates a new SyncConfig with force_full_sync enabled.
187    pub fn with_full_sync(mut self) -> Self {
188        self.force_full_sync = true;
189        self
190    }
191
192    /// Creates a new SyncConfig with dry_run enabled.
193    pub fn with_dry_run(mut self) -> Self {
194        self.dry_run = true;
195        self
196    }
197
198    /// Creates a new SyncConfig with a custom embedding batch size.
199    pub fn with_embedding_batch_size(mut self, size: usize) -> Self {
200        self.embedding_batch_size = size.max(1);
201        self
202    }
203
204    /// Creates a new SyncConfig with custom circuit breaker configuration.
205    pub fn with_circuit_breaker(mut self, config: CircuitBreakerConfig) -> Self {
206        self.circuit_breaker = config;
207        self
208    }
209}
210
211// =============================================================================
212// Portal Configuration (portals.toml)
213// =============================================================================
214
215/// Portal type identifier.
216///
217/// Determines which portal API client to use for harvesting.
218#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
219#[serde(rename_all = "lowercase")]
220pub enum PortalType {
221    /// CKAN open data portal (default).
222    #[default]
223    Ckan,
224    /// Socrata open data portal (US cities: NYC, Chicago, SF).
225    Socrata,
226    /// DCAT-AP / SPARQL endpoint (EU portals, data.europa.eu).
227    Dcat,
228}
229
230impl fmt::Display for PortalType {
231    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
232        match self {
233            Self::Ckan => write!(f, "ckan"),
234            Self::Socrata => write!(f, "socrata"),
235            Self::Dcat => write!(f, "dcat"),
236        }
237    }
238}
239
240impl FromStr for PortalType {
241    type Err = AppError;
242
243    fn from_str(s: &str) -> Result<Self, Self::Err> {
244        match s.to_lowercase().as_str() {
245            "ckan" => Ok(Self::Ckan),
246            "socrata" => Ok(Self::Socrata),
247            "dcat" => Ok(Self::Dcat),
248            _ => Err(AppError::ConfigError(format!(
249                "Unknown portal type: '{}'. Valid options: ckan, socrata, dcat",
250                s
251            ))),
252        }
253    }
254}
255
256/// Default enabled status when not specified in configuration.
257fn default_enabled() -> bool {
258    true
259}
260
261/// Root configuration structure for portals.toml.
262///
263/// This structure represents the entire configuration file containing
264/// an array of portal definitions.
265///
266/// # Example
267///
268/// ```toml
269/// [[portals]]
270/// name = "dati-gov-it"
271/// url = "https://dati.gov.it"
272/// type = "ckan"
273/// description = "Italian national open data portal"
274///
275/// [[portals]]
276/// name = "milano"
277/// url = "https://dati.comune.milano.it"
278/// enabled = true
279/// ```
280#[derive(Debug, Clone, Serialize, Deserialize)]
281pub struct PortalsConfig {
282    /// Array of portal configurations.
283    pub portals: Vec<PortalEntry>,
284}
285
286impl PortalsConfig {
287    /// Returns only enabled portals.
288    ///
289    /// Portals with `enabled = false` are excluded from batch harvesting.
290    pub fn enabled_portals(&self) -> Vec<&PortalEntry> {
291        self.portals.iter().filter(|p| p.enabled).collect()
292    }
293
294    /// Find a portal by name (case-insensitive).
295    ///
296    /// # Arguments
297    /// * `name` - The portal name to search for.
298    ///
299    /// # Returns
300    /// The matching portal entry, or None if not found.
301    pub fn find_by_name(&self, name: &str) -> Option<&PortalEntry> {
302        self.portals
303            .iter()
304            .find(|p| p.name.eq_ignore_ascii_case(name))
305    }
306}
307
308/// A single portal entry in the configuration file.
309///
310/// Each portal entry defines a CKAN portal to harvest, including
311/// its URL, type, and whether it's enabled for batch harvesting.
312#[derive(Debug, Clone, Serialize, Deserialize)]
313pub struct PortalEntry {
314    /// Human-readable portal name.
315    ///
316    /// Used for `--portal <name>` lookup and logging.
317    pub name: String,
318
319    /// Base URL of the CKAN portal.
320    ///
321    /// Example: "<https://dati.comune.milano.it>"
322    pub url: String,
323
324    /// Portal type: ckan, socrata, or dcat.
325    ///
326    /// Defaults to `Ckan` if not specified.
327    #[serde(rename = "type", default)]
328    pub portal_type: PortalType,
329
330    /// Whether this portal is enabled for batch harvesting.
331    ///
332    /// Defaults to `true` if not specified.
333    #[serde(default = "default_enabled")]
334    pub enabled: bool,
335
336    /// Optional description of the portal.
337    pub description: Option<String>,
338
339    /// Optional URL template for dataset landing pages.
340    ///
341    /// Supports placeholders:
342    /// - `{id}` — dataset UUID from the CKAN API
343    /// - `{name}` — dataset slug/name
344    ///
345    /// If not set, defaults to `{portal_url}/dataset/{name}`.
346    pub url_template: Option<String>,
347
348    /// Preferred language for multilingual portals (e.g., `"en"`, `"de"`, `"fr"`).
349    ///
350    /// Some portals return title and description as language-keyed objects.
351    /// This field controls which language is selected when resolving those fields.
352    /// Defaults to `"en"` when not specified.
353    pub language: Option<String>,
354}
355
356impl PortalEntry {
357    /// Returns the preferred language, defaulting to `"en"`.
358    pub fn language(&self) -> &str {
359        self.language.as_deref().unwrap_or("en")
360    }
361}
362
363/// Default configuration file name.
364pub const CONFIG_FILE_NAME: &str = "portals.toml";
365
366/// Returns the default configuration directory path.
367///
368/// Uses XDG Base Directory specification: `~/.config/ceres/`
369pub fn default_config_dir() -> Option<PathBuf> {
370    dirs::config_dir().map(|p| p.join("ceres"))
371}
372
373/// Returns the default configuration file path.
374///
375/// Path: `~/.config/ceres/portals.toml`
376pub fn default_config_path() -> Option<PathBuf> {
377    default_config_dir().map(|p| p.join(CONFIG_FILE_NAME))
378}
379
380/// Default template content for a new portals.toml file.
381///
382/// Includes pre-configured Italian open data portals so users can
383/// immediately run `ceres harvest` without manual configuration.
384const DEFAULT_CONFIG_TEMPLATE: &str = r#"# Ceres Portal Configuration
385#
386# Usage:
387#   ceres harvest                 # Harvest all enabled portals
388#   ceres harvest --portal milano # Harvest specific portal by name
389#   ceres harvest https://...     # Harvest single URL (ignores this file)
390#
391# Set enabled = false to skip a portal during batch harvest.
392# Use url_template for portals with non-standard frontends:
393#   url_template = "https://example.com/dataset?id={id}"
394#   Placeholders: {id} = dataset UUID, {name} = dataset slug
395
396# City of Milan open data
397[[portals]]
398name = "milano"
399url = "https://dati.comune.milano.it"
400type = "ckan"
401description = "Open data del Comune di Milano"
402
403# Sicily Region open data
404[[portals]]
405name = "sicilia"
406url = "https://dati.regione.sicilia.it"
407type = "ckan"
408description = "Open data della Regione Siciliana"
409"#;
410
411/// Load portal configuration from a TOML file.
412///
413/// # Arguments
414/// * `path` - Optional custom path. If `None`, uses default XDG path.
415///
416/// # Returns
417/// * `Ok(Some(config))` - Configuration loaded successfully
418/// * `Ok(None)` - No configuration file found (not an error for backward compatibility)
419/// * `Err(e)` - Configuration file exists but is invalid
420///
421/// # Behavior
422/// If no configuration file exists at the default path, a template file
423/// is automatically created to help users get started.
424pub fn load_portals_config(path: Option<PathBuf>) -> Result<Option<PortalsConfig>, AppError> {
425    let using_default_path = path.is_none();
426    let config_path = match path {
427        Some(p) => p,
428        None => match default_config_path() {
429            Some(p) => p,
430            None => return Ok(None),
431        },
432    };
433
434    if !config_path.exists() {
435        // Auto-create template if using default path
436        if using_default_path {
437            match create_default_config(&config_path) {
438                Ok(()) => {
439                    // Template created successfully - read it and return the config
440                    // This allows the user to immediately harvest without re-running
441                    tracing::info!(
442                        "Config file created at {}. Starting harvest with default portals...",
443                        config_path.display()
444                    );
445                    // Continue to read the newly created file below
446                }
447                Err(e) => {
448                    // Log warning but don't fail - user might not have write permissions
449                    tracing::warn!("Could not create default config template: {}", e);
450                    return Ok(None);
451                }
452            }
453        } else {
454            // Custom path specified but doesn't exist - that's an error
455            return Err(AppError::ConfigError(format!(
456                "Config file not found: {}",
457                config_path.display()
458            )));
459        }
460    }
461
462    let content = std::fs::read_to_string(&config_path).map_err(|e| {
463        AppError::ConfigError(format!(
464            "Failed to read config file '{}': {}",
465            config_path.display(),
466            e
467        ))
468    })?;
469
470    let config: PortalsConfig = toml::from_str(&content).map_err(|e| {
471        AppError::ConfigError(format!(
472            "Invalid TOML in '{}': {}",
473            config_path.display(),
474            e
475        ))
476    })?;
477
478    Ok(Some(config))
479}
480
481/// Create a default configuration file with a template.
482///
483/// Creates the parent directory if it doesn't exist.
484///
485/// # Arguments
486/// * `path` - The path where the config file should be created.
487fn create_default_config(path: &Path) -> std::io::Result<()> {
488    // Create parent directory if needed
489    if let Some(parent) = path.parent() {
490        std::fs::create_dir_all(parent)?;
491    }
492
493    std::fs::write(path, DEFAULT_CONFIG_TEMPLATE)?;
494    tracing::info!("Created default config template at: {}", path.display());
495
496    Ok(())
497}
498
499#[cfg(test)]
500mod tests {
501    use super::*;
502
503    #[test]
504    fn test_db_config_defaults() {
505        let config = DbConfig::default();
506        assert_eq!(config.max_connections, 5);
507    }
508
509    #[test]
510    fn test_http_config_defaults() {
511        let config = HttpConfig::default();
512        assert_eq!(config.timeout, Duration::from_secs(30));
513        assert_eq!(config.max_retries, 3);
514        assert_eq!(config.retry_base_delay, Duration::from_millis(500));
515    }
516
517    #[test]
518    fn test_sync_config_defaults() {
519        let config = SyncConfig::default();
520        assert_eq!(config.concurrency, 10);
521    }
522
523    // =========================================================================
524    // Portal Configuration Tests
525    // =========================================================================
526
527    #[test]
528    fn test_portals_config_deserialize() {
529        let toml = r#"
530[[portals]]
531name = "test-portal"
532url = "https://example.com"
533type = "ckan"
534"#;
535        let config: PortalsConfig = toml::from_str(toml).unwrap();
536        assert_eq!(config.portals.len(), 1);
537        assert_eq!(config.portals[0].name, "test-portal");
538        assert_eq!(config.portals[0].url, "https://example.com");
539        assert_eq!(config.portals[0].portal_type, PortalType::Ckan);
540        assert!(config.portals[0].enabled); // default
541        assert!(config.portals[0].description.is_none());
542    }
543
544    #[test]
545    fn test_portals_config_defaults() {
546        let toml = r#"
547[[portals]]
548name = "minimal"
549url = "https://example.com"
550"#;
551        let config: PortalsConfig = toml::from_str(toml).unwrap();
552        assert_eq!(config.portals[0].portal_type, PortalType::Ckan); // default type
553        assert!(config.portals[0].enabled); // default enabled
554    }
555
556    #[test]
557    fn test_portals_config_enabled_filter() {
558        let toml = r#"
559[[portals]]
560name = "enabled-portal"
561url = "https://a.com"
562
563[[portals]]
564name = "disabled-portal"
565url = "https://b.com"
566enabled = false
567"#;
568        let config: PortalsConfig = toml::from_str(toml).unwrap();
569        let enabled = config.enabled_portals();
570        assert_eq!(enabled.len(), 1);
571        assert_eq!(enabled[0].name, "enabled-portal");
572    }
573
574    #[test]
575    fn test_portals_config_find_by_name() {
576        let toml = r#"
577[[portals]]
578name = "Milano"
579url = "https://dati.comune.milano.it"
580"#;
581        let config: PortalsConfig = toml::from_str(toml).unwrap();
582
583        // Case-insensitive search
584        assert!(config.find_by_name("milano").is_some());
585        assert!(config.find_by_name("MILANO").is_some());
586        assert!(config.find_by_name("Milano").is_some());
587
588        // Not found
589        assert!(config.find_by_name("roma").is_none());
590    }
591
592    #[test]
593    fn test_portals_config_with_description() {
594        let toml = r#"
595[[portals]]
596name = "test"
597url = "https://example.com"
598description = "A test portal"
599"#;
600        let config: PortalsConfig = toml::from_str(toml).unwrap();
601        assert_eq!(
602            config.portals[0].description,
603            Some("A test portal".to_string())
604        );
605    }
606
607    #[test]
608    fn test_portals_config_multiple_portals() {
609        let toml = r#"
610[[portals]]
611name = "portal-1"
612url = "https://a.com"
613
614[[portals]]
615name = "portal-2"
616url = "https://b.com"
617
618[[portals]]
619name = "portal-3"
620url = "https://c.com"
621enabled = false
622"#;
623        let config: PortalsConfig = toml::from_str(toml).unwrap();
624        assert_eq!(config.portals.len(), 3);
625        assert_eq!(config.enabled_portals().len(), 2);
626    }
627
628    #[test]
629    fn test_default_config_path() {
630        // This test just verifies the function doesn't panic
631        // Actual path depends on the platform
632        let path = default_config_path();
633        if let Some(p) = path {
634            assert!(p.ends_with("portals.toml"));
635        }
636    }
637
638    // =========================================================================
639    // load_portals_config() tests with real files
640    // =========================================================================
641
642    use std::io::Write;
643    use tempfile::NamedTempFile;
644
645    #[test]
646    fn test_load_portals_config_valid_file() {
647        let mut file = NamedTempFile::new().unwrap();
648        writeln!(
649            file,
650            r#"
651[[portals]]
652name = "test"
653url = "https://test.com"
654"#
655        )
656        .unwrap();
657
658        let config = load_portals_config(Some(file.path().to_path_buf()))
659            .unwrap()
660            .unwrap();
661
662        assert_eq!(config.portals.len(), 1);
663        assert_eq!(config.portals[0].name, "test");
664        assert_eq!(config.portals[0].url, "https://test.com");
665    }
666
667    #[test]
668    fn test_load_portals_config_custom_path_not_found() {
669        let result = load_portals_config(Some("/nonexistent/path/to/config.toml".into()));
670        assert!(result.is_err());
671        let err = result.unwrap_err();
672        assert!(matches!(err, AppError::ConfigError(_)));
673    }
674
675    #[test]
676    fn test_load_portals_config_invalid_toml() {
677        let mut file = NamedTempFile::new().unwrap();
678        writeln!(file, "this is not valid toml {{{{").unwrap();
679
680        let result = load_portals_config(Some(file.path().to_path_buf()));
681        assert!(result.is_err());
682        let err = result.unwrap_err();
683        assert!(matches!(err, AppError::ConfigError(_)));
684    }
685
686    #[test]
687    fn test_load_portals_config_multiple_portals_with_enabled_filter() {
688        let mut file = NamedTempFile::new().unwrap();
689        writeln!(
690            file,
691            r#"
692[[portals]]
693name = "enabled-portal"
694url = "https://a.com"
695
696[[portals]]
697name = "disabled-portal"
698url = "https://b.com"
699enabled = false
700
701[[portals]]
702name = "another-enabled"
703url = "https://c.com"
704enabled = true
705"#
706        )
707        .unwrap();
708
709        let config = load_portals_config(Some(file.path().to_path_buf()))
710            .unwrap()
711            .unwrap();
712
713        assert_eq!(config.portals.len(), 3);
714        assert_eq!(config.enabled_portals().len(), 2);
715    }
716
717    #[test]
718    fn test_load_portals_config_with_all_fields() {
719        let mut file = NamedTempFile::new().unwrap();
720        writeln!(
721            file,
722            r#"
723[[portals]]
724name = "full-config"
725url = "https://example.com"
726type = "ckan"
727enabled = true
728description = "A fully configured portal"
729"#
730        )
731        .unwrap();
732
733        let config = load_portals_config(Some(file.path().to_path_buf()))
734            .unwrap()
735            .unwrap();
736
737        let portal = &config.portals[0];
738        assert_eq!(portal.name, "full-config");
739        assert_eq!(portal.url, "https://example.com");
740        assert_eq!(portal.portal_type, PortalType::Ckan);
741        assert!(portal.enabled);
742        assert_eq!(
743            portal.description,
744            Some("A fully configured portal".to_string())
745        );
746    }
747
748    #[test]
749    fn test_load_portals_config_empty_portals_array() {
750        let mut file = NamedTempFile::new().unwrap();
751        writeln!(file, "portals = []").unwrap();
752
753        let config = load_portals_config(Some(file.path().to_path_buf()))
754            .unwrap()
755            .unwrap();
756
757        assert!(config.portals.is_empty());
758        assert!(config.enabled_portals().is_empty());
759    }
760
761    // =========================================================================
762    // Embedding Provider Configuration Tests
763    // =========================================================================
764
765    #[test]
766    fn test_embedding_provider_type_from_str() {
767        assert_eq!(
768            "gemini".parse::<EmbeddingProviderType>().unwrap(),
769            EmbeddingProviderType::Gemini
770        );
771        assert_eq!(
772            "openai".parse::<EmbeddingProviderType>().unwrap(),
773            EmbeddingProviderType::OpenAI
774        );
775        assert_eq!(
776            "GEMINI".parse::<EmbeddingProviderType>().unwrap(),
777            EmbeddingProviderType::Gemini
778        );
779        assert_eq!(
780            "OpenAI".parse::<EmbeddingProviderType>().unwrap(),
781            EmbeddingProviderType::OpenAI
782        );
783    }
784
785    #[test]
786    fn test_embedding_provider_type_invalid() {
787        let result = "invalid".parse::<EmbeddingProviderType>();
788        assert!(result.is_err());
789    }
790
791    #[test]
792    fn test_embedding_provider_type_display() {
793        assert_eq!(EmbeddingProviderType::Gemini.to_string(), "gemini");
794        assert_eq!(EmbeddingProviderType::OpenAI.to_string(), "openai");
795    }
796
797    #[test]
798    fn test_embedding_dimension() {
799        // Gemini is always 768
800        assert_eq!(
801            embedding_dimension(EmbeddingProviderType::Gemini, None),
802            768
803        );
804        assert_eq!(
805            embedding_dimension(EmbeddingProviderType::Gemini, Some("gemini-embedding-001")),
806            768
807        );
808
809        // OpenAI defaults to 1536
810        assert_eq!(
811            embedding_dimension(EmbeddingProviderType::OpenAI, None),
812            1536
813        );
814        assert_eq!(
815            embedding_dimension(
816                EmbeddingProviderType::OpenAI,
817                Some("text-embedding-3-small")
818            ),
819            1536
820        );
821        assert_eq!(
822            embedding_dimension(
823                EmbeddingProviderType::OpenAI,
824                Some("text-embedding-3-large")
825            ),
826            3072
827        );
828    }
829
830    #[test]
831    fn test_gemini_embedding_config_default() {
832        let config = GeminiEmbeddingConfig::default();
833        assert_eq!(config.model, "gemini-embedding-001");
834    }
835
836    #[test]
837    fn test_openai_embedding_config_default() {
838        let config = OpenAIEmbeddingConfig::default();
839        assert_eq!(config.model, "text-embedding-3-small");
840        assert!(config.endpoint.is_none());
841    }
842}