Skip to main content

dnslib/cli/
interactive.rs

1use inquire::validator::Validation;
2use inquire::{InquireError, MultiSelect, Select, Text};
3
4use crate::control_plane::config::{
5    CLOUDFLARE_DEFAULT_BASE_URL, DnsServerConfig, McpPermissions, PANGOLIN_DEFAULT_BASE_URL,
6    PIHOLE_DEFAULT_BASE_URL, ServerLocation, TECHNITIUM_DEFAULT_BASE_URL, UNIFI_DEFAULT_BASE_URL,
7    ValidationEndpointConfig, VendorKind,
8};
9use crate::control_plane::policy::PolicyRule;
10use crate::core::error::{Error, Result};
11
12pub fn run_add_wizard(existing_ids: &[String]) -> Result<DnsServerConfig> {
13    let existing: Vec<String> = existing_ids.iter().map(|s| s.to_lowercase()).collect();
14    let id = Text::new("Server ID:")
15        .with_help_message("Unique identifier for this server entry")
16        .with_validator(move |input: &str| {
17            if existing.iter().any(|id| id == &input.to_lowercase()) {
18                Ok(Validation::Invalid(
19                    format!("a server with id '{input}' already exists").into(),
20                ))
21            } else {
22                Ok(Validation::Valid)
23            }
24        })
25        .prompt()
26        .map_err(wizard_err)?;
27
28    let vendor = {
29        let choices = vec![
30            VendorChoice {
31                kind: VendorKind::Technitium,
32                label: "technitium",
33            },
34            VendorChoice {
35                kind: VendorKind::Pangolin,
36                label: "pangolin",
37            },
38            VendorChoice {
39                kind: VendorKind::Cloudflare,
40                label: "cloudflare",
41            },
42            VendorChoice {
43                kind: VendorKind::Unifi,
44                label: "unifi",
45            },
46            VendorChoice {
47                kind: VendorKind::Pihole,
48                label: "pihole",
49            },
50        ];
51        Select::new("Vendor:", choices)
52            .prompt()
53            .map_err(wizard_err)?
54            .kind
55    };
56
57    let default_url = match vendor {
58        VendorKind::Technitium => TECHNITIUM_DEFAULT_BASE_URL,
59        VendorKind::Pangolin => PANGOLIN_DEFAULT_BASE_URL,
60        VendorKind::Cloudflare => CLOUDFLARE_DEFAULT_BASE_URL,
61        VendorKind::Unifi => UNIFI_DEFAULT_BASE_URL,
62        VendorKind::Pihole => PIHOLE_DEFAULT_BASE_URL,
63    };
64
65    let base_url = optional_text(
66        "Base URL:",
67        &format!("Press Enter for default ({default_url}), or type a custom URL"),
68        Some(default_url),
69    )?;
70
71    let token_env = optional_text(
72        "Token environment variable:",
73        "Name of the env var holding the API token (recommended). Leave empty to skip.",
74        None,
75    )?;
76
77    let token = if token_env.is_none() {
78        optional_text(
79            "API token (stored in plain text — prefer token env var above):",
80            "Leave empty to skip",
81            None,
82        )?
83    } else {
84        None
85    };
86
87    let org_id = match vendor {
88        VendorKind::Pangolin => {
89            optional_text("Organisation ID (Pangolin):", "Leave empty to skip", None)?
90        }
91        VendorKind::Unifi => Some(
92            Text::new("Site name (UniFi):")
93                .with_help_message(
94                    "Human-readable site name (e.g. \"Default\") or site UUID; stored in org_id. \
95                     Run `dns settings` after saving to list valid site names.",
96                )
97                .with_validator(|input: &str| {
98                    if input.trim().is_empty() {
99                        Ok(Validation::Invalid(
100                            "site is required for UniFi".into(),
101                        ))
102                    } else {
103                        Ok(Validation::Valid)
104                    }
105                })
106                .prompt()
107                .map_err(wizard_err)?,
108        ),
109        _ => None,
110    };
111
112    let location = {
113        let choices = vec![
114            LocationChoice {
115                value: None,
116                label: "auto-detect",
117            },
118            LocationChoice {
119                value: Some(ServerLocation::Local),
120                label: "local",
121            },
122            LocationChoice {
123                value: Some(ServerLocation::External),
124                label: "external",
125            },
126        ];
127        Select::new("Location:", choices)
128            .with_help_message(
129                "auto-detect infers from the base URL (localhost/private IP → local)",
130            )
131            .prompt()
132            .map_err(wizard_err)?
133            .value
134    };
135
136    let access: Vec<PolicyRule> = {
137        let choices = vec![
138            AccessChoice {
139                rule: PolicyRule::Read,
140                label: "read   (list/export/stats/settings)",
141            },
142            AccessChoice {
143                rule: PolicyRule::Write,
144                label: "write  (create/update/import/flush)",
145            },
146            AccessChoice {
147                rule: PolicyRule::Delete,
148                label: "delete (delete zones/records/cache)",
149            },
150        ];
151        let defaults: Vec<usize> = (0..choices.len()).collect();
152        let chosen = MultiSelect::new("MCP allowed operations:", choices)
153            .with_default(&defaults)
154            .with_help_message("Select which operations are permitted for MCP tools on this server")
155            .prompt()
156            .map_err(wizard_err)?;
157        chosen.into_iter().map(|c| c.rule).collect()
158    };
159
160    let mut allowed_zones: Vec<String> = Vec::new();
161    loop {
162        let help = if allowed_zones.is_empty() {
163            "Restrict zone-targeting tools to specific zones; subdomains are also permitted. Leave empty to skip.".to_string()
164        } else {
165            format!(
166                "Added: {} — enter another, or leave empty to finish",
167                allowed_zones.join(", ")
168            )
169        };
170        let zone = match Text::new("Allowed zone:").with_help_message(&help).prompt() {
171            Ok(z) => z,
172            Err(InquireError::OperationCanceled | InquireError::OperationInterrupted) => {
173                return Err(Error::cancelled());
174            }
175            Err(e) => return Err(wizard_err(e)),
176        };
177        if zone.is_empty() {
178            break;
179        }
180        allowed_zones.push(zone);
181    }
182
183    let mut validation_endpoints: Vec<ValidationEndpointConfig> = Vec::new();
184    loop {
185        let help = if validation_endpoints.is_empty() {
186            "Optional DNS validation endpoints as name:transport:address (transport: dns, doh, dot). Leave empty to skip.".to_string()
187        } else {
188            format!(
189                "Added: {} — enter another, or leave empty to finish",
190                validation_endpoints
191                    .iter()
192                    .map(|endpoint| endpoint.name.as_str())
193                    .collect::<Vec<_>>()
194                    .join(", ")
195            )
196        };
197        let endpoint = match Text::new("Validation endpoint:")
198            .with_help_message(&help)
199            .prompt()
200        {
201            Ok(endpoint) => endpoint,
202            Err(InquireError::OperationCanceled | InquireError::OperationInterrupted) => {
203                return Err(Error::cancelled());
204            }
205            Err(e) => return Err(wizard_err(e)),
206        };
207        if endpoint.is_empty() {
208            break;
209        }
210        validation_endpoints.push(endpoint.parse::<ValidationEndpointConfig>().map_err(Error::parse)?);
211    }
212
213    Ok(DnsServerConfig {
214        id,
215        vendor,
216        location,
217        base_url,
218        base_url_env: None,
219        token,
220        token_env,
221        org_id,
222        cluster: None,
223        dns: None,
224        dot: None,
225        doh: None,
226        mcp: McpPermissions {
227            access,
228            allowed_zones,
229        },
230        validation_endpoints,
231    })
232}
233
234fn optional_text(label: &str, help: &str, default: Option<&str>) -> Result<Option<String>> {
235    let mut builder = Text::new(label).with_help_message(help);
236    if let Some(d) = default {
237        builder = builder.with_default(d);
238    }
239    let val = builder.prompt().map_err(wizard_err)?;
240    Ok(if val.is_empty() { None } else { Some(val) })
241}
242
243fn wizard_err(e: inquire::InquireError) -> Error {
244    match e {
245        InquireError::OperationCanceled | InquireError::OperationInterrupted => Error::cancelled(),
246        other => Error::io(
247            format!("interactive prompt failed: {other}"),
248            std::io::Error::other(other.to_string()),
249        ),
250    }
251}
252
253// ─── Display wrappers so Select can render enum variants ─────────────────────
254
255struct VendorChoice {
256    kind: VendorKind,
257    label: &'static str,
258}
259
260impl std::fmt::Display for VendorChoice {
261    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
262        f.write_str(self.label)
263    }
264}
265
266struct LocationChoice {
267    value: Option<ServerLocation>,
268    label: &'static str,
269}
270
271impl std::fmt::Display for LocationChoice {
272    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
273        f.write_str(self.label)
274    }
275}
276
277struct AccessChoice {
278    rule: PolicyRule,
279    label: &'static str,
280}
281
282impl std::fmt::Display for AccessChoice {
283    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
284        f.write_str(self.label)
285    }
286}