1#[cfg(feature = "mcp-client")]
7use crate::core::McpClientManager;
8use crate::core::config::{AgentConfig, ConfigError, StorageConfig};
9use crate::core::runtime::AgentRuntime;
10use a2a_rs::domain::{
11 A2AError, Task, TaskArtifactUpdateEvent, TaskPushNotificationConfig, TaskState,
12 TaskStatusUpdateEvent,
13};
14use a2a_rs::port::{
15 AsyncMessageHandler, AsyncNotificationManager, AsyncStreamingHandler, AsyncTaskManager,
16 StreamingSubscriber, UpdateEvent,
17};
18use a2a_rs::{HttpPushNotificationSender, InMemoryTaskStorage};
19use async_trait::async_trait;
20use futures::Stream;
21use std::path::Path;
22use std::pin::Pin;
23use std::sync::Arc;
24#[cfg(feature = "mcp-client")]
25use tracing::info;
26
27#[cfg(feature = "sqlx")]
28use a2a_rs::adapter::storage::SqlxTaskStorage;
29
30#[derive(Clone)]
33pub enum AutoStorage {
34 InMemory(InMemoryTaskStorage),
35 #[cfg(feature = "sqlx")]
36 Sqlx(SqlxTaskStorage),
37}
38
39#[async_trait]
40impl AsyncTaskManager for AutoStorage {
41 async fn create_task(&self, task_id: &str, context_id: &str) -> Result<Task, A2AError> {
42 match self {
43 AutoStorage::InMemory(s) => s.create_task(task_id, context_id).await,
44 #[cfg(feature = "sqlx")]
45 AutoStorage::Sqlx(s) => s.create_task(task_id, context_id).await,
46 }
47 }
48
49 async fn get_task(&self, task_id: &str, history_length: Option<u32>) -> Result<Task, A2AError> {
50 match self {
51 AutoStorage::InMemory(s) => s.get_task(task_id, history_length).await,
52 #[cfg(feature = "sqlx")]
53 AutoStorage::Sqlx(s) => s.get_task(task_id, history_length).await,
54 }
55 }
56
57 async fn update_task_status(
58 &self,
59 task_id: &str,
60 state: TaskState,
61 message: Option<a2a_rs::domain::Message>,
62 ) -> Result<Task, A2AError> {
63 match self {
64 AutoStorage::InMemory(s) => s.update_task_status(task_id, state, message).await,
65 #[cfg(feature = "sqlx")]
66 AutoStorage::Sqlx(s) => s.update_task_status(task_id, state, message).await,
67 }
68 }
69
70 async fn cancel_task(&self, task_id: &str) -> Result<Task, A2AError> {
71 match self {
72 AutoStorage::InMemory(s) => s.cancel_task(task_id).await,
73 #[cfg(feature = "sqlx")]
74 AutoStorage::Sqlx(s) => s.cancel_task(task_id).await,
75 }
76 }
77
78 async fn task_exists(&self, task_id: &str) -> Result<bool, A2AError> {
79 match self {
80 AutoStorage::InMemory(s) => s.task_exists(task_id).await,
81 #[cfg(feature = "sqlx")]
82 AutoStorage::Sqlx(s) => s.task_exists(task_id).await,
83 }
84 }
85}
86
87#[async_trait]
88impl AsyncNotificationManager for AutoStorage {
89 async fn set_task_notification(
90 &self,
91 config: &TaskPushNotificationConfig,
92 ) -> Result<TaskPushNotificationConfig, A2AError> {
93 match self {
94 AutoStorage::InMemory(s) => s.set_task_notification(config).await,
95 #[cfg(feature = "sqlx")]
96 AutoStorage::Sqlx(s) => s.set_task_notification(config).await,
97 }
98 }
99
100 async fn get_task_notification(
101 &self,
102 task_id: &str,
103 ) -> Result<TaskPushNotificationConfig, A2AError> {
104 match self {
105 AutoStorage::InMemory(s) => s.get_task_notification(task_id).await,
106 #[cfg(feature = "sqlx")]
107 AutoStorage::Sqlx(s) => s.get_task_notification(task_id).await,
108 }
109 }
110
111 async fn remove_task_notification(&self, task_id: &str) -> Result<(), A2AError> {
112 match self {
113 AutoStorage::InMemory(s) => s.remove_task_notification(task_id).await,
114 #[cfg(feature = "sqlx")]
115 AutoStorage::Sqlx(s) => s.remove_task_notification(task_id).await,
116 }
117 }
118}
119
120#[async_trait]
121impl AsyncStreamingHandler for AutoStorage {
122 async fn add_status_subscriber(
123 &self,
124 task_id: &str,
125 subscriber: Box<dyn StreamingSubscriber<TaskStatusUpdateEvent> + Send + Sync>,
126 ) -> Result<String, A2AError> {
127 match self {
128 AutoStorage::InMemory(s) => s.add_status_subscriber(task_id, subscriber).await,
129 #[cfg(feature = "sqlx")]
130 AutoStorage::Sqlx(s) => s.add_status_subscriber(task_id, subscriber).await,
131 }
132 }
133
134 async fn add_artifact_subscriber(
135 &self,
136 task_id: &str,
137 subscriber: Box<dyn StreamingSubscriber<TaskArtifactUpdateEvent> + Send + Sync>,
138 ) -> Result<String, A2AError> {
139 match self {
140 AutoStorage::InMemory(s) => s.add_artifact_subscriber(task_id, subscriber).await,
141 #[cfg(feature = "sqlx")]
142 AutoStorage::Sqlx(s) => s.add_artifact_subscriber(task_id, subscriber).await,
143 }
144 }
145
146 async fn remove_subscription(&self, subscription_id: &str) -> Result<(), A2AError> {
147 match self {
148 AutoStorage::InMemory(s) => s.remove_subscription(subscription_id).await,
149 #[cfg(feature = "sqlx")]
150 AutoStorage::Sqlx(s) => s.remove_subscription(subscription_id).await,
151 }
152 }
153
154 async fn remove_task_subscribers(&self, task_id: &str) -> Result<(), A2AError> {
155 match self {
156 AutoStorage::InMemory(s) => s.remove_task_subscribers(task_id).await,
157 #[cfg(feature = "sqlx")]
158 AutoStorage::Sqlx(s) => s.remove_task_subscribers(task_id).await,
159 }
160 }
161
162 async fn get_subscriber_count(&self, task_id: &str) -> Result<usize, A2AError> {
163 match self {
164 AutoStorage::InMemory(s) => s.get_subscriber_count(task_id).await,
165 #[cfg(feature = "sqlx")]
166 AutoStorage::Sqlx(s) => s.get_subscriber_count(task_id).await,
167 }
168 }
169
170 async fn broadcast_status_update(
171 &self,
172 task_id: &str,
173 update: TaskStatusUpdateEvent,
174 ) -> Result<(), A2AError> {
175 match self {
176 AutoStorage::InMemory(s) => s.broadcast_status_update(task_id, update).await,
177 #[cfg(feature = "sqlx")]
178 AutoStorage::Sqlx(s) => s.broadcast_status_update(task_id, update).await,
179 }
180 }
181
182 async fn broadcast_artifact_update(
183 &self,
184 task_id: &str,
185 update: TaskArtifactUpdateEvent,
186 ) -> Result<(), A2AError> {
187 match self {
188 AutoStorage::InMemory(s) => s.broadcast_artifact_update(task_id, update).await,
189 #[cfg(feature = "sqlx")]
190 AutoStorage::Sqlx(s) => s.broadcast_artifact_update(task_id, update).await,
191 }
192 }
193
194 async fn status_update_stream(
195 &self,
196 task_id: &str,
197 ) -> Result<Pin<Box<dyn Stream<Item = Result<TaskStatusUpdateEvent, A2AError>> + Send>>, A2AError>
198 {
199 match self {
200 AutoStorage::InMemory(s) => s.status_update_stream(task_id).await,
201 #[cfg(feature = "sqlx")]
202 AutoStorage::Sqlx(s) => s.status_update_stream(task_id).await,
203 }
204 }
205
206 async fn artifact_update_stream(
207 &self,
208 task_id: &str,
209 ) -> Result<
210 Pin<Box<dyn Stream<Item = Result<TaskArtifactUpdateEvent, A2AError>> + Send>>,
211 A2AError,
212 > {
213 match self {
214 AutoStorage::InMemory(s) => s.artifact_update_stream(task_id).await,
215 #[cfg(feature = "sqlx")]
216 AutoStorage::Sqlx(s) => s.artifact_update_stream(task_id).await,
217 }
218 }
219
220 async fn combined_update_stream(
221 &self,
222 task_id: &str,
223 ) -> Result<Pin<Box<dyn Stream<Item = Result<UpdateEvent, A2AError>> + Send>>, A2AError> {
224 match self {
225 AutoStorage::InMemory(s) => s.combined_update_stream(task_id).await,
226 #[cfg(feature = "sqlx")]
227 AutoStorage::Sqlx(s) => s.combined_update_stream(task_id).await,
228 }
229 }
230}
231
232pub struct AgentBuilder<H = (), S = ()> {
234 config: AgentConfig,
235 handler: Option<H>,
236 storage: Option<S>,
237}
238
239impl AgentBuilder<(), ()> {
240 pub fn from_file<P: AsRef<Path>>(path: P) -> Result<Self, ConfigError> {
242 let config = AgentConfig::from_file(path)?;
243 Ok(Self {
244 config,
245 handler: None,
246 storage: None,
247 })
248 }
249
250 pub fn from_toml(toml: &str) -> Result<Self, ConfigError> {
252 let config = AgentConfig::from_toml(toml)?;
253 Ok(Self {
254 config,
255 handler: None,
256 storage: None,
257 })
258 }
259
260 pub fn new(config: AgentConfig) -> Self {
262 Self {
263 config,
264 handler: None,
265 storage: None,
266 }
267 }
268}
269
270impl<H, S> AgentBuilder<H, S> {
271 pub fn with_handler<NewH>(self, handler: NewH) -> AgentBuilder<NewH, S>
273 where
274 NewH: AsyncMessageHandler + Clone + Send + Sync + 'static,
275 {
276 AgentBuilder {
277 config: self.config,
278 handler: Some(handler),
279 storage: self.storage,
280 }
281 }
282
283 pub fn with_storage<NewS>(self, storage: NewS) -> AgentBuilder<H, NewS>
285 where
286 NewS: AsyncTaskManager + AsyncNotificationManager + Clone + Send + Sync + 'static,
287 {
288 AgentBuilder {
289 config: self.config,
290 handler: self.handler,
291 storage: Some(storage),
292 }
293 }
294
295 pub fn config(&self) -> &AgentConfig {
297 &self.config
298 }
299
300 pub fn with_config<F>(mut self, f: F) -> Self
302 where
303 F: FnOnce(&mut AgentConfig),
304 {
305 f(&mut self.config);
306 self
307 }
308}
309
310impl<H, S> AgentBuilder<H, S>
311where
312 H: AsyncMessageHandler + Clone + Send + Sync + 'static,
313 S: AsyncTaskManager + AsyncNotificationManager + Clone + Send + Sync + 'static,
314{
315 pub fn build(self) -> Result<AgentRuntime<H, S>, BuildError> {
317 let handler = self.handler.ok_or(BuildError::MissingHandler)?;
318 let storage = self.storage.ok_or(BuildError::MissingStorage)?;
319
320 Ok(AgentRuntime::new(
321 self.config,
322 Arc::new(handler),
323 Arc::new(storage),
324 ))
325 }
326}
327
328impl<H> AgentBuilder<H, ()>
329where
330 H: AsyncMessageHandler + Clone + Send + Sync + 'static,
331{
332 pub async fn build_with_auto_storage(self) -> Result<AgentRuntime<H, AutoStorage>, BuildError> {
336 let handler = self.handler.ok_or(BuildError::MissingHandler)?;
337
338 let storage = match &self.config.server.storage {
339 StorageConfig::InMemory => {
340 let push_sender = HttpPushNotificationSender::new()
341 .with_timeout(30)
342 .with_max_retries(3);
343 let storage = InMemoryTaskStorage::with_push_sender(push_sender);
344 AutoStorage::InMemory(storage)
345 }
346 #[cfg(feature = "sqlx")]
347 StorageConfig::Sqlx {
348 url,
349 enable_logging,
350 ..
351 } => {
352 if *enable_logging {
353 tracing::info!("SQL query logging enabled");
354 }
355
356 let storage = SqlxTaskStorage::new(url).await.map_err(|e| {
357 BuildError::StorageError(format!("Failed to create SQLx storage: {}", e))
358 })?;
359
360 AutoStorage::Sqlx(storage)
361 }
362 #[cfg(not(feature = "sqlx"))]
363 StorageConfig::Sqlx { .. } => {
364 return Err(BuildError::StorageError(
365 "SQLx storage requested but 'sqlx' feature is not enabled".to_string(),
366 ));
367 }
368 };
369
370 #[cfg(feature = "mcp-client")]
372 if self.config.features.mcp_client.enabled {
373 info!("Initializing MCP client...");
374 let mcp_client = McpClientManager::new();
375
376 if let Err(e) = mcp_client
378 .initialize(&self.config.features.mcp_client)
379 .await
380 {
381 return Err(BuildError::RuntimeError(format!(
382 "Failed to initialize MCP client: {}",
383 e
384 )));
385 }
386
387 return Ok(AgentRuntime::with_mcp_client(
388 self.config,
389 Arc::new(handler),
390 Arc::new(storage),
391 mcp_client,
392 ));
393 }
394
395 Ok(AgentRuntime::new(
396 self.config,
397 Arc::new(handler),
398 Arc::new(storage),
399 ))
400 }
401
402 #[cfg(feature = "sqlx")]
405 pub async fn build_with_auto_storage_and_migrations(
406 self,
407 migrations: &'static [&'static str],
408 ) -> Result<AgentRuntime<H, AutoStorage>, BuildError> {
409 let handler = self.handler.ok_or(BuildError::MissingHandler)?;
410
411 let storage = match &self.config.server.storage {
412 StorageConfig::InMemory => {
413 tracing::warn!(
414 "Migrations provided but using in-memory storage - migrations ignored"
415 );
416 let push_sender = HttpPushNotificationSender::new()
417 .with_timeout(30)
418 .with_max_retries(3);
419 let storage = InMemoryTaskStorage::with_push_sender(push_sender);
420 AutoStorage::InMemory(storage)
421 }
422 StorageConfig::Sqlx {
423 url,
424 enable_logging,
425 ..
426 } => {
427 if *enable_logging {
428 tracing::info!("SQL query logging enabled");
429 }
430
431 let storage = SqlxTaskStorage::with_migrations(url, migrations)
432 .await
433 .map_err(|e| {
434 BuildError::StorageError(format!("Failed to create SQLx storage: {}", e))
435 })?;
436
437 AutoStorage::Sqlx(storage)
438 }
439 };
440
441 #[cfg(feature = "mcp-client")]
443 if self.config.features.mcp_client.enabled {
444 info!("Initializing MCP client...");
445 let mcp_client = McpClientManager::new();
446
447 if let Err(e) = mcp_client
449 .initialize(&self.config.features.mcp_client)
450 .await
451 {
452 return Err(BuildError::RuntimeError(format!(
453 "Failed to initialize MCP client: {}",
454 e
455 )));
456 }
457
458 return Ok(AgentRuntime::with_mcp_client(
459 self.config,
460 Arc::new(handler),
461 Arc::new(storage),
462 mcp_client,
463 ));
464 }
465
466 Ok(AgentRuntime::new(
467 self.config,
468 Arc::new(handler),
469 Arc::new(storage),
470 ))
471 }
472}
473
474#[derive(Debug, thiserror::Error)]
476pub enum BuildError {
477 #[error("Handler must be set before building")]
478 MissingHandler,
479
480 #[error("Storage must be set before building")]
481 MissingStorage,
482
483 #[error("Configuration error: {0}")]
484 ConfigError(#[from] ConfigError),
485
486 #[error("Storage error: {0}")]
487 StorageError(String),
488
489 #[error("Runtime error: {0}")]
490 RuntimeError(String),
491}
492
493#[cfg(test)]
494mod tests {
495 use super::*;
496
497 #[test]
498 fn test_builder_from_toml() {
499 let toml = r#"
500 [agent]
501 name = "Test Agent"
502
503 [server]
504 http_port = 9000
505 "#;
506
507 let builder = AgentBuilder::from_toml(toml).unwrap();
508 assert_eq!(builder.config().agent.name, "Test Agent");
509 assert_eq!(builder.config().server.http_port, 9000);
510 }
511
512 #[test]
513 fn test_builder_config_modification() {
514 let toml = r#"
515 [agent]
516 name = "Test Agent"
517 "#;
518
519 let builder = AgentBuilder::from_toml(toml)
520 .unwrap()
521 .with_config(|config| {
522 config.server.http_port = 7000;
523 });
524
525 assert_eq!(builder.config().server.http_port, 7000);
526 }
527}