use std::path::PathBuf;
use arrrg_derive::CommandLine;
use crate::Budget;
use crate::types::{KnownModel, MessageCreateTemplate, Model, SystemPrompt, ThinkingConfig};
const DEFAULT_MAX_TOKENS: u32 = 4096;
#[derive(CommandLine, Debug, Default, PartialEq, Eq)]
pub struct ChatArgs {
#[arrrg(optional, "Model to use (default: claude-haiku-4-5)", "MODEL")]
pub model: Option<String>,
#[arrrg(optional, "System prompt for the conversation", "PROMPT")]
pub system: Option<String>,
#[arrrg(optional, "Max tokens per response (default: 4096)", "TOKENS")]
pub max_tokens: Option<u32>,
#[arrrg(optional, "Sampling temperature (0.0 to 1.0)", "TEMP")]
pub temperature: Option<String>,
#[arrrg(optional, "Top-p (nucleus) sampling (0.0 to 1.0)", "TOP_P")]
pub top_p: Option<String>,
#[arrrg(optional, "Top-k sampling", "TOP_K")]
pub top_k: Option<u32>,
#[arrrg(
optional,
"Thinking budget in tokens (enables extended thinking)",
"TOKENS"
)]
pub thinking: Option<u32>,
#[arrrg(flag, "Disable ANSI colors/styles")]
pub no_color: bool,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ChatArgsError {
message: String,
}
impl std::fmt::Display for ChatArgsError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.message)
}
}
impl std::error::Error for ChatArgsError {}
fn parse_f32_arg(value: &str, name: &str) -> Result<f32, ChatArgsError> {
value.parse::<f32>().map_err(|_| ChatArgsError {
message: format!(
"invalid value for --{}: '{}' is not a valid number",
name, value
),
})
}
impl TryFrom<ChatArgs> for MessageCreateTemplate {
type Error = ChatArgsError;
fn try_from(args: ChatArgs) -> Result<Self, Self::Error> {
let mut template = MessageCreateTemplate::new();
if let Some(model) = args.model {
let parsed = model.parse::<Model>().unwrap_or(Model::Custom(model));
template = template.with_model(parsed);
}
if let Some(system) = args.system {
template = template.with_system(system);
}
if let Some(max_tokens) = args.max_tokens {
template = template.with_max_tokens(max_tokens);
}
if let Some(ref temp) = args.temperature {
template.temperature = Some(parse_f32_arg(temp, "temperature")?);
}
if let Some(ref top_p) = args.top_p {
template.top_p = Some(parse_f32_arg(top_p, "top-p")?);
}
template.top_k = args.top_k;
if let Some(thinking) = args.thinking {
template.thinking = Some(ThinkingConfig::enabled(thinking));
}
Ok(template)
}
}
#[derive(Debug, Clone)]
pub struct ChatConfig {
pub template: MessageCreateTemplate,
pub use_color: bool,
pub session_budget: Option<Budget>,
pub transcript_path: Option<PathBuf>,
pub caching_enabled: bool,
}
impl ChatConfig {
pub fn new() -> Self {
Self {
template: default_template(),
use_color: true,
session_budget: None,
transcript_path: None,
caching_enabled: true,
}
}
pub fn with_model(mut self, model: Model) -> Self {
self.template.model = Some(model);
self
}
pub fn with_system_prompt(mut self, prompt: String) -> Self {
self.template.system = Some(SystemPrompt::from(prompt));
self
}
pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
self.template.max_tokens = Some(max_tokens);
self
}
pub fn without_color(mut self) -> Self {
self.use_color = false;
self
}
pub fn with_temperature(mut self, temperature: Option<f32>) -> Self {
self.template.temperature = temperature;
self
}
pub fn with_top_p(mut self, top_p: Option<f32>) -> Self {
self.template.top_p = top_p;
self
}
pub fn with_top_k(mut self, top_k: Option<u32>) -> Self {
self.template.top_k = top_k;
self
}
pub fn with_stop_sequences(mut self, stop_sequences: Vec<String>) -> Self {
self.template.stop_sequences = Some(stop_sequences);
self
}
pub fn with_thinking_budget(mut self, budget: Option<u32>) -> Self {
self.template.thinking = budget.map(ThinkingConfig::enabled);
self
}
pub fn with_session_budget(mut self, budget: Option<u64>) -> Self {
self.session_budget = budget.map(Self::token_budget);
self
}
pub fn with_transcript_path(mut self, path: Option<PathBuf>) -> Self {
self.transcript_path = path;
self
}
pub fn with_caching(mut self, enabled: bool) -> Self {
self.caching_enabled = enabled;
self
}
pub fn model(&self) -> Model {
self.template
.model
.clone()
.unwrap_or(Model::Known(KnownModel::ClaudeHaiku45))
}
pub fn max_tokens(&self) -> u32 {
self.template.max_tokens.unwrap_or(DEFAULT_MAX_TOKENS)
}
pub fn system_prompt_text(&self) -> Option<&str> {
match self.template.system.as_ref()? {
SystemPrompt::String(text) => Some(text.as_str()),
SystemPrompt::Blocks(_) => None,
}
}
pub fn stop_sequences(&self) -> &[String] {
self.template.stop_sequences.as_deref().unwrap_or(&[])
}
pub fn thinking_budget(&self) -> Option<u32> {
match self.template.thinking {
Some(ThinkingConfig::Enabled { budget_tokens }) => Some(budget_tokens),
_ => None,
}
}
pub fn set_model(&mut self, model: Model) {
self.template.model = Some(model);
}
pub fn set_system_prompt(&mut self, prompt: Option<String>) {
self.template.system = prompt.map(SystemPrompt::from);
}
pub fn set_max_tokens(&mut self, max_tokens: u32) {
self.template.max_tokens = Some(max_tokens);
}
pub fn set_temperature(&mut self, temperature: Option<f32>) {
self.template.temperature = temperature;
}
pub fn set_top_p(&mut self, top_p: Option<f32>) {
self.template.top_p = top_p;
}
pub fn set_top_k(&mut self, top_k: Option<u32>) {
self.template.top_k = top_k;
}
pub fn set_thinking_budget(&mut self, budget: Option<u32>) {
self.template.thinking = budget.map(ThinkingConfig::enabled);
}
pub fn set_session_budget(&mut self, budget: Option<u64>) {
self.session_budget = budget.map(Self::token_budget);
}
fn token_budget(limit_tokens: u64) -> Budget {
Budget::new_with_rates(limit_tokens, 1, 1, 1, 1)
}
}
impl Default for ChatConfig {
fn default() -> Self {
Self::new()
}
}
impl TryFrom<ChatArgs> for ChatConfig {
type Error = ChatArgsError;
fn try_from(args: ChatArgs) -> Result<Self, Self::Error> {
let use_color = !args.no_color;
let template = default_template().merge(MessageCreateTemplate::try_from(args)?);
Ok(ChatConfig {
template,
use_color,
session_budget: None,
transcript_path: None,
caching_enabled: true,
})
}
}
fn default_template() -> MessageCreateTemplate {
let mut template = MessageCreateTemplate::new();
template.model = Some(Model::Known(KnownModel::ClaudeHaiku45));
template.max_tokens = Some(DEFAULT_MAX_TOKENS);
template
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_config() {
let config = ChatConfig::new();
assert_eq!(config.model(), Model::Known(KnownModel::ClaudeHaiku45));
assert_eq!(config.max_tokens(), 4096);
assert!(config.use_color);
assert!(config.template.system.is_none());
assert!(config.template.temperature.is_none());
assert!(config.template.top_p.is_none());
assert!(config.template.top_k.is_none());
assert!(config.stop_sequences().is_empty());
assert!(config.thinking_budget().is_none());
assert!(config.session_budget.is_none());
assert!(config.transcript_path.is_none());
assert!(config.caching_enabled);
}
#[test]
fn config_from_args_defaults() {
let args = ChatArgs::default();
let config = ChatConfig::try_from(args).unwrap();
assert_eq!(config.model(), Model::Known(KnownModel::ClaudeHaiku45));
assert_eq!(config.max_tokens(), 4096);
assert!(config.use_color);
assert!(config.thinking_budget().is_none());
}
#[test]
fn config_from_args_custom() {
let args = ChatArgs {
model: Some("claude-sonnet-4-0".to_string()),
system: Some("You are helpful.".to_string()),
max_tokens: Some(8192),
temperature: Some("0.7".to_string()),
top_p: Some("0.9".to_string()),
top_k: Some(40),
thinking: Some(2048),
no_color: true,
};
let config = ChatConfig::try_from(args).unwrap();
assert_eq!(config.model(), Model::Known(KnownModel::ClaudeSonnet40));
assert_eq!(config.system_prompt_text(), Some("You are helpful."));
assert_eq!(config.max_tokens(), 8192);
assert_eq!(config.template.temperature, Some(0.7));
assert_eq!(config.template.top_p, Some(0.9));
assert_eq!(config.template.top_k, Some(40));
assert_eq!(config.thinking_budget(), Some(2048));
assert!(!config.use_color);
}
#[test]
fn config_from_args_invalid_temperature() {
let args = ChatArgs {
temperature: Some("not-a-number".to_string()),
..Default::default()
};
let result = ChatConfig::try_from(args);
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err.message.contains("--temperature"));
assert!(err.message.contains("not-a-number"));
}
#[test]
fn config_from_args_invalid_top_p() {
let args = ChatArgs {
top_p: Some("invalid".to_string()),
..Default::default()
};
let result = ChatConfig::try_from(args);
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err.message.contains("--top-p"));
}
#[test]
fn config_builder_pattern() {
let config = ChatConfig::new()
.with_model(Model::Known(KnownModel::ClaudeSonnet40))
.with_system_prompt("Test prompt".to_string())
.with_max_tokens(2048)
.without_color()
.with_temperature(Some(0.6))
.with_top_p(Some(0.9))
.with_top_k(Some(64))
.with_stop_sequences(vec!["END".to_string()])
.with_thinking_budget(Some(2048))
.with_session_budget(Some(10_000))
.with_transcript_path(Some(PathBuf::from("transcript.json")))
.with_caching(false);
assert_eq!(config.model(), Model::Known(KnownModel::ClaudeSonnet40));
assert_eq!(config.system_prompt_text(), Some("Test prompt"));
assert_eq!(config.max_tokens(), 2048);
assert!(!config.use_color);
assert_eq!(config.template.temperature, Some(0.6));
assert_eq!(config.template.top_p, Some(0.9));
assert_eq!(config.template.top_k, Some(64));
assert_eq!(config.stop_sequences(), vec!["END".to_string()]);
assert_eq!(config.thinking_budget(), Some(2048));
assert_eq!(
config
.session_budget
.as_ref()
.map(Budget::total_micro_cents),
Some(10_000)
);
assert_eq!(
config.transcript_path,
Some(PathBuf::from("transcript.json"))
);
assert!(!config.caching_enabled);
}
}