use std::env;
use std::path::{Path, PathBuf};
use serde::{Deserialize, Serialize};
use crate::{
AssClassificationPolicy, PgsOcrConfig, ProviderConfig, TranslationOptions, TranslatorError,
};
const DEFAULT_CONFIG_DIR_NAME: &str = "shinkai-translator";
const DEFAULT_CONFIG_FILE_NAME: &str = "config.toml";
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct AppConfig {
#[serde(default)]
pub provider: ProviderConfig,
#[serde(default)]
pub translation: TranslationDefaults,
#[serde(default)]
pub ocr: PgsOcrConfig,
#[serde(default)]
pub classification: AssClassificationPolicy,
#[serde(default)]
pub tools: ToolConfig,
}
impl Default for AppConfig {
fn default() -> Self {
Self {
provider: ProviderConfig::default(),
translation: TranslationDefaults::default(),
ocr: PgsOcrConfig::default(),
classification: AssClassificationPolicy::default(),
tools: ToolConfig::default(),
}
}
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct TranslationDefaults {
#[serde(default)]
pub source_language: Option<String>,
#[serde(default)]
pub target_language: Option<String>,
#[serde(default = "default_max_batch_items")]
pub max_batch_items: usize,
#[serde(default = "default_max_batch_characters")]
pub max_batch_characters: usize,
#[serde(default = "default_max_parallel_batches")]
pub max_parallel_batches: usize,
#[serde(default)]
pub system_prompt: Option<String>,
}
impl Default for TranslationDefaults {
fn default() -> Self {
let defaults = TranslationOptions::default();
Self {
source_language: defaults.source_language,
target_language: None,
max_batch_items: defaults.max_batch_items,
max_batch_characters: defaults.max_batch_characters,
max_parallel_batches: defaults.max_parallel_batches,
system_prompt: defaults.system_prompt,
}
}
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct ToolConfig {
#[serde(default = "default_ffmpeg_bin")]
pub ffmpeg_bin: String,
#[serde(default = "default_ffprobe_bin")]
pub ffprobe_bin: String,
}
impl Default for ToolConfig {
fn default() -> Self {
Self {
ffmpeg_bin: default_ffmpeg_bin(),
ffprobe_bin: default_ffprobe_bin(),
}
}
}
#[derive(Clone, Debug, PartialEq)]
pub struct LoadedConfig {
pub path: PathBuf,
pub data: AppConfig,
pub exists: bool,
}
impl LoadedConfig {
pub fn load(path: Option<&Path>) -> Result<Self, TranslatorError> {
let path = path
.map(Path::to_path_buf)
.unwrap_or_else(default_config_path);
if !path.exists() {
return Ok(Self {
path,
data: AppConfig::default(),
exists: false,
});
}
let source = std::fs::read_to_string(&path)?;
let data = toml::from_str(&source).map_err(|error| {
TranslatorError::InvalidConfig(format!(
"failed to parse config {}: {error}",
path.display()
))
})?;
Ok(Self {
path,
data,
exists: true,
})
}
pub fn save(&mut self) -> Result<(), TranslatorError> {
if let Some(parent) = self.path.parent() {
std::fs::create_dir_all(parent)?;
}
let body = toml::to_string_pretty(&self.data).map_err(|error| {
TranslatorError::InvalidConfig(format!("failed to serialize config: {error}"))
})?;
std::fs::write(&self.path, body)?;
self.exists = true;
Ok(())
}
}
pub fn default_config_path() -> PathBuf {
default_config_dir().join(DEFAULT_CONFIG_FILE_NAME)
}
pub fn default_config_dir() -> PathBuf {
if let Ok(value) = env::var("XDG_CONFIG_HOME") {
if !value.trim().is_empty() {
return PathBuf::from(value).join(DEFAULT_CONFIG_DIR_NAME);
}
}
if let Ok(home) = env::var("HOME") {
if !home.trim().is_empty() {
return PathBuf::from(home)
.join(".config")
.join(DEFAULT_CONFIG_DIR_NAME);
}
}
env::temp_dir().join(DEFAULT_CONFIG_DIR_NAME)
}
fn default_ffmpeg_bin() -> String {
"ffmpeg".to_owned()
}
fn default_ffprobe_bin() -> String {
"ffprobe".to_owned()
}
fn default_max_batch_items() -> usize {
TranslationOptions::default().max_batch_items
}
fn default_max_batch_characters() -> usize {
TranslationOptions::default().max_batch_characters
}
fn default_max_parallel_batches() -> usize {
TranslationOptions::default().max_parallel_batches
}
#[cfg(test)]
mod tests {
use super::{AppConfig, LoadedConfig, ToolConfig, TranslationDefaults};
#[test]
fn translation_defaults_start_without_target_language() {
let defaults = TranslationDefaults::default();
assert_eq!(defaults.target_language, None);
assert_eq!(defaults.max_batch_items, 8);
assert_eq!(defaults.max_batch_characters, 4_000);
assert_eq!(defaults.max_parallel_batches, 5);
}
#[test]
fn config_roundtrip_preserves_values() {
let temp_dir = tempfile::tempdir().expect("temp dir should exist");
let path = temp_dir.path().join("config.toml");
let mut config = LoadedConfig {
path: path.clone(),
data: AppConfig {
translation: TranslationDefaults {
target_language: Some("Portuguese (Brazil)".to_owned()),
..TranslationDefaults::default()
},
tools: ToolConfig {
ffmpeg_bin: "/usr/local/bin/ffmpeg".to_owned(),
ffprobe_bin: "/usr/local/bin/ffprobe".to_owned(),
},
..AppConfig::default()
},
exists: false,
};
config.save().expect("config should save");
let loaded = LoadedConfig::load(Some(&path)).expect("config should load");
assert!(loaded.exists);
assert_eq!(loaded.data.translation.target_language.as_deref(), Some("Portuguese (Brazil)"));
assert_eq!(loaded.data.tools.ffmpeg_bin, "/usr/local/bin/ffmpeg");
assert_eq!(loaded.data.tools.ffprobe_bin, "/usr/local/bin/ffprobe");
}
}