use std::collections::HashMap;
use std::io::BufRead;
use std::net::SocketAddr;
use std::path::{Path, PathBuf};
use crate::registry::builtin_provider_defs;
pub struct InitOptions {
pub providers: Vec<String>,
pub api_keys: HashMap<String, String>,
pub custom_providers: Vec<CustomProviderInit>,
pub listen_addr: Option<SocketAddr>,
pub home_dir: PathBuf,
}
pub struct CustomProviderInit {
pub name: String,
pub derives: String,
pub api_base: String,
pub env_key_var: String,
}
pub struct InitResult {
pub config_path: PathBuf,
pub env_path: PathBuf,
pub providers_configured: Vec<String>,
}
pub fn write_init_config(
options: &InitOptions,
overwrite: bool,
) -> crate::error::Result<InitResult> {
std::fs::create_dir_all(&options.home_dir).map_err(|e| {
crate::error::ConfigError::ConfigRead {
path: options.home_dir.clone(),
source: e,
}
})?;
std::fs::create_dir_all(options.home_dir.join("run")).ok();
std::fs::create_dir_all(options.home_dir.join("logs")).ok();
let config_path = options.home_dir.join("bitrouter.yaml");
let env_path = options.home_dir.join(".env");
if overwrite {
backup_if_exists(&config_path);
backup_if_exists(&env_path);
}
let yaml = generate_config_yaml(options);
std::fs::write(&config_path, &yaml).map_err(|e| crate::error::ConfigError::ConfigRead {
path: config_path.clone(),
source: e,
})?;
let env_content = if !overwrite && env_path.exists() {
merge_env_file(&env_path, options)
} else {
generate_env_content(options)
};
std::fs::write(&env_path, &env_content).map_err(|e| crate::error::ConfigError::ConfigRead {
path: env_path.clone(),
source: e,
})?;
let gitignore_path = options.home_dir.join(".gitignore");
if !gitignore_path.exists() {
std::fs::write(&gitignore_path, "logs/\nrun/\n.env\n").ok();
}
let mut providers_configured = options.providers.clone();
providers_configured.extend(options.custom_providers.iter().map(|cp| cp.name.clone()));
Ok(InitResult {
config_path,
env_path,
providers_configured,
})
}
fn backup_if_exists(path: &Path) {
if path.exists() {
let mut backup = path.to_path_buf();
let name = path
.file_name()
.unwrap_or_default()
.to_string_lossy()
.to_string();
backup.set_file_name(format!("{name}.bak"));
let _ = std::fs::rename(path, backup);
}
}
fn generate_config_yaml(options: &InitOptions) -> String {
let defs = builtin_provider_defs();
let mut yaml = String::from(
"# BitRouter configuration\n\
# Generated by `bitrouter init`\n\n",
);
let listen = options
.listen_addr
.unwrap_or_else(|| "127.0.0.1:8787".parse().unwrap());
yaml.push_str(&format!("server:\n listen: \"{listen}\"\n\n"));
let has_providers = !options.providers.is_empty() || !options.custom_providers.is_empty();
if has_providers {
yaml.push_str("providers:\n");
for name in &options.providers {
let fallback = name.to_uppercase();
let prefix = defs
.get(name)
.and_then(|bp| bp.config.env_prefix.as_deref())
.unwrap_or(&fallback);
yaml.push_str(&format!(
" {name}:\n api_key: \"${{{prefix}_API_KEY}}\"\n\n"
));
}
for cp in &options.custom_providers {
yaml.push_str(&format!(
" {}:\n derives: {}\n api_base: \"{}\"\n api_key: \"${{{}}}\"\n\n",
cp.name, cp.derives, cp.api_base, cp.env_key_var,
));
}
}
yaml
}
fn generate_env_content(options: &InitOptions) -> String {
let defs = builtin_provider_defs();
let mut content = String::from(
"# BitRouter environment variables\n\
# This file is ignored by git.\n\n",
);
for name in &options.providers {
let fallback = name.to_uppercase();
let prefix = defs
.get(name)
.and_then(|bp| bp.config.env_prefix.as_deref())
.unwrap_or(&fallback);
let key_var = format!("{prefix}_API_KEY");
let key_value = options.api_keys.get(name).map(|s| s.as_str()).unwrap_or("");
content.push_str(&format!("{key_var}={key_value}\n"));
}
for cp in &options.custom_providers {
let key_value = options
.api_keys
.get(&cp.name)
.map(|s| s.as_str())
.unwrap_or("");
content.push_str(&format!("{}={key_value}\n", cp.env_key_var));
}
content
}
fn merge_env_file(env_path: &Path, options: &InitOptions) -> String {
let defs = builtin_provider_defs();
let mut new_vars: HashMap<String, String> = HashMap::new();
for name in &options.providers {
let fallback = name.to_uppercase();
let prefix = defs
.get(name)
.and_then(|bp| bp.config.env_prefix.as_deref())
.unwrap_or(&fallback);
let key_var = format!("{prefix}_API_KEY");
let key_value = options.api_keys.get(name).cloned().unwrap_or_default();
new_vars.insert(key_var, key_value);
}
for cp in &options.custom_providers {
let key_value = options.api_keys.get(&cp.name).cloned().unwrap_or_default();
new_vars.insert(cp.env_key_var.clone(), key_value);
}
let mut lines = Vec::new();
let mut seen = std::collections::HashSet::new();
if let Ok(file) = std::fs::File::open(env_path) {
for line in std::io::BufReader::new(file).lines().map_while(Result::ok) {
let trimmed = line.trim();
if let Some((key, existing_value)) = trimmed.split_once('=') {
let key = key.trim();
if let Some(new_value) = new_vars.get(key) {
seen.insert(key.to_owned());
let existing_value = existing_value.trim();
if existing_value.is_empty() {
lines.push(format!("{key}={new_value}"));
} else {
lines.push(line);
}
continue;
}
}
lines.push(line);
}
}
for name in &options.providers {
let fallback = name.to_uppercase();
let prefix = defs
.get(name)
.and_then(|bp| bp.config.env_prefix.as_deref())
.unwrap_or(&fallback);
let key_var = format!("{prefix}_API_KEY");
if !seen.contains(&key_var) {
let value = new_vars.get(&key_var).map(|s| s.as_str()).unwrap_or("");
lines.push(format!("{key_var}={value}"));
}
}
for cp in &options.custom_providers {
if !seen.contains(&cp.env_key_var) {
let value = new_vars
.get(&cp.env_key_var)
.map(|s| s.as_str())
.unwrap_or("");
lines.push(format!("{}={value}", cp.env_key_var));
}
}
let mut result = lines.join("\n");
if !result.ends_with('\n') {
result.push('\n');
}
result
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn generates_valid_yaml() {
let options = InitOptions {
providers: vec!["openai".into()],
api_keys: HashMap::from([("openai".into(), "sk-test".into())]),
listen_addr: None,
custom_providers: vec![],
home_dir: PathBuf::from("/tmp"),
};
let yaml = generate_config_yaml(&options);
assert!(yaml.contains("providers:"));
assert!(yaml.contains("openai:"));
assert!(yaml.contains("${OPENAI_API_KEY}"));
let _: serde_yaml::Value = serde_yaml::from_str(&yaml).unwrap();
}
#[test]
fn generates_env_content() {
let options = InitOptions {
providers: vec!["openai".into(), "anthropic".into()],
api_keys: HashMap::from([
("openai".into(), "sk-test".into()),
("anthropic".into(), "sk-ant-test".into()),
]),
listen_addr: None,
custom_providers: vec![],
home_dir: PathBuf::from("/tmp"),
};
let env = generate_env_content(&options);
assert!(env.contains("OPENAI_API_KEY=sk-test"));
assert!(env.contains("ANTHROPIC_API_KEY=sk-ant-test"));
}
#[test]
fn yaml_round_trips_through_config() {
let options = InitOptions {
providers: vec!["openai".into()],
api_keys: HashMap::from([("openai".into(), "sk-test".into())]),
listen_addr: Some("127.0.0.1:9090".parse().unwrap()),
custom_providers: vec![],
home_dir: PathBuf::from("/tmp"),
};
let yaml = generate_config_yaml(&options);
let config = crate::config::BitrouterConfig::load_from_str(&yaml, None).unwrap();
assert_eq!(config.server.listen, "127.0.0.1:9090".parse().unwrap());
assert!(config.providers.contains_key("openai"));
}
}