tower_a2a/layer/
validation.rs1use std::{
4 future::Future,
5 pin::Pin,
6 task::{Context, Poll},
7};
8
9use tower_layer::Layer;
10use tower_service::Service;
11
12use crate::{
13 prelude::{MessagePart, TaskStatus},
14 protocol::{error::A2AError, operation::A2AOperation},
15 service::{A2ARequest, A2AResponse},
16};
17
18#[derive(Clone, Debug, Default)]
20pub struct A2AValidationLayer;
21
22impl A2AValidationLayer {
23 pub fn new() -> Self {
25 Self
26 }
27}
28
29impl<S> Layer<S> for A2AValidationLayer {
30 type Service = A2AValidationService<S>;
31
32 fn layer(&self, inner: S) -> Self::Service {
33 A2AValidationService { inner }
34 }
35}
36
37#[derive(Clone)]
39pub struct A2AValidationService<S> {
40 inner: S,
41}
42
43impl<S> A2AValidationService<S> {
44 fn validate_request(req: &A2ARequest) -> Result<(), A2AError> {
46 match &req.operation {
47 A2AOperation::SendMessage { message, .. } => {
48 if message.parts.is_empty() {
50 return Err(A2AError::Validation(
51 "Message must have at least one part".into(),
52 ));
53 }
54
55 for part in &message.parts {
57 match part {
58 MessagePart::Text { text } => {
59 if text.is_empty() {
60 return Err(A2AError::Validation(
61 "Text part cannot be empty".into(),
62 ));
63 }
64 }
65 MessagePart::File { file } => {
66 if file.name.is_empty() {
67 return Err(A2AError::Validation(
68 "File name cannot be empty".into(),
69 ));
70 }
71 if file.file_with_uri.is_none() && file.file_with_bytes.is_none() {
72 return Err(A2AError::Validation(
73 "File must have either URI or bytes content".into(),
74 ));
75 }
76 }
77 MessagePart::Data { .. } => {
78 }
80 }
81 }
82 }
83 A2AOperation::GetTask { task_id } => {
84 if task_id.is_empty() {
85 return Err(A2AError::Validation("Task ID cannot be empty".into()));
86 }
87 }
88 A2AOperation::CancelTask { task_id } => {
89 if task_id.is_empty() {
90 return Err(A2AError::Validation("Task ID cannot be empty".into()));
91 }
92 }
93 A2AOperation::ListTasks { limit, offset, .. } => {
94 if let Some(limit_val) = limit {
95 if *limit_val == 0 {
96 return Err(A2AError::Validation("Limit must be greater than 0".into()));
97 }
98 if *limit_val > 1000 {
99 return Err(A2AError::Validation("Limit cannot exceed 1000".into()));
100 }
101 }
102
103 if let Some(offset_val) = offset {
104 if *offset_val > 1000000 {
105 return Err(A2AError::Validation("Offset is too large".into()));
106 }
107 }
108 }
109 A2AOperation::RegisterWebhook { url, events, .. } => {
110 if url.is_empty() {
111 return Err(A2AError::Validation("Webhook URL cannot be empty".into()));
112 }
113 if events.is_empty() {
114 return Err(A2AError::Validation(
115 "Webhook must subscribe to at least one event".into(),
116 ));
117 }
118 }
119 _ => {}
120 }
121
122 if req.context.agent_url.is_empty() {
124 return Err(A2AError::Validation("Agent URL cannot be empty".into()));
125 }
126
127 Ok(())
128 }
129
130 fn validate_response(resp: &A2AResponse) -> Result<(), A2AError> {
132 match resp {
133 A2AResponse::Task(task) => {
134 if task.id.is_empty() {
135 return Err(A2AError::Validation("Task ID cannot be empty".into()));
136 }
137
138 if task.input.parts.is_empty() {
140 return Err(A2AError::Validation(
141 "Task input must have at least one part".into(),
142 ));
143 }
144
145 if task.status == TaskStatus::Completed
147 && task.artifacts.is_empty()
148 && task.error.is_none()
149 {
150 return Err(A2AError::Validation(
151 "Completed task must have artifacts or error".into(),
152 ));
153 }
154
155 if task.status == TaskStatus::Failed && task.error.is_none() {
157 return Err(A2AError::Validation(
158 "Failed task must have an error".into(),
159 ));
160 }
161 }
162 A2AResponse::AgentCard(card) => {
163 if card.name.is_empty() {
164 return Err(A2AError::Validation("Agent name cannot be empty".into()));
165 }
166 if card.endpoints.is_empty() {
167 return Err(A2AError::Validation(
168 "Agent card must have at least one endpoint".into(),
169 ));
170 }
171 }
172 _ => {}
173 }
174
175 Ok(())
176 }
177}
178
179impl<S> Service<A2ARequest> for A2AValidationService<S>
180where
181 S: Service<A2ARequest, Response = A2AResponse, Error = A2AError> + Clone + Send + 'static,
182 S::Future: Send,
183{
184 type Response = A2AResponse;
185 type Error = A2AError;
186 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
187
188 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
189 self.inner.poll_ready(cx)
190 }
191
192 fn call(&mut self, req: A2ARequest) -> Self::Future {
193 if let Err(e) = Self::validate_request(&req) {
195 return Box::pin(async move { Err(e) });
196 }
197
198 let mut inner = self.inner.clone();
199 Box::pin(async move {
200 let response = inner.call(req).await?;
201
202 Self::validate_response(&response)?;
204
205 Ok(response)
206 })
207 }
208}
209
210#[cfg(test)]
211mod tests {
212 use crate::{
213 protocol::{message::Message, task::Task},
214 service::RequestContext,
215 };
216
217 use super::*;
218
219 #[test]
220 fn test_validate_send_message() {
221 let operation = A2AOperation::SendMessage {
222 message: Message::user("Hello"),
223 stream: false,
224 context_id: None,
225 task_id: None,
226 };
227
228 let context = RequestContext::new("https://example.com");
229 let request = A2ARequest::new(operation, context);
230
231 assert!(A2AValidationService::<()>::validate_request(&request).is_ok());
232 }
233
234 #[test]
235 fn test_validate_empty_message() {
236 let mut message = Message::user("Test");
237 message.parts.clear();
238
239 let operation = A2AOperation::SendMessage {
240 message,
241 stream: false,
242 context_id: None,
243 task_id: None,
244 };
245
246 let context = RequestContext::new("https://example.com");
247 let request = A2ARequest::new(operation, context);
248
249 assert!(A2AValidationService::<()>::validate_request(&request).is_err());
250 }
251
252 #[test]
253 fn test_validate_task_response() {
254 let task = Task::new("task-123", Message::user("Test"));
255 let response = A2AResponse::Task(Box::new(task));
256
257 assert!(A2AValidationService::<()>::validate_response(&response).is_ok());
258 }
259}