1use crate::a2a::{
2 Message, TaskState, TaskStatus, TaskStatusUpdateEvent, UpdateEvent, events::message_to_event,
3 metadata::to_invocation_meta, processor::EventProcessor,
4};
5use adk_core::{Result, SessionId, UserId};
6use adk_runner::{Runner, RunnerConfig};
7use adk_session::{CreateRequest, GetRequest};
8use futures::StreamExt;
9use std::sync::Arc;
10use tokio_util::sync::CancellationToken;
11
12#[cfg(feature = "a2a-interceptors")]
13use crate::a2a::interceptor::{A2aDelegationContext, InterceptorChain, InterceptorDecision};
14
15pub struct ExecutorConfig {
16 pub app_name: String,
17 pub runner_config: Arc<RunnerConfig>,
18 pub cancellation_token: Option<CancellationToken>,
19 #[cfg(feature = "a2a-interceptors")]
24 pub interceptor_chain: Option<Arc<InterceptorChain>>,
25}
26
27pub struct Executor {
28 config: ExecutorConfig,
29}
30
31impl Executor {
32 pub fn new(config: ExecutorConfig) -> Self {
33 Self { config }
34 }
35
36 pub async fn execute(
37 &self,
38 context_id: &str,
39 task_id: &str,
40 message: &Message,
41 ) -> Result<Vec<UpdateEvent>> {
42 #[cfg(feature = "a2a-interceptors")]
44 let interceptor_ctx = {
45 if let Some(chain) = &self.config.interceptor_chain {
46 let params = serde_json::to_value(message).unwrap_or(serde_json::Value::Null);
47
48 let metadata_map = message
49 .metadata
50 .as_ref()
51 .map(|m| {
52 m.iter()
53 .filter_map(|(k, v)| v.as_str().map(|s| (k.clone(), s.to_string())))
54 .collect()
55 })
56 .unwrap_or_default();
57
58 let mut ctx = A2aDelegationContext {
59 method: "message/send".to_string(),
60 params,
61 caller_id: message
62 .metadata
63 .as_ref()
64 .and_then(|m| m.get("caller_id"))
65 .and_then(|v| v.as_str())
66 .map(String::from),
67 metadata: metadata_map,
68 };
69
70 let decision = chain
71 .run_before(&mut ctx)
72 .await
73 .map_err(|e| adk_core::AdkError::agent(e.to_string()))?;
74
75 match decision {
76 InterceptorDecision::Continue => {}
77 InterceptorDecision::ShortCircuit(response) => {
78 let results = vec![UpdateEvent::TaskStatusUpdate(TaskStatusUpdateEvent {
80 task_id: task_id.to_string(),
81 context_id: Some(context_id.to_string()),
82 status: TaskStatus {
83 state: TaskState::Completed,
84 message: response.as_str().map(String::from),
85 },
86 final_update: true,
87 })];
88 return Ok(results);
89 }
90 InterceptorDecision::Reject { code, message: msg } => {
91 return Err(adk_core::AdkError::agent(format!(
92 "A2A request rejected (code {code}): {msg}"
93 )));
94 }
95 }
96
97 Some(ctx)
98 } else {
99 None
100 }
101 };
102
103 let meta = to_invocation_meta(&self.config.app_name, context_id, None);
104 let cancellation_token = self.config.cancellation_token.clone();
105
106 self.prepare_session(&meta.user_id, &meta.session_id).await?;
108
109 let invocation_id = uuid::Uuid::new_v4().to_string();
111 let event = message_to_event(message, invocation_id)?;
112
113 let mut runner_builder = Runner::builder()
115 .app_name(self.config.runner_config.app_name.clone())
116 .agent(self.config.runner_config.agent.clone())
117 .session_service(self.config.runner_config.session_service.clone());
118 if let Some(ref artifact_service) = self.config.runner_config.artifact_service {
119 runner_builder = runner_builder.artifact_service(artifact_service.clone());
120 }
121 if let Some(ref memory_service) = self.config.runner_config.memory_service {
122 runner_builder = runner_builder.memory_service(memory_service.clone());
123 }
124 if let Some(ref plugin_manager) = self.config.runner_config.plugin_manager {
125 runner_builder = runner_builder.plugin_manager(plugin_manager.clone());
126 }
127 if let Some(ref run_config) = self.config.runner_config.run_config {
128 runner_builder = runner_builder.run_config(run_config.clone());
129 }
130 if let Some(ref compaction_config) = self.config.runner_config.compaction_config {
131 runner_builder = runner_builder.compaction_config(compaction_config.clone());
132 }
133 if let Some(ref context_cache_config) = self.config.runner_config.context_cache_config {
134 runner_builder = runner_builder.context_cache_config(context_cache_config.clone());
135 }
136 if let Some(ref cache_capable) = self.config.runner_config.cache_capable {
137 runner_builder = runner_builder.cache_capable(cache_capable.clone());
138 }
139 if let Some(ref request_context) = self.config.runner_config.request_context {
140 runner_builder = runner_builder.request_context(request_context.clone());
141 }
142 if let Some(cancellation_token) = cancellation_token.clone() {
143 runner_builder = runner_builder.cancellation_token(cancellation_token);
144 }
145 let runner = runner_builder.build()?;
146
147 let mut processor =
149 EventProcessor::new(context_id.to_string(), task_id.to_string(), meta.clone());
150
151 let mut results = vec![];
152
153 results.push(UpdateEvent::TaskStatusUpdate(TaskStatusUpdateEvent {
155 task_id: task_id.to_string(),
156 context_id: Some(context_id.to_string()),
157 status: TaskStatus { state: TaskState::Submitted, message: None },
158 final_update: false,
159 }));
160
161 results.push(UpdateEvent::TaskStatusUpdate(TaskStatusUpdateEvent {
163 task_id: task_id.to_string(),
164 context_id: Some(context_id.to_string()),
165 status: TaskStatus { state: TaskState::Working, message: None },
166 final_update: false,
167 }));
168
169 let content = event
171 .llm_response
172 .content
173 .ok_or_else(|| adk_core::AdkError::agent("Event has no content"))?;
174
175 let mut event_stream = runner
176 .run(
177 UserId::new(meta.user_id.clone())?,
178 SessionId::new(meta.session_id.clone())?,
179 content,
180 )
181 .await?;
182
183 while let Some(result) = event_stream.next().await {
185 if cancellation_token.as_ref().is_some_and(CancellationToken::is_cancelled) {
186 results.push(UpdateEvent::TaskStatusUpdate(TaskStatusUpdateEvent {
187 task_id: task_id.to_string(),
188 context_id: Some(context_id.to_string()),
189 status: TaskStatus { state: TaskState::Canceled, message: None },
190 final_update: true,
191 }));
192 return Ok(results);
193 }
194
195 match result {
196 Ok(adk_event) => {
197 if let Some(artifact_event) = processor.process(&adk_event)? {
198 results.push(UpdateEvent::TaskArtifactUpdate(artifact_event));
199 }
200 }
201 Err(e) => {
202 results.push(UpdateEvent::TaskStatusUpdate(TaskStatusUpdateEvent {
204 task_id: task_id.to_string(),
205 context_id: Some(context_id.to_string()),
206 status: TaskStatus {
207 state: TaskState::Failed,
208 message: Some(e.to_string()),
209 },
210 final_update: true,
211 }));
212 return Ok(results);
213 }
214 }
215 }
216
217 if cancellation_token.as_ref().is_some_and(CancellationToken::is_cancelled) {
218 results.push(UpdateEvent::TaskStatusUpdate(TaskStatusUpdateEvent {
219 task_id: task_id.to_string(),
220 context_id: Some(context_id.to_string()),
221 status: TaskStatus { state: TaskState::Canceled, message: None },
222 final_update: true,
223 }));
224 return Ok(results);
225 }
226
227 for terminal_event in processor.make_terminal_events() {
229 results.push(UpdateEvent::TaskStatusUpdate(terminal_event));
230 }
231
232 #[cfg(feature = "a2a-interceptors")]
234 if let Some(chain) = &self.config.interceptor_chain {
235 if let Some(ctx) = &interceptor_ctx {
236 let mut response_value =
237 serde_json::to_value(&results).unwrap_or(serde_json::Value::Null);
238 chain
239 .run_after(ctx, &mut response_value)
240 .await
241 .map_err(|e| adk_core::AdkError::agent(e.to_string()))?;
242 }
243 }
244
245 Ok(results)
246 }
247
248 pub async fn cancel(&self, context_id: &str, task_id: &str) -> Result<TaskStatusUpdateEvent> {
249 Ok(TaskStatusUpdateEvent {
250 task_id: task_id.to_string(),
251 context_id: Some(context_id.to_string()),
252 status: TaskStatus { state: TaskState::Canceled, message: None },
253 final_update: true,
254 })
255 }
256
257 async fn prepare_session(&self, user_id: &str, session_id: &str) -> Result<()> {
258 let session_service = &self.config.runner_config.session_service;
259
260 let get_result = session_service
262 .get(GetRequest {
263 app_name: self.config.app_name.clone(),
264 user_id: user_id.to_string(),
265 session_id: session_id.to_string(),
266 num_recent_events: None,
267 after: None,
268 })
269 .await;
270
271 if get_result.is_ok() {
272 return Ok(());
273 }
274
275 session_service
277 .create(CreateRequest {
278 app_name: self.config.app_name.clone(),
279 user_id: user_id.to_string(),
280 session_id: Some(session_id.to_string()),
281 state: std::collections::HashMap::new(),
282 })
283 .await?;
284
285 Ok(())
286 }
287}