mod audio;
mod image;
mod path;
mod slide;
mod video;
use clap::Parser;
use serde_json::json;
use std::collections::HashMap;
use std::path::Path;
use std::path::PathBuf;
use tracing::subscriber::SetGlobalDefaultError;
use transformrs::Provider;
#[derive(Parser)]
#[command(author, version, about = "Text and image to video")]
struct Arguments {
#[arg(long)]
input: String,
#[arg(long)]
verbose: bool,
#[arg(long)]
provider: Option<String>,
#[arg(long)]
model: Option<String>,
#[arg(long, default_value = "am_adam")]
voice: String,
#[arg(long)]
speed: Option<f32>,
#[arg(long)]
audio_format: Option<String>,
#[arg(long)]
language_code: Option<String>,
#[arg(long, default_value = "_out")]
out_dir: String,
#[arg(long, default_value = "true")]
cache: bool,
#[arg(long, default_value = "false")]
release: bool,
#[arg(long, default_value = "opus")]
audio_codec: String,
}
fn provider_from_str(s: &str) -> Provider {
if s.starts_with("openai-compatible(") {
let s = s.strip_prefix("openai-compatible(").unwrap();
let s = s.strip_suffix(")").unwrap();
let mut domain = s.to_string();
if !domain.starts_with("https") {
if domain.contains("localhost") {
domain = format!("http://{}", domain);
} else {
domain = format!("https://{}", domain);
}
}
Provider::OpenAICompatible(domain)
} else if s == "google" {
Provider::Google
} else if s == "deepinfra" {
Provider::DeepInfra
} else {
panic!("Unsupported provider: {}. Try setting a key like `GOOGLE_KEY` and not passing `--provider`.", s);
}
}
fn init_subscriber(level: tracing::Level) -> Result<(), SetGlobalDefaultError> {
let subscriber = tracing_subscriber::FmtSubscriber::builder()
.with_max_level(level)
.with_writer(std::io::stdout)
.without_time()
.with_target(false)
.finish();
tracing::subscriber::set_global_default(subscriber)
}
fn include_includes(input_dir: &Path, content: &str) -> String {
let mut output = String::new();
for line in content.lines() {
if line.starts_with("#include") {
let include = line.split_whitespace().nth(1).unwrap().trim_matches('"');
let include_path = input_dir.join(include);
tracing::info!("Including file: {}", include_path.display());
let content = std::fs::read_to_string(include_path).unwrap();
for line in content.lines() {
output.push_str(line);
output.push('\n');
}
} else {
output.push_str(line);
output.push('\n');
}
}
output
}
fn copy_input_with_includes(dir: &str, input: &str) -> PathBuf {
let output_path = Path::new(dir).join("input.typ");
let content = std::fs::read_to_string(input).unwrap();
let input_dir = Path::new(input).parent().unwrap();
let content = include_includes(input_dir, &content);
std::fs::write(&output_path, content).unwrap();
output_path
}
#[tokio::main]
async fn main() {
let args = Arguments::parse();
if args.verbose {
init_subscriber(tracing::Level::DEBUG).unwrap();
} else {
init_subscriber(tracing::Level::INFO).unwrap();
}
let dir = &args.out_dir;
let path = Path::new(dir);
if !path.exists() {
std::fs::create_dir_all(path).unwrap();
}
let input = copy_input_with_includes(dir, &args.input);
let provider = args.provider.map(|p| provider_from_str(&p));
let provider = provider.unwrap_or(Provider::DeepInfra);
let mut other = HashMap::new();
if provider != Provider::Google {
other.insert("seed".to_string(), json!(42));
}
let config = transformrs::text_to_speech::TTSConfig {
voice: Some(args.voice.clone()),
output_format: args.audio_format.clone(),
speed: args.speed,
other: Some(other),
language_code: args.language_code.clone(),
};
let slides = slide::slides(input.to_str().unwrap());
if slides.is_empty() {
panic!("No slides found in input file: {}", args.input);
}
image::generate_images(&input, dir);
let audio_ext = config.output_format.clone().unwrap_or("mp3".to_string());
audio::generate_audio_files(
&provider,
dir,
&slides,
args.cache,
&config,
&args.model,
&audio_ext,
)
.await;
let output = "out.mkv";
video::generate_video(dir, &slides, args.cache, &config, output, &audio_ext);
if args.release {
video::generate_release_video(dir, output, "release.mp4", &args.audio_codec);
}
}