Skip to main content

buffa_types/
any_ext.rs

1//! Ergonomic helpers for [`google::protobuf::Any`](crate::google::protobuf::Any).
2
3use alloc::string::String;
4
5use crate::google::protobuf::Any;
6
7impl Any {
8    /// Pack a message into an [`Any`] with the given type URL.
9    ///
10    /// The type URL is conventionally of the form
11    /// `type.googleapis.com/fully.qualified.TypeName`, but this method does
12    /// not enforce that convention — any string is accepted.
13    pub fn pack(msg: &impl buffa::Message, type_url: impl Into<String>) -> Self {
14        Any {
15            type_url: type_url.into(),
16            value: msg.encode_to_vec(),
17            ..Default::default()
18        }
19    }
20
21    /// Unpack the contained message, decoding its bytes as `T`, **without
22    /// checking the `type_url`**.
23    ///
24    /// This method always attempts to decode the payload as `T` regardless
25    /// of whether `type_url` actually identifies `T`. Use [`Any::unpack_if`]
26    /// when you need to verify the stored type before decoding.
27    ///
28    /// # Errors
29    ///
30    /// Returns a [`buffa::DecodeError`] if the bytes cannot be decoded as `T`.
31    pub fn unpack_unchecked<T: buffa::Message>(&self) -> Result<T, buffa::DecodeError> {
32        T::decode(&mut self.value.as_slice())
33    }
34
35    /// Unpack the contained message as `T`, but only if the `type_url`
36    /// matches `expected_type_url`.
37    ///
38    /// Returns `Ok(None)` when the type URL does not match.
39    ///
40    /// # Errors
41    ///
42    /// Returns a [`buffa::DecodeError`] if the type URL matches but the bytes
43    /// cannot be decoded as `T`.
44    pub fn unpack_if<T: buffa::Message>(
45        &self,
46        expected_type_url: &str,
47    ) -> Result<Option<T>, buffa::DecodeError> {
48        if self.type_url != expected_type_url {
49            return Ok(None);
50        }
51        T::decode(&mut self.value.as_slice()).map(Some)
52    }
53
54    /// Returns `true` if this [`Any`]'s `type_url` matches the given string.
55    pub fn is_type(&self, type_url: &str) -> bool {
56        self.type_url == type_url
57    }
58
59    /// Returns the type URL stored in this [`Any`].
60    pub fn type_url(&self) -> &str {
61        &self.type_url
62    }
63}
64
65// ── WKT type registry ───────────────────────────────────────────────────────
66
67/// Registers all well-known types with the given [`AnyRegistry`](buffa::any_registry::AnyRegistry).
68///
69/// This registers Duration, Timestamp, FieldMask, Value, Struct, ListValue,
70/// Empty, all wrapper types, and Any itself, enabling proto3-compliant JSON
71/// serialization when these types appear inside `google.protobuf.Any` fields.
72///
73/// # Example
74///
75/// ```rust,no_run
76/// use buffa::any_registry::AnyRegistry;
77///
78/// let mut registry = AnyRegistry::new();
79/// buffa_types::register_wkt_types(&mut registry);
80/// ```
81#[cfg(feature = "json")]
82pub fn register_wkt_types(registry: &mut buffa::any_registry::AnyRegistry) {
83    use crate::google::protobuf::*;
84    use alloc::string::ToString;
85    use buffa::any_registry::AnyTypeEntry;
86
87    macro_rules! register_type {
88        ($type:ty, $wkt:expr) => {
89            registry.register(AnyTypeEntry {
90                type_url: <$type>::TYPE_URL,
91                to_json: |bytes| {
92                    let msg = <$type as buffa::Message>::decode(&mut &*bytes)
93                        .map_err(|e| e.to_string())?;
94                    serde_json::to_value(&msg).map_err(|e| e.to_string())
95                },
96                from_json: |value| {
97                    let msg: $type = serde_json::from_value(value).map_err(|e| e.to_string())?;
98                    Ok(buffa::Message::encode_to_vec(&msg))
99                },
100                is_wkt: $wkt,
101            });
102        };
103    }
104
105    // WKTs with special JSON mappings (use "value" wrapping in Any).
106    register_type!(Duration, true);
107    register_type!(Timestamp, true);
108    register_type!(FieldMask, true);
109    register_type!(Value, true);
110    register_type!(Struct, true);
111    register_type!(ListValue, true);
112    register_type!(BoolValue, true);
113    register_type!(Int32Value, true);
114    register_type!(UInt32Value, true);
115    register_type!(Int64Value, true);
116    register_type!(UInt64Value, true);
117    register_type!(FloatValue, true);
118    register_type!(DoubleValue, true);
119    register_type!(StringValue, true);
120    register_type!(BytesValue, true);
121    register_type!(Any, true);
122
123    // Regular messages (fields inlined in Any JSON).
124    register_type!(Empty, false);
125}
126
127// ── serde impls ──────────────────────────────────────────────────────────────
128//
129// Proto3 JSON for `Any` uses the global `AnyRegistry` to serialize the
130// embedded message with its fields inline (regular messages) or wrapped in a
131// `"value"` key (WKTs). Falls back to base64-encoded `value` when the
132// registry is absent or the type URL is not registered.
133
134#[cfg(feature = "json")]
135struct Base64Bytes<'a>(&'a [u8]);
136
137#[cfg(feature = "json")]
138impl serde::Serialize for Base64Bytes<'_> {
139    fn serialize<S: serde::Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
140        buffa::json_helpers::bytes::serialize(self.0, s)
141    }
142}
143
144#[cfg(feature = "json")]
145impl serde::Serialize for Any {
146    fn serialize<S: serde::Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
147        use serde::ser::SerializeMap;
148
149        if self.type_url.is_empty() {
150            return s.serialize_map(Some(0))?.end();
151        }
152
153        let lookup = buffa::any_registry::with_any_registry(|reg| {
154            reg.and_then(|r| r.lookup(&self.type_url))
155                .map(|e| (e.to_json, e.is_wkt))
156        });
157
158        match lookup {
159            Some((to_json, is_wkt)) => {
160                let json_val = to_json(&self.value).map_err(serde::ser::Error::custom)?;
161                if is_wkt {
162                    let mut map = s.serialize_map(Some(2))?;
163                    map.serialize_entry("@type", &self.type_url)?;
164                    map.serialize_entry("value", &json_val)?;
165                    map.end()
166                } else {
167                    let fields = match &json_val {
168                        serde_json::Value::Object(m) => m,
169                        _ => {
170                            return Err(serde::ser::Error::custom(
171                                "Any: to_json for non-WKT must return a JSON object",
172                            ))
173                        }
174                    };
175                    let mut map = s.serialize_map(Some(1 + fields.len()))?;
176                    map.serialize_entry("@type", &self.type_url)?;
177                    for (k, v) in fields {
178                        map.serialize_entry(k, v)?;
179                    }
180                    map.end()
181                }
182            }
183            None => {
184                let mut map = s.serialize_map(Some(2))?;
185                map.serialize_entry("@type", &self.type_url)?;
186                map.serialize_entry("value", &Base64Bytes(&self.value))?;
187                map.end()
188            }
189        }
190    }
191}
192
193#[cfg(feature = "json")]
194impl<'de> serde::Deserialize<'de> for Any {
195    fn deserialize<D: serde::Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
196        // Buffer the entire object so @type can appear at any position.
197        let mut obj: serde_json::Map<String, serde_json::Value> =
198            serde::Deserialize::deserialize(d)?;
199
200        let type_url = match obj.remove("@type") {
201            Some(serde_json::Value::String(s)) => s,
202            Some(_) => {
203                return Err(serde::de::Error::custom("@type must be a string"));
204            }
205            None => return Ok(Any::default()),
206        };
207
208        // The type URL must be non-empty and contain a '/' separating the
209        // host/authority from the fully-qualified type name (e.g.
210        // "type.googleapis.com/google.protobuf.Duration").
211        if type_url.is_empty() || !type_url.contains('/') {
212            return Err(serde::de::Error::custom(
213                "@type must be a valid type URL containing a '/' (e.g. type.googleapis.com/pkg.Type)",
214            ));
215        }
216
217        let lookup = buffa::any_registry::with_any_registry(|reg| {
218            reg.and_then(|r| r.lookup(&type_url))
219                .map(|e| (e.from_json, e.is_wkt))
220        });
221
222        let value = match lookup {
223            Some((from_json, true)) => {
224                let json_val = obj.remove("value").unwrap_or(serde_json::Value::Null);
225                from_json(json_val).map_err(serde::de::Error::custom)?
226            }
227            Some((from_json, false)) => {
228                let json_obj = serde_json::Value::Object(obj);
229                from_json(json_obj).map_err(serde::de::Error::custom)?
230            }
231            None => {
232                // Fallback: base64 decode the "value" field.
233                match obj.remove("value") {
234                    Some(serde_json::Value::String(s)) => buffa::json_helpers::bytes::deserialize(
235                        serde::de::value::StringDeserializer::<D::Error>::new(s),
236                    )?,
237                    _ => alloc::vec::Vec::new(),
238                }
239            }
240        };
241
242        Ok(Any {
243            type_url,
244            value,
245            ..Default::default()
246        })
247    }
248}
249
250#[cfg(test)]
251mod tests {
252    use super::*;
253    use crate::google::protobuf::Timestamp;
254    use buffa::Message as _;
255
256    #[test]
257    fn pack_and_unpack() {
258        let ts = Timestamp {
259            seconds: 1_000_000_000,
260            nanos: 0,
261            ..Default::default()
262        };
263        let any = Any::pack(&ts, "type.googleapis.com/google.protobuf.Timestamp");
264        assert_eq!(
265            any.type_url(),
266            "type.googleapis.com/google.protobuf.Timestamp"
267        );
268
269        let decoded: Timestamp = any.unpack_unchecked().unwrap();
270        assert_eq!(decoded, ts);
271    }
272
273    #[test]
274    fn unpack_if_matching() {
275        let ts = Timestamp {
276            seconds: 42,
277            ..Default::default()
278        };
279        let any = Any::pack(&ts, "type.googleapis.com/google.protobuf.Timestamp");
280
281        let result: Option<Timestamp> = any
282            .unpack_if("type.googleapis.com/google.protobuf.Timestamp")
283            .unwrap();
284        assert_eq!(result, Some(ts));
285    }
286
287    #[test]
288    fn unpack_if_wrong_type_returns_none() {
289        let ts = Timestamp {
290            seconds: 42,
291            ..Default::default()
292        };
293        let any = Any::pack(&ts, "type.googleapis.com/google.protobuf.Timestamp");
294
295        let result: Option<Timestamp> = any
296            .unpack_if("type.googleapis.com/google.protobuf.Duration")
297            .unwrap();
298        assert!(result.is_none());
299    }
300
301    #[test]
302    fn is_type() {
303        let ts = Timestamp::default();
304        let any = Any::pack(&ts, "type.googleapis.com/google.protobuf.Timestamp");
305        assert!(any.is_type("type.googleapis.com/google.protobuf.Timestamp"));
306        assert!(!any.is_type("type.googleapis.com/google.protobuf.Duration"));
307    }
308
309    #[test]
310    fn round_trip_encoding() {
311        let ts = Timestamp {
312            seconds: 99,
313            nanos: 1,
314            ..Default::default()
315        };
316        let any = Any::pack(&ts, "test");
317
318        let bytes = any.encode_to_vec();
319        let decoded_any = Any::decode(&mut bytes.as_slice()).unwrap();
320        let decoded_ts: Timestamp = decoded_any.unpack_unchecked().unwrap();
321        assert_eq!(decoded_ts, ts);
322    }
323
324    #[cfg(feature = "json")]
325    mod serde_tests {
326        use super::*;
327        use crate::google::protobuf::Duration;
328        use buffa::any_registry::{clear_any_registry, set_any_registry, AnyRegistry};
329
330        /// Mutex to serialize tests that manipulate the global AnyRegistry.
331        /// Each test binary needs its own lock since #[cfg(test)] modules
332        /// cannot be shared across crates.
333        static REGISTRY_LOCK: std::sync::Mutex<()> = std::sync::Mutex::new(());
334
335        fn with_registry<R>(f: impl FnOnce() -> R) -> R {
336            let _guard = REGISTRY_LOCK.lock().unwrap();
337            let mut registry = AnyRegistry::new();
338            register_wkt_types(&mut registry);
339            set_any_registry(Box::new(registry));
340            let result = f();
341            clear_any_registry();
342            result
343        }
344
345        fn without_registry<R>(f: impl FnOnce() -> R) -> R {
346            let _guard = REGISTRY_LOCK.lock().unwrap();
347            clear_any_registry();
348            f()
349        }
350
351        #[test]
352        fn serialize_wkt_uses_value_wrapping() {
353            with_registry(|| {
354                let ts = Timestamp {
355                    seconds: 1_000_000_000,
356                    nanos: 0,
357                    ..Default::default()
358                };
359                let any = Any::pack(&ts, Timestamp::TYPE_URL);
360                let json = serde_json::to_value(&any).unwrap();
361                assert_eq!(json["@type"], Timestamp::TYPE_URL);
362                assert_eq!(json["value"], "2001-09-09T01:46:40Z");
363            });
364        }
365
366        #[test]
367        fn serialize_duration_wkt() {
368            with_registry(|| {
369                let dur = Duration::from_secs_nanos(1, 500_000_000);
370                let any = Any::pack(&dur, Duration::TYPE_URL);
371                let json = serde_json::to_value(&any).unwrap();
372                assert_eq!(json["@type"], Duration::TYPE_URL);
373                assert_eq!(json["value"], "1.500s");
374            });
375        }
376
377        #[test]
378        fn serialize_empty_any_is_empty_object() {
379            with_registry(|| {
380                let any = Any::default();
381                let json = serde_json::to_string(&any).unwrap();
382                assert_eq!(json, "{}");
383            });
384        }
385
386        #[test]
387        fn deserialize_wkt_from_json() {
388            with_registry(|| {
389                let json = r#"{
390                    "@type": "type.googleapis.com/google.protobuf.Duration",
391                    "value": "1.5s"
392                }"#;
393                let any: Any = serde_json::from_str(json).unwrap();
394                assert_eq!(any.type_url, Duration::TYPE_URL);
395
396                let dur: Duration = any.unpack_unchecked().unwrap();
397                assert_eq!(dur.seconds, 1);
398                assert_eq!(dur.nanos, 500_000_000);
399            });
400        }
401
402        #[test]
403        fn deserialize_unordered_type_tag() {
404            with_registry(|| {
405                // @type appears after the value field.
406                let json = r#"{
407                    "value": "1.5s",
408                    "@type": "type.googleapis.com/google.protobuf.Duration"
409                }"#;
410                let any: Any = serde_json::from_str(json).unwrap();
411                assert_eq!(any.type_url, Duration::TYPE_URL);
412
413                let dur: Duration = any.unpack_unchecked().unwrap();
414                assert_eq!(dur.seconds, 1);
415                assert_eq!(dur.nanos, 500_000_000);
416            });
417        }
418
419        #[test]
420        fn roundtrip_wkt_json() {
421            with_registry(|| {
422                let ts = Timestamp {
423                    seconds: 1_000_000_000,
424                    nanos: 0,
425                    ..Default::default()
426                };
427                let any = Any::pack(&ts, Timestamp::TYPE_URL);
428                let json = serde_json::to_string(&any).unwrap();
429                let decoded: Any = serde_json::from_str(&json).unwrap();
430                let decoded_ts: Timestamp = decoded.unpack_unchecked().unwrap();
431                assert_eq!(decoded_ts, ts);
432            });
433        }
434
435        #[test]
436        fn nested_any_roundtrip() {
437            with_registry(|| {
438                let dur = Duration::from_secs(42);
439                let inner_any = Any::pack(&dur, Duration::TYPE_URL);
440                let outer_any = Any::pack(&inner_any, Any::TYPE_URL);
441
442                let json = serde_json::to_string(&outer_any).unwrap();
443                let decoded_outer: Any = serde_json::from_str(&json).unwrap();
444                let decoded_inner: Any = decoded_outer.unpack_unchecked().unwrap();
445                let decoded_dur: Duration = decoded_inner.unpack_unchecked().unwrap();
446                assert_eq!(decoded_dur.seconds, 42);
447            });
448        }
449
450        #[test]
451        fn fallback_base64_without_registry() {
452            without_registry(|| {
453                let any = Any {
454                    type_url: "type.googleapis.com/unknown.Type".into(),
455                    value: vec![0x08, 0x96, 0x01],
456                    ..Default::default()
457                };
458                let json = serde_json::to_string(&any).unwrap();
459                assert!(json.contains("@type"));
460                assert!(json.contains("value"));
461
462                let decoded: Any = serde_json::from_str(&json).unwrap();
463                assert_eq!(decoded.type_url, any.type_url);
464                assert_eq!(decoded.value, any.value);
465            });
466        }
467
468        #[test]
469        fn deserialize_missing_type_returns_default() {
470            let json = r#"{}"#;
471            let any: Any = serde_json::from_str(json).unwrap();
472            assert_eq!(any, Any::default());
473        }
474
475        #[test]
476        fn fallback_base64_with_registry_but_unknown_type() {
477            with_registry(|| {
478                let any = Any {
479                    type_url: "type.googleapis.com/unknown.Type".into(),
480                    value: vec![0x08, 0x96, 0x01],
481                    ..Default::default()
482                };
483                let json = serde_json::to_string(&any).unwrap();
484                let decoded: Any = serde_json::from_str(&json).unwrap();
485                assert_eq!(decoded.type_url, any.type_url);
486                assert_eq!(decoded.value, any.value);
487            });
488        }
489
490        #[test]
491        fn deserialize_rejects_empty_type_url() {
492            let json = r#"{"@type": "", "value": ""}"#;
493            let err = serde_json::from_str::<Any>(json).unwrap_err();
494            assert!(err.to_string().contains("valid type URL"), "{err}");
495        }
496
497        #[test]
498        fn deserialize_rejects_type_url_without_slash() {
499            let json = r#"{"@type": "not_a_url", "value": ""}"#;
500            let err = serde_json::from_str::<Any>(json).unwrap_err();
501            assert!(err.to_string().contains("valid type URL"), "{err}");
502        }
503
504        // ── Non-WKT registered type (fields inlined at top level) ─────
505        // WKTs use {"@type": ..., "value": <json>} wrapping.
506        // Regular messages use {"@type": ..., "field1": ..., "field2": ...}.
507        // Previously only the WKT path was tested.
508
509        /// Hand-written to_json: decode the Any bytes as a single varint
510        /// field (number=1), return it as a JSON object {"id": N}.
511        fn user_type_to_json(bytes: &[u8]) -> Result<serde_json::Value, String> {
512            use buffa::encoding::Tag;
513            let mut cur = bytes;
514            let mut id = 0i64;
515            while !cur.is_empty() {
516                let tag = Tag::decode(&mut cur).map_err(|e| e.to_string())?;
517                if tag.field_number() == 1 {
518                    id =
519                        buffa::encoding::decode_varint(&mut cur).map_err(|e| e.to_string())? as i64;
520                } else {
521                    buffa::encoding::skip_field(tag, &mut cur).map_err(|e| e.to_string())?;
522                }
523            }
524            Ok(serde_json::json!({ "id": id }))
525        }
526
527        /// Hand-written from_json: extract {"id": N}, encode as varint field 1.
528        fn user_type_from_json(value: serde_json::Value) -> Result<alloc::vec::Vec<u8>, String> {
529            use buffa::encoding::{encode_varint, Tag, WireType};
530            let id = value
531                .get("id")
532                .and_then(|v| v.as_i64())
533                .ok_or_else(|| "missing or invalid 'id' field".to_string())?;
534            let mut buf = alloc::vec::Vec::new();
535            Tag::new(1, WireType::Varint).encode(&mut buf);
536            encode_varint(id as u64, &mut buf);
537            Ok(buf)
538        }
539
540        fn with_user_type_registry<R>(f: impl FnOnce() -> R) -> R {
541            use buffa::any_registry::AnyTypeEntry;
542            let _guard = REGISTRY_LOCK.lock().unwrap();
543            let mut registry = AnyRegistry::new();
544            // Register as NON-WKT (is_wkt=false) — fields inline at top level.
545            registry.register(AnyTypeEntry {
546                type_url: "type.example.com/user.Thing",
547                to_json: user_type_to_json,
548                from_json: user_type_from_json,
549                is_wkt: false,
550            });
551            set_any_registry(Box::new(registry));
552            let result = f();
553            clear_any_registry();
554            result
555        }
556
557        #[test]
558        fn serialize_non_wkt_inlines_fields() {
559            with_user_type_registry(|| {
560                // Encode {id: 42} as proto wire bytes.
561                let any = Any {
562                    type_url: "type.example.com/user.Thing".into(),
563                    // field 1, varint 42: tag=0x08, value=0x2A
564                    value: vec![0x08, 0x2A],
565                    ..Default::default()
566                };
567
568                let json = serde_json::to_value(&any).unwrap();
569                // Non-WKT format: fields at top level alongside @type.
570                assert_eq!(json["@type"], "type.example.com/user.Thing");
571                assert_eq!(json["id"], 42);
572                // Should NOT have a "value" wrapper key.
573                assert!(
574                    json.get("value").is_none(),
575                    "non-WKT should not use 'value' wrapping: {json}"
576                );
577            });
578        }
579
580        #[test]
581        fn deserialize_non_wkt_from_inlined_fields() {
582            with_user_type_registry(|| {
583                let json = r#"{
584                    "@type": "type.example.com/user.Thing",
585                    "id": 99
586                }"#;
587                let any: Any = serde_json::from_str(json).unwrap();
588                assert_eq!(any.type_url, "type.example.com/user.Thing");
589                // Verify the from_json encoded it back to wire bytes.
590                assert_eq!(any.value, vec![0x08, 99]);
591            });
592        }
593
594        #[test]
595        fn non_wkt_round_trip() {
596            with_user_type_registry(|| {
597                let original = Any {
598                    type_url: "type.example.com/user.Thing".into(),
599                    value: vec![0x08, 0x07], // id=7
600                    ..Default::default()
601                };
602                let json = serde_json::to_string(&original).unwrap();
603                let decoded: Any = serde_json::from_str(&json).unwrap();
604                assert_eq!(decoded.type_url, original.type_url);
605                assert_eq!(decoded.value, original.value);
606            });
607        }
608
609        #[test]
610        fn serialize_non_wkt_rejects_non_object_json() {
611            // If to_json for a non-WKT type returns something other than a
612            // JSON object, serialization must fail (can't inline non-object
613            // fields alongside @type).
614            use buffa::any_registry::AnyTypeEntry;
615            let _guard = REGISTRY_LOCK.lock().unwrap();
616            let mut registry = AnyRegistry::new();
617            registry.register(AnyTypeEntry {
618                type_url: "type.example.com/user.BadType",
619                to_json: |_bytes| Ok(serde_json::Value::Number(42.into())),
620                from_json: |_v| Ok(alloc::vec::Vec::new()),
621                is_wkt: false,
622            });
623            set_any_registry(Box::new(registry));
624
625            let any = Any {
626                type_url: "type.example.com/user.BadType".into(),
627                value: vec![],
628                ..Default::default()
629            };
630            let result = serde_json::to_string(&any);
631            clear_any_registry();
632            assert!(result.is_err(), "expected error for non-object to_json");
633            assert!(
634                result
635                    .unwrap_err()
636                    .to_string()
637                    .contains("must return a JSON object"),
638                "wrong error message"
639            );
640        }
641
642        #[test]
643        fn deserialize_rejects_non_string_type() {
644            // @type as a non-string value → error.
645            let json = r#"{"@type": 123}"#;
646            let err = serde_json::from_str::<Any>(json).unwrap_err();
647            assert!(err.to_string().contains("@type must be a string"), "{err}");
648        }
649    }
650}