use crate::template::OutputFormat;
use crate::tokenizer::TokenizerType;
use crate::{sort::FileSortMethod, tokenizer::TokenFormat};
use derive_builder::Builder;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::PathBuf;
#[derive(Debug, Clone, Default, Builder)]
#[builder(setter(into), default)]
pub struct Code2PromptConfig {
pub path: PathBuf,
pub include_patterns: Vec<String>,
pub exclude_patterns: Vec<String>,
pub line_numbers: bool,
pub absolute_path: bool,
pub full_directory_tree: bool,
pub no_codeblock: bool,
pub follow_symlinks: bool,
pub hidden: bool,
pub no_ignore: bool,
pub sort_method: Option<FileSortMethod>,
pub output_format: OutputFormat,
pub custom_template: Option<String>,
pub encoding: TokenizerType,
pub token_format: TokenFormat,
pub diff_enabled: bool,
pub diff_branches: Option<(String, String)>,
pub log_branches: Option<(String, String)>,
pub template_name: String,
pub template_str: String,
pub user_variables: HashMap<String, String>,
pub token_map_enabled: bool,
}
impl Code2PromptConfig {
pub fn builder() -> Code2PromptConfigBuilder {
Code2PromptConfigBuilder::default()
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
#[serde(rename_all = "lowercase")]
pub enum OutputDestination {
#[default]
Stdout,
Clipboard,
File,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
#[serde(default)]
pub struct TomlConfig {
pub default_output: OutputDestination,
pub path: Option<String>,
pub include_patterns: Vec<String>,
pub exclude_patterns: Vec<String>,
pub line_numbers: bool,
pub absolute_path: bool,
pub full_directory_tree: bool,
pub output_format: Option<OutputFormat>,
pub sort_method: Option<FileSortMethod>,
pub encoding: Option<TokenizerType>,
pub token_format: Option<TokenFormat>,
pub diff_enabled: bool,
pub diff_branches: Option<Vec<String>>,
pub log_branches: Option<Vec<String>>,
pub template_name: Option<String>,
pub template_str: Option<String>,
pub user_variables: HashMap<String, String>,
pub token_map_enabled: bool,
}
impl TomlConfig {
pub fn from_toml_str(content: &str) -> Result<Self, toml::de::Error> {
toml::from_str(content)
}
pub fn to_string(&self) -> Result<String, toml::ser::Error> {
toml::to_string_pretty(self)
}
pub fn to_code2prompt_config(&self) -> Code2PromptConfig {
let mut builder = Code2PromptConfig::builder();
if let Some(path) = &self.path {
builder.path(PathBuf::from(path));
}
builder
.include_patterns(self.include_patterns.clone())
.exclude_patterns(self.exclude_patterns.clone())
.line_numbers(self.line_numbers)
.absolute_path(self.absolute_path)
.full_directory_tree(self.full_directory_tree);
builder.output_format(self.output_format.unwrap_or_default());
builder.sort_method(self.sort_method);
builder.encoding(self.encoding.unwrap_or_default());
builder.token_format(self.token_format.unwrap_or_default());
builder.diff_enabled(self.diff_enabled);
if let Some(diff_branches) = &self.diff_branches
&& diff_branches.len() == 2
{
builder.diff_branches(Some((diff_branches[0].clone(), diff_branches[1].clone())));
}
if let Some(log_branches) = &self.log_branches
&& log_branches.len() == 2
{
builder.log_branches(Some((log_branches[0].clone(), log_branches[1].clone())));
}
if let Some(template_name) = &self.template_name {
builder.template_name(template_name.clone());
}
if let Some(template_str) = &self.template_str {
builder.template_str(template_str.clone());
}
builder
.user_variables(self.user_variables.clone())
.token_map_enabled(self.token_map_enabled);
builder.build().unwrap_or_default()
}
}
pub fn export_config_to_toml(config: &Code2PromptConfig) -> Result<String, toml::ser::Error> {
let toml_config = TomlConfig {
default_output: OutputDestination::Stdout, path: Some(config.path.to_string_lossy().to_string()),
include_patterns: config.include_patterns.clone(),
exclude_patterns: config.exclude_patterns.clone(),
line_numbers: config.line_numbers,
absolute_path: config.absolute_path,
full_directory_tree: config.full_directory_tree,
output_format: Some(config.output_format),
sort_method: config.sort_method,
encoding: Some(config.encoding),
token_format: Some(config.token_format),
diff_enabled: config.diff_enabled,
diff_branches: config
.diff_branches
.as_ref()
.map(|(a, b)| vec![a.clone(), b.clone()]),
log_branches: config
.log_branches
.as_ref()
.map(|(a, b)| vec![a.clone(), b.clone()]),
template_name: if config.template_name.is_empty() {
None
} else {
Some(config.template_name.clone())
},
template_str: if config.template_str.is_empty() {
None
} else {
Some(config.template_str.clone())
},
user_variables: config.user_variables.clone(),
token_map_enabled: config.token_map_enabled,
};
toml_config.to_string()
}