Skip to main content

hh_cli/config/
settings.rs

1use serde::{Deserialize, Serialize};
2use std::collections::BTreeMap;
3use std::path::PathBuf;
4
5use crate::core::system_prompt::default_system_prompt;
6
7#[derive(Debug, Clone, Serialize, Deserialize, Default)]
8pub struct Settings {
9    #[serde(default)]
10    pub models: ModelSettings,
11    #[serde(default)]
12    pub providers: BTreeMap<String, ProviderConfig>,
13    pub agent: AgentSettings,
14    pub tools: ToolSettings,
15    pub permission: PermissionSettings,
16    pub session: SessionSettings,
17    #[serde(default)]
18    pub selected_agent: Option<String>,
19    #[serde(default)]
20    pub agents: BTreeMap<String, AgentSpecificSettings>,
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct ModelSettings {
25    #[serde(default = "default_model_ref")]
26    pub default: String,
27}
28
29#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct ProviderConfig {
31    #[serde(default)]
32    pub display_name: String,
33    pub base_url: String,
34    pub api_key_env: String,
35    #[serde(default)]
36    pub models: BTreeMap<String, ModelMetadata>,
37}
38
39#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
40pub struct ModelMetadata {
41    #[serde(default, alias = "provider_model_id")]
42    pub id: String,
43    #[serde(default)]
44    pub display_name: String,
45    #[serde(default)]
46    pub modalities: ModelModalities,
47    #[serde(default)]
48    pub limits: ModelLimits,
49}
50
51#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)]
52#[serde(rename_all = "snake_case")]
53pub enum ModelModalityType {
54    #[default]
55    Text,
56    Image,
57    Audio,
58    Video,
59}
60
61impl std::fmt::Display for ModelModalityType {
62    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
63        let label = match self {
64            Self::Text => "text",
65            Self::Image => "image",
66            Self::Audio => "audio",
67            Self::Video => "video",
68        };
69        f.write_str(label)
70    }
71}
72
73#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
74pub struct ModelModalities {
75    #[serde(default = "default_input_modalities")]
76    pub input: Vec<ModelModalityType>,
77    #[serde(default = "default_output_modalities")]
78    pub output: Vec<ModelModalityType>,
79}
80
81impl Default for ModelModalities {
82    fn default() -> Self {
83        Self {
84            input: default_input_modalities(),
85            output: default_output_modalities(),
86        }
87    }
88}
89
90#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
91pub struct ModelLimits {
92    #[serde(default = "default_model_context_limit")]
93    pub context: usize,
94    #[serde(default = "default_model_output_limit")]
95    pub output: usize,
96}
97
98impl Default for ModelLimits {
99    fn default() -> Self {
100        Self {
101            context: default_model_context_limit(),
102            output: default_model_output_limit(),
103        }
104    }
105}
106
107#[derive(Debug, Clone)]
108pub struct ResolvedModel<'a> {
109    pub provider_id: String,
110    pub model_id: String,
111    pub provider: &'a ProviderConfig,
112    pub model: &'a ModelMetadata,
113}
114
115impl<'a> ResolvedModel<'a> {
116    pub fn full_id(&self) -> String {
117        format!("{}/{}", self.provider_id, self.model_id)
118    }
119}
120
121impl Settings {
122    pub fn selected_model_ref(&self) -> &str {
123        self.models.default.as_str()
124    }
125
126    pub fn selected_model(&self) -> Option<ResolvedModel<'_>> {
127        self.resolve_model_ref(self.models.default.as_str())
128    }
129
130    pub fn resolve_model_ref(&self, model_ref: &str) -> Option<ResolvedModel<'_>> {
131        let (provider_id, model_id) = split_model_ref(model_ref)?;
132        let provider = self.providers.get(provider_id)?;
133        let model = provider.models.get(model_id)?;
134        Some(ResolvedModel {
135            provider_id: provider_id.to_string(),
136            model_id: model_id.to_string(),
137            provider,
138            model,
139        })
140    }
141
142    pub fn model_refs(&self) -> Vec<String> {
143        let mut refs = Vec::new();
144        for (provider_id, provider) in &self.providers {
145            for model_id in provider.models.keys() {
146                refs.push(format!("{provider_id}/{model_id}"));
147            }
148        }
149        refs
150    }
151
152    pub fn normalize_models(&mut self) {
153        if self.models.default.trim().is_empty() {
154            self.models.default = default_model_ref();
155        }
156
157        if self.providers.is_empty() {
158            self.providers = default_providers();
159        }
160
161        if !self.models.default.contains('/')
162            && let Some(provider_id) = self.providers.keys().next().cloned()
163        {
164            self.models.default = format!("{provider_id}/{}", self.models.default);
165        }
166
167        if let Some((provider_id, model_id)) = split_model_ref(self.models.default.as_str())
168            && let Some(provider) = self.providers.get_mut(provider_id)
169            && !provider.models.contains_key(model_id)
170        {
171            provider.models.insert(
172                model_id.to_string(),
173                ModelMetadata {
174                    id: model_id.to_string(),
175                    display_name: model_id.to_string(),
176                    modalities: ModelModalities::default(),
177                    limits: ModelLimits::default(),
178                },
179            );
180        }
181
182        for provider in self.providers.values_mut() {
183            for (model_id, model) in &mut provider.models {
184                if model.id.trim().is_empty() {
185                    model.id = model_id.clone();
186                }
187            }
188        }
189
190        if self.selected_model().is_none()
191            && let Some((provider_id, provider)) = self.providers.iter().next()
192            && let Some((model_id, _)) = provider.models.iter().next()
193        {
194            self.models.default = format!("{provider_id}/{model_id}");
195        }
196    }
197}
198
199impl Default for ModelSettings {
200    fn default() -> Self {
201        Self {
202            default: default_model_ref(),
203        }
204    }
205}
206
207fn default_provider_id() -> String {
208    "openai".to_string()
209}
210
211fn default_provider_model() -> String {
212    "gpt-4.1-mini".to_string()
213}
214
215fn default_model_ref() -> String {
216    format!("{}/{}", default_provider_id(), default_provider_model())
217}
218
219fn default_provider_base_url() -> String {
220    "https://api.openai.com/v1".to_string()
221}
222
223fn default_api_key_env() -> String {
224    "OPENAI_API_KEY".to_string()
225}
226
227fn default_provider_display_name() -> String {
228    "OpenAI".to_string()
229}
230
231fn default_providers() -> BTreeMap<String, ProviderConfig> {
232    let mut providers = BTreeMap::new();
233    providers.insert(
234        default_provider_id(),
235        ProviderConfig {
236            display_name: default_provider_display_name(),
237            base_url: default_provider_base_url(),
238            api_key_env: default_api_key_env(),
239            models: BTreeMap::from([(
240                default_provider_model(),
241                ModelMetadata {
242                    id: default_provider_model(),
243                    display_name: "GPT-4.1 mini".to_string(),
244                    modalities: ModelModalities::default(),
245                    limits: ModelLimits::default(),
246                },
247            )]),
248        },
249    );
250    providers
251}
252
253fn split_model_ref(model_ref: &str) -> Option<(&str, &str)> {
254    let (provider_id, model_id) = model_ref.split_once('/')?;
255    let provider_id = provider_id.trim();
256    let model_id = model_id.trim();
257    if provider_id.is_empty() || model_id.is_empty() {
258        return None;
259    }
260    Some((provider_id, model_id))
261}
262
263fn default_input_modalities() -> Vec<ModelModalityType> {
264    vec![ModelModalityType::Text, ModelModalityType::Image]
265}
266
267fn default_output_modalities() -> Vec<ModelModalityType> {
268    vec![ModelModalityType::Text]
269}
270
271fn default_model_context_limit() -> usize {
272    128_000
273}
274
275fn default_model_output_limit() -> usize {
276    128_000
277}
278
279#[derive(Debug, Clone, Serialize, Deserialize)]
280pub struct AgentSettings {
281    pub max_steps: usize,
282    #[serde(default = "default_sub_agent_max_depth")]
283    pub sub_agent_max_depth: usize,
284    #[serde(default = "default_parallel_subagents")]
285    pub parallel_subagents: bool,
286    #[serde(default = "default_max_parallel_subagents")]
287    pub max_parallel_subagents: usize,
288    #[serde(default, skip_serializing_if = "Option::is_none")]
289    pub system_prompt: Option<String>,
290}
291
292impl Default for AgentSettings {
293    fn default() -> Self {
294        Self {
295            max_steps: 0,
296            sub_agent_max_depth: default_sub_agent_max_depth(),
297            parallel_subagents: default_parallel_subagents(),
298            max_parallel_subagents: default_max_parallel_subagents(),
299            system_prompt: None,
300        }
301    }
302}
303
304fn default_sub_agent_max_depth() -> usize {
305    2
306}
307
308fn default_parallel_subagents() -> bool {
309    true
310}
311
312fn default_max_parallel_subagents() -> usize {
313    2
314}
315
316impl AgentSettings {
317    pub fn resolved_system_prompt(&self) -> String {
318        self.system_prompt
319            .clone()
320            .filter(|s| !s.trim().is_empty())
321            .unwrap_or_else(default_system_prompt)
322    }
323}
324
325#[derive(Debug, Clone, Serialize, Deserialize)]
326pub struct ToolSettings {
327    pub fs: bool,
328    pub bash: bool,
329    pub web: bool,
330}
331
332impl Default for ToolSettings {
333    fn default() -> Self {
334        Self {
335            fs: true,
336            bash: true,
337            web: true,
338        }
339    }
340}
341
342#[derive(Debug, Clone, Serialize, Deserialize)]
343pub struct PermissionSettings {
344    pub read: String,
345    pub list: String,
346    pub glob: String,
347    pub grep: String,
348    pub write: String,
349    #[serde(default = "default_edit_permission")]
350    pub edit: String,
351    #[serde(default = "default_todo_write_permission")]
352    pub todo_write: String,
353    #[serde(default = "default_todo_read_permission")]
354    pub todo_read: String,
355    #[serde(default = "default_question_permission")]
356    pub question: String,
357    #[serde(default = "default_task_permission")]
358    pub task: String,
359    pub bash: String,
360    pub web: String,
361    #[serde(default)]
362    pub capabilities: BTreeMap<String, String>,
363}
364
365impl Default for PermissionSettings {
366    fn default() -> Self {
367        Self {
368            read: "allow".to_string(),
369            list: "allow".to_string(),
370            glob: "allow".to_string(),
371            grep: "allow".to_string(),
372            write: "ask".to_string(),
373            edit: default_edit_permission(),
374            todo_write: default_todo_write_permission(),
375            todo_read: default_todo_read_permission(),
376            question: default_question_permission(),
377            task: default_task_permission(),
378            bash: "ask".to_string(),
379            web: "ask".to_string(),
380            capabilities: BTreeMap::new(),
381        }
382    }
383}
384
385fn default_edit_permission() -> String {
386    "ask".to_string()
387}
388
389fn default_todo_write_permission() -> String {
390    "allow".to_string()
391}
392
393fn default_todo_read_permission() -> String {
394    "allow".to_string()
395}
396
397fn default_question_permission() -> String {
398    "allow".to_string()
399}
400
401fn default_task_permission() -> String {
402    "allow".to_string()
403}
404
405#[derive(Debug, Clone, Serialize, Deserialize)]
406pub struct SessionSettings {
407    pub root: PathBuf,
408}
409
410impl Default for SessionSettings {
411    fn default() -> Self {
412        let home = dirs::home_dir().unwrap_or_else(|| PathBuf::from("."));
413        Self {
414            root: home.join(".local/state/hh/sessions"),
415        }
416    }
417}
418
419#[derive(Debug, Clone, Serialize, Deserialize, Default)]
420pub struct AgentSpecificSettings {
421    #[serde(default)]
422    pub model: Option<String>,
423}
424
425impl Settings {
426    pub fn apply_agent_settings(&mut self, agent: &crate::agent::AgentConfig) {
427        // Apply agent system prompt if specified
428        if let Some(prompt) = &agent.system_prompt {
429            self.agent.system_prompt = Some(prompt.clone());
430        }
431
432        // Apply agent model or use global override
433        let model_to_use = agent
434            .model
435            .as_ref()
436            .or_else(|| self.agents.get(&agent.name).and_then(|s| s.model.as_ref()))
437            .or(Some(&self.models.default));
438
439        if let Some(model) = model_to_use {
440            self.models.default = model.clone();
441        }
442
443        // Apply permission overrides
444        for (capability, policy) in &agent.permission_overrides {
445            match capability.as_str() {
446                "read" => self.permission.read = policy.clone(),
447                "list" => self.permission.list = policy.clone(),
448                "glob" => self.permission.glob = policy.clone(),
449                "grep" => self.permission.grep = policy.clone(),
450                "write" => self.permission.write = policy.clone(),
451                "edit" => self.permission.edit = policy.clone(),
452                "todo_write" => self.permission.todo_write = policy.clone(),
453                "todo_read" => self.permission.todo_read = policy.clone(),
454                "question" => self.permission.question = policy.clone(),
455                "task" => self.permission.task = policy.clone(),
456                "bash" => self.permission.bash = policy.clone(),
457                "web" => self.permission.web = policy.clone(),
458                _ => {
459                    self.permission
460                        .capabilities
461                        .insert(capability.clone(), policy.clone());
462                }
463            }
464        }
465    }
466}