shinkai-translator 0.1.3

CLI tool for translating video subtitles with LLMs through OpenAI-compatible APIs, with native PGS OCR
use std::env;
use std::path::{Path, PathBuf};

use serde::{Deserialize, Serialize};

use crate::{
    AssClassificationPolicy, PgsOcrConfig, ProviderConfig, TranslationOptions, TranslatorError,
};

const DEFAULT_CONFIG_DIR_NAME: &str = "shinkai-translator";
const DEFAULT_CONFIG_FILE_NAME: &str = "config.toml";

#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct AppConfig {
    #[serde(default)]
    pub provider: ProviderConfig,
    #[serde(default)]
    pub translation: TranslationDefaults,
    #[serde(default)]
    pub ocr: PgsOcrConfig,
    #[serde(default)]
    pub classification: AssClassificationPolicy,
    #[serde(default)]
    pub tools: ToolConfig,
}

impl Default for AppConfig {
    fn default() -> Self {
        Self {
            provider: ProviderConfig::default(),
            translation: TranslationDefaults::default(),
            ocr: PgsOcrConfig::default(),
            classification: AssClassificationPolicy::default(),
            tools: ToolConfig::default(),
        }
    }
}

#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct TranslationDefaults {
    #[serde(default)]
    pub source_language: Option<String>,
    #[serde(default)]
    pub target_language: Option<String>,
    #[serde(default = "default_max_batch_items")]
    pub max_batch_items: usize,
    #[serde(default = "default_max_batch_characters")]
    pub max_batch_characters: usize,
    #[serde(default = "default_max_parallel_batches")]
    pub max_parallel_batches: usize,
    #[serde(default)]
    pub system_prompt: Option<String>,
}

impl Default for TranslationDefaults {
    fn default() -> Self {
        let defaults = TranslationOptions::default();
        Self {
            source_language: defaults.source_language,
            target_language: None,
            max_batch_items: defaults.max_batch_items,
            max_batch_characters: defaults.max_batch_characters,
            max_parallel_batches: defaults.max_parallel_batches,
            system_prompt: defaults.system_prompt,
        }
    }
}

#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct ToolConfig {
    #[serde(default = "default_ffmpeg_bin")]
    pub ffmpeg_bin: String,
    #[serde(default = "default_ffprobe_bin")]
    pub ffprobe_bin: String,
}

impl Default for ToolConfig {
    fn default() -> Self {
        Self {
            ffmpeg_bin: default_ffmpeg_bin(),
            ffprobe_bin: default_ffprobe_bin(),
        }
    }
}

#[derive(Clone, Debug, PartialEq)]
pub struct LoadedConfig {
    pub path: PathBuf,
    pub data: AppConfig,
    pub exists: bool,
}

impl LoadedConfig {
    pub fn load(path: Option<&Path>) -> Result<Self, TranslatorError> {
        let path = path
            .map(Path::to_path_buf)
            .unwrap_or_else(default_config_path);

        if !path.exists() {
            return Ok(Self {
                path,
                data: AppConfig::default(),
                exists: false,
            });
        }

        let source = std::fs::read_to_string(&path)?;
        let data = toml::from_str(&source).map_err(|error| {
            TranslatorError::InvalidConfig(format!(
                "failed to parse config {}: {error}",
                path.display()
            ))
        })?;

        Ok(Self {
            path,
            data,
            exists: true,
        })
    }

    pub fn save(&mut self) -> Result<(), TranslatorError> {
        if let Some(parent) = self.path.parent() {
            std::fs::create_dir_all(parent)?;
        }

        let body = toml::to_string_pretty(&self.data).map_err(|error| {
            TranslatorError::InvalidConfig(format!("failed to serialize config: {error}"))
        })?;
        std::fs::write(&self.path, body)?;
        self.exists = true;
        Ok(())
    }
}

pub fn default_config_path() -> PathBuf {
    default_config_dir().join(DEFAULT_CONFIG_FILE_NAME)
}

pub fn default_config_dir() -> PathBuf {
    if let Ok(value) = env::var("XDG_CONFIG_HOME") {
        if !value.trim().is_empty() {
            return PathBuf::from(value).join(DEFAULT_CONFIG_DIR_NAME);
        }
    }

    if let Ok(home) = env::var("HOME") {
        if !home.trim().is_empty() {
            return PathBuf::from(home)
                .join(".config")
                .join(DEFAULT_CONFIG_DIR_NAME);
        }
    }

    env::temp_dir().join(DEFAULT_CONFIG_DIR_NAME)
}

fn default_ffmpeg_bin() -> String {
    "ffmpeg".to_owned()
}

fn default_ffprobe_bin() -> String {
    "ffprobe".to_owned()
}

fn default_max_batch_items() -> usize {
    TranslationOptions::default().max_batch_items
}

fn default_max_batch_characters() -> usize {
    TranslationOptions::default().max_batch_characters
}

fn default_max_parallel_batches() -> usize {
    TranslationOptions::default().max_parallel_batches
}

#[cfg(test)]
mod tests {
    use super::{AppConfig, LoadedConfig, ToolConfig, TranslationDefaults};

    #[test]
    fn translation_defaults_start_without_target_language() {
        let defaults = TranslationDefaults::default();

        assert_eq!(defaults.target_language, None);
        assert_eq!(defaults.max_batch_items, 8);
        assert_eq!(defaults.max_batch_characters, 4_000);
        assert_eq!(defaults.max_parallel_batches, 5);
    }

    #[test]
    fn config_roundtrip_preserves_values() {
        let temp_dir = tempfile::tempdir().expect("temp dir should exist");
        let path = temp_dir.path().join("config.toml");
        let mut config = LoadedConfig {
            path: path.clone(),
            data: AppConfig {
                translation: TranslationDefaults {
                    target_language: Some("Portuguese (Brazil)".to_owned()),
                    ..TranslationDefaults::default()
                },
                tools: ToolConfig {
                    ffmpeg_bin: "/usr/local/bin/ffmpeg".to_owned(),
                    ffprobe_bin: "/usr/local/bin/ffprobe".to_owned(),
                },
                ..AppConfig::default()
            },
            exists: false,
        };

        config.save().expect("config should save");

        let loaded = LoadedConfig::load(Some(&path)).expect("config should load");
        assert!(loaded.exists);
        assert_eq!(loaded.data.translation.target_language.as_deref(), Some("Portuguese (Brazil)"));
        assert_eq!(loaded.data.tools.ffmpeg_bin, "/usr/local/bin/ffmpeg");
        assert_eq!(loaded.data.tools.ffprobe_bin, "/usr/local/bin/ffprobe");
    }
}