Skip to main content

adk_server/a2a/
executor.rs

1use crate::a2a::{
2    Message, TaskState, TaskStatus, TaskStatusUpdateEvent, UpdateEvent, events::message_to_event,
3    metadata::to_invocation_meta, processor::EventProcessor,
4};
5use adk_core::{Result, SessionId, UserId};
6use adk_runner::{Runner, RunnerConfig};
7use adk_session::{CreateRequest, GetRequest};
8use futures::StreamExt;
9use std::sync::Arc;
10use tokio_util::sync::CancellationToken;
11
12pub struct ExecutorConfig {
13    pub app_name: String,
14    pub runner_config: Arc<RunnerConfig>,
15    pub cancellation_token: Option<CancellationToken>,
16}
17
18pub struct Executor {
19    config: ExecutorConfig,
20}
21
22impl Executor {
23    pub fn new(config: ExecutorConfig) -> Self {
24        Self { config }
25    }
26
27    pub async fn execute(
28        &self,
29        context_id: &str,
30        task_id: &str,
31        message: &Message,
32    ) -> Result<Vec<UpdateEvent>> {
33        let meta = to_invocation_meta(&self.config.app_name, context_id, None);
34        let cancellation_token = self.config.cancellation_token.clone();
35
36        // Prepare session
37        self.prepare_session(&meta.user_id, &meta.session_id).await?;
38
39        // Convert message to event
40        let invocation_id = uuid::Uuid::new_v4().to_string();
41        let event = message_to_event(message, invocation_id)?;
42
43        // Create runner
44        let runner = Runner::new(RunnerConfig {
45            app_name: self.config.runner_config.app_name.clone(),
46            agent: self.config.runner_config.agent.clone(),
47            session_service: self.config.runner_config.session_service.clone(),
48            artifact_service: self.config.runner_config.artifact_service.clone(),
49            memory_service: self.config.runner_config.memory_service.clone(),
50            plugin_manager: self.config.runner_config.plugin_manager.clone(),
51            run_config: self.config.runner_config.run_config.clone(),
52            compaction_config: self.config.runner_config.compaction_config.clone(),
53            context_cache_config: self.config.runner_config.context_cache_config.clone(),
54            cache_capable: self.config.runner_config.cache_capable.clone(),
55            request_context: self.config.runner_config.request_context.clone(),
56            cancellation_token: cancellation_token.clone(),
57            intra_compaction_config: None,
58            intra_compaction_summarizer: None,
59        })?;
60
61        // Create processor
62        let mut processor =
63            EventProcessor::new(context_id.to_string(), task_id.to_string(), meta.clone());
64
65        let mut results = vec![];
66
67        // Send submitted event
68        results.push(UpdateEvent::TaskStatusUpdate(TaskStatusUpdateEvent {
69            task_id: task_id.to_string(),
70            context_id: Some(context_id.to_string()),
71            status: TaskStatus { state: TaskState::Submitted, message: None },
72            final_update: false,
73        }));
74
75        // Send working event
76        results.push(UpdateEvent::TaskStatusUpdate(TaskStatusUpdateEvent {
77            task_id: task_id.to_string(),
78            context_id: Some(context_id.to_string()),
79            status: TaskStatus { state: TaskState::Working, message: None },
80            final_update: false,
81        }));
82
83        // Run agent
84        let content = event
85            .llm_response
86            .content
87            .ok_or_else(|| adk_core::AdkError::agent("Event has no content"))?;
88
89        let mut event_stream = runner
90            .run(
91                UserId::new(meta.user_id.clone())?,
92                SessionId::new(meta.session_id.clone())?,
93                content,
94            )
95            .await?;
96
97        // Process events
98        while let Some(result) = event_stream.next().await {
99            if cancellation_token.as_ref().is_some_and(CancellationToken::is_cancelled) {
100                results.push(UpdateEvent::TaskStatusUpdate(TaskStatusUpdateEvent {
101                    task_id: task_id.to_string(),
102                    context_id: Some(context_id.to_string()),
103                    status: TaskStatus { state: TaskState::Canceled, message: None },
104                    final_update: true,
105                }));
106                return Ok(results);
107            }
108
109            match result {
110                Ok(adk_event) => {
111                    if let Some(artifact_event) = processor.process(&adk_event)? {
112                        results.push(UpdateEvent::TaskArtifactUpdate(artifact_event));
113                    }
114                }
115                Err(e) => {
116                    // Send failed event
117                    results.push(UpdateEvent::TaskStatusUpdate(TaskStatusUpdateEvent {
118                        task_id: task_id.to_string(),
119                        context_id: Some(context_id.to_string()),
120                        status: TaskStatus {
121                            state: TaskState::Failed,
122                            message: Some(e.to_string()),
123                        },
124                        final_update: true,
125                    }));
126                    return Ok(results);
127                }
128            }
129        }
130
131        if cancellation_token.as_ref().is_some_and(CancellationToken::is_cancelled) {
132            results.push(UpdateEvent::TaskStatusUpdate(TaskStatusUpdateEvent {
133                task_id: task_id.to_string(),
134                context_id: Some(context_id.to_string()),
135                status: TaskStatus { state: TaskState::Canceled, message: None },
136                final_update: true,
137            }));
138            return Ok(results);
139        }
140
141        // Send terminal events
142        for terminal_event in processor.make_terminal_events() {
143            results.push(UpdateEvent::TaskStatusUpdate(terminal_event));
144        }
145
146        Ok(results)
147    }
148
149    pub async fn cancel(&self, context_id: &str, task_id: &str) -> Result<TaskStatusUpdateEvent> {
150        Ok(TaskStatusUpdateEvent {
151            task_id: task_id.to_string(),
152            context_id: Some(context_id.to_string()),
153            status: TaskStatus { state: TaskState::Canceled, message: None },
154            final_update: true,
155        })
156    }
157
158    async fn prepare_session(&self, user_id: &str, session_id: &str) -> Result<()> {
159        let session_service = &self.config.runner_config.session_service;
160
161        // Try to get existing session
162        let get_result = session_service
163            .get(GetRequest {
164                app_name: self.config.app_name.clone(),
165                user_id: user_id.to_string(),
166                session_id: session_id.to_string(),
167                num_recent_events: None,
168                after: None,
169            })
170            .await;
171
172        if get_result.is_ok() {
173            return Ok(());
174        }
175
176        // Create new session
177        session_service
178            .create(CreateRequest {
179                app_name: self.config.app_name.clone(),
180                user_id: user_id.to_string(),
181                session_id: Some(session_id.to_string()),
182                state: std::collections::HashMap::new(),
183            })
184            .await?;
185
186        Ok(())
187    }
188}