Skip to main content

a2a_rs/adapter/business/
request_processor.rs

1//! A default request processor implementation
2
3// This module is already conditionally compiled with #[cfg(feature = "server")] in mod.rs
4
5use std::sync::Arc;
6
7use async_trait::async_trait;
8
9use crate::{
10    application::{
11        JSONRPCError, JSONRPCResponse,
12        json_rpc::{
13            self, A2ARequest, CancelTaskRequest, GetExtendedCardRequest,
14            GetTaskPushNotificationRequest, GetTaskRequest, SendTaskRequest,
15            SendTaskStreamingRequest, SetTaskPushNotificationRequest, TaskResubscriptionRequest,
16        },
17    },
18    domain::A2AError,
19    port::{AsyncMessageHandler, AsyncNotificationManager, AsyncTaskManager},
20    services::server::{AgentInfoProvider, AsyncA2ARequestProcessor},
21};
22
23/// Default implementation of a request processor that routes requests to business handlers
24#[derive(Clone)]
25pub struct DefaultRequestProcessor<M, T, N, A = crate::adapter::SimpleAgentInfo>
26where
27    M: AsyncMessageHandler + Send + Sync + 'static,
28    T: AsyncTaskManager + Send + Sync + 'static,
29    N: AsyncNotificationManager + Send + Sync + 'static,
30    A: AgentInfoProvider + Send + Sync + 'static,
31{
32    /// Message handler
33    message_handler: Arc<M>,
34    /// Task manager
35    task_manager: Arc<T>,
36    /// Notification manager
37    notification_manager: Arc<N>,
38    /// Agent info provider
39    agent_info: Arc<A>,
40}
41
42impl<M, T, N, A> DefaultRequestProcessor<M, T, N, A>
43where
44    M: AsyncMessageHandler + Send + Sync + 'static,
45    T: AsyncTaskManager + Send + Sync + 'static,
46    N: AsyncNotificationManager + Send + Sync + 'static,
47    A: AgentInfoProvider + Send + Sync + 'static,
48{
49    /// Create a new request processor with the given handlers
50    pub fn new(
51        message_handler: M,
52        task_manager: T,
53        notification_manager: N,
54        agent_info: A,
55    ) -> Self {
56        Self {
57            message_handler: Arc::new(message_handler),
58            task_manager: Arc::new(task_manager),
59            notification_manager: Arc::new(notification_manager),
60            agent_info: Arc::new(agent_info),
61        }
62    }
63}
64
65impl<H, A> DefaultRequestProcessor<H, H, H, A>
66where
67    H: AsyncMessageHandler + AsyncTaskManager + AsyncNotificationManager + Send + Sync + 'static,
68    A: AgentInfoProvider + Send + Sync + 'static,
69{
70    /// Create a new request processor with a single handler that implements all traits
71    pub fn with_handler(handler: H, agent_info: A) -> Self {
72        let handler_arc = Arc::new(handler);
73        Self {
74            message_handler: handler_arc.clone(),
75            task_manager: handler_arc.clone(),
76            notification_manager: handler_arc,
77            agent_info: Arc::new(agent_info),
78        }
79    }
80}
81
82impl<M, T, N, A> DefaultRequestProcessor<M, T, N, A>
83where
84    M: AsyncMessageHandler + Send + Sync + 'static,
85    T: AsyncTaskManager + Send + Sync + 'static,
86    N: AsyncNotificationManager + Send + Sync + 'static,
87    A: AgentInfoProvider + Send + Sync + 'static,
88{
89    /// Process a send task request
90    async fn process_send_task(
91        &self,
92        request: &SendTaskRequest,
93    ) -> Result<JSONRPCResponse, A2AError> {
94        let params = &request.params;
95        let session_id = params.session_id.as_deref();
96
97        tracing::info!(
98            task_id = %params.id,
99            message_id = %params.message.message_id,
100            "🔄 DefaultRequestProcessor: About to call message_handler.process_message"
101        );
102
103        // Process the message through the handler
104        // The handler is responsible for managing history
105        let task = self
106            .message_handler
107            .process_message(&params.id, &params.message, session_id)
108            .await?;
109
110        tracing::info!(
111            task_id = %params.id,
112            "✅ DefaultRequestProcessor: Message handler returned successfully"
113        );
114
115        Ok(JSONRPCResponse::success(
116            request.id.clone(),
117            serde_json::to_value(task)?,
118        ))
119    }
120
121    /// Process a get task request
122    async fn process_get_task(
123        &self,
124        request: &GetTaskRequest,
125    ) -> Result<JSONRPCResponse, A2AError> {
126        let params = &request.params;
127        let task = self
128            .task_manager
129            .get_task(&params.id, params.history_length)
130            .await?;
131
132        Ok(JSONRPCResponse::success(
133            request.id.clone(),
134            serde_json::to_value(task)?,
135        ))
136    }
137
138    /// Process a cancel task request
139    async fn process_cancel_task(
140        &self,
141        request: &CancelTaskRequest,
142    ) -> Result<JSONRPCResponse, A2AError> {
143        let params = &request.params;
144        let task = self.task_manager.cancel_task(&params.id).await?;
145
146        Ok(JSONRPCResponse::success(
147            request.id.clone(),
148            serde_json::to_value(task)?,
149        ))
150    }
151
152    /// Process a set task push notification request
153    async fn process_set_push_notification(
154        &self,
155        request: &SetTaskPushNotificationRequest,
156    ) -> Result<JSONRPCResponse, A2AError> {
157        let config = self
158            .notification_manager
159            .set_task_notification(&request.params)
160            .await?;
161
162        Ok(JSONRPCResponse::success(
163            request.id.clone(),
164            serde_json::to_value(config)?,
165        ))
166    }
167
168    /// Process a get task push notification request
169    async fn process_get_push_notification(
170        &self,
171        request: &GetTaskPushNotificationRequest,
172    ) -> Result<JSONRPCResponse, A2AError> {
173        let params = &request.params;
174        let config = self
175            .notification_manager
176            .get_task_notification(&params.id)
177            .await?;
178
179        Ok(JSONRPCResponse::success(
180            request.id.clone(),
181            serde_json::to_value(config)?,
182        ))
183    }
184
185    /// Process a task resubscription request
186    async fn process_task_resubscription(
187        &self,
188        request: &TaskResubscriptionRequest,
189    ) -> Result<JSONRPCResponse, A2AError> {
190        // For resubscription, we return an initial success response,
191        // and then the streaming updates are handled separately
192        let params = &request.params;
193
194        // Try to get the task, but don't fail if it doesn't exist
195        // This allows clients to subscribe to tasks before they're created
196        match self
197            .task_manager
198            .get_task(&params.id, params.history_length)
199            .await
200        {
201            Ok(task) => {
202                // Task exists, return it
203                Ok(JSONRPCResponse::success(
204                    request.id.clone(),
205                    serde_json::to_value(task)?,
206                ))
207            }
208            Err(A2AError::TaskNotFound(_)) => {
209                // Task doesn't exist yet, return null result
210                // The WebSocket server will still set up subscriptions
211                // and send updates when the task is created
212                Ok(JSONRPCResponse::success(
213                    request.id.clone(),
214                    serde_json::Value::Null,
215                ))
216            }
217            Err(e) => {
218                // Other errors should still be propagated
219                Err(e)
220            }
221        }
222    }
223
224    /// Process a send task streaming request
225    async fn process_send_task_streaming(
226        &self,
227        request: &SendTaskStreamingRequest,
228    ) -> Result<JSONRPCResponse, A2AError> {
229        // For streaming, we process the message and return an initial success response,
230        // and then the streaming updates are handled separately
231        let params = &request.params;
232        let session_id = params.session_id.as_deref();
233
234        // Process the message through the handler
235        // The handler is responsible for managing history
236        let task = self
237            .message_handler
238            .process_message(&params.id, &params.message, session_id)
239            .await?;
240
241        Ok(JSONRPCResponse::success(
242            request.id.clone(),
243            serde_json::to_value(task)?,
244        ))
245    }
246
247    /// Process a get extended card request (v0.3.0)
248    async fn process_get_extended_card(
249        &self,
250        request: &GetExtendedCardRequest,
251    ) -> Result<JSONRPCResponse, A2AError> {
252        // Get the agent card from the agent info provider
253        // For v0.3.0, this method should return extended information
254        // that may only be available to authenticated clients.
255        // Authentication checking should be handled by middleware.
256        let card = self.agent_info.get_agent_card().await?;
257
258        Ok(JSONRPCResponse::success(
259            request.id.clone(),
260            serde_json::to_value(card)?,
261        ))
262    }
263
264    // ===== v0.3.0 New Methods =====
265
266    async fn process_list_tasks(
267        &self,
268        request: &crate::application::handlers::task::ListTasksRequest,
269    ) -> Result<JSONRPCResponse, A2AError> {
270        let default_params = crate::domain::ListTasksParams::default();
271        let params = request.params.as_ref().unwrap_or(&default_params);
272        let result = self.task_manager.list_tasks_v3(params).await?;
273
274        Ok(JSONRPCResponse::success(
275            request.id.clone(),
276            serde_json::to_value(result)?,
277        ))
278    }
279
280    async fn process_get_push_notification_config(
281        &self,
282        request: &crate::application::handlers::task::GetTaskPushNotificationConfigRequest,
283    ) -> Result<JSONRPCResponse, A2AError> {
284        if let Some(ref params) = request.params {
285            let result = self
286                .task_manager
287                .get_push_notification_config(params)
288                .await?;
289
290            Ok(JSONRPCResponse::success(
291                request.id.clone(),
292                serde_json::to_value(result)?,
293            ))
294        } else {
295            Err(A2AError::InvalidParams(
296                "Missing params for get push notification config".to_string(),
297            ))
298        }
299    }
300
301    async fn process_list_push_notification_configs(
302        &self,
303        request: &crate::application::handlers::task::ListTaskPushNotificationConfigRequest,
304    ) -> Result<JSONRPCResponse, A2AError> {
305        let result = self
306            .task_manager
307            .list_push_notification_configs(&request.params)
308            .await?;
309
310        Ok(JSONRPCResponse::success(
311            request.id.clone(),
312            serde_json::to_value(result)?,
313        ))
314    }
315
316    async fn process_delete_push_notification_config(
317        &self,
318        request: &crate::application::handlers::task::DeleteTaskPushNotificationConfigRequest,
319    ) -> Result<JSONRPCResponse, A2AError> {
320        self.task_manager
321            .delete_push_notification_config(&request.params)
322            .await?;
323
324        // Return null on success
325        Ok(JSONRPCResponse::success(
326            request.id.clone(),
327            serde_json::Value::Null,
328        ))
329    }
330
331    async fn process_get_authenticated_extended_card(
332        &self,
333        request: &crate::application::handlers::agent::GetAuthenticatedExtendedCardRequest,
334    ) -> Result<JSONRPCResponse, A2AError> {
335        // Get the authenticated extended card from the agent info provider
336        // Authentication checking should be handled by middleware before this point
337        let card = self.agent_info.get_authenticated_extended_card().await?;
338
339        Ok(JSONRPCResponse::success(
340            request.id.clone(),
341            serde_json::to_value(card)?,
342        ))
343    }
344}
345
346#[async_trait]
347impl<M, T, N, A> AsyncA2ARequestProcessor for DefaultRequestProcessor<M, T, N, A>
348where
349    M: AsyncMessageHandler + Send + Sync + 'static,
350    T: AsyncTaskManager + Send + Sync + 'static,
351    N: AsyncNotificationManager + Send + Sync + 'static,
352    A: AgentInfoProvider + Send + Sync + 'static,
353{
354    async fn process_raw_request(&self, request: &str) -> Result<String, A2AError> {
355        // Parse the request
356        let request = match json_rpc::parse_request(request) {
357            Ok(req) => req,
358            Err(e) => {
359                // Return a JSON-RPC error response
360                let error = JSONRPCError::from(e);
361                let response = JSONRPCResponse::error(None, error);
362                return Ok(serde_json::to_string(&response)?);
363            }
364        };
365
366        // Process the request
367        let response = match self.process_request(&request).await {
368            Ok(resp) => resp,
369            Err(e) => {
370                // Return a JSON-RPC error response
371                let error = JSONRPCError::from(e);
372                let response = JSONRPCResponse::error(request.id().cloned(), error);
373                return Ok(serde_json::to_string(&response)?);
374            }
375        };
376
377        // Serialize the response
378        Ok(serde_json::to_string(&response)?)
379    }
380
381    async fn process_request(&self, request: &A2ARequest) -> Result<JSONRPCResponse, A2AError> {
382        match request {
383            A2ARequest::SendTask(req) => self.process_send_task(req).await,
384            A2ARequest::SendMessage(_req) => {
385                // Convert MessageSendParams to TaskSendParams for backwards compatibility
386                // TODO: Implement proper message handling
387                Err(A2AError::UnsupportedOperation(
388                    "Message sending not yet implemented".to_string(),
389                ))
390            }
391            A2ARequest::GetTask(req) => self.process_get_task(req).await,
392            A2ARequest::CancelTask(req) => self.process_cancel_task(req).await,
393            A2ARequest::SetTaskPushNotification(req) => {
394                self.process_set_push_notification(req).await
395            }
396            A2ARequest::GetTaskPushNotification(req) => {
397                self.process_get_push_notification(req).await
398            }
399            A2ARequest::TaskResubscription(req) => self.process_task_resubscription(req).await,
400            A2ARequest::SendTaskStreaming(req) => self.process_send_task_streaming(req).await,
401            A2ARequest::SendMessageStreaming(_req) => {
402                // Convert MessageSendParams to TaskSendParams for backwards compatibility
403                // TODO: Implement proper message streaming
404                Err(A2AError::UnsupportedOperation(
405                    "Message streaming not yet implemented".to_string(),
406                ))
407            }
408            A2ARequest::GetExtendedCard(req) => self.process_get_extended_card(req).await,
409            // v0.3.0 new methods
410            A2ARequest::ListTasks(req) => self.process_list_tasks(req).await,
411            A2ARequest::GetTaskPushNotificationConfig(req) => {
412                self.process_get_push_notification_config(req).await
413            }
414            A2ARequest::ListTaskPushNotificationConfigs(req) => {
415                self.process_list_push_notification_configs(req).await
416            }
417            A2ARequest::DeleteTaskPushNotificationConfig(req) => {
418                self.process_delete_push_notification_config(req).await
419            }
420            A2ARequest::GetAuthenticatedExtendedCard(req) => {
421                self.process_get_authenticated_extended_card(req).await
422            }
423            A2ARequest::Generic(req) => {
424                // Handle unknown method
425                Err(A2AError::MethodNotFound(format!(
426                    "Method '{}' not found",
427                    req.method
428                )))
429            }
430        }
431    }
432}