tower_a2a/layer/
validation.rs

1//! Validation layer for A2A protocol requests and responses
2
3use std::{
4    future::Future,
5    pin::Pin,
6    task::{Context, Poll},
7};
8
9use tower_layer::Layer;
10use tower_service::Service;
11
12use crate::{
13    prelude::{MessagePart, TaskStatus},
14    protocol::{error::A2AError, operation::A2AOperation},
15    service::{A2ARequest, A2AResponse},
16};
17
18/// Layer that validates A2A protocol requests and responses
19#[derive(Clone, Debug, Default)]
20pub struct A2AValidationLayer;
21
22impl A2AValidationLayer {
23    /// Create a new validation layer
24    pub fn new() -> Self {
25        Self
26    }
27}
28
29impl<S> Layer<S> for A2AValidationLayer {
30    type Service = A2AValidationService<S>;
31
32    fn layer(&self, inner: S) -> Self::Service {
33        A2AValidationService { inner }
34    }
35}
36
37/// Validation service that wraps an inner service
38#[derive(Clone)]
39pub struct A2AValidationService<S> {
40    inner: S,
41}
42
43impl<S> A2AValidationService<S> {
44    /// Validate an A2A request
45    fn validate_request(req: &A2ARequest) -> Result<(), A2AError> {
46        match &req.operation {
47            A2AOperation::SendMessage { message, .. } => {
48                // Message must have at least one part
49                if message.parts.is_empty() {
50                    return Err(A2AError::Validation(
51                        "Message must have at least one part".into(),
52                    ));
53                }
54
55                // Validate each part (basic checks)
56                for part in &message.parts {
57                    match part {
58                        MessagePart::Text { text } => {
59                            if text.is_empty() {
60                                return Err(A2AError::Validation(
61                                    "Text part cannot be empty".into(),
62                                ));
63                            }
64                        }
65                        MessagePart::File { file } => {
66                            if file.name.is_empty() {
67                                return Err(A2AError::Validation(
68                                    "File name cannot be empty".into(),
69                                ));
70                            }
71                            if file.file_with_uri.is_none() && file.file_with_bytes.is_none() {
72                                return Err(A2AError::Validation(
73                                    "File must have either URI or bytes content".into(),
74                                ));
75                            }
76                        }
77                        MessagePart::Data { .. } => {
78                            // Data validation could be more specific
79                        }
80                    }
81                }
82            }
83            A2AOperation::GetTask { task_id } => {
84                if task_id.is_empty() {
85                    return Err(A2AError::Validation("Task ID cannot be empty".into()));
86                }
87            }
88            A2AOperation::CancelTask { task_id } => {
89                if task_id.is_empty() {
90                    return Err(A2AError::Validation("Task ID cannot be empty".into()));
91                }
92            }
93            A2AOperation::ListTasks { limit, offset, .. } => {
94                if let Some(limit_val) = limit {
95                    if *limit_val == 0 {
96                        return Err(A2AError::Validation("Limit must be greater than 0".into()));
97                    }
98                    if *limit_val > 1000 {
99                        return Err(A2AError::Validation("Limit cannot exceed 1000".into()));
100                    }
101                }
102
103                if let Some(offset_val) = offset {
104                    if *offset_val > 1000000 {
105                        return Err(A2AError::Validation("Offset is too large".into()));
106                    }
107                }
108            }
109            A2AOperation::RegisterWebhook { url, events, .. } => {
110                if url.is_empty() {
111                    return Err(A2AError::Validation("Webhook URL cannot be empty".into()));
112                }
113                if events.is_empty() {
114                    return Err(A2AError::Validation(
115                        "Webhook must subscribe to at least one event".into(),
116                    ));
117                }
118            }
119            _ => {}
120        }
121
122        // Validate agent URL
123        if req.context.agent_url.is_empty() {
124            return Err(A2AError::Validation("Agent URL cannot be empty".into()));
125        }
126
127        Ok(())
128    }
129
130    /// Validate an A2A response
131    fn validate_response(resp: &A2AResponse) -> Result<(), A2AError> {
132        match resp {
133            A2AResponse::Task(task) => {
134                if task.id.is_empty() {
135                    return Err(A2AError::Validation("Task ID cannot be empty".into()));
136                }
137
138                // Validate task has input
139                if task.input.parts.is_empty() {
140                    return Err(A2AError::Validation(
141                        "Task input must have at least one part".into(),
142                    ));
143                }
144
145                // If task is completed, it should have artifacts or error
146                if task.status == TaskStatus::Completed
147                    && task.artifacts.is_empty()
148                    && task.error.is_none()
149                {
150                    return Err(A2AError::Validation(
151                        "Completed task must have artifacts or error".into(),
152                    ));
153                }
154
155                // If task is failed, it should have an error
156                if task.status == TaskStatus::Failed && task.error.is_none() {
157                    return Err(A2AError::Validation(
158                        "Failed task must have an error".into(),
159                    ));
160                }
161            }
162            A2AResponse::AgentCard(card) => {
163                if card.name.is_empty() {
164                    return Err(A2AError::Validation("Agent name cannot be empty".into()));
165                }
166                if card.endpoints.is_empty() {
167                    return Err(A2AError::Validation(
168                        "Agent card must have at least one endpoint".into(),
169                    ));
170                }
171            }
172            _ => {}
173        }
174
175        Ok(())
176    }
177}
178
179impl<S> Service<A2ARequest> for A2AValidationService<S>
180where
181    S: Service<A2ARequest, Response = A2AResponse, Error = A2AError> + Clone + Send + 'static,
182    S::Future: Send,
183{
184    type Response = A2AResponse;
185    type Error = A2AError;
186    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
187
188    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
189        self.inner.poll_ready(cx)
190    }
191
192    fn call(&mut self, req: A2ARequest) -> Self::Future {
193        // Validate request before passing to inner service
194        if let Err(e) = Self::validate_request(&req) {
195            return Box::pin(async move { Err(e) });
196        }
197
198        let mut inner = self.inner.clone();
199        Box::pin(async move {
200            let response = inner.call(req).await?;
201
202            // Validate response
203            Self::validate_response(&response)?;
204
205            Ok(response)
206        })
207    }
208}
209
210#[cfg(test)]
211mod tests {
212    use crate::{
213        protocol::{message::Message, task::Task},
214        service::RequestContext,
215    };
216
217    use super::*;
218
219    #[test]
220    fn test_validate_send_message() {
221        let operation = A2AOperation::SendMessage {
222            message: Message::user("Hello"),
223            stream: false,
224            context_id: None,
225            task_id: None,
226        };
227
228        let context = RequestContext::new("https://example.com");
229        let request = A2ARequest::new(operation, context);
230
231        assert!(A2AValidationService::<()>::validate_request(&request).is_ok());
232    }
233
234    #[test]
235    fn test_validate_empty_message() {
236        let mut message = Message::user("Test");
237        message.parts.clear();
238
239        let operation = A2AOperation::SendMessage {
240            message,
241            stream: false,
242            context_id: None,
243            task_id: None,
244        };
245
246        let context = RequestContext::new("https://example.com");
247        let request = A2ARequest::new(operation, context);
248
249        assert!(A2AValidationService::<()>::validate_request(&request).is_err());
250    }
251
252    #[test]
253    fn test_validate_task_response() {
254        let task = Task::new("task-123", Message::user("Test"));
255        let response = A2AResponse::Task(Box::new(task));
256
257        assert!(A2AValidationService::<()>::validate_response(&response).is_ok());
258    }
259}