use serde::{Deserialize, Serialize};
use std::fmt;
use std::path::{Path, PathBuf};
use std::str::FromStr;
use std::time::Duration;
use crate::circuit_breaker::CircuitBreakerConfig;
use crate::error::AppError;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum EmbeddingProviderType {
#[default]
Gemini,
OpenAI,
Ollama,
}
impl fmt::Display for EmbeddingProviderType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Gemini => write!(f, "gemini"),
Self::OpenAI => write!(f, "openai"),
Self::Ollama => write!(f, "ollama"),
}
}
}
impl FromStr for EmbeddingProviderType {
type Err = AppError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"gemini" => Ok(Self::Gemini),
"openai" => Ok(Self::OpenAI),
"ollama" => Ok(Self::Ollama),
_ => Err(AppError::ConfigError(format!(
"Unknown embedding provider: '{}'. Valid options: gemini, openai, ollama",
s
))),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GeminiEmbeddingConfig {
#[serde(default = "default_gemini_model")]
pub model: String,
}
fn default_gemini_model() -> String {
"gemini-embedding-001".to_string()
}
impl Default for GeminiEmbeddingConfig {
fn default() -> Self {
Self {
model: default_gemini_model(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OpenAIEmbeddingConfig {
#[serde(default = "default_openai_model")]
pub model: String,
pub endpoint: Option<String>,
}
fn default_openai_model() -> String {
"text-embedding-3-small".to_string()
}
impl Default for OpenAIEmbeddingConfig {
fn default() -> Self {
Self {
model: default_openai_model(),
endpoint: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OllamaEmbeddingConfig {
#[serde(default = "default_ollama_model")]
pub model: String,
#[serde(default = "default_ollama_endpoint")]
pub endpoint: String,
}
fn default_ollama_model() -> String {
"nomic-embed-text".to_string()
}
fn default_ollama_endpoint() -> String {
"http://localhost:11434".to_string()
}
impl Default for OllamaEmbeddingConfig {
fn default() -> Self {
Self {
model: default_ollama_model(),
endpoint: default_ollama_endpoint(),
}
}
}
pub fn embedding_dimension(provider: EmbeddingProviderType, model: Option<&str>) -> usize {
match provider {
EmbeddingProviderType::Gemini => 768, EmbeddingProviderType::OpenAI => match model.unwrap_or("text-embedding-3-small") {
"text-embedding-3-large" => 3072,
_ => 1536, },
EmbeddingProviderType::Ollama => {
let normalized = model
.map(|m| m.split(':').next().unwrap_or(m))
.unwrap_or("nomic-embed-text");
match normalized {
"mxbai-embed-large" | "snowflake-arctic-embed" => 1024,
"all-minilm" => 384,
_ => 768, }
}
}
}
pub struct DbConfig {
pub max_connections: u32,
}
impl Default for DbConfig {
fn default() -> Self {
Self { max_connections: 5 }
}
}
pub struct HttpConfig {
pub timeout: Duration,
pub max_retries: u32,
pub retry_base_delay: Duration,
}
impl Default for HttpConfig {
fn default() -> Self {
Self {
timeout: Duration::from_secs(30),
max_retries: 3,
retry_base_delay: Duration::from_millis(500),
}
}
}
#[derive(Clone)]
pub struct HarvestConfig {
pub concurrency: usize,
pub upsert_batch_size: usize,
pub force_full_sync: bool,
pub dry_run: bool,
}
impl Default for HarvestConfig {
fn default() -> Self {
Self {
concurrency: 10,
upsert_batch_size: 500,
force_full_sync: false,
dry_run: false,
}
}
}
impl HarvestConfig {
pub fn with_full_sync(mut self) -> Self {
self.force_full_sync = true;
self
}
pub fn with_dry_run(mut self) -> Self {
self.dry_run = true;
self
}
}
#[derive(Clone)]
pub struct EmbeddingServiceConfig {
pub batch_size: usize,
pub circuit_breaker: CircuitBreakerConfig,
}
impl Default for EmbeddingServiceConfig {
fn default() -> Self {
Self {
batch_size: 64,
circuit_breaker: CircuitBreakerConfig::default(),
}
}
}
#[derive(Clone)]
pub struct SyncConfig {
pub concurrency: usize,
pub embedding_batch_size: usize,
pub upsert_batch_size: usize,
pub force_full_sync: bool,
pub dry_run: bool,
pub circuit_breaker: CircuitBreakerConfig,
}
impl Default for SyncConfig {
fn default() -> Self {
Self {
concurrency: 10,
embedding_batch_size: 64,
upsert_batch_size: 500,
force_full_sync: false,
dry_run: false,
circuit_breaker: CircuitBreakerConfig::default(),
}
}
}
impl SyncConfig {
pub fn with_full_sync(mut self) -> Self {
self.force_full_sync = true;
self
}
pub fn with_dry_run(mut self) -> Self {
self.dry_run = true;
self
}
pub fn with_embedding_batch_size(mut self, size: usize) -> Self {
self.embedding_batch_size = size.max(1);
self
}
pub fn with_circuit_breaker(mut self, config: CircuitBreakerConfig) -> Self {
self.circuit_breaker = config;
self
}
pub fn harvest_config(&self) -> HarvestConfig {
HarvestConfig {
concurrency: self.concurrency,
upsert_batch_size: self.upsert_batch_size,
force_full_sync: self.force_full_sync,
dry_run: self.dry_run,
}
}
pub fn embedding_service_config(&self) -> EmbeddingServiceConfig {
EmbeddingServiceConfig {
batch_size: self.embedding_batch_size,
circuit_breaker: self.circuit_breaker.clone(),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum PortalType {
#[default]
Ckan,
Socrata,
Dcat,
}
impl fmt::Display for PortalType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Ckan => write!(f, "ckan"),
Self::Socrata => write!(f, "socrata"),
Self::Dcat => write!(f, "dcat"),
}
}
}
impl FromStr for PortalType {
type Err = AppError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"ckan" => Ok(Self::Ckan),
"socrata" => Ok(Self::Socrata),
"dcat" => Ok(Self::Dcat),
_ => Err(AppError::ConfigError(format!(
"Unknown portal type: '{}'. Valid options: ckan, socrata, dcat",
s
))),
}
}
}
fn default_enabled() -> bool {
true
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PortalsConfig {
pub portals: Vec<PortalEntry>,
}
impl PortalsConfig {
pub fn enabled_portals(&self) -> Vec<&PortalEntry> {
self.portals.iter().filter(|p| p.enabled).collect()
}
pub fn find_by_name(&self, name: &str) -> Option<&PortalEntry> {
self.portals
.iter()
.find(|p| p.name.eq_ignore_ascii_case(name))
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PortalEntry {
pub name: String,
pub url: String,
#[serde(rename = "type", default)]
pub portal_type: PortalType,
#[serde(default = "default_enabled")]
pub enabled: bool,
pub description: Option<String>,
pub url_template: Option<String>,
pub language: Option<String>,
#[serde(default)]
pub profile: Option<String>,
#[serde(default)]
pub sparql_endpoint: Option<String>,
}
impl PortalEntry {
pub fn language(&self) -> &str {
self.language.as_deref().unwrap_or("en")
}
pub fn profile(&self) -> Option<&str> {
self.profile.as_deref()
}
pub fn sparql_endpoint(&self) -> Option<&str> {
self.sparql_endpoint.as_deref()
}
}
pub const CONFIG_FILE_NAME: &str = "portals.toml";
pub fn default_config_dir() -> Option<PathBuf> {
dirs::config_dir().map(|p| p.join("ceres"))
}
pub fn default_config_path() -> Option<PathBuf> {
default_config_dir().map(|p| p.join(CONFIG_FILE_NAME))
}
const DEFAULT_CONFIG_TEMPLATE: &str = r#"# Ceres Portal Configuration
#
# Usage:
# ceres harvest # Harvest all enabled portals
# ceres harvest --portal milano # Harvest specific portal by name
# ceres harvest https://... # Harvest single URL (ignores this file)
#
# Set enabled = false to skip a portal during batch harvest.
# Use url_template for portals with non-standard frontends:
# url_template = "https://example.com/dataset?id={id}"
# Placeholders: {id} = dataset UUID, {name} = dataset slug
# City of Milan open data
[[portals]]
name = "milano"
url = "https://dati.comune.milano.it"
type = "ckan"
description = "Open data del Comune di Milano"
# Sicily Region open data
[[portals]]
name = "sicilia"
url = "https://dati.regione.sicilia.it"
type = "ckan"
description = "Open data della Regione Siciliana"
"#;
pub fn load_portals_config(path: Option<PathBuf>) -> Result<Option<PortalsConfig>, AppError> {
let using_default_path = path.is_none();
let config_path = match path {
Some(p) => p,
None => match default_config_path() {
Some(p) => p,
None => return Ok(None),
},
};
if !config_path.exists() {
if using_default_path {
match create_default_config(&config_path) {
Ok(()) => {
tracing::info!(
"Config file created at {}. Starting harvest with default portals...",
config_path.display()
);
}
Err(e) => {
tracing::warn!("Could not create default config template: {}", e);
return Ok(None);
}
}
} else {
return Err(AppError::ConfigError(format!(
"Config file not found: {}",
config_path.display()
)));
}
}
let content = std::fs::read_to_string(&config_path).map_err(|e| {
AppError::ConfigError(format!(
"Failed to read config file '{}': {}",
config_path.display(),
e
))
})?;
let config: PortalsConfig = toml::from_str(&content).map_err(|e| {
AppError::ConfigError(format!(
"Invalid TOML in '{}': {}",
config_path.display(),
e
))
})?;
Ok(Some(config))
}
fn create_default_config(path: &Path) -> std::io::Result<()> {
if let Some(parent) = path.parent() {
std::fs::create_dir_all(parent)?;
}
std::fs::write(path, DEFAULT_CONFIG_TEMPLATE)?;
tracing::info!("Created default config template at: {}", path.display());
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_db_config_defaults() {
let config = DbConfig::default();
assert_eq!(config.max_connections, 5);
}
#[test]
fn test_http_config_defaults() {
let config = HttpConfig::default();
assert_eq!(config.timeout, Duration::from_secs(30));
assert_eq!(config.max_retries, 3);
assert_eq!(config.retry_base_delay, Duration::from_millis(500));
}
#[test]
fn test_sync_config_defaults() {
let config = SyncConfig::default();
assert_eq!(config.concurrency, 10);
assert_eq!(config.upsert_batch_size, 500);
}
#[test]
fn test_sync_config_harvest_config_upsert_batch_size() {
let config = SyncConfig::default();
let harvest = config.harvest_config();
assert_eq!(harvest.upsert_batch_size, 500);
assert_ne!(harvest.upsert_batch_size, config.embedding_batch_size);
}
#[test]
fn test_portals_config_deserialize() {
let toml = r#"
[[portals]]
name = "test-portal"
url = "https://example.com"
type = "ckan"
"#;
let config: PortalsConfig = toml::from_str(toml).unwrap();
assert_eq!(config.portals.len(), 1);
assert_eq!(config.portals[0].name, "test-portal");
assert_eq!(config.portals[0].url, "https://example.com");
assert_eq!(config.portals[0].portal_type, PortalType::Ckan);
assert!(config.portals[0].enabled); assert!(config.portals[0].description.is_none());
}
#[test]
fn test_portals_config_defaults() {
let toml = r#"
[[portals]]
name = "minimal"
url = "https://example.com"
"#;
let config: PortalsConfig = toml::from_str(toml).unwrap();
assert_eq!(config.portals[0].portal_type, PortalType::Ckan); assert!(config.portals[0].enabled); }
#[test]
fn test_portals_config_enabled_filter() {
let toml = r#"
[[portals]]
name = "enabled-portal"
url = "https://a.com"
[[portals]]
name = "disabled-portal"
url = "https://b.com"
enabled = false
"#;
let config: PortalsConfig = toml::from_str(toml).unwrap();
let enabled = config.enabled_portals();
assert_eq!(enabled.len(), 1);
assert_eq!(enabled[0].name, "enabled-portal");
}
#[test]
fn test_portals_config_find_by_name() {
let toml = r#"
[[portals]]
name = "Milano"
url = "https://dati.comune.milano.it"
"#;
let config: PortalsConfig = toml::from_str(toml).unwrap();
assert!(config.find_by_name("milano").is_some());
assert!(config.find_by_name("MILANO").is_some());
assert!(config.find_by_name("Milano").is_some());
assert!(config.find_by_name("roma").is_none());
}
#[test]
fn test_portals_config_with_description() {
let toml = r#"
[[portals]]
name = "test"
url = "https://example.com"
description = "A test portal"
"#;
let config: PortalsConfig = toml::from_str(toml).unwrap();
assert_eq!(
config.portals[0].description,
Some("A test portal".to_string())
);
}
#[test]
fn test_portals_config_multiple_portals() {
let toml = r#"
[[portals]]
name = "portal-1"
url = "https://a.com"
[[portals]]
name = "portal-2"
url = "https://b.com"
[[portals]]
name = "portal-3"
url = "https://c.com"
enabled = false
"#;
let config: PortalsConfig = toml::from_str(toml).unwrap();
assert_eq!(config.portals.len(), 3);
assert_eq!(config.enabled_portals().len(), 2);
}
#[test]
fn test_default_config_path() {
let path = default_config_path();
if let Some(p) = path {
assert!(p.ends_with("portals.toml"));
}
}
use std::io::Write;
use tempfile::NamedTempFile;
#[test]
fn test_load_portals_config_valid_file() {
let mut file = NamedTempFile::new().unwrap();
writeln!(
file,
r#"
[[portals]]
name = "test"
url = "https://test.com"
"#
)
.unwrap();
let config = load_portals_config(Some(file.path().to_path_buf()))
.unwrap()
.unwrap();
assert_eq!(config.portals.len(), 1);
assert_eq!(config.portals[0].name, "test");
assert_eq!(config.portals[0].url, "https://test.com");
}
#[test]
fn test_load_portals_config_custom_path_not_found() {
let result = load_portals_config(Some("/nonexistent/path/to/config.toml".into()));
assert!(result.is_err());
let err = result.unwrap_err();
assert!(matches!(err, AppError::ConfigError(_)));
}
#[test]
fn test_load_portals_config_invalid_toml() {
let mut file = NamedTempFile::new().unwrap();
writeln!(file, "this is not valid toml {{{{").unwrap();
let result = load_portals_config(Some(file.path().to_path_buf()));
assert!(result.is_err());
let err = result.unwrap_err();
assert!(matches!(err, AppError::ConfigError(_)));
}
#[test]
fn test_load_portals_config_multiple_portals_with_enabled_filter() {
let mut file = NamedTempFile::new().unwrap();
writeln!(
file,
r#"
[[portals]]
name = "enabled-portal"
url = "https://a.com"
[[portals]]
name = "disabled-portal"
url = "https://b.com"
enabled = false
[[portals]]
name = "another-enabled"
url = "https://c.com"
enabled = true
"#
)
.unwrap();
let config = load_portals_config(Some(file.path().to_path_buf()))
.unwrap()
.unwrap();
assert_eq!(config.portals.len(), 3);
assert_eq!(config.enabled_portals().len(), 2);
}
#[test]
fn test_load_portals_config_with_all_fields() {
let mut file = NamedTempFile::new().unwrap();
writeln!(
file,
r#"
[[portals]]
name = "full-config"
url = "https://example.com"
type = "ckan"
enabled = true
description = "A fully configured portal"
"#
)
.unwrap();
let config = load_portals_config(Some(file.path().to_path_buf()))
.unwrap()
.unwrap();
let portal = &config.portals[0];
assert_eq!(portal.name, "full-config");
assert_eq!(portal.url, "https://example.com");
assert_eq!(portal.portal_type, PortalType::Ckan);
assert!(portal.enabled);
assert_eq!(
portal.description,
Some("A fully configured portal".to_string())
);
}
#[test]
fn test_portals_config_dcat_profile() {
let toml = r#"
[[portals]]
name = "eu-sparql"
url = "https://data.europa.eu"
type = "dcat"
profile = "sparql"
language = "en"
"#;
let config: PortalsConfig = toml::from_str(toml).unwrap();
let portal = &config.portals[0];
assert_eq!(portal.portal_type, PortalType::Dcat);
assert_eq!(portal.profile(), Some("sparql"));
assert_eq!(portal.language(), "en");
}
#[test]
fn test_portals_config_profile_defaults_none() {
let toml = r#"
[[portals]]
name = "luxembourg"
url = "https://data.public.lu"
type = "dcat"
"#;
let config: PortalsConfig = toml::from_str(toml).unwrap();
assert_eq!(config.portals[0].profile(), None);
}
#[test]
fn test_load_portals_config_empty_portals_array() {
let mut file = NamedTempFile::new().unwrap();
writeln!(file, "portals = []").unwrap();
let config = load_portals_config(Some(file.path().to_path_buf()))
.unwrap()
.unwrap();
assert!(config.portals.is_empty());
assert!(config.enabled_portals().is_empty());
}
#[test]
fn test_embedding_provider_type_from_str() {
assert_eq!(
"gemini".parse::<EmbeddingProviderType>().unwrap(),
EmbeddingProviderType::Gemini
);
assert_eq!(
"openai".parse::<EmbeddingProviderType>().unwrap(),
EmbeddingProviderType::OpenAI
);
assert_eq!(
"GEMINI".parse::<EmbeddingProviderType>().unwrap(),
EmbeddingProviderType::Gemini
);
assert_eq!(
"OpenAI".parse::<EmbeddingProviderType>().unwrap(),
EmbeddingProviderType::OpenAI
);
assert_eq!(
"ollama".parse::<EmbeddingProviderType>().unwrap(),
EmbeddingProviderType::Ollama
);
assert_eq!(
"OLLAMA".parse::<EmbeddingProviderType>().unwrap(),
EmbeddingProviderType::Ollama
);
}
#[test]
fn test_embedding_provider_type_invalid() {
let result = "invalid".parse::<EmbeddingProviderType>();
assert!(result.is_err());
}
#[test]
fn test_embedding_provider_type_display() {
assert_eq!(EmbeddingProviderType::Gemini.to_string(), "gemini");
assert_eq!(EmbeddingProviderType::OpenAI.to_string(), "openai");
assert_eq!(EmbeddingProviderType::Ollama.to_string(), "ollama");
}
#[test]
fn test_embedding_dimension() {
assert_eq!(
embedding_dimension(EmbeddingProviderType::Gemini, None),
768
);
assert_eq!(
embedding_dimension(EmbeddingProviderType::Gemini, Some("gemini-embedding-001")),
768
);
assert_eq!(
embedding_dimension(EmbeddingProviderType::OpenAI, None),
1536
);
assert_eq!(
embedding_dimension(
EmbeddingProviderType::OpenAI,
Some("text-embedding-3-small")
),
1536
);
assert_eq!(
embedding_dimension(
EmbeddingProviderType::OpenAI,
Some("text-embedding-3-large")
),
3072
);
}
#[test]
fn test_gemini_embedding_config_default() {
let config = GeminiEmbeddingConfig::default();
assert_eq!(config.model, "gemini-embedding-001");
}
#[test]
fn test_openai_embedding_config_default() {
let config = OpenAIEmbeddingConfig::default();
assert_eq!(config.model, "text-embedding-3-small");
assert!(config.endpoint.is_none());
}
#[test]
fn test_ollama_embedding_config_default() {
let config = OllamaEmbeddingConfig::default();
assert_eq!(config.model, "nomic-embed-text");
assert_eq!(config.endpoint, "http://localhost:11434");
}
#[test]
fn test_embedding_dimension_ollama() {
assert_eq!(
embedding_dimension(EmbeddingProviderType::Ollama, None),
768
);
assert_eq!(
embedding_dimension(EmbeddingProviderType::Ollama, Some("nomic-embed-text")),
768
);
assert_eq!(
embedding_dimension(EmbeddingProviderType::Ollama, Some("mxbai-embed-large")),
1024
);
assert_eq!(
embedding_dimension(EmbeddingProviderType::Ollama, Some("all-minilm")),
384
);
assert_eq!(
embedding_dimension(EmbeddingProviderType::Ollama, Some("unknown-model")),
768
);
assert_eq!(
embedding_dimension(
EmbeddingProviderType::Ollama,
Some("snowflake-arctic-embed:335m")
),
1024
);
assert_eq!(
embedding_dimension(
EmbeddingProviderType::Ollama,
Some("nomic-embed-text:latest")
),
768
);
}
}