use std::error::Error;
use std::fmt::Write as _;
use std::fs;
use std::net::TcpStream;
use std::path::{Path, PathBuf};
use std::process::{Child, Command};
use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
use clap::{Args as ClapArgs, ValueEnum};
use serde_json::Value;
use toml_edit::{value as toml_value, DocumentMut, Item, Table};
use crate::seed::{
client_integrations as seed_client_integrations, ClientIntegration, ConfigFormat,
ModelArgPosition,
};
use crate::DEFAULT_MODEL;
const DEFAULT_BASE_URL: &str = "http://127.0.0.1:8080";
const EMPTY_BACKUP_SENTINEL: &str = "# formal-ai-empty-config-backup-v1\n";
#[derive(Debug, Clone, Copy, ValueEnum)]
pub enum ClientProtocol {
Openai,
Gemini,
Vertex,
Anthropic,
}
impl ClientProtocol {
const fn as_str(self) -> &'static str {
match self {
Self::Openai => "openai",
Self::Gemini => "gemini",
Self::Vertex => "vertex",
Self::Anthropic => "anthropic",
}
}
}
#[derive(Debug, Clone, ClapArgs)]
#[command(trailing_var_arg = true)]
#[allow(clippy::struct_excessive_bools)]
pub struct WithFormalAiArgs {
#[arg(short = 'g', long = "global", default_value_t = false)]
pub global: bool,
#[arg(long, default_value_t = false)]
pub undo: bool,
#[arg(long, default_value_t = false)]
pub all: bool,
#[arg(long, default_value = DEFAULT_BASE_URL)]
pub base_url: String,
#[arg(long)]
pub port: Option<u16>,
#[arg(long, default_value_t = false)]
pub start_server: bool,
#[arg(long, value_enum)]
pub protocol: Option<ClientProtocol>,
#[arg(long, default_value = DEFAULT_MODEL)]
pub model: String,
#[arg(value_name = "TOOL")]
pub tool: Option<String>,
#[arg(
value_name = "ARGS",
allow_hyphen_values = true,
trailing_var_arg = true
)]
pub tool_args: Vec<String>,
}
#[derive(Debug, Clone)]
struct RenderContext {
base_url: String,
endpoint_base_url: String,
provider_id: String,
model: String,
model_selector: String,
api_key_env: String,
api_key: String,
protocol_base_env: String,
}
struct TempConfigDir {
path: PathBuf,
}
impl TempConfigDir {
fn new(tool: &str) -> Result<Self, Box<dyn Error>> {
let nanos = SystemTime::now().duration_since(UNIX_EPOCH)?.as_nanos();
let path = std::env::temp_dir().join(format!(
"formal-ai-{tool}-config-{}-{nanos}",
std::process::id()
));
fs::create_dir_all(&path)?;
Ok(Self { path })
}
}
impl Drop for TempConfigDir {
fn drop(&mut self) {
let _ = fs::remove_dir_all(&self.path);
}
}
struct ServerGuard {
child: Child,
}
impl Drop for ServerGuard {
fn drop(&mut self) {
let _ = self.child.kill();
let _ = self.child.wait();
}
}
pub fn run_with_formal_ai(args: &WithFormalAiArgs) -> Result<(), Box<dyn Error>> {
let integrations = seed_client_integrations();
if args.global || args.undo {
let selected = select_integrations(args, &integrations)?;
for integration in selected {
if args.undo {
undo_global_config(integration)?;
} else {
write_global_config(integration, args)?;
}
}
return Ok(());
}
if args.all {
return Err("--all is only valid with --global or --undo".into());
}
let tool = args
.tool
.as_deref()
.ok_or("missing tool; expected codex, opencode, or gemini")?;
let integration = find_integration(tool, &integrations)?;
let context = render_context(integration, args)?;
let _server = if args.start_server {
maybe_start_server(&context.base_url, args.port)?
} else {
None
};
run_ephemeral(integration, &args.tool_args, &context)
}
fn select_integrations<'a>(
args: &WithFormalAiArgs,
integrations: &'a [ClientIntegration],
) -> Result<Vec<&'a ClientIntegration>, Box<dyn Error>> {
if args.all {
return Ok(integrations.iter().collect());
}
let tool = args
.tool
.as_deref()
.ok_or("missing tool; pass a tool name or --all")?;
Ok(vec![find_integration(tool, integrations)?])
}
fn find_integration<'a>(
tool: &str,
integrations: &'a [ClientIntegration],
) -> Result<&'a ClientIntegration, Box<dyn Error>> {
integrations
.iter()
.find(|integration| integration.id == tool)
.ok_or_else(|| {
let supported = integrations
.iter()
.map(|integration| integration.id.as_str())
.collect::<Vec<_>>()
.join(", ");
format!("unsupported tool `{tool}`; supported tools: {supported}").into()
})
}
fn render_context(
integration: &ClientIntegration,
args: &WithFormalAiArgs,
) -> Result<RenderContext, Box<dyn Error>> {
let protocol = args
.protocol
.map_or(integration.default_protocol.as_str(), |protocol| {
protocol.as_str()
});
if !integration
.supported_protocols
.iter()
.any(|supported| supported == protocol)
{
return Err(format!("{} does not support protocol `{protocol}`", integration.id).into());
}
let endpoint_path = integration
.endpoint_path_for(protocol)
.ok_or_else(|| format!("{} has no endpoint for {protocol}", integration.id))?;
let base_url = base_url_with_port(&args.base_url, args.port);
let endpoint_base_url = join_url_path(&base_url, endpoint_path);
let api_key = std::env::var(&integration.api_key_env)
.ok()
.filter(|value| !value.is_empty())
.or_else(|| std::env::var("FORMAL_AI_API_KEY").ok())
.filter(|value| !value.is_empty())
.unwrap_or_else(|| integration.api_key_default.clone());
let protocol_base_env = match protocol {
"vertex" => "GOOGLE_VERTEX_BASE_URL",
"gemini" => "GOOGLE_GEMINI_BASE_URL",
"openai" => "OPENAI_BASE_URL",
"anthropic" => "ANTHROPIC_BASE_URL",
_ => "FORMAL_AI_BASE_URL",
}
.to_string();
let mut context = RenderContext {
base_url,
endpoint_base_url,
provider_id: integration.provider_id.clone(),
model: args.model.clone(),
model_selector: String::new(),
api_key_env: integration.api_key_env.clone(),
api_key,
protocol_base_env,
};
context.model_selector = if integration.model_selector.is_empty() {
context.model.clone()
} else {
render_template(&integration.model_selector, &context)
};
Ok(context)
}
fn run_ephemeral(
integration: &ClientIntegration,
user_args: &[String],
context: &RenderContext,
) -> Result<(), Box<dyn Error>> {
let invocation = &integration.invocation;
let mut temp_config: Option<TempConfigDir> = None;
let mut command = Command::new(&integration.command);
for env in &invocation.env {
command.env(
render_template(&env.key, context),
render_template(&env.value, context),
);
}
if !invocation.config_json_settings.is_empty() {
let temp = TempConfigDir::new(&integration.id)?;
let config_path = temp.path.join(format!("{}.json", integration.id));
fs::write(
&config_path,
render_json_settings(&invocation.config_json_settings, context)?,
)?;
if !invocation.config_env.is_empty() {
command.env(&invocation.config_env, &config_path);
}
if !invocation.config_dir_env.is_empty() {
command.env(&invocation.config_dir_env, &temp.path);
}
temp_config = Some(temp);
}
let final_args = build_invocation_args(integration, user_args, context);
command.args(final_args);
let status = command.status()?;
drop(temp_config);
if status.success() {
return Ok(());
}
Err(format!(
"{} exited with status {}",
integration.command,
status
.code()
.map_or_else(|| String::from("signal"), |code| code.to_string())
)
.into())
}
fn build_invocation_args(
integration: &ClientIntegration,
user_args: &[String],
context: &RenderContext,
) -> Vec<String> {
let invocation = &integration.invocation;
let mut args = invocation
.prepend_args
.iter()
.chain(invocation.args.iter())
.map(|arg| render_template(arg, context))
.collect::<Vec<_>>();
if invocation.model_arg.is_empty() || contains_model_arg(user_args) {
args.extend(user_args.iter().cloned());
return args;
}
let model_arg = render_template(&invocation.model_arg, context);
let model_value = context.model_selector.clone();
match invocation.model_arg_position {
Some(ModelArgPosition::AfterFirstArg) if !user_args.is_empty() => {
args.push(user_args[0].clone());
args.push(model_arg);
args.push(model_value);
args.extend(user_args.iter().skip(1).cloned());
}
_ => {
args.push(model_arg);
args.push(model_value);
args.extend(user_args.iter().cloned());
}
}
args
}
fn contains_model_arg(args: &[String]) -> bool {
args.iter()
.any(|arg| matches!(arg.as_str(), "-m" | "--model") || arg.starts_with("--model="))
}
fn write_global_config(
integration: &ClientIntegration,
args: &WithFormalAiArgs,
) -> Result<(), Box<dyn Error>> {
let context = render_context(integration, args)?;
let path = global_config_path(&integration.global_config.path)?;
let backup_path = backup_path(&path, &integration.global_config.backup_suffix);
ensure_backup(&path, &backup_path)?;
let existing = fs::read_to_string(&path).unwrap_or_default();
let next = match integration.global_config.format {
ConfigFormat::Toml => merge_toml_config(integration, &existing, &context)?,
ConfigFormat::Json => merge_json_config(integration, &existing, &context)?,
ConfigFormat::ShellEnv => merge_shell_env_config(integration, &existing, &context),
};
if next == existing {
println!(
"{} already configured at {}",
integration.id,
path.display()
);
} else {
write_file(&path, &next)?;
println!("configured {} at {}", integration.id, path.display());
}
Ok(())
}
fn undo_global_config(integration: &ClientIntegration) -> Result<(), Box<dyn Error>> {
let path = global_config_path(&integration.global_config.path)?;
let backup_path = backup_path(&path, &integration.global_config.backup_suffix);
if !backup_path.exists() {
println!(
"no formal-ai backup for {} at {}",
integration.id,
path.display()
);
return Ok(());
}
let backup = fs::read_to_string(&backup_path)?;
if backup == EMPTY_BACKUP_SENTINEL {
match fs::remove_file(&path) {
Ok(()) => {}
Err(error) if error.kind() == std::io::ErrorKind::NotFound => {}
Err(error) => return Err(error.into()),
}
} else {
write_file(&path, &backup)?;
}
fs::remove_file(&backup_path)?;
println!("restored {} from {}", integration.id, backup_path.display());
Ok(())
}
fn ensure_backup(path: &Path, backup_path: &Path) -> Result<(), Box<dyn Error>> {
if backup_path.exists() {
return Ok(());
}
if let Some(parent) = backup_path.parent() {
fs::create_dir_all(parent)?;
}
if path.exists() {
fs::copy(path, backup_path)?;
} else {
fs::write(backup_path, EMPTY_BACKUP_SENTINEL)?;
}
Ok(())
}
fn merge_toml_config(
integration: &ClientIntegration,
existing: &str,
context: &RenderContext,
) -> Result<String, Box<dyn Error>> {
let mut document = if existing.trim().is_empty() {
DocumentMut::new()
} else {
existing.parse::<DocumentMut>()?
};
for (path, value) in &integration.global_config.toml_settings {
set_toml_string(
document.as_table_mut(),
&render_template(path, context),
&render_template(value, context),
)?;
}
Ok(ensure_trailing_newline(document.to_string()))
}
fn set_toml_string(
table: &mut Table,
dotted_path: &str,
value: &str,
) -> Result<(), Box<dyn Error>> {
let parts = dotted_path
.split('.')
.map(str::trim)
.filter(|part| !part.is_empty())
.collect::<Vec<_>>();
let Some((last, parents)) = parts.split_last() else {
return Err("empty TOML setting path".into());
};
let parent = table_at_path_mut(table, parents);
parent[*last] = toml_value(value);
Ok(())
}
fn table_at_path_mut<'a>(mut table: &'a mut Table, parts: &[&str]) -> &'a mut Table {
for part in parts {
let item = table
.entry(part)
.or_insert_with(|| Item::Table(Table::new()));
if !item.is_table() {
*item = Item::Table(Table::new());
}
table = item.as_table_mut().expect("table item");
}
table
}
fn merge_json_config(
integration: &ClientIntegration,
existing: &str,
context: &RenderContext,
) -> Result<String, Box<dyn Error>> {
let mut base = if existing.trim().is_empty() {
Value::Object(serde_json::Map::new())
} else {
serde_json::from_str(existing)?
};
let overlay = json_settings_value(&integration.global_config.json_settings, context)?;
merge_json_value(&mut base, overlay);
Ok(format!("{}\n", serde_json::to_string_pretty(&base)?))
}
fn render_json_settings(
settings: &[(String, String)],
context: &RenderContext,
) -> Result<String, Box<dyn Error>> {
Ok(format!(
"{}\n",
serde_json::to_string_pretty(&json_settings_value(settings, context)?)?
))
}
fn json_settings_value(
settings: &[(String, String)],
context: &RenderContext,
) -> Result<Value, Box<dyn Error>> {
let mut value = Value::Object(serde_json::Map::new());
for (path, setting_value) in settings {
set_json_string(&mut value, path, setting_value, context)?;
}
Ok(value)
}
fn set_json_string(
root: &mut Value,
dotted_path: &str,
value: &str,
context: &RenderContext,
) -> Result<(), Box<dyn Error>> {
let parts = dotted_path
.split('.')
.map(str::trim)
.filter(|part| !part.is_empty())
.map(|part| render_template(part, context))
.collect::<Vec<_>>();
let Some((last, parents)) = parts.split_last() else {
return Err("empty JSON setting path".into());
};
let mut current = root;
for part in parents {
let object = current
.as_object_mut()
.ok_or("JSON setting path conflicts with a scalar value")?;
current = object
.entry(part.clone())
.or_insert_with(|| Value::Object(serde_json::Map::new()));
}
let object = current
.as_object_mut()
.ok_or("JSON setting path conflicts with a scalar value")?;
object.insert(last.clone(), Value::String(render_template(value, context)));
Ok(())
}
fn merge_json_value(base: &mut Value, overlay: Value) {
match (base, overlay) {
(Value::Object(base_map), Value::Object(overlay_map)) => {
for (key, overlay_value) in overlay_map {
match base_map.get_mut(&key) {
Some(base_value) => merge_json_value(base_value, overlay_value),
None => {
base_map.insert(key, overlay_value);
}
}
}
}
(base_value, overlay_value) => *base_value = overlay_value,
}
}
fn merge_shell_env_config(
integration: &ClientIntegration,
existing: &str,
context: &RenderContext,
) -> String {
let mut next = remove_managed_block(existing, &integration.id);
if !next.is_empty() && !next.ends_with('\n') {
next.push('\n');
}
let _ = writeln!(next, "# >>> formal-ai {}", integration.id);
for env in &integration.global_config.shell_env {
next.push_str("export ");
next.push_str(&render_template(&env.key, context));
next.push('=');
next.push_str(&shell_double_quote(&render_template(&env.value, context)));
next.push('\n');
}
let _ = writeln!(next, "# <<< formal-ai {}", integration.id);
next
}
fn remove_managed_block(existing: &str, tool: &str) -> String {
let start = format!("# >>> formal-ai {tool}");
let end = format!("# <<< formal-ai {tool}");
let mut out = String::new();
let mut skipping = false;
for line in existing.lines() {
if line == start {
skipping = true;
continue;
}
if skipping {
if line == end {
skipping = false;
}
continue;
}
out.push_str(line);
out.push('\n');
}
out
}
fn shell_double_quote(value: &str) -> String {
let escaped = value.replace('\\', "\\\\").replace('"', "\\\"");
format!("\"{escaped}\"")
}
fn render_template(template: &str, context: &RenderContext) -> String {
template
.replace("{provider_id}", &context.provider_id)
.replace("{model}", &context.model)
.replace("{model_selector}", &context.model_selector)
.replace("{endpoint_base_url}", &context.endpoint_base_url)
.replace("{base_url}", &context.base_url)
.replace("{api_key_env}", &context.api_key_env)
.replace("{api_key}", &context.api_key)
.replace("{protocol_base_env}", &context.protocol_base_env)
}
fn global_config_path(relative: &str) -> Result<PathBuf, Box<dyn Error>> {
let path = Path::new(relative);
if path.is_absolute() {
return Ok(path.to_path_buf());
}
let home = std::env::var_os("HOME")
.or_else(|| std::env::var_os("USERPROFILE"))
.ok_or("HOME is not set; cannot resolve global config path")?;
Ok(PathBuf::from(home).join(path))
}
fn backup_path(path: &Path, suffix: &str) -> PathBuf {
let mut backup = path.as_os_str().to_os_string();
backup.push(suffix);
PathBuf::from(backup)
}
fn write_file(path: &Path, contents: &str) -> Result<(), Box<dyn Error>> {
if let Some(parent) = path.parent() {
fs::create_dir_all(parent)?;
}
fs::write(path, contents)?;
Ok(())
}
fn base_url_with_port(base_url: &str, port: Option<u16>) -> String {
let trimmed = base_url.trim().trim_end_matches('/').to_string();
let Some(port) = port else {
return trimmed;
};
replace_url_port(&trimmed, port)
}
fn replace_url_port(url: &str, port: u16) -> String {
let Some((scheme, rest)) = url.split_once("://") else {
return format!("{url}:{port}");
};
let (authority, path) = rest.split_once('/').unwrap_or((rest, ""));
let host = authority.strip_prefix('[').map_or_else(
|| unbracketed_authority_host(authority),
|stripped| bracketed_authority_host(authority, stripped),
);
if path.is_empty() {
format!("{scheme}://{host}:{port}")
} else {
format!("{scheme}://{host}:{port}/{path}")
}
}
fn bracketed_authority_host(authority: &str, stripped: &str) -> String {
stripped.split_once(']').map_or_else(
|| authority.to_string(),
|(inside, _after)| format!("[{inside}]"),
)
}
fn unbracketed_authority_host(authority: &str) -> String {
authority
.split_once(':')
.map_or_else(|| authority.to_string(), |(host, _)| host.to_string())
}
fn join_url_path(base_url: &str, endpoint_path: &str) -> String {
let base = base_url.trim_end_matches('/');
if base.ends_with(endpoint_path) {
return base.to_string();
}
format!("{base}/{}", endpoint_path.trim_start_matches('/'))
}
fn ensure_trailing_newline(mut value: String) -> String {
if !value.ends_with('\n') {
value.push('\n');
}
value
}
fn maybe_start_server(
base_url: &str,
port_override: Option<u16>,
) -> Result<Option<ServerGuard>, Box<dyn Error>> {
let (host, port) = parse_host_port(base_url, port_override)?;
let address = format!("{host}:{port}");
if TcpStream::connect(&address).is_ok() {
return Ok(None);
}
let binary = formal_ai_binary_path()?;
let mut child = Command::new(binary)
.args(["serve", "--host", &host, "--port", &port.to_string()])
.spawn()?;
wait_for_server(&address, &mut child)?;
Ok(Some(ServerGuard { child }))
}
fn parse_host_port(
base_url: &str,
port_override: Option<u16>,
) -> Result<(String, u16), Box<dyn Error>> {
let (_, rest) = base_url
.split_once("://")
.ok_or("base URL must include a scheme, for example http://127.0.0.1:8080")?;
let authority = rest.split('/').next().unwrap_or(rest);
let (host, parsed_port) = if let Some(stripped) = authority.strip_prefix('[') {
let (inside, after) = stripped
.split_once(']')
.ok_or("invalid bracketed IPv6 host in base URL")?;
let port = after.strip_prefix(':').and_then(|value| value.parse().ok());
(inside.to_string(), port)
} else if let Some((host, port)) = authority.split_once(':') {
(host.to_string(), port.parse().ok())
} else {
(authority.to_string(), None)
};
let port = port_override.or(parsed_port).unwrap_or(8080);
Ok((host, port))
}
fn formal_ai_binary_path() -> Result<PathBuf, Box<dyn Error>> {
let current = std::env::current_exe()?;
let stem = current.file_stem().and_then(|value| value.to_str());
if stem == Some("formal-ai") {
return Ok(current);
}
let sibling = current.with_file_name(format!("formal-ai{}", std::env::consts::EXE_SUFFIX));
if sibling.exists() {
return Ok(sibling);
}
Ok(PathBuf::from("formal-ai"))
}
fn wait_for_server(address: &str, child: &mut Child) -> Result<(), Box<dyn Error>> {
let deadline = Instant::now() + Duration::from_secs(5);
while Instant::now() < deadline {
if let Some(status) = child.try_wait()? {
return Err(format!("formal-ai serve exited before listening: {status}").into());
}
if TcpStream::connect(address).is_ok() {
return Ok(());
}
std::thread::sleep(Duration::from_millis(50));
}
Err(format!("formal-ai serve did not listen on {address}").into())
}