Skip to main content

hen/
error.rs

1use std::{
2    collections::BTreeSet,
3    fmt::{self, Display, Formatter},
4    path::Path,
5};
6
7use pest::{RuleType, error::{ErrorVariant, LineColLocation}};
8use serde_json::{Value, json};
9
10/// Alias for results that return a `HenError`.
11pub type HenResult<T> = Result<T, HenError>;
12
13/// High-level classification for hen errors to aid structured reporting.
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum HenErrorKind {
16    Cli,
17    Input,
18    Io,
19    Parse,
20    Planner,
21    Execution,
22    Benchmark,
23    Prompt,
24}
25
26#[derive(Debug, Clone, Copy, PartialEq, Eq)]
27pub enum HenDiagnosticSeverity {
28    Error,
29    Warning,
30    Information,
31    Hint,
32}
33
34impl HenDiagnosticSeverity {
35    pub fn label(self) -> &'static str {
36        match self {
37            HenDiagnosticSeverity::Error => "error",
38            HenDiagnosticSeverity::Warning => "warning",
39            HenDiagnosticSeverity::Information => "information",
40            HenDiagnosticSeverity::Hint => "hint",
41        }
42    }
43}
44
45#[derive(Debug, Clone, Copy, PartialEq, Eq)]
46pub enum HenDiagnosticPhase {
47    Parse,
48    Preprocess,
49    Inspect,
50    Validate,
51}
52
53impl HenDiagnosticPhase {
54    pub fn label(self) -> &'static str {
55        match self {
56            HenDiagnosticPhase::Parse => "parse",
57            HenDiagnosticPhase::Preprocess => "preprocess",
58            HenDiagnosticPhase::Inspect => "inspect",
59            HenDiagnosticPhase::Validate => "validate",
60        }
61    }
62}
63
64#[derive(Debug, Clone, PartialEq, Eq)]
65pub struct HenDiagnosticPosition {
66    pub line: usize,
67    pub character: usize,
68}
69
70#[derive(Debug, Clone, PartialEq, Eq)]
71pub struct HenDiagnosticRange {
72    pub start: HenDiagnosticPosition,
73    pub end: HenDiagnosticPosition,
74}
75
76#[derive(Debug, Clone, PartialEq, Eq)]
77pub struct HenDiagnosticLocation {
78    pub path: Option<String>,
79    pub range: HenDiagnosticRange,
80}
81
82#[derive(Debug, Clone, PartialEq, Eq)]
83pub struct HenDiagnosticRelatedInformation {
84    pub message: String,
85    pub location: HenDiagnosticLocation,
86}
87
88#[derive(Debug, Clone, PartialEq, Eq)]
89pub struct HenDiagnosticSymbol {
90    pub kind: String,
91    pub name: String,
92    pub role: String,
93}
94
95#[derive(Debug, Clone, PartialEq, Eq)]
96pub struct HenDiagnosticSuggestion {
97    pub kind: String,
98    pub label: String,
99    pub range: Option<HenDiagnosticRange>,
100    pub text: Option<String>,
101}
102
103#[derive(Debug, Clone, PartialEq, Eq)]
104pub struct HenDiagnostic {
105    pub code: String,
106    pub severity: HenDiagnosticSeverity,
107    pub phase: HenDiagnosticPhase,
108    pub message: String,
109    pub source: &'static str,
110    pub location: HenDiagnosticLocation,
111    pub related_information: Vec<HenDiagnosticRelatedInformation>,
112    pub symbol: Option<HenDiagnosticSymbol>,
113    pub suggestions: Vec<HenDiagnosticSuggestion>,
114    pub data: Option<Value>,
115}
116
117#[derive(Debug, Clone, Copy, PartialEq, Eq)]
118struct PestDiagnosticMetadata {
119    code: &'static str,
120    phase: HenDiagnosticPhase,
121    source: &'static str,
122}
123
124impl HenDiagnostic {
125    pub fn from_pest_error<T>(err: &pest::error::Error<T>, path: Option<&Path>) -> Self
126    where
127        T: RuleType + fmt::Debug,
128    {
129        let metadata = pest_error_metadata(&err.variant);
130        let message = pest_error_message(&err.variant);
131
132        Self {
133            code: metadata.code.to_string(),
134            severity: HenDiagnosticSeverity::Error,
135            phase: metadata.phase,
136            message: message.clone(),
137            source: metadata.source,
138            location: HenDiagnosticLocation {
139                path: path.map(|value| value.display().to_string()),
140                range: diagnostic_range(&err.line_col),
141            },
142            related_information: Vec::new(),
143            symbol: pest_error_symbol(metadata.code, &message),
144            suggestions: Vec::new(),
145            data: pest_error_data(metadata.code, &message),
146        }
147    }
148
149    pub fn with_symbol(mut self, symbol: HenDiagnosticSymbol) -> Self {
150        self.symbol = Some(symbol);
151        self
152    }
153
154    pub fn with_data(mut self, data: Value) -> Self {
155        self.data = Some(data);
156        self
157    }
158
159    pub fn with_suggestions(mut self, suggestions: Vec<HenDiagnosticSuggestion>) -> Self {
160        self.suggestions = suggestions;
161        self
162    }
163}
164
165impl HenErrorKind {
166    pub fn label(self) -> &'static str {
167        match self {
168            HenErrorKind::Cli => "CLI",
169            HenErrorKind::Input => "Input",
170            HenErrorKind::Io => "IO",
171            HenErrorKind::Parse => "Parse",
172            HenErrorKind::Planner => "Planner",
173            HenErrorKind::Execution => "Execution",
174            HenErrorKind::Benchmark => "Benchmark",
175            HenErrorKind::Prompt => "Prompt",
176        }
177    }
178}
179
180/// Application error with a short summary and optional detail lines.
181#[derive(Debug, Clone)]
182pub struct HenError {
183    kind: HenErrorKind,
184    summary: String,
185    details: Vec<String>,
186    diagnostics: Vec<HenDiagnostic>,
187    exit_code: i32,
188}
189
190impl HenError {
191    /// Construct a new error with the provided kind and summary line.
192    pub fn new(kind: HenErrorKind, summary: impl Into<String>) -> Self {
193        Self {
194            kind,
195            summary: summary.into(),
196            details: Vec::new(),
197            diagnostics: Vec::new(),
198            exit_code: 1,
199        }
200    }
201
202    /// Attach a detail line to the error for additional context.
203    pub fn with_detail(mut self, detail: impl Into<String>) -> Self {
204        self.details.push(detail.into());
205        self
206    }
207
208    pub fn with_diagnostic(mut self, diagnostic: HenDiagnostic) -> Self {
209        self.diagnostics.push(diagnostic);
210        self
211    }
212
213    pub fn with_diagnostics(mut self, diagnostics: Vec<HenDiagnostic>) -> Self {
214        self.diagnostics.extend(diagnostics);
215        self
216    }
217
218    /// Override the process exit code returned for this error.
219    pub fn with_exit_code(mut self, code: i32) -> Self {
220        self.exit_code = code;
221        self
222    }
223
224    pub fn kind(&self) -> HenErrorKind {
225        self.kind
226    }
227
228    pub fn summary(&self) -> &str {
229        &self.summary
230    }
231
232    pub fn details(&self) -> &[String] {
233        &self.details
234    }
235
236    pub fn diagnostics(&self) -> &[HenDiagnostic] {
237        &self.diagnostics
238    }
239
240    pub fn exit_code(&self) -> i32 {
241        self.exit_code
242    }
243
244    pub fn from_pest_error<T>(err: pest::error::Error<T>, path: Option<&Path>) -> Self
245    where
246        T: RuleType + fmt::Debug,
247    {
248        let diagnostic = HenDiagnostic::from_pest_error(&err, path);
249
250        HenError::new(HenErrorKind::Parse, "Failed to parse hen file")
251            .with_diagnostic(diagnostic)
252            .with_detail(err.to_string())
253    }
254}
255
256impl Display for HenError {
257    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
258        writeln!(f, "[{}] {}", self.kind.label(), self.summary)?;
259        for detail in &self.details {
260            writeln!(f, "  - {}", detail)?;
261        }
262        Ok(())
263    }
264}
265
266impl std::error::Error for HenError {}
267
268/// Render an error to stderr using the structured format.
269pub fn print_error(err: &HenError) {
270    eprint!("{}", err);
271}
272
273impl<T> From<pest::error::Error<T>> for HenError
274where
275    T: RuleType + fmt::Debug,
276{
277    fn from(err: pest::error::Error<T>) -> Self {
278        HenError::from_pest_error(err, None)
279    }
280}
281
282fn pest_error_metadata<T>(variant: &ErrorVariant<T>) -> PestDiagnosticMetadata {
283    match variant {
284        ErrorVariant::ParsingError { .. } => {
285            diagnostic_metadata("parse_error", HenDiagnosticPhase::Parse)
286        }
287        ErrorVariant::CustomError { message } => custom_pest_error_metadata(message),
288    }
289}
290
291fn custom_pest_error_metadata(message: &str) -> PestDiagnosticMetadata {
292    if message.starts_with("Failed to read import '") {
293        return diagnostic_metadata("fragment_import_io", HenDiagnosticPhase::Preprocess);
294    }
295
296    if message == "GraphQL directives require 'protocol = graphql'" {
297        return protocol_diagnostic_metadata("graphql_protocol_required");
298    }
299
300    if message == "MCP directives require 'protocol = mcp'" {
301        return protocol_diagnostic_metadata("mcp_protocol_required");
302    }
303
304    if message == "WebSocket directives require 'protocol = ws'" {
305        return protocol_diagnostic_metadata("ws_protocol_required");
306    }
307
308    if message == "SSE directives require 'protocol = sse'" {
309        return protocol_diagnostic_metadata("sse_protocol_required");
310    }
311
312    if message == "request is missing a URL" {
313        return diagnostic_metadata("missing_request_url", HenDiagnosticPhase::Validate);
314    }
315
316    if message == "request is missing an HTTP method" {
317        return diagnostic_metadata("missing_request_method", HenDiagnosticPhase::Validate);
318    }
319
320    if message == "schema and scalar declarations must appear before the first ---" {
321        return diagnostic_metadata("misplaced_declaration", HenDiagnosticPhase::Validate);
322    }
323
324    if message == "dotenv directives must appear before the first ---" {
325        return diagnostic_metadata("misplaced_dotenv", HenDiagnosticPhase::Validate);
326    }
327
328    if message.starts_with("Unknown environment '") {
329        return diagnostic_metadata("unknown_environment", HenDiagnosticPhase::Validate);
330    }
331
332    if message.starts_with("request references unknown OAuth profile '") {
333        return diagnostic_metadata("unknown_oauth_profile", HenDiagnosticPhase::Validate);
334    }
335
336    if message.starts_with("Duplicate environment '") {
337        return diagnostic_metadata("duplicate_environment", HenDiagnosticPhase::Validate);
338    }
339
340    if message.starts_with("Duplicate OAuth profile '") {
341        return diagnostic_metadata("duplicate_oauth_profile", HenDiagnosticPhase::Validate);
342    }
343
344    if message.starts_with("OAuth profile '") && message.contains(" uses unsupported field '") {
345        return diagnostic_metadata("unsupported_oauth_field", HenDiagnosticPhase::Validate);
346    }
347
348    if message.starts_with("OAuth profile '") && message.contains(" defines field '") && message.contains(" more than once.") {
349        return diagnostic_metadata("duplicate_oauth_field", HenDiagnosticPhase::Validate);
350    }
351
352    if message.starts_with("OAuth profile '") && message.contains(" defines param '") && message.contains(" more than once.") {
353        return diagnostic_metadata("duplicate_oauth_param", HenDiagnosticPhase::Validate);
354    }
355
356    if message.starts_with("OAuth profile '") && message.contains(" maps field '") && message.contains(" more than once.") {
357        return diagnostic_metadata("duplicate_oauth_mapping_source", HenDiagnosticPhase::Validate);
358    }
359
360    if message.starts_with("OAuth profile '") && message.contains(" maps more than one field into '") {
361        return diagnostic_metadata("duplicate_oauth_mapping_target", HenDiagnosticPhase::Validate);
362    }
363
364    if message.starts_with("OAuth profile '") && message.contains(" requires 'grant = ...'.") {
365        return diagnostic_metadata("missing_oauth_grant", HenDiagnosticPhase::Validate);
366    }
367
368    if message.starts_with("OAuth profile '")
369        && message.contains(" must define exactly one of 'issuer' or 'token_url'.")
370    {
371        return diagnostic_metadata("invalid_oauth_endpoint_config", HenDiagnosticPhase::Validate);
372    }
373
374    if message.starts_with("OAuth profile '")
375        && message.contains(" with grant '")
376        && message.contains(" requires '")
377    {
378        return diagnostic_metadata("missing_oauth_field", HenDiagnosticPhase::Validate);
379    }
380
381    if message.starts_with("OAuth profile '") && message.contains(" uses unsupported grant '") {
382        return diagnostic_metadata("unsupported_oauth_grant", HenDiagnosticPhase::Validate);
383    }
384
385    if message.starts_with("session-backed request omitted its method/URL, but session '") {
386        return diagnostic_metadata(
387            "missing_session_inherited_target",
388            HenDiagnosticPhase::Validate,
389        );
390    }
391
392    if message.starts_with("session '") && message.contains(" already uses protocol '") {
393        return protocol_diagnostic_metadata("session_protocol_conflict");
394    }
395
396    if message.starts_with("unsupported protocol '") {
397        return protocol_diagnostic_metadata("unsupported_protocol");
398    }
399
400    if message == "GraphQL requests currently require POST" {
401        return protocol_diagnostic_metadata("graphql_requires_post");
402    }
403
404    if message == "GraphQL requests do not support form fields" {
405        return protocol_diagnostic_metadata("graphql_form_fields_unsupported");
406    }
407
408    if message == "GraphQL requests require a ~~~graphql document block" {
409        return protocol_diagnostic_metadata("graphql_missing_document");
410    }
411
412    if message.starts_with("MCP-over-HTTP requests currently require POST") {
413        return protocol_diagnostic_metadata("mcp_requires_post");
414    }
415
416    if message.starts_with("MCP-over-HTTP requests do not support explicit body") {
417        return protocol_diagnostic_metadata("mcp_body_unsupported");
418    }
419
420    if message == "MCP requests require 'call = ...'" {
421        return protocol_diagnostic_metadata("mcp_missing_call");
422    }
423
424    if message == "'tool' and 'arguments' are only valid with 'call = tools/call'" {
425        return protocol_diagnostic_metadata("mcp_tool_arguments_invalid_for_call");
426    }
427
428    if message.starts_with("'call = tools/list' does not accept ") {
429        return protocol_diagnostic_metadata("mcp_tools_list_directives_unsupported");
430    }
431
432    if message.starts_with("'call = resources/list' does not accept ") {
433        return protocol_diagnostic_metadata("mcp_resources_list_directives_unsupported");
434    }
435
436    if message == "initialize override directives are only valid with 'call = initialize'" {
437        return protocol_diagnostic_metadata("mcp_initialize_overrides_invalid_for_call");
438    }
439
440    if message == "'call = tools/call' requires 'tool = ...'" {
441        return protocol_diagnostic_metadata("mcp_missing_tool");
442    }
443
444    if message.starts_with("unsupported MCP call '") {
445        return protocol_diagnostic_metadata("unsupported_mcp_call");
446    }
447
448    if message == "SSE requests require 'session = ...'" {
449        return protocol_diagnostic_metadata("sse_missing_session");
450    }
451
452    if message == "SSE requests currently require GET" {
453        return protocol_diagnostic_metadata("sse_requires_get");
454    }
455
456    if message.starts_with("SSE requests do not support explicit body") {
457        return protocol_diagnostic_metadata("sse_body_unsupported");
458    }
459
460    if message == "SSE receive steps require 'within = ...'" {
461        return protocol_diagnostic_metadata("sse_missing_within");
462    }
463
464    if message.starts_with("invalid within duration '") {
465        return protocol_diagnostic_metadata("invalid_within_duration");
466    }
467
468    if message == "'within = ...' is only valid with 'receive'" {
469        return protocol_diagnostic_metadata("within_requires_receive");
470    }
471
472    if message == "WebSocket requests require 'session = ...'" {
473        return protocol_diagnostic_metadata("ws_missing_session");
474    }
475
476    if message == "WebSocket requests currently require GET" {
477        return protocol_diagnostic_metadata("ws_requires_get");
478    }
479
480    if message == "WebSocket requests do not support form fields" {
481        return protocol_diagnostic_metadata("ws_form_fields_unsupported");
482    }
483
484    if message == "WebSocket requests cannot combine 'send = ...' with 'receive'" {
485        return protocol_diagnostic_metadata("ws_send_receive_conflict");
486    }
487
488    if message == "WebSocket receive steps require 'within = ...'" {
489        return protocol_diagnostic_metadata("ws_missing_within");
490    }
491
492    if message == "WebSocket send steps require a body block"
493        || (message.starts_with("WebSocket send kind '")
494            && message.ends_with("' requires a body block"))
495    {
496        return protocol_diagnostic_metadata("ws_missing_body");
497    }
498
499    if message.starts_with("WebSocket ") && message.contains(" do not support explicit body or content type blocks") {
500        return protocol_diagnostic_metadata("ws_body_unsupported");
501    }
502
503    if message.starts_with("unsupported WebSocket send kind '") {
504        return protocol_diagnostic_metadata("unsupported_ws_send_kind");
505    }
506
507    if message.starts_with("unsupported WebSocket body block type '") {
508        return protocol_diagnostic_metadata("unsupported_ws_body_content_type");
509    }
510
511    if message.starts_with("WebSocket send kind '") && message.contains(" conflicts with body block type '") {
512        return protocol_diagnostic_metadata("conflicting_ws_send_kind");
513    }
514
515    if message.starts_with("invalid WebSocket JSON payload:") {
516        return protocol_diagnostic_metadata("invalid_ws_json_payload");
517    }
518
519    if message.starts_with("invalid GraphQL variables JSON:") {
520        return protocol_diagnostic_metadata("invalid_graphql_variables_json");
521    }
522
523    if message.starts_with("invalid MCP ") && message.contains(" JSON:") {
524        return protocol_diagnostic_metadata("invalid_mcp_json");
525    }
526
527    if message.starts_with("MCP ") && message.ends_with(" must be a JSON object") {
528        return protocol_diagnostic_metadata("invalid_mcp_json_object");
529    }
530
531    if message == "scalar expressions can only declare one base type" {
532        return diagnostic_metadata("multiple_scalar_base_types", HenDiagnosticPhase::Validate);
533    }
534
535    if message == "scalar declaration must define a base type or predicate" {
536        return diagnostic_metadata("missing_scalar_base_or_predicate", HenDiagnosticPhase::Validate);
537    }
538
539    if message.starts_with("invalid ") && message.ends_with("() predicate") {
540        return diagnostic_metadata("invalid_scalar_predicate", HenDiagnosticPhase::Validate);
541    }
542
543    if message.contains("() requires min..max syntax") {
544        return diagnostic_metadata("invalid_scalar_bounds_syntax", HenDiagnosticPhase::Validate);
545    }
546
547    if message.contains("() requires at least one bound") {
548        return diagnostic_metadata("missing_scalar_bounds", HenDiagnosticPhase::Validate);
549    }
550
551    if message == "range() bounds must be numbers" {
552        return diagnostic_metadata("invalid_range_bounds", HenDiagnosticPhase::Validate);
553    }
554
555    if message.ends_with(" is reserved and cannot be redefined") {
556        return diagnostic_metadata("reserved_declaration_name", HenDiagnosticPhase::Validate);
557    }
558
559    if message.ends_with(" is already defined") {
560        return diagnostic_metadata("duplicate_declaration_name", HenDiagnosticPhase::Validate);
561    }
562
563    if message.contains(" references unknown validation target ") {
564        return diagnostic_metadata("unknown_validation_target", HenDiagnosticPhase::Validate);
565    }
566
567    if message.starts_with("scalar ") && message.contains(" cannot use schema ") {
568        return diagnostic_metadata(
569            "invalid_scalar_base_reference",
570            HenDiagnosticPhase::Validate,
571        );
572    }
573
574    if message.starts_with("schema declarations contain a circular reference:") {
575        return diagnostic_metadata("circular_schema_reference", HenDiagnosticPhase::Validate);
576    }
577
578    if message.starts_with("Unknown schema validation target '") {
579        return diagnostic_metadata(
580            "unknown_schema_validation_target",
581            HenDiagnosticPhase::Validate,
582        );
583    }
584
585    if message.starts_with("Environment '")
586        && message.contains(" overrides unknown or non-scalar variable '")
587    {
588        return diagnostic_metadata("unknown_environment_variable", HenDiagnosticPhase::Validate);
589    }
590
591    if message.starts_with("Environment '")
592        && message.contains(" defines variable '")
593        && message.contains(" with unsupported value '")
594    {
595        return diagnostic_metadata("invalid_environment_value", HenDiagnosticPhase::Validate);
596    }
597
598    if message.starts_with("Secret reference in '") && message.contains(" has invalid syntax '") {
599        return diagnostic_metadata("invalid_secret_reference", HenDiagnosticPhase::Validate);
600    }
601
602    if message.starts_with("Secret reference in '") && message.contains(" uses unsupported provider '") {
603        return diagnostic_metadata("unsupported_secret_provider", HenDiagnosticPhase::Validate);
604    }
605
606    if message.starts_with("Secret reference in '")
607        && message.contains(" requires environment variable '")
608    {
609        return diagnostic_metadata("missing_env_secret", HenDiagnosticPhase::Validate);
610    }
611
612    if message.starts_with("Secret reference in '") && message.contains(" failed to read file '") {
613        return diagnostic_metadata("file_secret_io", HenDiagnosticPhase::Validate);
614    }
615
616    if message.starts_with("Request '")
617        && message.contains(" references ")
618        && message.contains(" array variables but the limit is ")
619    {
620        return diagnostic_metadata("too_many_array_variables", HenDiagnosticPhase::Validate);
621    }
622
623    if message.starts_with("Request '")
624        && message.contains(" references array variable '")
625        && message.contains(" which is not defined as an array.")
626    {
627        return diagnostic_metadata("missing_array_values", HenDiagnosticPhase::Validate);
628    }
629
630    if message.starts_with("Request '")
631        && message.contains(" references array variable '")
632        && message.contains(" but it contains no values.")
633    {
634        return diagnostic_metadata("empty_array_values", HenDiagnosticPhase::Validate);
635    }
636
637    if message.starts_with("Request '")
638        && message.contains(" expands into ")
639        && message.contains(" combinations which exceeds the limit of ")
640    {
641        return diagnostic_metadata("too_many_combinations", HenDiagnosticPhase::Validate);
642    }
643
644    if message.starts_with("Request '")
645        && message.contains(" defines array variable '")
646        && message.contains(" with unsupported value '")
647    {
648        return diagnostic_metadata("invalid_array_value", HenDiagnosticPhase::Validate);
649    }
650
651    if message.starts_with("Request '")
652        && message.contains(" depends on '")
653        && message.contains(" expands into multiple iterations.")
654    {
655        return diagnostic_metadata(
656            "mapped_request_dependency_unsupported",
657            HenDiagnosticPhase::Validate,
658        );
659    }
660
661    if message.starts_with("Request '") && message.contains(" declares unknown dependency '") {
662        return diagnostic_metadata("unknown_dependency", HenDiagnosticPhase::Validate);
663    }
664
665    if message.starts_with("Request '") && message.contains(" has invalid fragment guard '") {
666        return diagnostic_metadata("invalid_fragment_guard", HenDiagnosticPhase::Validate);
667    }
668
669    if message.starts_with("Request '") && message.contains(" failed to load fragment '") {
670        return diagnostic_metadata("fragment_load_error", HenDiagnosticPhase::Validate);
671    }
672
673    if message.starts_with("Request '") && message.contains(" failed to parse fragment '") {
674        return diagnostic_metadata("fragment_parse_error", HenDiagnosticPhase::Validate);
675    }
676
677    if message.starts_with("Request '") && message.contains(" cannot include line '") {
678        return diagnostic_metadata("fragment_unsupported_line", HenDiagnosticPhase::Validate);
679    }
680
681    if message.starts_with("Request '")
682        && message.contains(" requires prompt '")
683        && message.contains(" but no input was provided")
684    {
685        return diagnostic_metadata("missing_prompt_input", HenDiagnosticPhase::Validate);
686    }
687
688    if message.starts_with("Request '") && message.contains(" defines invalid ") {
689        return diagnostic_metadata("invalid_reliability_value", HenDiagnosticPhase::Validate);
690    }
691
692    if message.starts_with("Missing value for prompt '") {
693        return diagnostic_metadata("missing_prompt_input", HenDiagnosticPhase::Validate);
694    }
695
696    diagnostic_metadata("parse_custom_error", HenDiagnosticPhase::Parse)
697}
698
699fn diagnostic_metadata(
700    code: &'static str,
701    phase: HenDiagnosticPhase,
702) -> PestDiagnosticMetadata {
703    PestDiagnosticMetadata {
704        code,
705        phase,
706        source: match phase {
707            HenDiagnosticPhase::Preprocess => "hen.preprocess",
708            HenDiagnosticPhase::Parse
709            | HenDiagnosticPhase::Inspect
710            | HenDiagnosticPhase::Validate => "hen.parser",
711        },
712    }
713}
714
715fn protocol_diagnostic_metadata(code: &'static str) -> PestDiagnosticMetadata {
716    PestDiagnosticMetadata {
717        code,
718        phase: HenDiagnosticPhase::Validate,
719        source: "hen.protocol",
720    }
721}
722
723fn pest_error_symbol(code: &str, message: &str) -> Option<HenDiagnosticSymbol> {
724    match code {
725        "unknown_oauth_profile" => first_quoted_value(message).map(|name| HenDiagnosticSymbol {
726            kind: "oauthProfile".to_string(),
727            name,
728            role: "reference".to_string(),
729        }),
730        "unknown_dependency" => {
731            unknown_dependency_details(message).map(|(_, dependency)| HenDiagnosticSymbol {
732                kind: "request".to_string(),
733                name: dependency,
734                role: "reference".to_string(),
735            })
736        }
737        "unknown_environment" => first_quoted_value(message).map(|name| HenDiagnosticSymbol {
738            kind: "environment".to_string(),
739            name,
740            role: "reference".to_string(),
741        }),
742        _ => None,
743    }
744}
745
746fn pest_error_data(code: &str, message: &str) -> Option<Value> {
747    match code {
748        "unknown_oauth_profile" => Some(json!({
749            "expectedKinds": ["oauthProfile"],
750            "symbolName": first_quoted_value(message),
751        })),
752        "unknown_dependency" => unknown_dependency_details(message).map(|(request, dependency)| {
753            json!({
754                "expectedKinds": ["request"],
755                "ownerName": request,
756                "symbolName": dependency,
757            })
758        }),
759        "unknown_schema_validation_target" => Some(json!({
760            "expectedKinds": ["schema", "scalar"],
761            "symbolName": first_quoted_value(message),
762        })),
763        "unknown_environment" => Some(json!({
764            "expectedKinds": ["environment"],
765            "symbolName": first_quoted_value(message),
766        })),
767        "missing_array_values" => request_variable_data(
768            message,
769            "' references array variable '",
770            "' which is not defined as an array.",
771        ),
772        "empty_array_values" => request_variable_data(
773            message,
774            "' references array variable '",
775            "' but it contains no values.",
776        ).map(|mut data| {
777            data["issue"] = Value::String("emptyArray".to_string());
778            data
779        }),
780        "invalid_array_value" => request_variable_value_data(
781            message,
782            "' defines array variable '",
783            "' with unsupported value '",
784            "'.",
785        ),
786        "too_many_combinations" => too_many_combinations_data(message),
787        "invalid_secret_reference" => source_value_data(
788            message,
789            "Secret reference in '",
790            "' has invalid syntax '",
791            "'. Use secret.env(\"NAME\") or secret.file(\"PATH\").",
792        ).map(|mut data| {
793            data["supportedValues"] = json!(["secret.env(\"NAME\")", "secret.file(\"PATH\")"]);
794            data
795        }),
796        "unsupported_secret_provider" => source_value_data(
797            message,
798            "Secret reference in '",
799            "' uses unsupported provider '",
800            "'. Supported providers are env and file.",
801        ).map(|mut data| {
802            data["supportedValues"] = json!(["env", "file"]);
803            data["provider"] = data["invalidValue"].clone();
804            data
805        }),
806        "missing_env_secret" => source_value_data(
807            message,
808            "Secret reference in '",
809            "' requires environment variable '",
810            "' but it is not set.",
811        ).map(|mut data| {
812            data["requiredEnvironmentVariable"] = data["invalidValue"].clone();
813            data["directiveName"] = Value::String("secret.env".to_string());
814            data
815        }),
816        "file_secret_io" => source_value_reason_data(
817            message,
818            "Secret reference in '",
819            "' failed to read file '",
820            "': ",
821            ".",
822        ).map(|mut data| {
823            data["directiveName"] = Value::String("secret.file".to_string());
824            data
825        }),
826        "missing_prompt_input" => missing_prompt_input_data(message),
827        "invalid_reliability_value" => invalid_reliability_value_data(message),
828        _ => protocol_error_data(code, message),
829    }
830}
831
832fn request_variable_data(
833    message: &str,
834    middle_prefix: &str,
835    suffix: &str,
836) -> Option<Value> {
837    let remainder = message.strip_prefix("Request '")?;
838    let (request, remainder) = remainder.split_once(middle_prefix)?;
839    let variable = remainder.strip_suffix(suffix)?;
840    Some(json!({
841        "ownerName": request,
842        "variableName": variable,
843    }))
844}
845
846fn request_variable_value_data(
847    message: &str,
848    middle_prefix: &str,
849    value_prefix: &str,
850    suffix: &str,
851) -> Option<Value> {
852    let remainder = message.strip_prefix("Request '")?;
853    let (request, remainder) = remainder.split_once(middle_prefix)?;
854    let (variable, value) = remainder.split_once(value_prefix)?;
855    let value = value.strip_suffix(suffix)?;
856    Some(json!({
857        "ownerName": request,
858        "variableName": variable,
859        "invalidValue": value,
860    }))
861}
862
863fn too_many_combinations_data(message: &str) -> Option<Value> {
864    let remainder = message.strip_prefix("Request '")?;
865    let (request, remainder) = remainder.split_once("' expands into ")?;
866    let (count, limit) = remainder.split_once(" combinations which exceeds the limit of ")?;
867    let limit = limit.strip_suffix('.')?;
868    Some(json!({
869        "ownerName": request,
870        "count": count.parse::<usize>().ok(),
871        "limit": limit.parse::<usize>().ok(),
872    }))
873}
874
875fn source_value_data(
876    message: &str,
877    source_prefix: &str,
878    middle_prefix: &str,
879    suffix: &str,
880) -> Option<Value> {
881    let remainder = message.strip_prefix(source_prefix)?;
882    let (source, remainder) = remainder.split_once(middle_prefix)?;
883    let value = remainder.strip_suffix(suffix)?;
884    Some(json!({
885        "sourceName": source,
886        "invalidValue": value,
887    }))
888}
889
890fn source_value_reason_data(
891    message: &str,
892    source_prefix: &str,
893    middle_prefix: &str,
894    reason_prefix: &str,
895    suffix: &str,
896) -> Option<Value> {
897    let remainder = message.strip_prefix(source_prefix)?;
898    let (source, remainder) = remainder.split_once(middle_prefix)?;
899    let (value, reason) = remainder.split_once(reason_prefix)?;
900    let reason = reason.strip_suffix(suffix)?;
901    Some(json!({
902        "sourceName": source,
903        "invalidValue": value,
904        "reason": reason,
905    }))
906}
907
908fn missing_prompt_input_data(message: &str) -> Option<Value> {
909    if let Some(remainder) = message.strip_prefix("Request '") {
910        let (request, remainder) = remainder.split_once("' requires prompt '")?;
911        let (prompt, remainder) = remainder.split_once("' but no input was provided")?;
912        let default = remainder
913            .strip_prefix(" (default: ")
914            .and_then(|value| value.strip_suffix(")."))
915            .map(str::to_string);
916        return Some(json!({
917            "ownerName": request,
918            "promptName": prompt,
919            "defaultValue": default,
920        }));
921    }
922
923    let remainder = message.strip_prefix("Missing value for prompt '")?;
924    let (prompt, remainder) = remainder.split_once("'")?;
925    let default = remainder
926        .strip_prefix(" (default: ")
927        .and_then(|value| value.strip_suffix(")"))
928        .map(str::to_string);
929    Some(json!({
930        "promptName": prompt,
931        "defaultValue": default,
932    }))
933}
934
935fn invalid_reliability_value_data(message: &str) -> Option<Value> {
936    let remainder = message.strip_prefix("Request '")?;
937    let (request, remainder) = remainder.split_once("' defines invalid ")?;
938    let (field, remainder) = remainder.split_once(" '")?;
939    let (value, reason) = remainder.split_once("': ")?;
940    let reason = reason.strip_suffix('.')?;
941    Some(json!({
942        "ownerName": request,
943        "fieldName": field,
944        "invalidValue": value,
945        "reason": reason,
946    }))
947}
948
949fn protocol_error_data(code: &str, message: &str) -> Option<Value> {
950    match code {
951        "graphql_protocol_required" => Some(json!({
952            "protocol": "graphql",
953            "expectedProtocol": "graphql",
954            "requiredDirectives": ["protocol"],
955            "directiveFamilies": ["graphql"],
956        })),
957        "graphql_requires_post" => Some(json!({
958            "protocol": "graphql",
959            "expectedMethod": "POST",
960        })),
961        "mcp_protocol_required" => Some(json!({
962            "protocol": "mcp",
963            "expectedProtocol": "mcp",
964            "requiredDirectives": ["protocol"],
965            "directiveFamilies": ["mcp"],
966        })),
967        "mcp_requires_post" => Some(json!({
968            "protocol": "mcp",
969            "expectedMethod": "POST",
970        })),
971        "sse_protocol_required" => Some(json!({
972            "protocol": "sse",
973            "expectedProtocol": "sse",
974            "requiredDirectives": ["protocol"],
975            "directiveFamilies": ["sse"],
976        })),
977        "sse_requires_get" => Some(json!({
978            "protocol": "sse",
979            "expectedMethod": "GET",
980        })),
981        "ws_protocol_required" => Some(json!({
982            "protocol": "ws",
983            "expectedProtocol": "ws",
984            "requiredDirectives": ["protocol"],
985            "directiveFamilies": ["ws"],
986        })),
987        "ws_requires_get" => Some(json!({
988            "protocol": "ws",
989            "expectedMethod": "GET",
990        })),
991        "unsupported_protocol" => Some(json!({
992            "directiveName": "protocol",
993            "invalidValue": first_quoted_value(message),
994            "supportedValues": ["http", "graphql", "mcp", "sse", "ws"],
995        })),
996        "session_protocol_conflict" => session_protocol_conflict_details(message).map(|(session_name, expected_protocol, conflicting_protocol)| json!({
997            "sessionName": session_name,
998            "expectedProtocol": expected_protocol,
999            "conflictingProtocol": conflicting_protocol,
1000        })),
1001        "graphql_missing_document" => Some(json!({
1002            "protocol": "graphql",
1003            "requiredBlocks": ["graphqlDocument"],
1004            "supportedBlockTypes": ["graphql"],
1005            "insertBlockTitle": "Insert GraphQL document block",
1006            "replacementOpeningFence": "~~~graphql",
1007            "replacementBodyText": "query {\n  field\n}",
1008        })),
1009        "graphql_form_fields_unsupported" => Some(json!({
1010            "protocol": "graphql",
1011            "cleanupTargets": ["formFields"],
1012        })),
1013        "mcp_body_unsupported" => Some(json!({
1014            "protocol": "mcp",
1015            "cleanupTargets": ["bodyBlock", "formFields"],
1016        })),
1017        "mcp_missing_call" => Some(json!({
1018            "protocol": "mcp",
1019            "requiredDirectives": ["call"],
1020            "supportedValues": ["initialize", "tools/list", "resources/list", "tools/call"],
1021        })),
1022        "mcp_tool_arguments_invalid_for_call" => Some(json!({
1023            "protocol": "mcp",
1024            "requiredDirectives": ["call"],
1025            "requiredCall": "tools/call",
1026        })),
1027        "mcp_tools_list_directives_unsupported" => mcp_call_conflict_data(message),
1028        "mcp_resources_list_directives_unsupported" => mcp_call_conflict_data(message),
1029        "mcp_initialize_overrides_invalid_for_call" => Some(json!({
1030            "protocol": "mcp",
1031            "requiredDirectives": ["call"],
1032            "requiredCall": "initialize",
1033        })),
1034        "unsupported_mcp_call" => Some(json!({
1035            "protocol": "mcp",
1036            "requiredDirectives": ["call"],
1037            "supportedValues": ["initialize", "tools/list", "resources/list", "tools/call"],
1038        })),
1039        "mcp_missing_tool" => Some(json!({
1040            "protocol": "mcp",
1041            "requiredDirectives": ["tool"],
1042            "requiredCall": "tools/call",
1043            "directiveName": "tool",
1044            "replacementValue": "exampleTool",
1045        })),
1046        "sse_missing_session" => Some(json!({
1047            "protocol": "sse",
1048            "requiredDirectives": ["session"],
1049            "directiveName": "session",
1050            "replacementValue": "exampleSession",
1051        })),
1052        "sse_body_unsupported" => Some(json!({
1053            "protocol": "sse",
1054            "cleanupTargets": ["bodyBlock"],
1055        })),
1056        "ws_missing_session" => Some(json!({
1057            "protocol": "ws",
1058            "requiredDirectives": ["session"],
1059            "directiveName": "session",
1060            "replacementValue": "exampleSession",
1061        })),
1062        "ws_form_fields_unsupported" => Some(json!({
1063            "protocol": "ws",
1064            "cleanupTargets": ["formFields"],
1065        })),
1066        "sse_missing_within" => Some(json!({
1067            "protocol": "sse",
1068            "requiredDirectives": ["within"],
1069            "requiredAction": "receive",
1070            "directiveName": "within",
1071            "replacementValue": "30s",
1072        })),
1073        "ws_missing_within" => Some(json!({
1074            "protocol": "ws",
1075            "requiredDirectives": ["within"],
1076            "requiredAction": "receive",
1077            "directiveName": "within",
1078            "replacementValue": "30s",
1079        })),
1080        "within_requires_receive" => Some(json!({
1081            "requiredDirectives": ["receive"],
1082            "requiredFlagDirective": "receive",
1083            "anchorDirectiveNames": ["within"],
1084        })),
1085        "ws_missing_body" => ws_missing_body_data(message),
1086        "ws_body_unsupported" => Some(json!({
1087            "protocol": "ws",
1088            "cleanupTargets": ["bodyBlock"],
1089        })),
1090        "ws_send_receive_conflict" => Some(json!({
1091            "protocol": "ws",
1092            "conflictingDirectives": ["send", "receive"],
1093        })),
1094        "unsupported_ws_send_kind" => Some(json!({
1095            "protocol": "ws",
1096            "directiveName": "send",
1097            "invalidValue": first_quoted_value(message),
1098            "supportedValues": ["text", "json"],
1099        })),
1100        "unsupported_ws_body_content_type" => Some(json!({
1101            "protocol": "ws",
1102            "blockKind": "body",
1103            "invalidValue": first_quoted_value(message),
1104            "supportedBlockTypes": ["text", "json"],
1105        })),
1106        "conflicting_ws_send_kind" => ws_send_kind_conflict_details(message).and_then(|(current_value, body_content_type)| {
1107            let expected_value = ws_send_kind_for_body_content_type(body_content_type.as_str())?;
1108            Some(json!({
1109                "protocol": "ws",
1110                "directiveName": "send",
1111                "currentValue": current_value,
1112                "expectedValue": expected_value,
1113                "bodyContentType": body_content_type,
1114            }))
1115        }),
1116        "invalid_ws_json_payload" => Some(json!({
1117            "protocol": "ws",
1118            "blockKind": "body",
1119            "replacementBodyText": "{\n  \"type\": \"message\"\n}",
1120        })),
1121        "invalid_graphql_variables_json" => Some(json!({
1122            "protocol": "graphql",
1123            "directiveName": "variables",
1124            "replacementValue": "{}",
1125        })),
1126        "invalid_mcp_json" => mcp_invalid_json_field_name(message).map(|directive_name| json!({
1127            "protocol": "mcp",
1128            "directiveName": directive_name,
1129            "replacementValue": "{}",
1130        })),
1131        "invalid_mcp_json_object" => mcp_json_object_field_name(message).map(|directive_name| json!({
1132            "protocol": "mcp",
1133            "directiveName": directive_name,
1134            "replacementValue": "{}",
1135        })),
1136        _ => None,
1137    }
1138}
1139
1140fn first_quoted_value(message: &str) -> Option<String> {
1141    let start = message.find('\'')?;
1142    let end = message[start + 1..].find('\'')? + start + 1;
1143    Some(message[start + 1..end].to_string())
1144}
1145
1146fn quoted_values(message: &str) -> Vec<String> {
1147    let mut values = Vec::new();
1148    let mut remainder = message;
1149
1150    while let Some(start) = remainder.find('\'') {
1151        let after_start = &remainder[start + 1..];
1152        let Some(end) = after_start.find('\'') else {
1153            break;
1154        };
1155        values.push(after_start[..end].to_string());
1156        remainder = &after_start[end + 1..];
1157    }
1158
1159    values
1160}
1161
1162fn ws_send_kind_conflict_details(message: &str) -> Option<(String, String)> {
1163    let values = quoted_values(message);
1164    if values.len() < 2 {
1165        return None;
1166    }
1167
1168    Some((values[0].clone(), values[1].clone()))
1169}
1170
1171fn ws_missing_body_data(message: &str) -> Option<Value> {
1172    let send_kind = first_quoted_value(message)?;
1173    let (insert_block_title, replacement_opening_fence, replacement_body_text) =
1174        match send_kind.as_str() {
1175            "json" => (
1176                "Insert WebSocket JSON body block",
1177                "~~~json",
1178                "{\n  \"type\": \"message\"\n}",
1179            ),
1180            "text" => (
1181                "Insert WebSocket body block",
1182                "~~~",
1183                "message payload",
1184            ),
1185            _ => return None,
1186        };
1187
1188    Some(json!({
1189        "protocol": "ws",
1190        "requiredBlocks": ["body"],
1191        "requiredDirectives": ["send"],
1192        "sendKind": send_kind,
1193        "insertBlockTitle": insert_block_title,
1194        "replacementOpeningFence": replacement_opening_fence,
1195        "replacementBodyText": replacement_body_text,
1196    }))
1197}
1198
1199fn session_protocol_conflict_details(message: &str) -> Option<(String, String, String)> {
1200    let values = quoted_values(message);
1201    if values.len() < 3 {
1202        return None;
1203    }
1204
1205    Some((values[0].clone(), values[1].clone(), values[2].clone()))
1206}
1207
1208fn mcp_invalid_json_field_name(message: &str) -> Option<String> {
1209    let remainder = message.strip_prefix("invalid MCP ")?;
1210    let end = remainder.find(" JSON:")?;
1211    Some(remainder[..end].trim().to_string())
1212}
1213
1214fn mcp_json_object_field_name(message: &str) -> Option<String> {
1215    let remainder = message.strip_prefix("MCP ")?;
1216    let field_name = remainder.strip_suffix(" must be a JSON object")?;
1217    Some(field_name.trim().to_string())
1218}
1219
1220fn mcp_call_conflict_data(message: &str) -> Option<Value> {
1221    let replacements = mcp_call_conflict_replacements(message);
1222    if replacements.is_empty() {
1223        return None;
1224    }
1225
1226    if replacements.len() == 1 {
1227        return Some(json!({
1228            "protocol": "mcp",
1229            "requiredDirectives": ["call"],
1230            "requiredCall": replacements[0],
1231        }));
1232    }
1233
1234    Some(json!({
1235        "protocol": "mcp",
1236        "requiredDirectives": ["call"],
1237        "supportedValues": replacements,
1238    }))
1239}
1240
1241fn mcp_call_conflict_replacements(message: &str) -> Vec<String> {
1242    quoted_values(message)
1243        .into_iter()
1244        .skip(1)
1245        .filter_map(|value| value.strip_prefix("call = ").map(str::to_string))
1246        .collect()
1247}
1248
1249fn ws_send_kind_for_body_content_type(value: &str) -> Option<&'static str> {
1250    match value.trim() {
1251        "json" | "application/json" => Some("json"),
1252        "text" | "text/plain" => Some("text"),
1253        _ => None,
1254    }
1255}
1256
1257fn unknown_dependency_details(message: &str) -> Option<(String, String)> {
1258    let remainder = message.strip_prefix("Request '")?;
1259    let (request, remainder) = remainder.split_once("' declares unknown dependency '")?;
1260    let dependency = remainder.strip_suffix("'.")?;
1261    Some((request.to_string(), dependency.to_string()))
1262}
1263
1264fn pest_error_message<T>(variant: &ErrorVariant<T>) -> String
1265where
1266    T: fmt::Debug,
1267{
1268    match variant {
1269        ErrorVariant::CustomError { message } => message.clone(),
1270        ErrorVariant::ParsingError {
1271            positives,
1272            negatives,
1273        } => {
1274            let expected = format_rule_list(positives);
1275            let unexpected = format_rule_list(negatives);
1276
1277            if !expected.is_empty() {
1278                return format!("Unexpected input; expected {expected}");
1279            }
1280
1281            if !unexpected.is_empty() {
1282                return format!("Unexpected input; disallowed {unexpected}");
1283            }
1284
1285            "Unexpected input".to_string()
1286        }
1287    }
1288}
1289
1290fn format_rule_list<T>(rules: &[T]) -> String
1291where
1292    T: fmt::Debug,
1293{
1294    let values = rules
1295        .iter()
1296        .map(|rule| format!("{:?}", rule))
1297        .collect::<BTreeSet<_>>()
1298        .into_iter()
1299        .collect::<Vec<_>>();
1300
1301    match values.len() {
1302        0 => String::new(),
1303        1 => values[0].clone(),
1304        _ => values.join(", "),
1305    }
1306}
1307
1308fn diagnostic_range(line_col: &LineColLocation) -> HenDiagnosticRange {
1309    let range = match line_col {
1310        LineColLocation::Pos((line, character)) => HenDiagnosticRange {
1311            start: diagnostic_position(*line, *character),
1312            end: diagnostic_position(*line, character.saturating_add(1)),
1313        },
1314        LineColLocation::Span((start_line, start_character), (end_line, end_character)) => {
1315            HenDiagnosticRange {
1316                start: diagnostic_position(*start_line, *start_character),
1317                end: diagnostic_position(*end_line, *end_character),
1318            }
1319        }
1320    };
1321
1322    ensure_non_empty_range(range)
1323}
1324
1325fn diagnostic_position(line: usize, character: usize) -> HenDiagnosticPosition {
1326    HenDiagnosticPosition {
1327        line: line.saturating_sub(1),
1328        character: character.saturating_sub(1),
1329    }
1330}
1331
1332fn ensure_non_empty_range(mut range: HenDiagnosticRange) -> HenDiagnosticRange {
1333    if range.end.line < range.start.line
1334        || (range.end.line == range.start.line && range.end.character <= range.start.character)
1335    {
1336        range.end.line = range.start.line;
1337        range.end.character = range.start.character.saturating_add(1);
1338    }
1339
1340    range
1341}