adk_server/a2a/
executor.rs1use crate::a2a::{
2 Message, TaskState, TaskStatus, TaskStatusUpdateEvent, UpdateEvent, events::message_to_event,
3 metadata::to_invocation_meta, processor::EventProcessor,
4};
5use adk_core::{Result, SessionId, UserId};
6use adk_runner::{Runner, RunnerConfig};
7use adk_session::{CreateRequest, GetRequest};
8use futures::StreamExt;
9use std::sync::Arc;
10use tokio_util::sync::CancellationToken;
11
12pub struct ExecutorConfig {
13 pub app_name: String,
14 pub runner_config: Arc<RunnerConfig>,
15 pub cancellation_token: Option<CancellationToken>,
16}
17
18pub struct Executor {
19 config: ExecutorConfig,
20}
21
22impl Executor {
23 pub fn new(config: ExecutorConfig) -> Self {
24 Self { config }
25 }
26
27 pub async fn execute(
28 &self,
29 context_id: &str,
30 task_id: &str,
31 message: &Message,
32 ) -> Result<Vec<UpdateEvent>> {
33 let meta = to_invocation_meta(&self.config.app_name, context_id, None);
34 let cancellation_token = self.config.cancellation_token.clone();
35
36 self.prepare_session(&meta.user_id, &meta.session_id).await?;
38
39 let invocation_id = uuid::Uuid::new_v4().to_string();
41 let event = message_to_event(message, invocation_id)?;
42
43 let runner = Runner::new(RunnerConfig {
45 app_name: self.config.runner_config.app_name.clone(),
46 agent: self.config.runner_config.agent.clone(),
47 session_service: self.config.runner_config.session_service.clone(),
48 artifact_service: self.config.runner_config.artifact_service.clone(),
49 memory_service: self.config.runner_config.memory_service.clone(),
50 plugin_manager: self.config.runner_config.plugin_manager.clone(),
51 run_config: self.config.runner_config.run_config.clone(),
52 compaction_config: self.config.runner_config.compaction_config.clone(),
53 context_cache_config: self.config.runner_config.context_cache_config.clone(),
54 cache_capable: self.config.runner_config.cache_capable.clone(),
55 request_context: self.config.runner_config.request_context.clone(),
56 cancellation_token: cancellation_token.clone(),
57 })?;
58
59 let mut processor =
61 EventProcessor::new(context_id.to_string(), task_id.to_string(), meta.clone());
62
63 let mut results = vec![];
64
65 results.push(UpdateEvent::TaskStatusUpdate(TaskStatusUpdateEvent {
67 task_id: task_id.to_string(),
68 context_id: Some(context_id.to_string()),
69 status: TaskStatus { state: TaskState::Submitted, message: None },
70 final_update: false,
71 }));
72
73 results.push(UpdateEvent::TaskStatusUpdate(TaskStatusUpdateEvent {
75 task_id: task_id.to_string(),
76 context_id: Some(context_id.to_string()),
77 status: TaskStatus { state: TaskState::Working, message: None },
78 final_update: false,
79 }));
80
81 let content = event
83 .llm_response
84 .content
85 .ok_or_else(|| adk_core::AdkError::agent("Event has no content"))?;
86
87 let mut event_stream = runner
88 .run(
89 UserId::new(meta.user_id.clone())?,
90 SessionId::new(meta.session_id.clone())?,
91 content,
92 )
93 .await?;
94
95 while let Some(result) = event_stream.next().await {
97 if cancellation_token.as_ref().is_some_and(CancellationToken::is_cancelled) {
98 results.push(UpdateEvent::TaskStatusUpdate(TaskStatusUpdateEvent {
99 task_id: task_id.to_string(),
100 context_id: Some(context_id.to_string()),
101 status: TaskStatus { state: TaskState::Canceled, message: None },
102 final_update: true,
103 }));
104 return Ok(results);
105 }
106
107 match result {
108 Ok(adk_event) => {
109 if let Some(artifact_event) = processor.process(&adk_event)? {
110 results.push(UpdateEvent::TaskArtifactUpdate(artifact_event));
111 }
112 }
113 Err(e) => {
114 results.push(UpdateEvent::TaskStatusUpdate(TaskStatusUpdateEvent {
116 task_id: task_id.to_string(),
117 context_id: Some(context_id.to_string()),
118 status: TaskStatus {
119 state: TaskState::Failed,
120 message: Some(e.to_string()),
121 },
122 final_update: true,
123 }));
124 return Ok(results);
125 }
126 }
127 }
128
129 if cancellation_token.as_ref().is_some_and(CancellationToken::is_cancelled) {
130 results.push(UpdateEvent::TaskStatusUpdate(TaskStatusUpdateEvent {
131 task_id: task_id.to_string(),
132 context_id: Some(context_id.to_string()),
133 status: TaskStatus { state: TaskState::Canceled, message: None },
134 final_update: true,
135 }));
136 return Ok(results);
137 }
138
139 for terminal_event in processor.make_terminal_events() {
141 results.push(UpdateEvent::TaskStatusUpdate(terminal_event));
142 }
143
144 Ok(results)
145 }
146
147 pub async fn cancel(&self, context_id: &str, task_id: &str) -> Result<TaskStatusUpdateEvent> {
148 Ok(TaskStatusUpdateEvent {
149 task_id: task_id.to_string(),
150 context_id: Some(context_id.to_string()),
151 status: TaskStatus { state: TaskState::Canceled, message: None },
152 final_update: true,
153 })
154 }
155
156 async fn prepare_session(&self, user_id: &str, session_id: &str) -> Result<()> {
157 let session_service = &self.config.runner_config.session_service;
158
159 let get_result = session_service
161 .get(GetRequest {
162 app_name: self.config.app_name.clone(),
163 user_id: user_id.to_string(),
164 session_id: session_id.to_string(),
165 num_recent_events: None,
166 after: None,
167 })
168 .await;
169
170 if get_result.is_ok() {
171 return Ok(());
172 }
173
174 session_service
176 .create(CreateRequest {
177 app_name: self.config.app_name.clone(),
178 user_id: user_id.to_string(),
179 session_id: Some(session_id.to_string()),
180 state: std::collections::HashMap::new(),
181 })
182 .await?;
183
184 Ok(())
185 }
186}