Skip to main content

outrig_cli/
config_init.rs

1//! `outrig config init` -- interactive writer for the global config.
2//!
3//! Walks the user through providers, models, and `default-model`, then writes
4//! a parseable + validated TOML file to the resolved global-config path. The
5//! file refuses to clobber an existing one without `--force`. Atomic writes go
6//! through `tempfile::NamedTempFile::persist` so an interrupted prompt never
7//! leaves a half-written config behind.
8//!
9//! `run` constructs real terminal I/O; `run_with` is the test seam that takes
10//! an arbitrary `PromptSource` and target path.
11
12use std::collections::BTreeMap;
13use std::path::{Path, PathBuf};
14
15use serde::Serialize;
16
17use crate::error::{OutrigError, Result};
18use crate::hf::{self, HfTreeFetcher};
19use crate::init::prompt::{self, Field, PromptSource};
20use crate::paths::{global_config_path, write_atomic};
21use outrig::config::{ApiKeyRef, LlmProvider, Model};
22
23/// Public entry: resolve the path, pick a `PromptSource` via
24/// `prompt::auto()` (dialoguer on a TTY, line-based on piped stdin), and
25/// delegate to `run_with`. `global_override` plumbs the top-level
26/// `--global-config` flag into the same resolver [`global_config_path`]
27/// uses elsewhere.
28pub async fn run(force: bool, global_override: Option<&Path>) -> Result<()> {
29    let path = global_config_path(global_override);
30    eprintln!("[outrig] writing global config to {}", path.display());
31    let mut prompt = prompt::auto();
32    let mut hf = hf::auto();
33    run_with(force, &path, &mut prompt, &mut hf).await?;
34    eprintln!("[outrig] wrote {}", path.display());
35    Ok(())
36}
37
38/// Drives the interactive flow against an arbitrary `PromptSource`. The flow
39/// short-circuits on existing files when `force == false` so an accidental
40/// re-run doesn't burn through prompts before bailing. `hf` is the
41/// HuggingFace tree-listing client used to discover GGUF files for
42/// mistralrs `model-id` configs; tests pass a stub.
43pub async fn run_with(
44    force: bool,
45    path: &Path,
46    prompt: &mut impl PromptSource,
47    hf: &mut impl HfTreeFetcher,
48) -> Result<()> {
49    if path.exists() && !force {
50        return Err(OutrigError::Configuration(format!(
51            "{} already exists; pass --force to overwrite.",
52            path.display()
53        ))
54        .into());
55    }
56
57    let mut providers = prompt_providers(prompt).await?;
58    let models = prompt_models(prompt, &mut providers, hf).await?;
59    let default_model = prompt_default_model(prompt, &models).await?;
60
61    let toml_text = render(default_model.as_deref(), &providers, &models)?;
62    write_atomic(path, &toml_text)?;
63    Ok(())
64}
65
66// ---- prompt-flow helpers --------------------------------------------------
67
68const STYLES: &[(&str, &str)] = &[
69    (
70        "openai",
71        "OpenAI Chat Completions wire format. Works with OpenAI, OpenRouter, vLLM, Ollama.",
72    ),
73    (
74        "mistralrs",
75        "In-process LLM via the mistralrs crate. Loads a local or HuggingFace model.",
76    ),
77];
78
79const STYLE_FIELD: Field = Field {
80    name: "Pick a provider style",
81    description: "Which wire format / runtime this provider speaks.",
82    options: STYLES,
83    doc_link: "doc/concepts/llm-providers.md",
84};
85
86const PROVIDER_NAME_FIELD: Field = Field {
87    name: "Provider name",
88    description: "Used as the [providers.<name>] key and referenced from models.",
89    options: &[],
90    doc_link: "doc/reference/config.md",
91};
92
93const BASE_URL_FIELD: Field = Field {
94    name: "Base URL",
95    description: "HTTPS endpoint for the OpenAI-compatible API.",
96    options: &[],
97    doc_link: "doc/concepts/llm-providers.md",
98};
99
100const API_KEY_ENV_FIELD: Field = Field {
101    name: "API key environment variable",
102    description: "Name of the env var that holds the API key. Stored as ${VAR}.",
103    options: &[],
104    doc_link: "doc/reference/config.md",
105};
106
107const ADD_PROVIDER_FIELD: Field = Field {
108    name: "Add another provider?",
109    description: "Whether to define one more [providers.<name>] entry.",
110    options: &[],
111    doc_link: "doc/reference/config.md",
112};
113
114const AUTO_DOWNLOAD_FIELD: Field = Field {
115    name: "Use auto-download by model ID?",
116    description: "Yes: pull weights from HuggingFace by repo ID. No: load a local GGUF file by path.",
117    options: &[],
118    doc_link: "doc/concepts/in-process-llm.md",
119};
120
121const MODEL_ID_FIELD: Field = Field {
122    name: "HuggingFace model-id",
123    description: "Repo identifier, e.g. microsoft/Phi-3-mini-4k-instruct-gguf.",
124    options: &[],
125    doc_link: "doc/concepts/in-process-llm.md",
126};
127
128const REVISION_FIELD: Field = Field {
129    name: "revision (blank for `main`)",
130    description: "Git ref on the HuggingFace repo to pin. Defaults to `main`.",
131    options: &[],
132    doc_link: "doc/concepts/in-process-llm.md",
133};
134
135const MODEL_PATH_FIELD: Field = Field {
136    name: "Local model-path",
137    description: "Filesystem path to a GGUF file.",
138    options: &[],
139    doc_link: "doc/concepts/in-process-llm.md",
140};
141
142const MODEL_FILE_FIELD: Field = Field {
143    name: "GGUF model-file",
144    description: "Filename inside the HF repo, e.g. \
145                  qwen2.5-coder-1.5b-instruct-q4_k_m.gguf. Used to pick \
146                  one quantization out of a multi-file repo.",
147    options: &[],
148    doc_link: "doc/concepts/in-process-llm.md",
149};
150
151const MODEL_FILE_PICK_FIELD: Field = Field {
152    name: "Pick GGUF file(s) from the repo",
153    description: "Comma-separated numbers (e.g. `1,3`) or filenames. Pick \
154                  multiple only when one quantization is split across \
155                  shards (model-00001-of-00003.gguf, ...). The first \
156                  option is the default.",
157    options: &[],
158    doc_link: "doc/concepts/in-process-llm.md",
159};
160
161const CONTEXT_LENGTH_FIELD: Field = Field {
162    name: "context-length (blank for the model's default)",
163    description: "Override the model's default context window. Integer.",
164    options: &[],
165    doc_link: "doc/concepts/in-process-llm.md",
166};
167
168const DEFINE_MODEL_FIELD: Field = Field {
169    name: "Define a model now?",
170    description: "Whether to add a [models.<name>] entry to the new config.",
171    options: &[],
172    doc_link: "doc/reference/config.md",
173};
174
175const MODEL_NAME_FIELD: Field = Field {
176    name: "Model name",
177    description: "Used as the [models.<name>] key and referenced from agents.",
178    options: &[],
179    doc_link: "doc/reference/config.md",
180};
181
182const MODEL_IDENTIFIER_FIELD: Field = Field {
183    name: "Model identifier",
184    description: "Identifier passed to the provider API (e.g. gpt-4o-mini).",
185    options: &[],
186    doc_link: "doc/reference/config.md",
187};
188
189const MODEL_PROVIDER_FIELD: Field = Field {
190    name: "Provider for this model",
191    description: "An LLM provider is a backend that hosts the model -- e.g. \
192                  OpenAI, OpenRouter, vLLM, or a local mistralrs runtime. \
193                  Each carries its own connection details (URL, API key, \
194                  etc.). This can be the name of an existing \
195                  [providers.<name>] entry or you can give a new name to \
196                  create a new provider.",
197    options: &[],
198    doc_link: "doc/concepts/llm-providers.md",
199};
200
201const ADD_NEW_PROVIDER_FIELD: Field = Field {
202    name: "Add this provider now?",
203    description: "Yes: walk through the provider style + connection prompts \
204                  to define a new [providers.<name>] entry under the name \
205                  you just typed. No: re-enter the provider name.",
206    options: &[],
207    doc_link: "doc/concepts/llm-providers.md",
208};
209
210const ADD_MODEL_FIELD: Field = Field {
211    name: "Add another model?",
212    description: "Whether to define one more [models.<name>] entry.",
213    options: &[],
214    doc_link: "doc/reference/config.md",
215};
216
217const USE_DEFAULT_FIELD: Field = Field {
218    name: "Use this model as default-model?",
219    description: "Sets the top-level `default-model` so agents without an explicit model use it.",
220    options: &[],
221    doc_link: "doc/reference/config.md",
222};
223
224const DEFAULT_MODEL_FIELD: Field = Field {
225    name: "Default model name",
226    description: "Name of an existing model to set as `default-model`. Blank for none.",
227    options: &[],
228    doc_link: "doc/reference/config.md",
229};
230
231/// Slice of every `Field` declared in this module, for `prompt_doc_sync.rs`.
232pub const DOC_SYNC_FIELDS: &[&Field] = &[
233    &STYLE_FIELD,
234    &PROVIDER_NAME_FIELD,
235    &BASE_URL_FIELD,
236    &API_KEY_ENV_FIELD,
237    &ADD_PROVIDER_FIELD,
238    &AUTO_DOWNLOAD_FIELD,
239    &MODEL_ID_FIELD,
240    &REVISION_FIELD,
241    &MODEL_PATH_FIELD,
242    &MODEL_FILE_FIELD,
243    &MODEL_FILE_PICK_FIELD,
244    &CONTEXT_LENGTH_FIELD,
245    &DEFINE_MODEL_FIELD,
246    &MODEL_NAME_FIELD,
247    &MODEL_IDENTIFIER_FIELD,
248    &MODEL_PROVIDER_FIELD,
249    &ADD_NEW_PROVIDER_FIELD,
250    &ADD_MODEL_FIELD,
251    &USE_DEFAULT_FIELD,
252    &DEFAULT_MODEL_FIELD,
253];
254
255async fn prompt_providers(prompt: &mut impl PromptSource) -> Result<BTreeMap<String, LlmProvider>> {
256    let mut out = BTreeMap::new();
257    loop {
258        let style_idx = prompt.ask_select(&STYLE_FIELD, 0).await?;
259        let style = STYLES[style_idx].0;
260        let name = prompt.ask_string(&PROVIDER_NAME_FIELD, style).await?;
261        let provider = prompt_provider_body(prompt, style).await?;
262        out.insert(name, provider);
263
264        if !prompt.ask_bool(&ADD_PROVIDER_FIELD, false).await? {
265            break;
266        }
267    }
268    Ok(out)
269}
270
271/// Walks just the style-specific prompts for a single provider whose name
272/// is already known. Used inline by `prompt_models_loop` when the user
273/// references a provider that doesn't exist yet -- we already have the
274/// name (what they typed at the model's provider prompt) and only need to
275/// ask the style + connection details.
276pub(crate) async fn prompt_new_provider_for_name(
277    prompt: &mut impl PromptSource,
278) -> Result<LlmProvider> {
279    let style_idx = prompt.ask_select(&STYLE_FIELD, 0).await?;
280    let style = STYLES[style_idx].0;
281    prompt_provider_body(prompt, style).await
282}
283
284async fn prompt_provider_body(prompt: &mut impl PromptSource, style: &str) -> Result<LlmProvider> {
285    match style {
286        "openai" => prompt_openai_provider(prompt).await,
287        "mistralrs" => Ok(LlmProvider::Mistralrs),
288        other => Err(OutrigError::Configuration(format!("unknown provider style: {other}")).into()),
289    }
290}
291
292async fn prompt_openai_provider(prompt: &mut impl PromptSource) -> Result<LlmProvider> {
293    let base_url = prompt
294        .ask_string(&BASE_URL_FIELD, "https://api.openai.com/v1")
295        .await?;
296    // We capture the env-var name and render it as `${VAR}` -- `ApiKeyRef` only
297    // accepts that form, so feeding a bare name would be rejected at parse time.
298    let env_name = prompt
299        .ask_string(&API_KEY_ENV_FIELD, "OPENAI_API_KEY")
300        .await?;
301    let api_key = ApiKeyRef::parse(&format!("${{{env_name}}}"))?;
302    Ok(LlmProvider::OpenAi {
303        base_url,
304        api_key,
305        request_timeout_secs: None,
306    })
307}
308
309async fn prompt_models(
310    prompt: &mut impl PromptSource,
311    providers: &mut BTreeMap<String, LlmProvider>,
312    hf: &mut impl HfTreeFetcher,
313) -> Result<BTreeMap<String, Model>> {
314    if !prompt.ask_bool(&DEFINE_MODEL_FIELD, true).await? {
315        return Ok(BTreeMap::new());
316    }
317    let (models, new_providers) = prompt_models_loop(prompt, providers, hf).await?;
318    providers.extend(new_providers);
319    Ok(models)
320}
321
322/// The model-add loop without the outer `Define a model now?` gate.
323/// Returns `(models, new_providers)` -- providers added inline (when the
324/// user references one that doesn't exist yet) come back to the caller so
325/// init::repo can write them to the repo config without mutating the
326/// global providers it was passed.
327pub(crate) async fn prompt_models_loop(
328    prompt: &mut impl PromptSource,
329    existing_providers: &BTreeMap<String, LlmProvider>,
330    hf: &mut impl HfTreeFetcher,
331) -> Result<(BTreeMap<String, Model>, BTreeMap<String, LlmProvider>)> {
332    let mut out = BTreeMap::new();
333    let mut new_providers: BTreeMap<String, LlmProvider> = BTreeMap::new();
334
335    loop {
336        let name = prompt.ask_string(&MODEL_NAME_FIELD, "fast").await?;
337
338        // Print providers defined so far (existing + any added inline) so
339        // the user has the list at hand for the next prompt.
340        let provider_names: Vec<&str> = existing_providers
341            .keys()
342            .chain(new_providers.keys())
343            .map(String::as_str)
344            .collect();
345        if !provider_names.is_empty() {
346            eprintln!("[outrig] providers defined: {}", provider_names.join(", "));
347        }
348
349        let suggestion = provider_names
350            .first()
351            .copied()
352            .unwrap_or("openai")
353            .to_string();
354        let provider_name = loop {
355            let answer = prompt
356                .ask_string(&MODEL_PROVIDER_FIELD, &suggestion)
357                .await?;
358            if existing_providers.contains_key(&answer) || new_providers.contains_key(&answer) {
359                break answer;
360            }
361            eprintln!("[outrig] no provider named `{answer}` yet.");
362            if prompt.ask_bool(&ADD_NEW_PROVIDER_FIELD, true).await? {
363                let provider = prompt_new_provider_for_name(prompt).await?;
364                new_providers.insert(answer.clone(), provider);
365                break answer;
366            }
367        };
368        let provider = existing_providers
369            .get(&provider_name)
370            .or_else(|| new_providers.get(&provider_name))
371            .expect("validated above");
372        let model = match provider {
373            LlmProvider::OpenAi { .. } => {
374                let identifier = prompt
375                    .ask_string(&MODEL_IDENTIFIER_FIELD, "gpt-4o-mini")
376                    .await?;
377                Model {
378                    provider: provider_name,
379                    identifier: Some(identifier),
380                    model_id: None,
381                    model_path: None,
382                    model_file: None,
383                    revision: None,
384                    context_length: None,
385                    device: None,
386                }
387            }
388            LlmProvider::Mistralrs => prompt_mistralrs_model(prompt, hf, provider_name).await?,
389        };
390        out.insert(name, model);
391        if !prompt.ask_bool(&ADD_MODEL_FIELD, false).await? {
392            break;
393        }
394    }
395    Ok((out, new_providers))
396}
397
398async fn prompt_mistralrs_model(
399    prompt: &mut impl PromptSource,
400    hf: &mut impl HfTreeFetcher,
401    provider_name: String,
402) -> Result<Model> {
403    let auto_download = prompt.ask_bool(&AUTO_DOWNLOAD_FIELD, true).await?;
404    let (model_id, model_file, model_path, revision) = if auto_download {
405        let id = ask_required(prompt, &MODEL_ID_FIELD).await?;
406        let rev = blank_to_none(prompt.ask_string(&REVISION_FIELD, "").await?);
407        let file = resolve_model_file(prompt, hf, &id, rev.as_deref()).await?;
408        (Some(id), Some(file), None, rev)
409    } else {
410        let path = ask_required(prompt, &MODEL_PATH_FIELD).await?;
411        (None, None, Some(PathBuf::from(path)), None)
412    };
413    let context_length = blank_to_none(prompt.ask_string(&CONTEXT_LENGTH_FIELD, "").await?)
414        .map(|s| {
415            s.parse::<u32>().map_err(|_| {
416                OutrigError::Configuration(format!(
417                    "context-length must be a non-negative integer; got `{s}`"
418                ))
419            })
420        })
421        .transpose()?;
422    Ok(Model {
423        provider: provider_name,
424        identifier: None,
425        model_id,
426        model_path,
427        model_file,
428        revision,
429        context_length,
430        device: None,
431    })
432}
433
434/// Discover GGUF files in `model_id` via `hf` and pick one or more. On a
435/// successful query: 0 files -> error, 1 file -> auto-pick (status line,
436/// no prompt), many -> render a numbered list (with sizes) and prompt
437/// for a comma-separated choice (numbers or filenames). On any HF error
438/// (offline, build without `mistralrs`, transient outage), fall back to
439/// the free-form `MODEL_FILE_FIELD` text prompt so the flow still
440/// completes.
441///
442/// Multi-select supports split-quantization repos where one quantization
443/// is sharded across multiple `model-NNNNN-of-NNNNN.gguf` files;
444/// mistralrs's GGUF loader takes the whole list.
445async fn resolve_model_file(
446    prompt: &mut impl PromptSource,
447    hf: &mut impl HfTreeFetcher,
448    model_id: &str,
449    revision: Option<&str>,
450) -> Result<Vec<String>> {
451    let files = match hf.list_files(model_id, revision).await {
452        Ok(siblings) => crate::hf::filter_gguf(siblings),
453        Err(e) => {
454            eprintln!(
455                "[outrig] could not list files in {model_id:?} ({e}); \
456                 enter the GGUF filename manually."
457            );
458            return ask_required(prompt, &MODEL_FILE_FIELD)
459                .await
460                .map(|s| vec![s]);
461        }
462    };
463
464    match files.as_slice() {
465        [] => Err(OutrigError::Configuration(format!(
466            "HF repo {model_id:?} contains no .gguf files; pick a different model-id"
467        ))
468        .into()),
469        [only] => {
470            let label = format_file_label(only);
471            eprintln!("[outrig] found one GGUF in {model_id:?}: {label}; using it");
472            Ok(vec![only.path.clone()])
473        }
474        many => {
475            eprintln!("[outrig] {} GGUF files in {model_id:?}:", many.len());
476            let idx_w = (many.len() as f64).log10().floor() as usize + 1;
477            for (i, file) in many.iter().enumerate() {
478                eprintln!("  {:>idx_w$}: {}", i + 1, format_file_label(file));
479            }
480            loop {
481                let answer = prompt
482                    .ask_string(&MODEL_FILE_PICK_FIELD, many[0].path.as_str())
483                    .await?;
484                let trimmed = answer.trim();
485                if trimmed.is_empty() {
486                    return Ok(vec![many[0].path.clone()]);
487                }
488                match parse_pick_input(trimmed, many) {
489                    Ok(picked) => return Ok(picked),
490                    Err(bad) => eprintln!(
491                        "[outrig] {bad:?} is not a number 1..={} or a filename in the list",
492                        many.len()
493                    ),
494                }
495            }
496        }
497    }
498}
499
500/// Render one row of the picker: filename plus a parenthesized
501/// human-readable size when known. Centralized so the auto-pick status
502/// line and the multi-line picker share a format.
503fn format_file_label(file: &crate::hf::HfFile) -> String {
504    match file.size {
505        Some(bytes) => format!("{}  ({})", file.path, crate::hf::format_size(bytes)),
506        None => file.path.clone(),
507    }
508}
509
510/// Parse a comma-separated picker answer against `files`. Each token is
511/// either a 1-based index or a literal filename match. Whitespace around
512/// tokens is ignored. Returns the unique paths in the order the user
513/// specified, deduplicated. Returns `Err(bad)` with the first
514/// unrecognized token.
515fn parse_pick_input(
516    input: &str,
517    files: &[crate::hf::HfFile],
518) -> std::result::Result<Vec<String>, String> {
519    let mut out: Vec<String> = Vec::new();
520    for tok in input.split(',') {
521        let t = tok.trim();
522        if t.is_empty() {
523            continue;
524        }
525        let path = if let Ok(n) = t.parse::<usize>()
526            && (1..=files.len()).contains(&n)
527        {
528            files[n - 1].path.clone()
529        } else if let Some(file) = files.iter().find(|f| f.path == t) {
530            file.path.clone()
531        } else {
532            return Err(t.to_string());
533        };
534        if !out.contains(&path) {
535            out.push(path);
536        }
537    }
538    if out.is_empty() {
539        return Err(input.trim().to_string());
540    }
541    Ok(out)
542}
543
544pub(crate) async fn prompt_default_model(
545    prompt: &mut impl PromptSource,
546    models: &BTreeMap<String, Model>,
547) -> Result<Option<String>> {
548    match models.len() {
549        0 => Ok(None),
550        1 => {
551            let only = models.keys().next().expect("len==1");
552            if prompt.ask_bool(&USE_DEFAULT_FIELD, true).await? {
553                Ok(Some(only.clone()))
554            } else {
555                Ok(None)
556            }
557        }
558        _ => loop {
559            // BTreeMap iteration order is alphabetical, which is fine as a
560            // suggestion -- the user picks freely from the validated set.
561            let suggestion = models.keys().next().expect("len>1");
562            let answer = prompt
563                .ask_string(&DEFAULT_MODEL_FIELD, suggestion.as_str())
564                .await?;
565            if answer.is_empty() {
566                return Ok(None);
567            }
568            if models.contains_key(&answer) {
569                return Ok(Some(answer));
570            }
571            eprintln!(
572                "[outrig] no model named `{answer}`; defined: {}",
573                models.keys().cloned().collect::<Vec<_>>().join(", ")
574            );
575        },
576    }
577}
578
579// ---- rendering + atomic write --------------------------------------------
580
581#[derive(Serialize)]
582#[serde(rename_all = "kebab-case")]
583struct GlobalOut<'a> {
584    #[serde(skip_serializing_if = "Option::is_none")]
585    default_model: Option<&'a str>,
586    #[serde(skip_serializing_if = "BTreeMap::is_empty")]
587    providers: &'a BTreeMap<String, LlmProvider>,
588    #[serde(skip_serializing_if = "BTreeMap::is_empty")]
589    models: &'a BTreeMap<String, Model>,
590}
591
592fn render(
593    default_model: Option<&str>,
594    providers: &BTreeMap<String, LlmProvider>,
595    models: &BTreeMap<String, Model>,
596) -> Result<String> {
597    let view = GlobalOut {
598        default_model,
599        providers,
600        models,
601    };
602    toml::to_string_pretty(&view)
603        .map_err(|e| OutrigError::Configuration(format!("rendering global config: {e}")).into())
604}
605
606fn blank_to_none(s: String) -> Option<String> {
607    if s.is_empty() { None } else { Some(s) }
608}
609
610/// `ask_string` wrapper that re-prompts on empty input. Used for fields
611/// where empty is not a meaningful answer (e.g. a HuggingFace model id).
612async fn ask_required(prompt: &mut impl PromptSource, field: &Field) -> Result<String> {
613    loop {
614        let answer = prompt.ask_string(field, "").await?;
615        if !answer.is_empty() {
616            return Ok(answer);
617        }
618        eprintln!("[outrig] this field requires a value");
619    }
620}