pub mod cli;
pub mod defaults;
pub mod file;
use std::env;
use std::error::Error;
use std::fmt;
use std::fs;
use std::path::{Path, PathBuf};
use serde::{Deserialize, Serialize};
#[derive(Debug)]
pub enum ConfigError {
IoError(std::io::Error),
ParseError(String),
ValidationError(String),
}
impl fmt::Display for ConfigError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
ConfigError::IoError(e) => write!(f, "IO error: {}", e),
ConfigError::ParseError(e) => write!(f, "Parse error: {}", e),
ConfigError::ValidationError(e) => write!(f, "Validation error: {}", e),
}
}
}
impl Error for ConfigError {}
impl From<std::io::Error> for ConfigError {
fn from(error: std::io::Error) -> Self {
ConfigError::IoError(error)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Config {
pub message_only: bool,
pub no_diff_stats: bool,
pub show_raw_diff: bool,
pub context_lines: u32,
pub max_lines_per_file: usize,
pub max_line_width: usize,
pub max_file_lines: usize,
pub provider: String,
pub model: Option<String>,
pub temperature: Option<f32>,
pub include_recent_commits: bool,
pub recent_commits_count: usize,
pub template: Option<String>,
pub hint: Option<String>,
}
impl Default for Config {
fn default() -> Self {
Self {
message_only: defaults::MESSAGE_ONLY,
no_diff_stats: defaults::NO_DIFF_STATS,
show_raw_diff: defaults::SHOW_RAW_DIFF,
context_lines: defaults::CONTEXT_LINES,
max_lines_per_file: defaults::MAX_LINES_PER_FILE,
max_line_width: defaults::MAX_LINE_WIDTH,
max_file_lines: defaults::MAX_FILE_LINES,
provider: defaults::DEFAULT_PROVIDER.to_string(),
model: None,
temperature: None,
include_recent_commits: defaults::INCLUDE_RECENT_COMMITS,
recent_commits_count: defaults::RECENT_COMMITS_COUNT,
template: None,
hint: None,
}
}
}
impl Config {
pub fn new() -> Self {
Self::default()
}
pub fn from_file(path: &Path) -> Result<Self, ConfigError> {
let content = fs::read_to_string(path)?;
if let Some(ext) = path.extension() {
if ext == "toml" {
toml::from_str(&content).map_err(|e| ConfigError::ParseError(e.to_string()))
} else if ext == "json" {
serde_json::from_str(&content).map_err(|e| ConfigError::ParseError(e.to_string()))
} else {
Err(ConfigError::ParseError(format!(
"Unsupported file format: {:?}",
ext
)))
}
} else {
Err(ConfigError::ParseError("Unknown file format".to_string()))
}
}
pub fn save_to_file(&self, path: &Path) -> Result<(), ConfigError> {
let content = if let Some(ext) = path.extension() {
if ext == "toml" {
toml::to_string_pretty(self).map_err(|e| ConfigError::ParseError(e.to_string()))?
} else if ext == "json" {
serde_json::to_string_pretty(self)
.map_err(|e| ConfigError::ParseError(e.to_string()))?
} else {
return Err(ConfigError::ParseError(format!(
"Unsupported file format: {:?}",
ext
)));
}
} else {
return Err(ConfigError::ParseError("Unknown file format".to_string()));
};
fs::write(path, content)?;
Ok(())
}
pub fn merge(&mut self, other: &Config) {
if other.message_only != defaults::MESSAGE_ONLY {
self.message_only = other.message_only;
}
if other.no_diff_stats != defaults::NO_DIFF_STATS {
self.no_diff_stats = other.no_diff_stats;
}
if other.show_raw_diff != defaults::SHOW_RAW_DIFF {
self.show_raw_diff = other.show_raw_diff;
}
if other.context_lines != defaults::CONTEXT_LINES {
self.context_lines = other.context_lines;
}
if other.max_lines_per_file != defaults::MAX_LINES_PER_FILE {
self.max_lines_per_file = other.max_lines_per_file;
}
if other.max_line_width != defaults::MAX_LINE_WIDTH {
self.max_line_width = other.max_line_width;
}
if other.max_file_lines != defaults::MAX_FILE_LINES {
self.max_file_lines = other.max_file_lines;
}
if other.provider != defaults::DEFAULT_PROVIDER {
self.provider = other.provider.clone();
}
if other.model.is_some() {
self.model = other.model.clone();
}
if other.temperature.is_some() {
self.temperature = other.temperature;
}
if other.include_recent_commits != defaults::INCLUDE_RECENT_COMMITS {
self.include_recent_commits = other.include_recent_commits;
}
if other.recent_commits_count != defaults::RECENT_COMMITS_COUNT {
self.recent_commits_count = other.recent_commits_count;
}
if other.template.is_some() {
self.template = other.template.clone();
}
if other.hint.is_some() {
self.hint = other.hint.clone();
}
}
pub fn from_args(args: &cli::Args) -> Self {
Self {
message_only: args.message_only,
no_diff_stats: args.no_diff_stats,
show_raw_diff: args.show_raw_diff,
context_lines: args.context_lines,
max_lines_per_file: args.max_lines_per_file,
max_line_width: args.max_line_width,
max_file_lines: args.max_file_lines,
provider: args.provider.clone(),
model: args.model.clone(),
temperature: args.temperature,
include_recent_commits: defaults::INCLUDE_RECENT_COMMITS,
recent_commits_count: defaults::RECENT_COMMITS_COUNT,
template: None,
hint: args.hint.clone(),
}
}
pub fn load() -> Result<Self, ConfigError> {
let mut config = Self::default();
if let Some(global_config_path) = Self::global_config_path() {
if global_config_path.exists() {
if let Ok(global_config) = Self::from_file(&global_config_path) {
config.merge(&global_config);
}
}
}
if let Some(project_config_path) = Self::find_project_config() {
if let Ok(project_config) = Self::from_file(&project_config_path) {
config.merge(&project_config);
}
}
Ok(config)
}
fn global_config_path() -> Option<PathBuf> {
if let Ok(home) = env::var("HOME") {
Some(
PathBuf::from(home)
.join(".config")
.join("cmt")
.join("config.toml"),
)
} else {
None
}
}
fn find_project_config() -> Option<PathBuf> {
let current_dir = env::current_dir().ok()?;
let mut dir = current_dir.as_path();
loop {
let config_path = dir.join(".cmt.toml");
if config_path.exists() {
return Some(config_path);
}
if let Some(parent) = dir.parent() {
dir = parent;
} else {
break;
}
}
None
}
}