a2a_protocol_server/dispatch/
cors.rs1use std::convert::Infallible;
14
15use bytes::Bytes;
16use http_body_util::combinators::BoxBody;
17use http_body_util::{BodyExt, Full};
18
19#[derive(Debug, Clone)]
33pub struct CorsConfig {
34 pub allow_origin: String,
36 pub allow_methods: String,
38 pub allow_headers: String,
40 pub max_age_secs: u32,
42}
43
44impl CorsConfig {
45 #[must_use]
47 pub fn new(allow_origin: impl Into<String>) -> Self {
48 Self {
49 allow_origin: allow_origin.into(),
50 allow_methods: "GET, POST, PUT, DELETE, OPTIONS".into(),
51 allow_headers: "content-type, authorization, a2a-notification-token".into(),
52 max_age_secs: 86400,
53 }
54 }
55
56 #[must_use]
61 pub fn permissive() -> Self {
62 Self::new("*")
63 }
64
65 pub fn apply_headers<B>(&self, resp: &mut hyper::Response<B>) {
67 let headers = resp.headers_mut();
68 if let Ok(v) = self.allow_origin.parse() {
71 headers.insert("access-control-allow-origin", v);
72 }
73 if let Ok(v) = self.allow_methods.parse() {
74 headers.insert("access-control-allow-methods", v);
75 }
76 if let Ok(v) = self.allow_headers.parse() {
77 headers.insert("access-control-allow-headers", v);
78 }
79 if let Ok(v) = self.max_age_secs.to_string().parse() {
80 headers.insert("access-control-max-age", v);
81 }
82 }
83
84 #[must_use]
86 pub fn preflight_response(&self) -> hyper::Response<BoxBody<Bytes, Infallible>> {
87 let mut resp = hyper::Response::builder()
88 .status(204)
89 .body(Full::new(Bytes::new()).boxed())
90 .unwrap_or_else(|_| hyper::Response::new(Full::new(Bytes::new()).boxed()));
91 self.apply_headers(&mut resp);
92 resp
93 }
94}
95
96#[cfg(test)]
97mod tests {
98 use super::*;
99
100 #[test]
101 fn new_sets_origin_and_defaults() {
102 let cors = CorsConfig::new("https://example.com");
103
104 assert_eq!(cors.allow_origin, "https://example.com");
105 assert_eq!(
106 cors.allow_methods, "GET, POST, PUT, DELETE, OPTIONS",
107 "default methods should include common HTTP verbs"
108 );
109 assert_eq!(
110 cors.allow_headers,
111 "content-type, authorization, a2a-notification-token",
112 "default headers should include content-type, authorization, and a2a-notification-token"
113 );
114 assert_eq!(
115 cors.max_age_secs, 86400,
116 "default max-age should be 24 hours"
117 );
118 }
119
120 #[test]
121 fn new_accepts_string_and_str() {
122 let from_str = CorsConfig::new("https://a.com");
123 let from_string = CorsConfig::new(String::from("https://b.com"));
124
125 assert_eq!(from_str.allow_origin, "https://a.com");
126 assert_eq!(from_string.allow_origin, "https://b.com");
127 }
128
129 #[test]
130 fn permissive_allows_all_origins() {
131 let cors = CorsConfig::permissive();
132 assert_eq!(
133 cors.allow_origin, "*",
134 "permissive config should use wildcard origin"
135 );
136 }
137
138 #[test]
139 fn apply_headers_sets_all_cors_headers() {
140 let cors = CorsConfig::new("https://app.example.com");
141 let mut resp = hyper::Response::new(Full::new(Bytes::new()).boxed());
142 cors.apply_headers(&mut resp);
143
144 let headers = resp.headers();
145 assert_eq!(
146 headers.get("access-control-allow-origin").unwrap(),
147 "https://app.example.com"
148 );
149 assert_eq!(
150 headers.get("access-control-allow-methods").unwrap(),
151 "GET, POST, PUT, DELETE, OPTIONS"
152 );
153 assert_eq!(
154 headers.get("access-control-allow-headers").unwrap(),
155 "content-type, authorization, a2a-notification-token"
156 );
157 assert_eq!(headers.get("access-control-max-age").unwrap(), "86400");
158 }
159
160 #[test]
161 fn apply_headers_with_custom_config() {
162 let mut cors = CorsConfig::new("https://custom.dev");
163 cors.allow_methods = "POST, OPTIONS".into();
164 cors.allow_headers = "content-type".into();
165 cors.max_age_secs = 3600;
166
167 let mut resp = hyper::Response::new(Full::new(Bytes::new()).boxed());
168 cors.apply_headers(&mut resp);
169
170 let headers = resp.headers();
171 assert_eq!(
172 headers.get("access-control-allow-origin").unwrap(),
173 "https://custom.dev"
174 );
175 assert_eq!(
176 headers.get("access-control-allow-methods").unwrap(),
177 "POST, OPTIONS",
178 "custom methods should be applied"
179 );
180 assert_eq!(
181 headers.get("access-control-allow-headers").unwrap(),
182 "content-type",
183 "custom headers should be applied"
184 );
185 assert_eq!(
186 headers.get("access-control-max-age").unwrap(),
187 "3600",
188 "custom max-age should be applied"
189 );
190 }
191
192 #[test]
193 fn apply_headers_overwrites_existing_cors_headers() {
194 let cors = CorsConfig::new("https://second.com");
195 let mut resp = hyper::Response::builder()
196 .header("access-control-allow-origin", "https://first.com")
197 .body(Full::new(Bytes::new()).boxed())
198 .unwrap();
199
200 cors.apply_headers(&mut resp);
201
202 assert_eq!(
203 resp.headers().get("access-control-allow-origin").unwrap(),
204 "https://second.com",
205 "apply_headers should overwrite pre-existing CORS headers"
206 );
207 }
208
209 #[test]
210 fn preflight_response_returns_204_no_content() {
211 let cors = CorsConfig::permissive();
212 let resp = cors.preflight_response();
213
214 assert_eq!(
215 resp.status().as_u16(),
216 204,
217 "preflight response should have 204 No Content status"
218 );
219 }
220
221 #[test]
222 fn preflight_response_includes_cors_headers() {
223 let cors = CorsConfig::new("https://preflight.test");
224 let resp = cors.preflight_response();
225
226 let headers = resp.headers();
227 assert_eq!(
228 headers.get("access-control-allow-origin").unwrap(),
229 "https://preflight.test"
230 );
231 assert!(
232 headers.get("access-control-allow-methods").is_some(),
233 "preflight response must include allow-methods header"
234 );
235 assert!(
236 headers.get("access-control-allow-headers").is_some(),
237 "preflight response must include allow-headers header"
238 );
239 assert!(
240 headers.get("access-control-max-age").is_some(),
241 "preflight response must include max-age header"
242 );
243 }
244
245 #[test]
246 fn config_is_cloneable() {
247 let cors = CorsConfig::new("https://clone.test");
248 let cloned = cors.clone();
249 assert_eq!(cors.allow_origin, cloned.allow_origin);
250 assert_eq!(cors.allow_methods, cloned.allow_methods);
251 assert_eq!(cors.allow_headers, cloned.allow_headers);
252 assert_eq!(cors.max_age_secs, cloned.max_age_secs);
253 }
254
255 #[test]
256 fn max_age_zero_is_valid() {
257 let mut cors = CorsConfig::permissive();
258 cors.max_age_secs = 0;
259
260 let mut resp = hyper::Response::new(Full::new(Bytes::new()).boxed());
261 cors.apply_headers(&mut resp);
262
263 assert_eq!(
264 resp.headers().get("access-control-max-age").unwrap(),
265 "0",
266 "max-age of 0 should be set correctly"
267 );
268 }
269}