Skip to main content

a2a_rust/types/
auth.rs

1use serde::{Deserialize, Serialize};
2use serde_json::Value;
3
4use crate::A2AError;
5
6use super::{JsonObject, Message, Task, TaskState, TaskStatus};
7
8/// Conventional metadata payload used when a task enters `TASK_STATE_AUTH_REQUIRED`.
9#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
10#[serde(rename_all = "camelCase")]
11pub struct AuthRequiredMetadata {
12    /// Authorization URL the user should visit.
13    pub auth_url: String,
14    /// Authentication scheme, such as `oauth2` or `apiKey`.
15    pub auth_scheme: String,
16    /// Scopes requested by the agent.
17    #[serde(default, skip_serializing_if = "Vec::is_empty")]
18    pub scopes: Vec<String>,
19    /// Human-readable explanation for the authorization request.
20    pub description: String,
21}
22
23impl AuthRequiredMetadata {
24    /// Parse the convention from a metadata object.
25    pub fn from_metadata(metadata: &JsonObject) -> Result<Self, A2AError> {
26        serde_json::from_value(Value::Object(metadata.clone())).map_err(A2AError::from)
27    }
28
29    /// Convert the convention into a message metadata object.
30    pub fn into_metadata(self) -> Result<JsonObject, A2AError> {
31        match serde_json::to_value(self)? {
32            Value::Object(object) => Ok(object),
33            _ => Err(A2AError::Internal(
34                "auth-required metadata did not serialize to an object".to_owned(),
35            )),
36        }
37    }
38}
39
40impl Message {
41    /// Parse `TASK_STATE_AUTH_REQUIRED` metadata from this message, if present.
42    ///
43    /// This helper is intended for messages already known to participate in the
44    /// auth-required flow. If `metadata` exists but does not match the
45    /// `AuthRequiredMetadata` schema, this returns `Err` rather than `Ok(None)`.
46    pub fn auth_required_metadata(&self) -> Result<Option<AuthRequiredMetadata>, A2AError> {
47        self.metadata
48            .as_ref()
49            .map(AuthRequiredMetadata::from_metadata)
50            .transpose()
51    }
52
53    /// Replace this message's metadata with the auth-required convention payload.
54    pub fn set_auth_required_metadata(
55        &mut self,
56        metadata: AuthRequiredMetadata,
57    ) -> Result<(), A2AError> {
58        self.metadata = Some(metadata.into_metadata()?);
59        Ok(())
60    }
61}
62
63impl TaskStatus {
64    /// Parse auth-required metadata from the current status message when present.
65    ///
66    /// If the nested status message carries unrelated metadata, this returns
67    /// the underlying parse error instead of `Ok(None)`.
68    pub fn auth_required_metadata(&self) -> Result<Option<AuthRequiredMetadata>, A2AError> {
69        self.message
70            .as_ref()
71            .map(Message::auth_required_metadata)
72            .transpose()
73            .map(|metadata| metadata.flatten())
74    }
75
76    /// Validate that `TASK_STATE_AUTH_REQUIRED` carries the expected metadata convention.
77    pub fn validate_auth_required_metadata(&self) -> Result<(), A2AError> {
78        if self.state != TaskState::AuthRequired {
79            return Ok(());
80        }
81
82        let Some(message) = &self.message else {
83            return Err(A2AError::InvalidRequest(
84                "TASK_STATE_AUTH_REQUIRED requires a status message carrying auth metadata"
85                    .to_owned(),
86            ));
87        };
88
89        if message.auth_required_metadata()?.is_none() {
90            return Err(A2AError::InvalidRequest(
91                "TASK_STATE_AUTH_REQUIRED status message metadata must include authUrl, authScheme, scopes, and description"
92                    .to_owned(),
93            ));
94        }
95
96        Ok(())
97    }
98}
99
100impl Task {
101    /// Return auth-required metadata from the current status message or last history item.
102    ///
103    /// This returns `Ok(None)` when the task is not in `TASK_STATE_AUTH_REQUIRED`.
104    /// When the task is in that state, unrelated metadata on the candidate
105    /// message is treated as an error so callers can distinguish malformed
106    /// auth-required payloads from the absence of auth metadata.
107    pub fn auth_required_metadata(&self) -> Result<Option<AuthRequiredMetadata>, A2AError> {
108        if self.status.state != TaskState::AuthRequired {
109            return Ok(None);
110        }
111
112        if let Some(metadata) = self.status.auth_required_metadata()? {
113            return Ok(Some(metadata));
114        }
115
116        self.history
117            .last()
118            .map(Message::auth_required_metadata)
119            .transpose()
120            .map(|metadata| metadata.flatten())
121    }
122
123    /// Validate the repository's `TASK_STATE_AUTH_REQUIRED` metadata convention.
124    pub fn validate_auth_required_convention(&self) -> Result<(), A2AError> {
125        if self.status.state != TaskState::AuthRequired {
126            return Ok(());
127        }
128
129        if self.auth_required_metadata()?.is_none() {
130            return Err(A2AError::InvalidRequest(
131                "TASK_STATE_AUTH_REQUIRED requires auth metadata on the status message or last task message"
132                    .to_owned(),
133            ));
134        }
135
136        Ok(())
137    }
138}
139
140#[cfg(test)]
141mod tests {
142    use super::AuthRequiredMetadata;
143    use crate::types::{Message, Part, Role, Task, TaskState, TaskStatus};
144
145    #[test]
146    fn auth_required_metadata_round_trips_through_message_metadata() {
147        let mut message = Message {
148            message_id: "msg-auth-1".to_owned(),
149            context_id: Some("ctx-1".to_owned()),
150            task_id: Some("task-1".to_owned()),
151            role: Role::Agent,
152            parts: vec![Part {
153                text: Some("Please authorize access.".to_owned()),
154                raw: None,
155                url: None,
156                data: None,
157                metadata: None,
158                filename: None,
159                media_type: None,
160            }],
161            metadata: None,
162            extensions: Vec::new(),
163            reference_task_ids: Vec::new(),
164        };
165
166        message
167            .set_auth_required_metadata(AuthRequiredMetadata {
168                auth_url: "https://example.com/oauth/authorize".to_owned(),
169                auth_scheme: "oauth2".to_owned(),
170                scopes: vec!["calendar.read".to_owned()],
171                description: "Grant calendar access".to_owned(),
172            })
173            .expect("metadata should set");
174
175        let metadata = message
176            .auth_required_metadata()
177            .expect("metadata should parse")
178            .expect("metadata should exist");
179
180        assert_eq!(metadata.auth_scheme, "oauth2");
181        assert_eq!(metadata.scopes, vec!["calendar.read"]);
182    }
183
184    #[test]
185    fn task_validates_auth_required_convention() {
186        let mut message = Message {
187            message_id: "msg-auth-1".to_owned(),
188            context_id: Some("ctx-1".to_owned()),
189            task_id: Some("task-1".to_owned()),
190            role: Role::Agent,
191            parts: vec![Part {
192                text: Some("Authorize to continue.".to_owned()),
193                raw: None,
194                url: None,
195                data: None,
196                metadata: None,
197                filename: None,
198                media_type: None,
199            }],
200            metadata: None,
201            extensions: Vec::new(),
202            reference_task_ids: Vec::new(),
203        };
204        message
205            .set_auth_required_metadata(AuthRequiredMetadata {
206                auth_url: "https://example.com/oauth/authorize".to_owned(),
207                auth_scheme: "oauth2".to_owned(),
208                scopes: vec!["drive.readonly".to_owned()],
209                description: "Grant drive access".to_owned(),
210            })
211            .expect("metadata should set");
212
213        let task = Task {
214            id: "task-1".to_owned(),
215            context_id: Some("ctx-1".to_owned()),
216            status: TaskStatus {
217                state: TaskState::AuthRequired,
218                message: Some(message),
219                timestamp: Some("2026-03-13T12:00:00Z".to_owned()),
220            },
221            artifacts: Vec::new(),
222            history: Vec::new(),
223            metadata: None,
224        };
225
226        task.validate_auth_required_convention()
227            .expect("convention should validate");
228    }
229
230    #[test]
231    fn task_rejects_auth_required_without_metadata() {
232        let task = Task {
233            id: "task-1".to_owned(),
234            context_id: Some("ctx-1".to_owned()),
235            status: TaskStatus {
236                state: TaskState::AuthRequired,
237                message: Some(Message {
238                    message_id: "msg-auth-1".to_owned(),
239                    context_id: Some("ctx-1".to_owned()),
240                    task_id: Some("task-1".to_owned()),
241                    role: Role::Agent,
242                    parts: vec![Part {
243                        text: Some("Authorize to continue.".to_owned()),
244                        raw: None,
245                        url: None,
246                        data: None,
247                        metadata: None,
248                        filename: None,
249                        media_type: None,
250                    }],
251                    metadata: None,
252                    extensions: Vec::new(),
253                    reference_task_ids: Vec::new(),
254                }),
255                timestamp: Some("2026-03-13T12:00:00Z".to_owned()),
256            },
257            artifacts: Vec::new(),
258            history: Vec::new(),
259            metadata: None,
260        };
261
262        let error = task
263            .validate_auth_required_convention()
264            .expect_err("convention should fail");
265        assert!(error.to_string().contains("TASK_STATE_AUTH_REQUIRED"));
266    }
267}