use anyhow::{Context, Result};
use clap::{Parser, Subcommand};
use serde::{Deserialize, Serialize};
use std::path::PathBuf;
#[derive(Parser, Debug)]
#[command(author, version, about = "Tmux Multi Agent Interface")]
pub struct Config {
#[arg(short, long, global = true)]
pub debug: bool,
#[arg(short, long, global = true)]
pub config: Option<PathBuf>,
#[arg(short = 'i', long)]
pub poll_interval: Option<u64>,
#[arg(short = 'l', long)]
pub capture_lines: Option<u32>,
#[arg(long, action = clap::ArgAction::Set)]
pub attached_only: Option<bool>,
#[arg(long)]
pub audit: bool,
#[command(subcommand)]
pub command: Option<Command>,
}
#[derive(Subcommand, Debug, Clone)]
pub enum Command {
Wrap {
#[arg(trailing_var_arg = true, allow_hyphen_values = true)]
args: Vec<String>,
},
Demo,
Audit {
#[command(subcommand)]
subcommand: AuditCommand,
},
}
#[derive(Subcommand, Debug, Clone)]
pub enum AuditCommand {
Stats {
#[arg(long, default_value = "20")]
top: usize,
},
Misdetections {
#[arg(long, short = 'n', default_value = "50")]
limit: usize,
},
Disagreements {
#[arg(long, short = 'n', default_value = "50")]
limit: usize,
},
}
impl Config {
pub fn parse_args() -> Self {
Self::parse()
}
pub fn is_wrap_mode(&self) -> bool {
matches!(self.command, Some(Command::Wrap { .. }))
}
pub fn is_demo_mode(&self) -> bool {
matches!(self.command, Some(Command::Demo))
}
pub fn is_audit_mode(&self) -> bool {
matches!(self.command, Some(Command::Audit { .. }))
}
pub fn get_audit_command(&self) -> Option<&AuditCommand> {
match &self.command {
Some(Command::Audit { subcommand }) => Some(subcommand),
_ => None,
}
}
pub fn get_wrap_args(&self) -> Option<(String, Vec<String>)> {
match &self.command {
Some(Command::Wrap { args }) if !args.is_empty() => {
let command = args[0].clone();
let cmd_args = args[1..].to_vec();
Some((command, cmd_args))
}
_ => None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Settings {
#[serde(default = "default_poll_interval")]
pub poll_interval_ms: u64,
#[serde(default = "default_passthrough_poll_interval")]
pub passthrough_poll_interval_ms: u64,
#[serde(default = "default_capture_lines")]
pub capture_lines: u32,
#[serde(default = "default_attached_only")]
pub attached_only: bool,
#[serde(default)]
pub agent_patterns: Vec<AgentPattern>,
#[serde(default)]
pub ui: UiSettings,
#[serde(default)]
pub web: WebSettings,
#[serde(default)]
pub exfil_detection: ExfilDetectionSettings,
#[serde(default)]
pub teams: TeamSettings,
#[serde(default)]
pub audit: AuditSettings,
#[serde(default)]
pub auto_approve: AutoApproveSettings,
#[serde(default)]
pub create_process: CreateProcessSettings,
#[serde(default)]
pub usage: UsageSettings,
}
fn default_poll_interval() -> u64 {
500
}
fn default_passthrough_poll_interval() -> u64 {
10
}
fn default_capture_lines() -> u32 {
100
}
fn default_attached_only() -> bool {
true
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AgentPattern {
pub pattern: String,
pub agent_type: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UiSettings {
#[serde(default = "default_show_preview")]
pub show_preview: bool,
#[serde(default = "default_preview_height")]
pub preview_height: u16,
#[serde(default = "default_color")]
pub color: bool,
#[serde(default = "default_show_activity_name")]
pub show_activity_name: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WebSettings {
#[serde(default = "default_web_enabled")]
pub enabled: bool,
#[serde(default = "default_web_port")]
pub port: u16,
}
fn default_web_enabled() -> bool {
true
}
fn default_web_port() -> u16 {
9876
}
impl Default for WebSettings {
fn default() -> Self {
Self {
enabled: default_web_enabled(),
port: default_web_port(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExfilDetectionSettings {
#[serde(default = "default_exfil_enabled")]
pub enabled: bool,
#[serde(default)]
pub additional_commands: Vec<String>,
}
fn default_exfil_enabled() -> bool {
true
}
impl Default for ExfilDetectionSettings {
fn default() -> Self {
Self {
enabled: default_exfil_enabled(),
additional_commands: Vec::new(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TeamSettings {
#[serde(default = "default_team_enabled")]
pub enabled: bool,
#[serde(default = "default_scan_interval")]
pub scan_interval: u32,
}
fn default_team_enabled() -> bool {
true
}
fn default_scan_interval() -> u32 {
5
}
impl Default for TeamSettings {
fn default() -> Self {
Self {
enabled: default_team_enabled(),
scan_interval: default_scan_interval(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AuditSettings {
#[serde(default = "default_audit_enabled")]
pub enabled: bool,
#[serde(default = "default_audit_max_size")]
pub max_size_bytes: u64,
#[serde(default)]
pub log_source_disagreement: bool,
}
fn default_audit_enabled() -> bool {
false
}
fn default_audit_max_size() -> u64 {
10_485_760
}
impl Default for AuditSettings {
fn default() -> Self {
Self {
enabled: default_audit_enabled(),
max_size_bytes: default_audit_max_size(),
log_source_disagreement: false,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AutoApproveSettings {
#[serde(default)]
pub enabled: bool,
#[serde(default)]
pub mode: Option<crate::auto_approve::types::AutoApproveMode>,
#[serde(default)]
pub rules: RuleSettings,
#[serde(default = "default_aa_provider")]
pub provider: String,
#[serde(default = "default_aa_model")]
pub model: String,
#[serde(default = "default_aa_timeout")]
pub timeout_secs: u64,
#[serde(default = "default_aa_cooldown")]
pub cooldown_secs: u64,
#[serde(default = "default_aa_interval")]
pub check_interval_ms: u64,
#[serde(default)]
pub allowed_types: Vec<String>,
#[serde(default = "default_aa_max_concurrent")]
pub max_concurrent: usize,
#[serde(default)]
pub custom_command: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RuleSettings {
#[serde(default = "default_true")]
pub allow_read: bool,
#[serde(default = "default_true")]
pub allow_tests: bool,
#[serde(default = "default_true")]
pub allow_fetch: bool,
#[serde(default = "default_true")]
pub allow_git_readonly: bool,
#[serde(default = "default_true")]
pub allow_format_lint: bool,
#[serde(default)]
pub allow_patterns: Vec<String>,
}
fn default_true() -> bool {
true
}
impl Default for RuleSettings {
fn default() -> Self {
Self {
allow_read: true,
allow_tests: true,
allow_fetch: true,
allow_git_readonly: true,
allow_format_lint: true,
allow_patterns: Vec::new(),
}
}
}
fn default_aa_provider() -> String {
"claude_haiku".to_string()
}
fn default_aa_model() -> String {
"haiku".to_string()
}
fn default_aa_timeout() -> u64 {
30
}
fn default_aa_cooldown() -> u64 {
10
}
fn default_aa_interval() -> u64 {
1000
}
fn default_aa_max_concurrent() -> usize {
3
}
impl AutoApproveSettings {
pub fn effective_mode(&self) -> crate::auto_approve::types::AutoApproveMode {
use crate::auto_approve::types::AutoApproveMode;
match self.mode {
Some(m) => m,
None => {
if self.enabled {
AutoApproveMode::Ai
} else {
AutoApproveMode::Off
}
}
}
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct UsageSettings {
#[serde(default)]
pub auto_refresh_min: u32,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct CreateProcessSettings {
#[serde(default)]
pub base_directories: Vec<String>,
#[serde(default)]
pub pinned: Vec<String>,
}
impl Default for AutoApproveSettings {
fn default() -> Self {
Self {
enabled: false,
mode: None,
rules: RuleSettings::default(),
provider: default_aa_provider(),
model: default_aa_model(),
timeout_secs: default_aa_timeout(),
cooldown_secs: default_aa_cooldown(),
check_interval_ms: default_aa_interval(),
allowed_types: Vec::new(),
max_concurrent: default_aa_max_concurrent(),
custom_command: None,
}
}
}
fn default_show_preview() -> bool {
true
}
fn default_preview_height() -> u16 {
40
}
fn default_color() -> bool {
true
}
fn default_show_activity_name() -> bool {
true
}
impl Default for UiSettings {
fn default() -> Self {
Self {
show_preview: default_show_preview(),
preview_height: default_preview_height(),
color: default_color(),
show_activity_name: default_show_activity_name(),
}
}
}
impl Default for Settings {
fn default() -> Self {
Self {
poll_interval_ms: default_poll_interval(),
passthrough_poll_interval_ms: default_passthrough_poll_interval(),
capture_lines: default_capture_lines(),
attached_only: default_attached_only(),
agent_patterns: Vec::new(),
ui: UiSettings::default(),
web: WebSettings::default(),
exfil_detection: ExfilDetectionSettings::default(),
teams: TeamSettings::default(),
audit: AuditSettings::default(),
auto_approve: AutoApproveSettings::default(),
create_process: CreateProcessSettings::default(),
usage: UsageSettings::default(),
}
}
}
impl Settings {
pub fn load(path: Option<&PathBuf>) -> Result<Self> {
if let Some(p) = path {
if p.exists() {
let content = std::fs::read_to_string(p)
.with_context(|| format!("Failed to read config file: {:?}", p))?;
return toml::from_str(&content)
.with_context(|| format!("Failed to parse config file: {:?}", p));
}
}
let default_paths = [
dirs::config_dir().map(|p| p.join("tmai/config.toml")),
dirs::home_dir().map(|p| p.join(".config/tmai/config.toml")),
dirs::home_dir().map(|p| p.join(".tmai.toml")),
];
for path in default_paths.iter().flatten() {
if path.exists() {
let content = std::fs::read_to_string(path)
.with_context(|| format!("Failed to read config file: {:?}", path))?;
return toml::from_str(&content)
.with_context(|| format!("Failed to parse config file: {:?}", path));
}
}
Ok(Self::default())
}
pub fn merge_cli(&mut self, cli: &Config) {
if let Some(poll_interval) = cli.poll_interval {
self.poll_interval_ms = poll_interval;
}
if let Some(capture_lines) = cli.capture_lines {
self.capture_lines = capture_lines;
}
if let Some(attached_only) = cli.attached_only {
self.attached_only = attached_only;
}
if cli.audit {
self.audit.enabled = true;
}
}
pub fn validate(&mut self) {
const MIN_POLL_INTERVAL: u64 = 1;
if self.poll_interval_ms < MIN_POLL_INTERVAL {
self.poll_interval_ms = MIN_POLL_INTERVAL;
}
if self.passthrough_poll_interval_ms < MIN_POLL_INTERVAL {
self.passthrough_poll_interval_ms = MIN_POLL_INTERVAL;
}
if self.auto_approve.check_interval_ms < 100 {
self.auto_approve.check_interval_ms = 100;
}
if self.auto_approve.max_concurrent == 0 {
self.auto_approve.max_concurrent = 1;
}
if self.auto_approve.timeout_secs == 0 {
self.auto_approve.timeout_secs = 5;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_settings() {
let settings = Settings::default();
assert_eq!(settings.poll_interval_ms, 500);
assert_eq!(settings.capture_lines, 100);
assert!(settings.attached_only);
assert!(settings.ui.show_preview);
}
#[test]
fn test_parse_toml() {
let toml = r#"
poll_interval_ms = 1000
capture_lines = 200
[ui]
show_preview = false
"#;
let settings: Settings = toml::from_str(toml).expect("Should parse TOML");
assert_eq!(settings.poll_interval_ms, 1000);
assert_eq!(settings.capture_lines, 200);
assert!(!settings.ui.show_preview);
}
}