Skip to main content

adk_server/a2a/
executor.rs

1use crate::a2a::{
2    Message, TaskState, TaskStatus, TaskStatusUpdateEvent, UpdateEvent, events::message_to_event,
3    metadata::to_invocation_meta, processor::EventProcessor,
4};
5use adk_core::{Result, SessionId, UserId};
6use adk_runner::{Runner, RunnerConfig};
7use adk_session::{CreateRequest, GetRequest};
8use futures::StreamExt;
9use std::sync::Arc;
10use tokio_util::sync::CancellationToken;
11
12pub struct ExecutorConfig {
13    pub app_name: String,
14    pub runner_config: Arc<RunnerConfig>,
15    pub cancellation_token: Option<CancellationToken>,
16}
17
18pub struct Executor {
19    config: ExecutorConfig,
20}
21
22impl Executor {
23    pub fn new(config: ExecutorConfig) -> Self {
24        Self { config }
25    }
26
27    pub async fn execute(
28        &self,
29        context_id: &str,
30        task_id: &str,
31        message: &Message,
32    ) -> Result<Vec<UpdateEvent>> {
33        let meta = to_invocation_meta(&self.config.app_name, context_id, None);
34        let cancellation_token = self.config.cancellation_token.clone();
35
36        // Prepare session
37        self.prepare_session(&meta.user_id, &meta.session_id).await?;
38
39        // Convert message to event
40        let invocation_id = uuid::Uuid::new_v4().to_string();
41        let event = message_to_event(message, invocation_id)?;
42
43        // Create runner
44        let runner = Runner::new(RunnerConfig {
45            app_name: self.config.runner_config.app_name.clone(),
46            agent: self.config.runner_config.agent.clone(),
47            session_service: self.config.runner_config.session_service.clone(),
48            artifact_service: self.config.runner_config.artifact_service.clone(),
49            memory_service: self.config.runner_config.memory_service.clone(),
50            plugin_manager: self.config.runner_config.plugin_manager.clone(),
51            run_config: self.config.runner_config.run_config.clone(),
52            compaction_config: self.config.runner_config.compaction_config.clone(),
53            context_cache_config: self.config.runner_config.context_cache_config.clone(),
54            cache_capable: self.config.runner_config.cache_capable.clone(),
55            request_context: self.config.runner_config.request_context.clone(),
56            cancellation_token: cancellation_token.clone(),
57        })?;
58
59        // Create processor
60        let mut processor =
61            EventProcessor::new(context_id.to_string(), task_id.to_string(), meta.clone());
62
63        let mut results = vec![];
64
65        // Send submitted event
66        results.push(UpdateEvent::TaskStatusUpdate(TaskStatusUpdateEvent {
67            task_id: task_id.to_string(),
68            context_id: Some(context_id.to_string()),
69            status: TaskStatus { state: TaskState::Submitted, message: None },
70            final_update: false,
71        }));
72
73        // Send working event
74        results.push(UpdateEvent::TaskStatusUpdate(TaskStatusUpdateEvent {
75            task_id: task_id.to_string(),
76            context_id: Some(context_id.to_string()),
77            status: TaskStatus { state: TaskState::Working, message: None },
78            final_update: false,
79        }));
80
81        // Run agent
82        let content = event
83            .llm_response
84            .content
85            .ok_or_else(|| adk_core::AdkError::agent("Event has no content"))?;
86
87        let mut event_stream = runner
88            .run(
89                UserId::new(meta.user_id.clone())?,
90                SessionId::new(meta.session_id.clone())?,
91                content,
92            )
93            .await?;
94
95        // Process events
96        while let Some(result) = event_stream.next().await {
97            if cancellation_token.as_ref().is_some_and(CancellationToken::is_cancelled) {
98                results.push(UpdateEvent::TaskStatusUpdate(TaskStatusUpdateEvent {
99                    task_id: task_id.to_string(),
100                    context_id: Some(context_id.to_string()),
101                    status: TaskStatus { state: TaskState::Canceled, message: None },
102                    final_update: true,
103                }));
104                return Ok(results);
105            }
106
107            match result {
108                Ok(adk_event) => {
109                    if let Some(artifact_event) = processor.process(&adk_event)? {
110                        results.push(UpdateEvent::TaskArtifactUpdate(artifact_event));
111                    }
112                }
113                Err(e) => {
114                    // Send failed event
115                    results.push(UpdateEvent::TaskStatusUpdate(TaskStatusUpdateEvent {
116                        task_id: task_id.to_string(),
117                        context_id: Some(context_id.to_string()),
118                        status: TaskStatus {
119                            state: TaskState::Failed,
120                            message: Some(e.to_string()),
121                        },
122                        final_update: true,
123                    }));
124                    return Ok(results);
125                }
126            }
127        }
128
129        if cancellation_token.as_ref().is_some_and(CancellationToken::is_cancelled) {
130            results.push(UpdateEvent::TaskStatusUpdate(TaskStatusUpdateEvent {
131                task_id: task_id.to_string(),
132                context_id: Some(context_id.to_string()),
133                status: TaskStatus { state: TaskState::Canceled, message: None },
134                final_update: true,
135            }));
136            return Ok(results);
137        }
138
139        // Send terminal events
140        for terminal_event in processor.make_terminal_events() {
141            results.push(UpdateEvent::TaskStatusUpdate(terminal_event));
142        }
143
144        Ok(results)
145    }
146
147    pub async fn cancel(&self, context_id: &str, task_id: &str) -> Result<TaskStatusUpdateEvent> {
148        Ok(TaskStatusUpdateEvent {
149            task_id: task_id.to_string(),
150            context_id: Some(context_id.to_string()),
151            status: TaskStatus { state: TaskState::Canceled, message: None },
152            final_update: true,
153        })
154    }
155
156    async fn prepare_session(&self, user_id: &str, session_id: &str) -> Result<()> {
157        let session_service = &self.config.runner_config.session_service;
158
159        // Try to get existing session
160        let get_result = session_service
161            .get(GetRequest {
162                app_name: self.config.app_name.clone(),
163                user_id: user_id.to_string(),
164                session_id: session_id.to_string(),
165                num_recent_events: None,
166                after: None,
167            })
168            .await;
169
170        if get_result.is_ok() {
171            return Ok(());
172        }
173
174        // Create new session
175        session_service
176            .create(CreateRequest {
177                app_name: self.config.app_name.clone(),
178                user_id: user_id.to_string(),
179                session_id: Some(session_id.to_string()),
180                state: std::collections::HashMap::new(),
181            })
182            .await?;
183
184        Ok(())
185    }
186}