Skip to main content

a2a/
errordetails.rs

1// Copyright AGNTCY Contributors (https://github.com/agntcy)
2// SPDX-License-Identifier: Apache-2.0
3use std::collections::HashMap;
4
5use serde::{Deserialize, Serialize, de::Deserializer, ser::Serializer};
6use serde_json::Value;
7
8pub const ERROR_INFO_TYPE: &str = "type.googleapis.com/google.rpc.ErrorInfo";
9pub const BAD_REQUEST_TYPE: &str = "type.googleapis.com/google.rpc.BadRequest";
10pub const STRUCT_TYPE: &str = "type.googleapis.com/google.protobuf.Struct";
11pub const PROTOCOL_DOMAIN: &str = "a2a-protocol.org";
12
13/// A field-level validation error, matching `google.rpc.BadRequest.FieldViolation`.
14#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
15pub struct FieldViolation {
16    pub field: String,
17    pub description: String,
18}
19
20/// A typed error detail object using ProtoJSON `Any` representation.
21///
22/// Each detail object carries a `@type` URL identifying its schema and
23/// a map of additional fields. When serialized to JSON, `@type` is
24/// flattened into the object alongside the value fields.
25#[derive(Debug, Clone, PartialEq)]
26pub struct TypedDetail {
27    pub type_url: String,
28    pub value: HashMap<String, Value>,
29}
30
31impl TypedDetail {
32    pub fn new(type_url: impl Into<String>, value: HashMap<String, Value>) -> Self {
33        Self {
34            type_url: type_url.into(),
35            value,
36        }
37    }
38
39    /// Create a `google.rpc.ErrorInfo` detail.
40    pub fn error_info(
41        reason: impl Into<String>,
42        domain: impl Into<String>,
43        metadata: Option<HashMap<String, String>>,
44    ) -> Self {
45        let mut value = HashMap::new();
46        value.insert("reason".to_string(), Value::String(reason.into()));
47        value.insert("domain".to_string(), Value::String(domain.into()));
48        if let Some(meta) = metadata {
49            let meta_obj: serde_json::Map<String, Value> = meta
50                .into_iter()
51                .map(|(k, v)| (k, Value::String(v)))
52                .collect();
53            value.insert("metadata".to_string(), Value::Object(meta_obj));
54        }
55        Self {
56            type_url: ERROR_INFO_TYPE.to_string(),
57            value,
58        }
59    }
60
61    /// Create a `google.rpc.BadRequest` detail with field violations.
62    pub fn bad_request(field_violations: Vec<FieldViolation>) -> Self {
63        let violations: Vec<Value> = field_violations
64            .into_iter()
65            .map(|fv| serde_json::to_value(fv).unwrap_or_default())
66            .collect();
67        let mut value = HashMap::new();
68        value.insert("fieldViolations".to_string(), Value::Array(violations));
69        Self {
70            type_url: BAD_REQUEST_TYPE.to_string(),
71            value,
72        }
73    }
74
75    /// Create a typed detail from a struct (arbitrary map).
76    pub fn from_struct(fields: HashMap<String, Value>) -> Self {
77        Self {
78            type_url: STRUCT_TYPE.to_string(),
79            value: fields,
80        }
81    }
82}
83
84impl Serialize for TypedDetail {
85    fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
86        use serde::ser::SerializeMap;
87        let mut map = serializer.serialize_map(Some(self.value.len() + 1))?;
88        map.serialize_entry("@type", &self.type_url)?;
89        for (k, v) in &self.value {
90            map.serialize_entry(k, v)?;
91        }
92        map.end()
93    }
94}
95
96impl<'de> Deserialize<'de> for TypedDetail {
97    fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
98        let mut map: HashMap<String, Value> = HashMap::deserialize(deserializer)?;
99        let type_url = map
100            .remove("@type")
101            .and_then(|v| v.as_str().map(String::from))
102            .unwrap_or_else(|| STRUCT_TYPE.to_string());
103        Ok(Self {
104            type_url,
105            value: map,
106        })
107    }
108}
109
110#[cfg(test)]
111mod tests {
112    use super::*;
113
114    #[test]
115    fn test_serialize_typed_detail() {
116        let detail = TypedDetail::error_info("TASK_NOT_FOUND", PROTOCOL_DOMAIN, None);
117        let json = serde_json::to_value(&detail).unwrap();
118        assert_eq!(json["@type"], ERROR_INFO_TYPE);
119        assert_eq!(json["reason"], "TASK_NOT_FOUND");
120        assert_eq!(json["domain"], PROTOCOL_DOMAIN);
121    }
122
123    #[test]
124    fn test_deserialize_typed_detail() {
125        let json = serde_json::json!({
126            "@type": ERROR_INFO_TYPE,
127            "reason": "TASK_NOT_FOUND",
128            "domain": PROTOCOL_DOMAIN,
129            "metadata": {"taskId": "t1"}
130        });
131        let detail: TypedDetail = serde_json::from_value(json).unwrap();
132        assert_eq!(detail.type_url, ERROR_INFO_TYPE);
133        assert_eq!(detail.value["reason"], "TASK_NOT_FOUND");
134        assert_eq!(detail.value["domain"], PROTOCOL_DOMAIN);
135        assert_eq!(detail.value["metadata"]["taskId"], "t1");
136    }
137
138    #[test]
139    fn test_deserialize_without_type_defaults_to_struct() {
140        let json = serde_json::json!({"resource": "task"});
141        let detail: TypedDetail = serde_json::from_value(json).unwrap();
142        assert_eq!(detail.type_url, STRUCT_TYPE);
143        assert_eq!(detail.value["resource"], "task");
144    }
145
146    #[test]
147    fn test_round_trip() {
148        let meta = HashMap::from([("taskId".to_string(), "t1".to_string())]);
149        let detail = TypedDetail::error_info("TASK_NOT_FOUND", PROTOCOL_DOMAIN, Some(meta));
150        let serialized = serde_json::to_value(&detail).unwrap();
151        let deserialized: TypedDetail = serde_json::from_value(serialized).unwrap();
152        assert_eq!(detail, deserialized);
153    }
154
155    #[test]
156    fn test_bad_request() {
157        let violations = vec![
158            FieldViolation {
159                field: "message.parts".into(),
160                description: "At least one part is required".into(),
161            },
162            FieldViolation {
163                field: "message.role".into(),
164                description: "Role must be 'user' or 'agent'".into(),
165            },
166        ];
167        let detail = TypedDetail::bad_request(violations);
168        assert_eq!(detail.type_url, BAD_REQUEST_TYPE);
169
170        let json = serde_json::to_value(&detail).unwrap();
171        assert_eq!(json["@type"], BAD_REQUEST_TYPE);
172        let fv = json["fieldViolations"].as_array().unwrap();
173        assert_eq!(fv.len(), 2);
174        assert_eq!(fv[0]["field"], "message.parts");
175        assert_eq!(fv[0]["description"], "At least one part is required");
176        assert_eq!(fv[1]["field"], "message.role");
177    }
178
179    #[test]
180    fn test_bad_request_round_trip() {
181        let violations = vec![FieldViolation {
182            field: "task.id".into(),
183            description: "Must not be empty".into(),
184        }];
185        let detail = TypedDetail::bad_request(violations);
186        let serialized = serde_json::to_value(&detail).unwrap();
187        let deserialized: TypedDetail = serde_json::from_value(serialized).unwrap();
188        assert_eq!(detail, deserialized);
189    }
190
191    #[test]
192    fn test_from_struct() {
193        let mut fields = HashMap::new();
194        fields.insert("key".to_string(), Value::String("val".to_string()));
195        let detail = TypedDetail::from_struct(fields.clone());
196        assert_eq!(detail.type_url, STRUCT_TYPE);
197        assert_eq!(detail.value, fields);
198    }
199}