Skip to main content

grpcurl_core/
format.rs

1use std::cell::Cell;
2use std::fmt;
3use std::io::{self, Read};
4use std::str::FromStr;
5
6use prost_reflect::{DeserializeOptions, DynamicMessage, MessageDescriptor, SerializeOptions};
7
8use crate::error::{GrpcurlError, Result};
9
10/// Format for request/response data.
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub enum Format {
13    Json,
14    Text,
15}
16
17impl FromStr for Format {
18    type Err = String;
19
20    fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
21        match s {
22            "json" => Ok(Format::Json),
23            "text" => Ok(Format::Text),
24            other => Err(format!(
25                "The --format option must be 'json' or 'text', got '{other}'."
26            )),
27        }
28    }
29}
30
31impl fmt::Display for Format {
32    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
33        match self {
34            Format::Json => write!(f, "json"),
35            Format::Text => write!(f, "text"),
36        }
37    }
38}
39
40/// Options controlling request parsing and response formatting.
41///
42/// Equivalent to Go's `FormatOptions` (format.go:380-398).
43#[derive(Debug, Clone, Default)]
44pub struct FormatOptions {
45    /// Include fields with default values in JSON output.
46    /// Maps to prost-reflect's `skip_default_fields(!emit_defaults)`.
47    pub emit_defaults: bool,
48
49    /// Accept unknown fields in JSON input without error.
50    /// Maps to prost-reflect's `deny_unknown_fields(!allow_unknown)`.
51    pub allow_unknown_fields: bool,
52}
53
54/// Parse error indicating end of input.
55#[derive(Debug)]
56pub enum ParseError {
57    /// No more messages to parse.
58    Eof,
59    /// A parse error occurred.
60    Error(GrpcurlError),
61}
62
63impl From<GrpcurlError> for ParseError {
64    fn from(err: GrpcurlError) -> Self {
65        ParseError::Error(err)
66    }
67}
68
69/// Stream-based request message parser.
70///
71/// Equivalent to Go's `RequestParser` interface (format.go:24-33).
72/// Reads one message at a time from the input, supporting multiple
73/// concatenated messages (separated by whitespace).
74pub struct JsonRequestParser {
75    data: String,
76    offset: usize,
77    num_requests: usize,
78    options: DeserializeOptions,
79}
80
81impl JsonRequestParser {
82    /// Create a new JSON request parser from the input data.
83    ///
84    /// If `data` is "@", reads from stdin. Otherwise uses the string directly.
85    pub fn new(data: Option<&str>, options: &FormatOptions) -> Result<Self> {
86        let input = match data {
87            Some("@") => {
88                let mut buf = String::new();
89                io::stdin().read_to_string(&mut buf).map_err(|e| {
90                    GrpcurlError::Io(io::Error::new(e.kind(), format!("reading stdin: {e}")))
91                })?;
92                buf
93            }
94            Some(s) => s.to_string(),
95            None => String::new(),
96        };
97
98        let de_options =
99            DeserializeOptions::new().deny_unknown_fields(!options.allow_unknown_fields);
100
101        Ok(JsonRequestParser {
102            data: input,
103            offset: 0,
104            num_requests: 0,
105            options: de_options,
106        })
107    }
108
109    /// Parse the next message from the input stream.
110    ///
111    /// Returns `ParseError::Eof` when there are no more messages.
112    /// Multiple JSON objects can be concatenated with whitespace between them.
113    pub fn next(
114        &mut self,
115        desc: &MessageDescriptor,
116    ) -> std::result::Result<DynamicMessage, ParseError> {
117        // Skip whitespace
118        let remaining = &self.data[self.offset..];
119        let trimmed = remaining.trim_start();
120        if trimmed.is_empty() {
121            return Err(ParseError::Eof);
122        }
123
124        // Update offset past whitespace
125        self.offset += remaining.len() - trimmed.len();
126
127        // Use serde_json's stream deserializer to read exactly one JSON value
128        let mut de = serde_json::Deserializer::from_str(trimmed).into_iter::<serde_json::Value>();
129
130        match de.next() {
131            Some(Ok(value)) => {
132                // Advance our offset by the bytes consumed
133                let bytes_consumed = de.byte_offset();
134                self.offset += bytes_consumed;
135                self.num_requests += 1;
136
137                // Deserialize the JSON value into a DynamicMessage
138                let msg =
139                    DynamicMessage::deserialize_with_options(desc.clone(), value, &self.options)
140                        .map_err(|e| {
141                            ParseError::Error(GrpcurlError::Proto(format!(
142                                "failed to parse JSON request: {e}"
143                            )))
144                        })?;
145
146                Ok(msg)
147            }
148            Some(Err(e)) => Err(ParseError::Error(GrpcurlError::Proto(format!(
149                "invalid JSON in request data: {e}"
150            )))),
151            None => Err(ParseError::Eof),
152        }
153    }
154
155    /// Return the number of messages parsed so far.
156    pub fn num_requests(&self) -> usize {
157        self.num_requests
158    }
159}
160
161/// Protobuf text format request parser.
162///
163/// Equivalent to Go's `textRequestParser` (format.go:84-88).
164/// Messages are separated by the 0x1E record separator character.
165pub struct TextRequestParser {
166    data: String,
167    offset: usize,
168    num_requests: usize,
169}
170
171impl TextRequestParser {
172    /// Create a new text format request parser from the input data.
173    ///
174    /// If `data` is "@", reads from stdin. Otherwise uses the string directly.
175    pub fn new(data: Option<&str>) -> Result<Self> {
176        let input = match data {
177            Some("@") => {
178                let mut buf = String::new();
179                io::stdin().read_to_string(&mut buf).map_err(|e| {
180                    GrpcurlError::Io(io::Error::new(e.kind(), format!("reading stdin: {e}")))
181                })?;
182                buf
183            }
184            Some(s) => s.to_string(),
185            None => String::new(),
186        };
187
188        Ok(TextRequestParser {
189            data: input,
190            offset: 0,
191            num_requests: 0,
192        })
193    }
194
195    /// Parse the next message from the input stream.
196    ///
197    /// Messages are separated by the 0x1E record separator character.
198    /// Returns `ParseError::Eof` when there are no more messages.
199    ///
200    /// Matches Go behavior: on the first call with empty input, returns an
201    /// empty DynamicMessage (empty text is a valid empty proto message).
202    /// Subsequent calls return `ParseError::Eof`.
203    pub fn next(
204        &mut self,
205        desc: &MessageDescriptor,
206    ) -> std::result::Result<DynamicMessage, ParseError> {
207        let remaining = &self.data[self.offset..];
208        if remaining.trim().is_empty() {
209            // On the very first call, empty input produces one empty message
210            // (matching Go's text parser semantics).
211            if self.num_requests == 0 {
212                self.offset = self.data.len();
213                self.num_requests += 1;
214                return Ok(DynamicMessage::new(desc.clone()));
215            }
216            return Err(ParseError::Eof);
217        }
218
219        // Read until 0x1E separator or end of input
220        let (text, consumed) = if let Some(pos) = remaining.find('\x1e') {
221            (&remaining[..pos], pos + 1)
222        } else {
223            (remaining, remaining.len())
224        };
225
226        let text = text.trim();
227        if text.is_empty() {
228            self.offset += consumed;
229            // Empty segment on first read still produces one empty message
230            if self.num_requests == 0 {
231                self.num_requests += 1;
232                return Ok(DynamicMessage::new(desc.clone()));
233            }
234            return Err(ParseError::Eof);
235        }
236
237        self.offset += consumed;
238        self.num_requests += 1;
239
240        DynamicMessage::parse_text_format(desc.clone(), text).map_err(|e| {
241            ParseError::Error(GrpcurlError::Proto(format!(
242                "failed to parse text format request: {e}"
243            )))
244        })
245    }
246
247    /// Return the number of messages parsed so far.
248    pub fn num_requests(&self) -> usize {
249        self.num_requests
250    }
251}
252
253/// Unified request parser that dispatches to the appropriate format.
254///
255/// This enum wraps either a JSON or text format parser, providing a
256/// common interface for the invocation engine.
257pub enum RequestParser {
258    Json(JsonRequestParser),
259    Text(TextRequestParser),
260}
261
262impl RequestParser {
263    /// Parse the next message from the input stream.
264    pub fn next(
265        &mut self,
266        desc: &MessageDescriptor,
267    ) -> std::result::Result<DynamicMessage, ParseError> {
268        match self {
269            RequestParser::Json(p) => p.next(desc),
270            RequestParser::Text(p) => p.next(desc),
271        }
272    }
273
274    /// Return the number of messages parsed so far.
275    pub fn num_requests(&self) -> usize {
276        match self {
277            RequestParser::Json(p) => p.num_requests(),
278            RequestParser::Text(p) => p.num_requests(),
279        }
280    }
281}
282
283/// Create a template DynamicMessage with default values for all fields.
284///
285/// Equivalent to Go's `MakeTemplate()` (grpcurl.go:396-510).
286///
287/// The template is useful for showing users what a valid JSON request
288/// looks like. Scalar fields are left at defaults; repeated fields get
289/// one default element; message fields are recursively populated.
290pub fn make_template(desc: &MessageDescriptor) -> DynamicMessage {
291    make_template_inner(desc, &mut Vec::new())
292}
293
294fn make_template_inner(desc: &MessageDescriptor, path: &mut Vec<String>) -> DynamicMessage {
295    let full_name = desc.full_name().to_string();
296
297    // Handle well-known types with special JSON representations.
298    // Matches Go's MakeTemplate() (grpcurl.go:407-449).
299    match full_name.as_str() {
300        "google.protobuf.Any" => {
301            let mut msg = DynamicMessage::new(desc.clone());
302            if let Some(type_url_field) = desc.get_field_by_name("type_url") {
303                msg.set_field(
304                    &type_url_field,
305                    prost_reflect::Value::String(
306                        "type.googleapis.com/google.protobuf.Empty".into(),
307                    ),
308                );
309            }
310            // value field left as empty bytes (default), producing {"@type":"..."}
311            return msg;
312        }
313        "google.protobuf.Value" => {
314            // Value supports arbitrary JSON; provide a string hint
315            let mut msg = DynamicMessage::new(desc.clone());
316            if let Some(string_value_field) = desc.get_field_by_name("string_value") {
317                msg.set_field(
318                    &string_value_field,
319                    prost_reflect::Value::String(
320                        "google.protobuf.Value supports arbitrary JSON".into(),
321                    ),
322                );
323            }
324            return msg;
325        }
326        "google.protobuf.ListValue" => {
327            // ListValue is a JSON array; provide one Value element
328            let mut msg = DynamicMessage::new(desc.clone());
329            if let Some(values_field) = desc.get_field_by_name("values") {
330                let value_desc = match values_field.kind() {
331                    prost_reflect::Kind::Message(m) => m,
332                    _ => return msg,
333                };
334                let mut value_msg = DynamicMessage::new(value_desc.clone());
335                if let Some(string_value_field) = value_desc.get_field_by_name("string_value") {
336                    value_msg.set_field(
337                        &string_value_field,
338                        prost_reflect::Value::String(
339                            "google.protobuf.Value supports arbitrary JSON".into(),
340                        ),
341                    );
342                }
343                msg.set_field(
344                    &values_field,
345                    prost_reflect::Value::List(vec![prost_reflect::Value::Message(value_msg)]),
346                );
347            }
348            return msg;
349        }
350        "google.protobuf.Struct" => {
351            // Struct is a JSON object; provide one key-value pair
352            let mut msg = DynamicMessage::new(desc.clone());
353            if let Some(fields_field) = desc.get_field_by_name("fields") {
354                let entry_desc = match fields_field.kind() {
355                    prost_reflect::Kind::Message(m) => m,
356                    _ => return msg,
357                };
358                let value_field_desc = entry_desc.get_field(2);
359                if let Some(value_field_desc) = value_field_desc {
360                    let value_msg_desc = match value_field_desc.kind() {
361                        prost_reflect::Kind::Message(m) => m,
362                        _ => return msg,
363                    };
364                    let mut value_msg = DynamicMessage::new(value_msg_desc.clone());
365                    if let Some(string_value_field) =
366                        value_msg_desc.get_field_by_name("string_value")
367                    {
368                        value_msg.set_field(
369                            &string_value_field,
370                            prost_reflect::Value::String(
371                                "google.protobuf.Struct supports arbitrary JSON objects".into(),
372                            ),
373                        );
374                    }
375                    let mut map = std::collections::HashMap::new();
376                    map.insert(
377                        prost_reflect::MapKey::String("key".into()),
378                        prost_reflect::Value::Message(value_msg),
379                    );
380                    msg.set_field(&fields_field, prost_reflect::Value::Map(map));
381                }
382            }
383            return msg;
384        }
385        _ => {}
386    }
387
388    // Cycle detection: if we've already seen this message type, return empty
389    if path.contains(&full_name) {
390        return DynamicMessage::new(desc.clone());
391    }
392
393    path.push(full_name);
394
395    let mut msg = DynamicMessage::new(desc.clone());
396
397    for field in desc.fields() {
398        if field.is_map() {
399            // Map field: add one entry with default key and value
400            let kind = field.kind();
401            let entry_desc = kind.as_message().expect("map field has message type");
402            let key_field = entry_desc.get_field(1).expect("map entry has key field");
403            let value_field = entry_desc.get_field(2).expect("map entry has value field");
404
405            let key = default_map_key(&key_field);
406            let value = if let prost_reflect::Kind::Message(value_desc) = value_field.kind() {
407                prost_reflect::Value::Message(make_template_inner(&value_desc, path))
408            } else {
409                default_value_for_kind(&value_field)
410            };
411
412            let mut map = std::collections::HashMap::new();
413            map.insert(key, value);
414            msg.set_field(&field, prost_reflect::Value::Map(map));
415        } else if field.is_list() {
416            // Repeated field: add one default element
417            let element = if let prost_reflect::Kind::Message(elem_desc) = field.kind() {
418                prost_reflect::Value::Message(make_template_inner(&elem_desc, path))
419            } else {
420                default_value_for_kind(&field)
421            };
422            msg.set_field(&field, prost_reflect::Value::List(vec![element]));
423        } else if let prost_reflect::Kind::Message(sub_desc) = field.kind() {
424            // Non-repeated message field: recursively populate
425            let sub_msg = make_template_inner(&sub_desc, path);
426            msg.set_field(&field, prost_reflect::Value::Message(sub_msg));
427        }
428        // Scalar non-repeated fields: leave at defaults (emit_defaults will show them)
429    }
430
431    path.pop();
432    msg
433}
434
435/// Return a default MapKey for a given field descriptor.
436fn default_map_key(field: &prost_reflect::FieldDescriptor) -> prost_reflect::MapKey {
437    use prost_reflect::Kind;
438    match field.kind() {
439        Kind::Int32 | Kind::Sint32 | Kind::Sfixed32 => prost_reflect::MapKey::I32(0),
440        Kind::Int64 | Kind::Sint64 | Kind::Sfixed64 => prost_reflect::MapKey::I64(0),
441        Kind::Uint32 | Kind::Fixed32 => prost_reflect::MapKey::U32(0),
442        Kind::Uint64 | Kind::Fixed64 => prost_reflect::MapKey::U64(0),
443        Kind::Bool => prost_reflect::MapKey::Bool(false),
444        Kind::String => prost_reflect::MapKey::String(String::new()),
445        _ => prost_reflect::MapKey::I32(0),
446    }
447}
448
449/// Return a default Value for a scalar field.
450fn default_value_for_kind(field: &prost_reflect::FieldDescriptor) -> prost_reflect::Value {
451    use prost_reflect::Kind;
452    match field.kind() {
453        Kind::Double => prost_reflect::Value::F64(0.0),
454        Kind::Float => prost_reflect::Value::F32(0.0),
455        Kind::Int32 | Kind::Sint32 | Kind::Sfixed32 => prost_reflect::Value::I32(0),
456        Kind::Int64 | Kind::Sint64 | Kind::Sfixed64 => prost_reflect::Value::I64(0),
457        Kind::Uint32 | Kind::Fixed32 => prost_reflect::Value::U32(0),
458        Kind::Uint64 | Kind::Fixed64 => prost_reflect::Value::U64(0),
459        Kind::Bool => prost_reflect::Value::Bool(false),
460        Kind::String => prost_reflect::Value::String(String::new()),
461        Kind::Bytes => prost_reflect::Value::Bytes(Default::default()),
462        Kind::Enum(e) => {
463            // Use first enum value (typically 0)
464            prost_reflect::Value::EnumNumber(e.default_value().number())
465        }
466        Kind::Message(m) => prost_reflect::Value::Message(DynamicMessage::new(m)),
467    }
468}
469
470/// Type alias for a response formatter function.
471///
472/// Equivalent to Go's `Formatter` type (format.go:129).
473pub type Formatter = Box<dyn Fn(&DynamicMessage) -> Result<String>>;
474
475/// Create a JSON response formatter.
476///
477/// Produces pretty-printed JSON with 2-space indentation.
478/// If `emit_defaults` is true, includes fields with default/zero values.
479///
480/// Equivalent to Go's `NewJSONFormatter()` (format.go:137-157).
481pub fn json_formatter(options: &FormatOptions) -> Formatter {
482    let serialize_options = SerializeOptions::new()
483        .skip_default_fields(!options.emit_defaults)
484        .stringify_64_bit_integers(true);
485
486    Box::new(move |msg: &DynamicMessage| {
487        let mut buf = Vec::new();
488        let mut serializer = serde_json::Serializer::pretty(&mut buf);
489
490        msg.serialize_with_options(&mut serializer, &serialize_options)
491            .map_err(|e| GrpcurlError::Proto(format!("failed to format response as JSON: {e}")))?;
492
493        let json = String::from_utf8(buf)
494            .map_err(|e| GrpcurlError::Proto(format!("JSON output is not valid UTF-8: {e}")))?;
495
496        // Post-process to match Go's float formatting: strip trailing ".0" from
497        // whole-valued doubles (e.g., "42.0" -> "42"). Go's encoding/json omits
498        // the decimal point for whole numbers, while serde_json always includes it.
499        Ok(normalize_json_floats(&json))
500    })
501}
502
503/// Strip trailing ".0" from whole-valued JSON numbers to match Go's encoding/json.
504///
505/// Only modifies numeric values (not strings). Handles the pretty-printed
506/// JSON format where numbers appear at the end of lines or before commas/brackets.
507fn normalize_json_floats(json: &str) -> String {
508    use regex::Regex;
509    use std::sync::LazyLock;
510
511    // Match numbers like 42.0 that are NOT inside quotes.
512    // This regex finds: digits followed by ".0" at a word boundary,
513    // not preceded by another digit after the decimal (i.e., exactly ".0").
514    static FLOAT_REGEX: LazyLock<Regex> =
515        LazyLock::new(|| Regex::new(r"(?m): (\d+)\.0([,\s\n\r\}\]]|$)").expect("float regex"));
516
517    FLOAT_REGEX.replace_all(json, ": $1$2").into_owned()
518}
519
520/// Create a protobuf text format response formatter.
521///
522/// When `use_separator` is true, prepends a 0x1E record separator
523/// character between messages (after the first).
524///
525/// Equivalent to Go's `NewTextFormatter()` (format.go:164-213).
526pub fn text_formatter(use_separator: bool) -> Formatter {
527    let num_formatted = Cell::new(0usize);
528
529    Box::new(move |msg: &DynamicMessage| {
530        let mut output = String::new();
531
532        if use_separator && num_formatted.get() > 0 {
533            output.push('\x1e');
534        }
535
536        // Use Display with alternate flag for pretty-printed (indented) text format.
537        // prost-reflect uses curly braces {} (modern proto text format) while Go
538        // uses angle brackets <> (legacy). Both are valid protobuf text format.
539        let text = format!("{msg:#}");
540        // Remove trailing newline (matching Go behavior)
541        let text = text.trim_end_matches('\n');
542        output.push_str(text);
543
544        num_formatted.set(num_formatted.get() + 1);
545        Ok(output)
546    })
547}
548
549/// Map a tonic gRPC status code to its canonical name.
550///
551/// Equivalent to Go's `codes.Code.String()`.
552pub fn status_code_name(code: tonic::Code) -> &'static str {
553    match code {
554        tonic::Code::Ok => "OK",
555        tonic::Code::Cancelled => "Canceled",
556        tonic::Code::Unknown => "Unknown",
557        tonic::Code::InvalidArgument => "InvalidArgument",
558        tonic::Code::DeadlineExceeded => "DeadlineExceeded",
559        tonic::Code::NotFound => "NotFound",
560        tonic::Code::AlreadyExists => "AlreadyExists",
561        tonic::Code::PermissionDenied => "PermissionDenied",
562        tonic::Code::ResourceExhausted => "ResourceExhausted",
563        tonic::Code::FailedPrecondition => "FailedPrecondition",
564        tonic::Code::Aborted => "Aborted",
565        tonic::Code::OutOfRange => "OutOfRange",
566        tonic::Code::Unimplemented => "Unimplemented",
567        tonic::Code::Internal => "Internal",
568        tonic::Code::Unavailable => "Unavailable",
569        tonic::Code::DataLoss => "DataLoss",
570        tonic::Code::Unauthenticated => "Unauthenticated",
571    }
572}
573
574/// Print a gRPC status to stderr in the standard format.
575///
576/// Equivalent to Go's `PrintStatus()` (format.go:517-554).
577///
578/// Format:
579/// ```text
580/// ERROR:
581///   Code: <CODE_NAME>
582///   Message: <message>
583/// ```
584pub fn print_status(status: &tonic::Status, formatter: Option<&Formatter>) {
585    write_status(&mut io::stderr(), status, formatter);
586}
587
588/// Write a gRPC status to the given writer.
589///
590/// Allows callers to direct status output to any writer (stderr, buffer, etc.)
591/// rather than hardcoding to stderr. The `print_status` function uses this
592/// with `io::stderr()`.
593pub fn write_status(w: &mut dyn io::Write, status: &tonic::Status, formatter: Option<&Formatter>) {
594    if status.code() == tonic::Code::Ok {
595        let _ = writeln!(w, "OK");
596        return;
597    }
598    let _ = writeln!(w, "ERROR:");
599    let _ = writeln!(w, "  Code: {}", status_code_name(status.code()));
600    let _ = writeln!(w, "  Message: {}", status.message());
601
602    // Parse status details from grpc-status-details-bin trailer.
603    // This contains a serialized google.rpc.Status with Any-typed details.
604    let details_bytes = status.details();
605    if details_bytes.is_empty() {
606        return;
607    }
608
609    // Decode as google.rpc.Status (manually, since prost_types doesn't include it).
610    // The wire format is: field 1 (int32 code), field 2 (string message),
611    // field 3 (repeated google.protobuf.Any details).
612    // We only need the details field, so we decode the Any messages directly.
613    let any_messages = decode_status_details(details_bytes);
614    if any_messages.is_empty() {
615        return;
616    }
617
618    for (i, any) in any_messages.iter().enumerate() {
619        if i == 0 {
620            let _ = writeln!(w, "  Details:");
621        }
622        // Try to format the Any message using the formatter if available
623        let formatted = formatter.and_then(|fmt| format_any_detail(any, fmt).ok());
624
625        if let Some(text) = formatted {
626            let _ = writeln!(w, "  - {}", any.type_url);
627            for line in text.lines() {
628                let _ = writeln!(w, "      {line}");
629            }
630        } else {
631            // Fallback: show type URL and raw base64 value
632            let _ = writeln!(w, "  - {} ({} bytes)", any.type_url, any.value.len());
633        }
634    }
635}
636
637/// Decode the details field (field 3, repeated Any) from a serialized google.rpc.Status.
638///
639/// google.rpc.Status wire format:
640///   field 1: int32 code
641///   field 2: string message
642///   field 3: repeated google.protobuf.Any
643///
644/// google.protobuf.Any wire format:
645///   field 1: string type_url
646///   field 2: bytes value
647fn decode_status_details(data: &[u8]) -> Vec<prost_types::Any> {
648    use prost::Message;
649
650    // Use prost's low-level decoding by defining the Status message structure
651    #[derive(Message, Clone)]
652    struct RpcStatus {
653        #[prost(int32, tag = "1")]
654        _code: i32,
655        #[prost(string, tag = "2")]
656        _message: String,
657        #[prost(message, repeated, tag = "3")]
658        details: Vec<prost_types::Any>,
659    }
660
661    match RpcStatus::decode(data) {
662        Ok(status) => status.details,
663        Err(_) => Vec::new(),
664    }
665}
666
667/// Attempt to format an Any-typed detail message as JSON.
668///
669/// Uses a well-known types descriptor pool to decode common error detail types
670/// like google.rpc.ErrorInfo, google.rpc.BadRequest, etc.
671fn format_any_detail(
672    any: &prost_types::Any,
673    formatter: &Formatter,
674) -> std::result::Result<String, Box<dyn std::error::Error>> {
675    // Extract the message type name from the type_url
676    let type_name = any
677        .type_url
678        .rsplit_once('/')
679        .map(|(_, name)| name)
680        .unwrap_or(&any.type_url);
681
682    // Try to find the message type in a pool with well-known types
683    let pool = prost_reflect::DescriptorPool::global();
684    let msg_desc = pool.get_message_by_name(type_name).ok_or("unknown type")?;
685
686    let msg = DynamicMessage::decode(msg_desc, any.value.as_slice())
687        .map_err(|e| format!("failed to decode detail: {e}"))?;
688
689    (formatter)(&msg).map_err(|e| e.into())
690}
691
692#[cfg(test)]
693mod tests {
694    use super::*;
695    use prost_reflect::DescriptorPool;
696
697    fn make_pool() -> DescriptorPool {
698        let fds = prost_types::FileDescriptorSet {
699            file: vec![prost_types::FileDescriptorProto {
700                name: Some("test.proto".into()),
701                package: Some("test.v1".into()),
702                message_type: vec![prost_types::DescriptorProto {
703                    name: Some("HelloRequest".into()),
704                    field: vec![
705                        prost_types::FieldDescriptorProto {
706                            name: Some("name".into()),
707                            number: Some(1),
708                            r#type: Some(9), // TYPE_STRING
709                            label: Some(1),  // LABEL_OPTIONAL
710                            json_name: Some("name".into()),
711                            ..Default::default()
712                        },
713                        prost_types::FieldDescriptorProto {
714                            name: Some("count".into()),
715                            number: Some(2),
716                            r#type: Some(5), // TYPE_INT32
717                            label: Some(1),
718                            json_name: Some("count".into()),
719                            ..Default::default()
720                        },
721                    ],
722                    ..Default::default()
723                }],
724                syntax: Some("proto3".into()),
725                ..Default::default()
726            }],
727        };
728        DescriptorPool::from_file_descriptor_set(fds).unwrap()
729    }
730
731    #[test]
732    fn parse_single_json_message() {
733        let pool = make_pool();
734        let desc = pool.get_message_by_name("test.v1.HelloRequest").unwrap();
735        let opts = FormatOptions::default();
736        let mut parser =
737            JsonRequestParser::new(Some(r#"{"name": "world", "count": 42}"#), &opts).unwrap();
738
739        let msg = parser.next(&desc).unwrap();
740        assert_eq!(parser.num_requests(), 1);
741
742        // Verify fields
743        let name_field = desc.get_field_by_name("name").unwrap();
744        let name_val = msg.get_field(&name_field);
745        assert_eq!(name_val.as_str(), Some("world"));
746
747        let count_field = desc.get_field_by_name("count").unwrap();
748        let count_val = msg.get_field(&count_field);
749        assert_eq!(count_val.as_i32(), Some(42));
750    }
751
752    #[test]
753    fn parse_multiple_json_messages() {
754        let pool = make_pool();
755        let desc = pool.get_message_by_name("test.v1.HelloRequest").unwrap();
756        let opts = FormatOptions::default();
757        let mut parser =
758            JsonRequestParser::new(Some(r#"{"name": "first"} {"name": "second"}"#), &opts).unwrap();
759
760        let msg1 = parser.next(&desc).unwrap();
761        let name1 = msg1.get_field(&desc.get_field_by_name("name").unwrap());
762        assert_eq!(name1.as_str(), Some("first"));
763
764        let msg2 = parser.next(&desc).unwrap();
765        let name2 = msg2.get_field(&desc.get_field_by_name("name").unwrap());
766        assert_eq!(name2.as_str(), Some("second"));
767
768        assert!(matches!(parser.next(&desc), Err(ParseError::Eof)));
769        assert_eq!(parser.num_requests(), 2);
770    }
771
772    #[test]
773    fn parse_empty_input() {
774        let pool = make_pool();
775        let desc = pool.get_message_by_name("test.v1.HelloRequest").unwrap();
776        let opts = FormatOptions::default();
777        let mut parser = JsonRequestParser::new(None, &opts).unwrap();
778
779        assert!(matches!(parser.next(&desc), Err(ParseError::Eof)));
780        assert_eq!(parser.num_requests(), 0);
781    }
782
783    #[test]
784    fn format_json_without_defaults() {
785        let pool = make_pool();
786        let desc = pool.get_message_by_name("test.v1.HelloRequest").unwrap();
787        let opts = FormatOptions {
788            emit_defaults: false,
789            ..Default::default()
790        };
791        let formatter = json_formatter(&opts);
792
793        let mut msg = DynamicMessage::new(desc.clone());
794        let name_field = desc.get_field_by_name("name").unwrap();
795        msg.set_field(&name_field, prost_reflect::Value::String("world".into()));
796
797        let output = (formatter)(&msg).unwrap();
798        assert!(output.contains("\"name\": \"world\""));
799        // count field has default value 0 and should be skipped
800        assert!(!output.contains("count"));
801    }
802
803    #[test]
804    fn format_json_with_defaults() {
805        let pool = make_pool();
806        let desc = pool.get_message_by_name("test.v1.HelloRequest").unwrap();
807        let opts = FormatOptions {
808            emit_defaults: true,
809            ..Default::default()
810        };
811        let formatter = json_formatter(&opts);
812
813        let mut msg = DynamicMessage::new(desc.clone());
814        let name_field = desc.get_field_by_name("name").unwrap();
815        msg.set_field(&name_field, prost_reflect::Value::String("world".into()));
816
817        let output = (formatter)(&msg).unwrap();
818        assert!(output.contains("\"name\": \"world\""));
819        // count field should now be included with default value
820        assert!(output.contains("\"count\""));
821    }
822
823    #[test]
824    fn parse_unknown_fields_rejected_by_default() {
825        let pool = make_pool();
826        let desc = pool.get_message_by_name("test.v1.HelloRequest").unwrap();
827        let opts = FormatOptions::default();
828        let mut parser =
829            JsonRequestParser::new(Some(r#"{"name": "test", "unknown_field": 42}"#), &opts)
830                .unwrap();
831
832        let result = parser.next(&desc);
833        assert!(matches!(result, Err(ParseError::Error(_))));
834    }
835
836    #[test]
837    fn parse_unknown_fields_allowed() {
838        let pool = make_pool();
839        let desc = pool.get_message_by_name("test.v1.HelloRequest").unwrap();
840        let opts = FormatOptions {
841            allow_unknown_fields: true,
842            ..Default::default()
843        };
844        let mut parser =
845            JsonRequestParser::new(Some(r#"{"name": "test", "unknown_field": 42}"#), &opts)
846                .unwrap();
847
848        let msg = parser.next(&desc).unwrap();
849        let name_val = msg.get_field(&desc.get_field_by_name("name").unwrap());
850        assert_eq!(name_val.as_str(), Some("test"));
851    }
852
853    #[test]
854    fn parse_text_format_single_message() {
855        let pool = make_pool();
856        let desc = pool.get_message_by_name("test.v1.HelloRequest").unwrap();
857        let mut parser = TextRequestParser::new(Some("name: \"world\" count: 42")).unwrap();
858
859        let msg = parser.next(&desc).unwrap();
860        assert_eq!(parser.num_requests(), 1);
861
862        let name_val = msg.get_field(&desc.get_field_by_name("name").unwrap());
863        assert_eq!(name_val.as_str(), Some("world"));
864
865        let count_val = msg.get_field(&desc.get_field_by_name("count").unwrap());
866        assert_eq!(count_val.as_i32(), Some(42));
867    }
868
869    #[test]
870    fn parse_text_format_multiple_with_separator() {
871        let pool = make_pool();
872        let desc = pool.get_message_by_name("test.v1.HelloRequest").unwrap();
873        let mut parser =
874            TextRequestParser::new(Some("name: \"first\"\x1ename: \"second\"")).unwrap();
875
876        let msg1 = parser.next(&desc).unwrap();
877        let name1 = msg1.get_field(&desc.get_field_by_name("name").unwrap());
878        assert_eq!(name1.as_str(), Some("first"));
879
880        let msg2 = parser.next(&desc).unwrap();
881        let name2 = msg2.get_field(&desc.get_field_by_name("name").unwrap());
882        assert_eq!(name2.as_str(), Some("second"));
883
884        assert!(matches!(parser.next(&desc), Err(ParseError::Eof)));
885        assert_eq!(parser.num_requests(), 2);
886    }
887
888    #[test]
889    fn parse_text_format_empty_input() {
890        let pool = make_pool();
891        let desc = pool.get_message_by_name("test.v1.HelloRequest").unwrap();
892        let mut parser = TextRequestParser::new(None).unwrap();
893
894        // First call with empty input returns an empty message (matching Go behavior)
895        let msg = parser.next(&desc).unwrap();
896        assert_eq!(parser.num_requests(), 1);
897        // Verify it's an empty/default message
898        let name_val = msg.get_field(&desc.get_field_by_name("name").unwrap());
899        assert_eq!(name_val.as_str(), Some(""));
900
901        // Second call returns Eof
902        assert!(matches!(parser.next(&desc), Err(ParseError::Eof)));
903        assert_eq!(parser.num_requests(), 1);
904    }
905
906    #[test]
907    fn format_text_output() {
908        let pool = make_pool();
909        let desc = pool.get_message_by_name("test.v1.HelloRequest").unwrap();
910        let formatter = text_formatter(false);
911
912        let mut msg = DynamicMessage::new(desc.clone());
913        let name_field = desc.get_field_by_name("name").unwrap();
914        msg.set_field(&name_field, prost_reflect::Value::String("world".into()));
915
916        let output = (formatter)(&msg).unwrap();
917        assert!(output.contains("name"));
918        assert!(output.contains("world"));
919    }
920
921    #[test]
922    fn format_text_with_separator() {
923        let pool = make_pool();
924        let desc = pool.get_message_by_name("test.v1.HelloRequest").unwrap();
925        let formatter = text_formatter(true);
926
927        let mut msg1 = DynamicMessage::new(desc.clone());
928        let name_field = desc.get_field_by_name("name").unwrap();
929        msg1.set_field(&name_field, prost_reflect::Value::String("first".into()));
930
931        let mut msg2 = DynamicMessage::new(desc.clone());
932        msg2.set_field(&name_field, prost_reflect::Value::String("second".into()));
933
934        let out1 = (formatter)(&msg1).unwrap();
935        assert!(!out1.contains('\x1e')); // No separator for first message
936
937        let out2 = (formatter)(&msg2).unwrap();
938        assert!(out2.starts_with('\x1e')); // Separator for subsequent messages
939    }
940}