Skip to main content

adk_server/a2a/
executor.rs

1use crate::a2a::{
2    Message, TaskState, TaskStatus, TaskStatusUpdateEvent, UpdateEvent, events::message_to_event,
3    metadata::to_invocation_meta, processor::EventProcessor,
4};
5use adk_core::{Result, SessionId, UserId};
6use adk_runner::{Runner, RunnerConfig};
7use adk_session::{CreateRequest, GetRequest};
8use futures::StreamExt;
9use std::sync::Arc;
10use tokio_util::sync::CancellationToken;
11
12#[cfg(feature = "a2a-interceptors")]
13use crate::a2a::interceptor::{A2aDelegationContext, InterceptorChain, InterceptorDecision};
14
15pub struct ExecutorConfig {
16    pub app_name: String,
17    pub runner_config: Arc<RunnerConfig>,
18    pub cancellation_token: Option<CancellationToken>,
19    /// Optional interceptor chain for A2A request/response middleware.
20    ///
21    /// When set, the chain's `run_before` is called before processing a request,
22    /// and `run_after` is called after the executor produces a response.
23    #[cfg(feature = "a2a-interceptors")]
24    pub interceptor_chain: Option<Arc<InterceptorChain>>,
25}
26
27pub struct Executor {
28    config: ExecutorConfig,
29}
30
31impl Executor {
32    pub fn new(config: ExecutorConfig) -> Self {
33        Self { config }
34    }
35
36    pub async fn execute(
37        &self,
38        context_id: &str,
39        task_id: &str,
40        message: &Message,
41    ) -> Result<Vec<UpdateEvent>> {
42        // --- Interceptor: before delegation ---
43        #[cfg(feature = "a2a-interceptors")]
44        let interceptor_ctx = {
45            if let Some(chain) = &self.config.interceptor_chain {
46                let params = serde_json::to_value(message).unwrap_or(serde_json::Value::Null);
47
48                let metadata_map = message
49                    .metadata
50                    .as_ref()
51                    .map(|m| {
52                        m.iter()
53                            .filter_map(|(k, v)| v.as_str().map(|s| (k.clone(), s.to_string())))
54                            .collect()
55                    })
56                    .unwrap_or_default();
57
58                let mut ctx = A2aDelegationContext {
59                    method: "message/send".to_string(),
60                    params,
61                    caller_id: message
62                        .metadata
63                        .as_ref()
64                        .and_then(|m| m.get("caller_id"))
65                        .and_then(|v| v.as_str())
66                        .map(String::from),
67                    metadata: metadata_map,
68                };
69
70                let decision = chain
71                    .run_before(&mut ctx)
72                    .await
73                    .map_err(|e| adk_core::AdkError::agent(e.to_string()))?;
74
75                match decision {
76                    InterceptorDecision::Continue => {}
77                    InterceptorDecision::ShortCircuit(response) => {
78                        // Return the short-circuited response as a completed task
79                        let results = vec![UpdateEvent::TaskStatusUpdate(TaskStatusUpdateEvent {
80                            task_id: task_id.to_string(),
81                            context_id: Some(context_id.to_string()),
82                            status: TaskStatus {
83                                state: TaskState::Completed,
84                                message: response.as_str().map(String::from),
85                            },
86                            final_update: true,
87                        })];
88                        return Ok(results);
89                    }
90                    InterceptorDecision::Reject { code, message: msg } => {
91                        return Err(adk_core::AdkError::agent(format!(
92                            "A2A request rejected (code {code}): {msg}"
93                        )));
94                    }
95                }
96
97                Some(ctx)
98            } else {
99                None
100            }
101        };
102
103        let meta = to_invocation_meta(&self.config.app_name, context_id, None);
104        let cancellation_token = self.config.cancellation_token.clone();
105
106        // Prepare session
107        self.prepare_session(&meta.user_id, &meta.session_id).await?;
108
109        // Convert message to event
110        let invocation_id = uuid::Uuid::new_v4().to_string();
111        let event = message_to_event(message, invocation_id)?;
112
113        // Create runner
114        let mut runner_builder = Runner::builder()
115            .app_name(self.config.runner_config.app_name.clone())
116            .agent(self.config.runner_config.agent.clone())
117            .session_service(self.config.runner_config.session_service.clone());
118        if let Some(ref artifact_service) = self.config.runner_config.artifact_service {
119            runner_builder = runner_builder.artifact_service(artifact_service.clone());
120        }
121        if let Some(ref memory_service) = self.config.runner_config.memory_service {
122            runner_builder = runner_builder.memory_service(memory_service.clone());
123        }
124        if let Some(ref plugin_manager) = self.config.runner_config.plugin_manager {
125            runner_builder = runner_builder.plugin_manager(plugin_manager.clone());
126        }
127        if let Some(ref run_config) = self.config.runner_config.run_config {
128            runner_builder = runner_builder.run_config(run_config.clone());
129        }
130        if let Some(ref compaction_config) = self.config.runner_config.compaction_config {
131            runner_builder = runner_builder.compaction_config(compaction_config.clone());
132        }
133        if let Some(ref context_cache_config) = self.config.runner_config.context_cache_config {
134            runner_builder = runner_builder.context_cache_config(context_cache_config.clone());
135        }
136        if let Some(ref cache_capable) = self.config.runner_config.cache_capable {
137            runner_builder = runner_builder.cache_capable(cache_capable.clone());
138        }
139        if let Some(ref request_context) = self.config.runner_config.request_context {
140            runner_builder = runner_builder.request_context(request_context.clone());
141        }
142        if let Some(cancellation_token) = cancellation_token.clone() {
143            runner_builder = runner_builder.cancellation_token(cancellation_token);
144        }
145        let runner = runner_builder.build()?;
146
147        // Create processor
148        let mut processor =
149            EventProcessor::new(context_id.to_string(), task_id.to_string(), meta.clone());
150
151        let mut results = vec![];
152
153        // Send submitted event
154        results.push(UpdateEvent::TaskStatusUpdate(TaskStatusUpdateEvent {
155            task_id: task_id.to_string(),
156            context_id: Some(context_id.to_string()),
157            status: TaskStatus { state: TaskState::Submitted, message: None },
158            final_update: false,
159        }));
160
161        // Send working event
162        results.push(UpdateEvent::TaskStatusUpdate(TaskStatusUpdateEvent {
163            task_id: task_id.to_string(),
164            context_id: Some(context_id.to_string()),
165            status: TaskStatus { state: TaskState::Working, message: None },
166            final_update: false,
167        }));
168
169        // Run agent
170        let content = event
171            .llm_response
172            .content
173            .ok_or_else(|| adk_core::AdkError::agent("Event has no content"))?;
174
175        let mut event_stream = runner
176            .run(
177                UserId::new(meta.user_id.clone())?,
178                SessionId::new(meta.session_id.clone())?,
179                content,
180            )
181            .await?;
182
183        // Process events
184        while let Some(result) = event_stream.next().await {
185            if cancellation_token.as_ref().is_some_and(CancellationToken::is_cancelled) {
186                results.push(UpdateEvent::TaskStatusUpdate(TaskStatusUpdateEvent {
187                    task_id: task_id.to_string(),
188                    context_id: Some(context_id.to_string()),
189                    status: TaskStatus { state: TaskState::Canceled, message: None },
190                    final_update: true,
191                }));
192                return Ok(results);
193            }
194
195            match result {
196                Ok(adk_event) => {
197                    if let Some(artifact_event) = processor.process(&adk_event)? {
198                        results.push(UpdateEvent::TaskArtifactUpdate(artifact_event));
199                    }
200                }
201                Err(e) => {
202                    // Send failed event
203                    results.push(UpdateEvent::TaskStatusUpdate(TaskStatusUpdateEvent {
204                        task_id: task_id.to_string(),
205                        context_id: Some(context_id.to_string()),
206                        status: TaskStatus {
207                            state: TaskState::Failed,
208                            message: Some(e.to_string()),
209                        },
210                        final_update: true,
211                    }));
212                    return Ok(results);
213                }
214            }
215        }
216
217        if cancellation_token.as_ref().is_some_and(CancellationToken::is_cancelled) {
218            results.push(UpdateEvent::TaskStatusUpdate(TaskStatusUpdateEvent {
219                task_id: task_id.to_string(),
220                context_id: Some(context_id.to_string()),
221                status: TaskStatus { state: TaskState::Canceled, message: None },
222                final_update: true,
223            }));
224            return Ok(results);
225        }
226
227        // Send terminal events
228        for terminal_event in processor.make_terminal_events() {
229            results.push(UpdateEvent::TaskStatusUpdate(terminal_event));
230        }
231
232        // --- Interceptor: after delegation ---
233        #[cfg(feature = "a2a-interceptors")]
234        if let Some(chain) = &self.config.interceptor_chain {
235            if let Some(ctx) = &interceptor_ctx {
236                let mut response_value =
237                    serde_json::to_value(&results).unwrap_or(serde_json::Value::Null);
238                chain
239                    .run_after(ctx, &mut response_value)
240                    .await
241                    .map_err(|e| adk_core::AdkError::agent(e.to_string()))?;
242            }
243        }
244
245        Ok(results)
246    }
247
248    pub async fn cancel(&self, context_id: &str, task_id: &str) -> Result<TaskStatusUpdateEvent> {
249        Ok(TaskStatusUpdateEvent {
250            task_id: task_id.to_string(),
251            context_id: Some(context_id.to_string()),
252            status: TaskStatus { state: TaskState::Canceled, message: None },
253            final_update: true,
254        })
255    }
256
257    async fn prepare_session(&self, user_id: &str, session_id: &str) -> Result<()> {
258        let session_service = &self.config.runner_config.session_service;
259
260        // Try to get existing session
261        let get_result = session_service
262            .get(GetRequest {
263                app_name: self.config.app_name.clone(),
264                user_id: user_id.to_string(),
265                session_id: session_id.to_string(),
266                num_recent_events: None,
267                after: None,
268            })
269            .await;
270
271        if get_result.is_ok() {
272            return Ok(());
273        }
274
275        // Create new session
276        session_service
277            .create(CreateRequest {
278                app_name: self.config.app_name.clone(),
279                user_id: user_id.to_string(),
280                session_id: Some(session_id.to_string()),
281                state: std::collections::HashMap::new(),
282            })
283            .await?;
284
285        Ok(())
286    }
287}