lxy 0.1.1

A convenient async http and RPC framework in Rust
Documentation
use std::{collections::HashSet, env, path::PathBuf};

use figment::{
  Figment,
  error::Kind,
  providers::{Env, Format, Toml},
};
use serde::Deserialize;

use crate::error::{Result, TypedError};

use super::{Config, ConfigParseError};

pub struct TomlConfigProvider<T> {
  filename: String,
  search_paths: HashSet<PathBuf>,
  content: T,
}

impl TomlConfigProvider<()> {
  pub fn new(filename: &str) -> Self {
    let mut search_paths = HashSet::new();

    if let Ok(cwd) = env::current_dir() {
      search_paths.insert(cwd);
    }

    if let Some(path) = env::var_os("CARGO_MANIFEST_DIR") {
      search_paths.insert(path.to_string_lossy().into_owned().into());
    }

    TomlConfigProvider {
      filename: filename.to_string(),
      search_paths,
      content: (),
    }
  }

  #[cfg(test)]
  fn from(content: &str) -> TomlConfigProvider<String> {
    TomlConfigProvider {
      filename: "app.toml".to_string(),
      search_paths: HashSet::new(),
      content: content.to_string(),
    }
  }

  pub fn build(self) -> TomlConfigProvider<String> {
    let file = self
      .search_paths
      .iter()
      .map(|p| p.join(&self.filename))
      .find(|p| p.is_file());

    if let Some(file) = file {
      // TODO: Handle error
      return TomlConfigProvider {
        filename: self.filename,
        search_paths: self.search_paths,
        content: std::fs::read_to_string(file).unwrap_or_default(),
      };
    }

    TomlConfigProvider {
      filename: self.filename,
      search_paths: self.search_paths,
      content: String::new(),
    }
  }
}

impl TomlConfigProvider<String> {
  pub fn parse<T: Config + Deserialize<'static>>(&self) -> Result<T> {
    let mut figment = Figment::from(Toml::string(&self.content));

    if let Some(prefix) = T::env_prefix() {
      figment = figment.merge(Env::prefixed(prefix));
    }

    figment.extract_inner::<T>(T::name()).or_else(|e| {
      match &e.kind {
        Kind::MissingField(_) => {
          // Return default value if config section is missing
          Ok(T::default())
        }
        _ => Err(
          ConfigParseError::error(format!("Failed to parse config section '{}'", T::name()))
            .with_source(e),
        ),
      }
    })
  }
}

#[cfg(test)]
mod tests {
  use super::*;

  /// Timeout in seconds.
  #[derive(Deserialize, Debug, PartialEq)]
  struct Timeout(u32);
  impl Default for Timeout {
    fn default() -> Self {
      Timeout(30)
    }
  }

  #[derive(Deserialize, Debug, Default)]
  struct ServerConfig {
    host: String,
    port: u16,
    #[serde(default)]
    timeout: Timeout,
  }

  impl Config for ServerConfig {
    fn name() -> &'static str {
      "server"
    }

    fn env_prefix() -> Option<&'static str> {
      Some("APP_SERVER_")
    }
  }

  #[derive(Deserialize, Debug, Default)]
  struct StorageConfig {
    layer: String,
  }

  impl Config for StorageConfig {
    fn name() -> &'static str {
      "storage"
    }
  }

  #[test]
  fn test_config_provider() {
    let provider = TomlConfigProvider::from(
      r#"
[server]
host = "127.0.0.1"
port = 8080

[storage]
layer = "memory"
"#,
    );

    let config = provider.parse::<ServerConfig>().unwrap();

    assert_eq!(config.host, "127.0.0.1");
    assert_eq!(config.port, 8080);

    let storage_config = provider.parse::<StorageConfig>().unwrap();
    assert_eq!(storage_config.layer, "memory");
  }

  #[test]
  fn test_override_default() {
    let provider = TomlConfigProvider::from(
      r#"
[server]
host = "127.0.0.1"
port = 8080
timeout = 120
"#,
    );

    let config = provider.parse::<ServerConfig>().unwrap();

    assert_eq!(config.timeout, Timeout(120));
  }

  #[test]
  fn test_missing_section() {
    let provider = TomlConfigProvider::from("");

    let config = provider.parse::<ServerConfig>().unwrap();
    let default_config = ServerConfig::default();
    assert_eq!(config.port, default_config.port);
    assert_eq!(config.host, default_config.host);
    assert_eq!(config.timeout, default_config.timeout);
  }

  #[test]
  fn test_parse_type_mismatch() {
    use crate::config::ConfigParseError;
    use crate::error::ErrorCode;

    let provider = TomlConfigProvider::from(
      r#"
[server]
host = "127.0.0.1"
port = "not a number"
"#,
    );

    let result = provider.parse::<ServerConfig>();
    assert!(result.is_err());

    let err = result.unwrap_err();
    assert!(err.is_code(ErrorCode::Internal));
    assert!(err.is(ConfigParseError));
    assert!(err.message().contains("Failed to parse config"));
  }

  #[test]
  fn test_parse_invalid_toml() {
    use crate::config::ConfigParseError;
    use crate::error::ErrorCode;

    let provider = TomlConfigProvider::from(
      r#"
[server
host = "127.0.0.1"
"#,
    );

    let result = provider.parse::<ServerConfig>();
    assert!(result.is_err());

    let err = result.unwrap_err();
    assert!(err.is_code(ErrorCode::Internal));
    assert!(err.is(ConfigParseError));
    assert!(err.message().contains("Failed to parse config"));
  }

  #[test]
  fn test_parse_wrong_type_in_nested_field() {
    use crate::config::ConfigParseError;
    use crate::error::ErrorCode;

    let provider = TomlConfigProvider::from(
      r#"
[server]
host = "127.0.0.1"
port = 8080
timeout = "not a number"
"#,
    );

    let result = provider.parse::<ServerConfig>();
    assert!(result.is_err());

    let err = result.unwrap_err();
    assert!(err.is_code(ErrorCode::Internal));
    assert!(err.is(ConfigParseError));
    assert!(err.message().contains("Failed to parse config"));
  }

  #[test]
  fn test_parse_error_preserves_source() {
    use crate::config::ConfigParseError;

    let provider = TomlConfigProvider::from(
      r#"
[server]
port = "invalid"
"#,
    );

    let result = provider.parse::<ServerConfig>();
    assert!(result.is_err());

    let err = result.unwrap_err();
    assert!(err.is(ConfigParseError));

    // The error should have a source (the figment error)
    assert!(err.source().is_some());
    let source = err.source().unwrap();
    // The source error message should contain details about what went wrong
    assert!(!source.to_string().is_empty());
  }
}