Skip to main content

a2a_agents/core/
builder.rs

1//! Agent builder for declarative agent construction
2//!
3//! Provides a fluent API for building agents from configuration files
4//! or programmatically with minimal boilerplate.
5
6#[cfg(feature = "mcp-client")]
7use crate::core::McpClientManager;
8use crate::core::config::{AgentConfig, ConfigError, StorageConfig};
9use crate::core::runtime::AgentRuntime;
10use a2a_rs::domain::{
11    A2AError, Task, TaskArtifactUpdateEvent, TaskPushNotificationConfig, TaskState,
12    TaskStatusUpdateEvent,
13};
14use a2a_rs::port::{
15    AsyncMessageHandler, AsyncNotificationManager, AsyncStreamingHandler, AsyncTaskManager,
16    StreamingSubscriber, UpdateEvent,
17};
18use a2a_rs::{HttpPushNotificationSender, InMemoryTaskStorage};
19use async_trait::async_trait;
20use futures::Stream;
21use std::path::Path;
22use std::pin::Pin;
23use std::sync::Arc;
24#[cfg(feature = "mcp-client")]
25use tracing::info;
26
27#[cfg(feature = "sqlx")]
28use a2a_rs::adapter::storage::SqlxTaskStorage;
29
30/// Storage wrapper that can hold either in-memory or SQLx storage
31/// This allows us to return different storage types from the builder
32#[derive(Clone)]
33pub enum AutoStorage {
34    InMemory(InMemoryTaskStorage),
35    #[cfg(feature = "sqlx")]
36    Sqlx(SqlxTaskStorage),
37}
38
39#[async_trait]
40impl AsyncTaskManager for AutoStorage {
41    async fn create_task(&self, task_id: &str, context_id: &str) -> Result<Task, A2AError> {
42        match self {
43            AutoStorage::InMemory(s) => s.create_task(task_id, context_id).await,
44            #[cfg(feature = "sqlx")]
45            AutoStorage::Sqlx(s) => s.create_task(task_id, context_id).await,
46        }
47    }
48
49    async fn get_task(&self, task_id: &str, history_length: Option<u32>) -> Result<Task, A2AError> {
50        match self {
51            AutoStorage::InMemory(s) => s.get_task(task_id, history_length).await,
52            #[cfg(feature = "sqlx")]
53            AutoStorage::Sqlx(s) => s.get_task(task_id, history_length).await,
54        }
55    }
56
57    async fn update_task_status(
58        &self,
59        task_id: &str,
60        state: TaskState,
61        message: Option<a2a_rs::domain::Message>,
62    ) -> Result<Task, A2AError> {
63        match self {
64            AutoStorage::InMemory(s) => s.update_task_status(task_id, state, message).await,
65            #[cfg(feature = "sqlx")]
66            AutoStorage::Sqlx(s) => s.update_task_status(task_id, state, message).await,
67        }
68    }
69
70    async fn cancel_task(&self, task_id: &str) -> Result<Task, A2AError> {
71        match self {
72            AutoStorage::InMemory(s) => s.cancel_task(task_id).await,
73            #[cfg(feature = "sqlx")]
74            AutoStorage::Sqlx(s) => s.cancel_task(task_id).await,
75        }
76    }
77
78    async fn task_exists(&self, task_id: &str) -> Result<bool, A2AError> {
79        match self {
80            AutoStorage::InMemory(s) => s.task_exists(task_id).await,
81            #[cfg(feature = "sqlx")]
82            AutoStorage::Sqlx(s) => s.task_exists(task_id).await,
83        }
84    }
85}
86
87#[async_trait]
88impl AsyncNotificationManager for AutoStorage {
89    async fn set_task_notification(
90        &self,
91        config: &TaskPushNotificationConfig,
92    ) -> Result<TaskPushNotificationConfig, A2AError> {
93        match self {
94            AutoStorage::InMemory(s) => s.set_task_notification(config).await,
95            #[cfg(feature = "sqlx")]
96            AutoStorage::Sqlx(s) => s.set_task_notification(config).await,
97        }
98    }
99
100    async fn get_task_notification(
101        &self,
102        task_id: &str,
103    ) -> Result<TaskPushNotificationConfig, A2AError> {
104        match self {
105            AutoStorage::InMemory(s) => s.get_task_notification(task_id).await,
106            #[cfg(feature = "sqlx")]
107            AutoStorage::Sqlx(s) => s.get_task_notification(task_id).await,
108        }
109    }
110
111    async fn remove_task_notification(&self, task_id: &str) -> Result<(), A2AError> {
112        match self {
113            AutoStorage::InMemory(s) => s.remove_task_notification(task_id).await,
114            #[cfg(feature = "sqlx")]
115            AutoStorage::Sqlx(s) => s.remove_task_notification(task_id).await,
116        }
117    }
118}
119
120#[async_trait]
121impl AsyncStreamingHandler for AutoStorage {
122    async fn add_status_subscriber(
123        &self,
124        task_id: &str,
125        subscriber: Box<dyn StreamingSubscriber<TaskStatusUpdateEvent> + Send + Sync>,
126    ) -> Result<String, A2AError> {
127        match self {
128            AutoStorage::InMemory(s) => s.add_status_subscriber(task_id, subscriber).await,
129            #[cfg(feature = "sqlx")]
130            AutoStorage::Sqlx(s) => s.add_status_subscriber(task_id, subscriber).await,
131        }
132    }
133
134    async fn add_artifact_subscriber(
135        &self,
136        task_id: &str,
137        subscriber: Box<dyn StreamingSubscriber<TaskArtifactUpdateEvent> + Send + Sync>,
138    ) -> Result<String, A2AError> {
139        match self {
140            AutoStorage::InMemory(s) => s.add_artifact_subscriber(task_id, subscriber).await,
141            #[cfg(feature = "sqlx")]
142            AutoStorage::Sqlx(s) => s.add_artifact_subscriber(task_id, subscriber).await,
143        }
144    }
145
146    async fn remove_subscription(&self, subscription_id: &str) -> Result<(), A2AError> {
147        match self {
148            AutoStorage::InMemory(s) => s.remove_subscription(subscription_id).await,
149            #[cfg(feature = "sqlx")]
150            AutoStorage::Sqlx(s) => s.remove_subscription(subscription_id).await,
151        }
152    }
153
154    async fn remove_task_subscribers(&self, task_id: &str) -> Result<(), A2AError> {
155        match self {
156            AutoStorage::InMemory(s) => s.remove_task_subscribers(task_id).await,
157            #[cfg(feature = "sqlx")]
158            AutoStorage::Sqlx(s) => s.remove_task_subscribers(task_id).await,
159        }
160    }
161
162    async fn get_subscriber_count(&self, task_id: &str) -> Result<usize, A2AError> {
163        match self {
164            AutoStorage::InMemory(s) => s.get_subscriber_count(task_id).await,
165            #[cfg(feature = "sqlx")]
166            AutoStorage::Sqlx(s) => s.get_subscriber_count(task_id).await,
167        }
168    }
169
170    async fn broadcast_status_update(
171        &self,
172        task_id: &str,
173        update: TaskStatusUpdateEvent,
174    ) -> Result<(), A2AError> {
175        match self {
176            AutoStorage::InMemory(s) => s.broadcast_status_update(task_id, update).await,
177            #[cfg(feature = "sqlx")]
178            AutoStorage::Sqlx(s) => s.broadcast_status_update(task_id, update).await,
179        }
180    }
181
182    async fn broadcast_artifact_update(
183        &self,
184        task_id: &str,
185        update: TaskArtifactUpdateEvent,
186    ) -> Result<(), A2AError> {
187        match self {
188            AutoStorage::InMemory(s) => s.broadcast_artifact_update(task_id, update).await,
189            #[cfg(feature = "sqlx")]
190            AutoStorage::Sqlx(s) => s.broadcast_artifact_update(task_id, update).await,
191        }
192    }
193
194    async fn status_update_stream(
195        &self,
196        task_id: &str,
197    ) -> Result<Pin<Box<dyn Stream<Item = Result<TaskStatusUpdateEvent, A2AError>> + Send>>, A2AError>
198    {
199        match self {
200            AutoStorage::InMemory(s) => s.status_update_stream(task_id).await,
201            #[cfg(feature = "sqlx")]
202            AutoStorage::Sqlx(s) => s.status_update_stream(task_id).await,
203        }
204    }
205
206    async fn artifact_update_stream(
207        &self,
208        task_id: &str,
209    ) -> Result<
210        Pin<Box<dyn Stream<Item = Result<TaskArtifactUpdateEvent, A2AError>> + Send>>,
211        A2AError,
212    > {
213        match self {
214            AutoStorage::InMemory(s) => s.artifact_update_stream(task_id).await,
215            #[cfg(feature = "sqlx")]
216            AutoStorage::Sqlx(s) => s.artifact_update_stream(task_id).await,
217        }
218    }
219
220    async fn combined_update_stream(
221        &self,
222        task_id: &str,
223    ) -> Result<Pin<Box<dyn Stream<Item = Result<UpdateEvent, A2AError>> + Send>>, A2AError> {
224        match self {
225            AutoStorage::InMemory(s) => s.combined_update_stream(task_id).await,
226            #[cfg(feature = "sqlx")]
227            AutoStorage::Sqlx(s) => s.combined_update_stream(task_id).await,
228        }
229    }
230}
231
232/// Builder for creating A2A agents with declarative configuration
233pub struct AgentBuilder<H = (), S = ()> {
234    config: AgentConfig,
235    handler: Option<H>,
236    storage: Option<S>,
237}
238
239impl AgentBuilder<(), ()> {
240    /// Create a new builder from a TOML configuration file
241    pub fn from_file<P: AsRef<Path>>(path: P) -> Result<Self, ConfigError> {
242        let config = AgentConfig::from_file(path)?;
243        Ok(Self {
244            config,
245            handler: None,
246            storage: None,
247        })
248    }
249
250    /// Create a new builder from a TOML string
251    pub fn from_toml(toml: &str) -> Result<Self, ConfigError> {
252        let config = AgentConfig::from_toml(toml)?;
253        Ok(Self {
254            config,
255            handler: None,
256            storage: None,
257        })
258    }
259
260    /// Create a new builder with programmatic configuration
261    pub fn new(config: AgentConfig) -> Self {
262        Self {
263            config,
264            handler: None,
265            storage: None,
266        }
267    }
268}
269
270impl<H, S> AgentBuilder<H, S> {
271    /// Set the message handler for this agent
272    pub fn with_handler<NewH>(self, handler: NewH) -> AgentBuilder<NewH, S>
273    where
274        NewH: AsyncMessageHandler + Clone + Send + Sync + 'static,
275    {
276        AgentBuilder {
277            config: self.config,
278            handler: Some(handler),
279            storage: self.storage,
280        }
281    }
282
283    /// Set custom storage for this agent
284    pub fn with_storage<NewS>(self, storage: NewS) -> AgentBuilder<H, NewS>
285    where
286        NewS: AsyncTaskManager + AsyncNotificationManager + Clone + Send + Sync + 'static,
287    {
288        AgentBuilder {
289            config: self.config,
290            handler: self.handler,
291            storage: Some(storage),
292        }
293    }
294
295    /// Access the configuration
296    pub fn config(&self) -> &AgentConfig {
297        &self.config
298    }
299
300    /// Modify the configuration
301    pub fn with_config<F>(mut self, f: F) -> Self
302    where
303        F: FnOnce(&mut AgentConfig),
304    {
305        f(&mut self.config);
306        self
307    }
308}
309
310impl<H, S> AgentBuilder<H, S>
311where
312    H: AsyncMessageHandler + Clone + Send + Sync + 'static,
313    S: AsyncTaskManager + AsyncNotificationManager + Clone + Send + Sync + 'static,
314{
315    /// Build the agent runtime
316    pub fn build(self) -> Result<AgentRuntime<H, S>, BuildError> {
317        let handler = self.handler.ok_or(BuildError::MissingHandler)?;
318        let storage = self.storage.ok_or(BuildError::MissingStorage)?;
319
320        Ok(AgentRuntime::new(
321            self.config,
322            Arc::new(handler),
323            Arc::new(storage),
324        ))
325    }
326}
327
328impl<H> AgentBuilder<H, ()>
329where
330    H: AsyncMessageHandler + Clone + Send + Sync + 'static,
331{
332    /// Create storage from the configuration
333    /// This is a convenience method that automatically creates the appropriate storage
334    /// based on what's configured in the TOML file
335    pub async fn build_with_auto_storage(self) -> Result<AgentRuntime<H, AutoStorage>, BuildError> {
336        let handler = self.handler.ok_or(BuildError::MissingHandler)?;
337
338        let storage = match &self.config.server.storage {
339            StorageConfig::InMemory => {
340                let push_sender = HttpPushNotificationSender::new()
341                    .with_timeout(30)
342                    .with_max_retries(3);
343                let storage = InMemoryTaskStorage::with_push_sender(push_sender);
344                AutoStorage::InMemory(storage)
345            }
346            #[cfg(feature = "sqlx")]
347            StorageConfig::Sqlx {
348                url,
349                enable_logging,
350                ..
351            } => {
352                if *enable_logging {
353                    tracing::info!("SQL query logging enabled");
354                }
355
356                let storage = SqlxTaskStorage::new(url).await.map_err(|e| {
357                    BuildError::StorageError(format!("Failed to create SQLx storage: {}", e))
358                })?;
359
360                AutoStorage::Sqlx(storage)
361            }
362            #[cfg(not(feature = "sqlx"))]
363            StorageConfig::Sqlx { .. } => {
364                return Err(BuildError::StorageError(
365                    "SQLx storage requested but 'sqlx' feature is not enabled".to_string(),
366                ));
367            }
368        };
369
370        // Initialize MCP client if configured
371        #[cfg(feature = "mcp-client")]
372        if self.config.features.mcp_client.enabled {
373            info!("Initializing MCP client...");
374            let mcp_client = McpClientManager::new();
375
376            // Initialize connections to configured servers
377            if let Err(e) = mcp_client
378                .initialize(&self.config.features.mcp_client)
379                .await
380            {
381                return Err(BuildError::RuntimeError(format!(
382                    "Failed to initialize MCP client: {}",
383                    e
384                )));
385            }
386
387            return Ok(AgentRuntime::with_mcp_client(
388                self.config,
389                Arc::new(handler),
390                Arc::new(storage),
391                mcp_client,
392            ));
393        }
394
395        Ok(AgentRuntime::new(
396            self.config,
397            Arc::new(handler),
398            Arc::new(storage),
399        ))
400    }
401
402    /// Create storage from configuration with custom migrations
403    /// This is useful when you need to run agent-specific database migrations
404    #[cfg(feature = "sqlx")]
405    pub async fn build_with_auto_storage_and_migrations(
406        self,
407        migrations: &'static [&'static str],
408    ) -> Result<AgentRuntime<H, AutoStorage>, BuildError> {
409        let handler = self.handler.ok_or(BuildError::MissingHandler)?;
410
411        let storage = match &self.config.server.storage {
412            StorageConfig::InMemory => {
413                tracing::warn!(
414                    "Migrations provided but using in-memory storage - migrations ignored"
415                );
416                let push_sender = HttpPushNotificationSender::new()
417                    .with_timeout(30)
418                    .with_max_retries(3);
419                let storage = InMemoryTaskStorage::with_push_sender(push_sender);
420                AutoStorage::InMemory(storage)
421            }
422            StorageConfig::Sqlx {
423                url,
424                enable_logging,
425                ..
426            } => {
427                if *enable_logging {
428                    tracing::info!("SQL query logging enabled");
429                }
430
431                let storage = SqlxTaskStorage::with_migrations(url, migrations)
432                    .await
433                    .map_err(|e| {
434                        BuildError::StorageError(format!("Failed to create SQLx storage: {}", e))
435                    })?;
436
437                AutoStorage::Sqlx(storage)
438            }
439        };
440
441        // Initialize MCP client if configured
442        #[cfg(feature = "mcp-client")]
443        if self.config.features.mcp_client.enabled {
444            info!("Initializing MCP client...");
445            let mcp_client = McpClientManager::new();
446
447            // Initialize connections to configured servers
448            if let Err(e) = mcp_client
449                .initialize(&self.config.features.mcp_client)
450                .await
451            {
452                return Err(BuildError::RuntimeError(format!(
453                    "Failed to initialize MCP client: {}",
454                    e
455                )));
456            }
457
458            return Ok(AgentRuntime::with_mcp_client(
459                self.config,
460                Arc::new(handler),
461                Arc::new(storage),
462                mcp_client,
463            ));
464        }
465
466        Ok(AgentRuntime::new(
467            self.config,
468            Arc::new(handler),
469            Arc::new(storage),
470        ))
471    }
472}
473
474/// Errors that can occur during agent building
475#[derive(Debug, thiserror::Error)]
476pub enum BuildError {
477    #[error("Handler must be set before building")]
478    MissingHandler,
479
480    #[error("Storage must be set before building")]
481    MissingStorage,
482
483    #[error("Configuration error: {0}")]
484    ConfigError(#[from] ConfigError),
485
486    #[error("Storage error: {0}")]
487    StorageError(String),
488
489    #[error("Runtime error: {0}")]
490    RuntimeError(String),
491}
492
493#[cfg(test)]
494mod tests {
495    use super::*;
496
497    #[test]
498    fn test_builder_from_toml() {
499        let toml = r#"
500            [agent]
501            name = "Test Agent"
502
503            [server]
504            http_port = 9000
505        "#;
506
507        let builder = AgentBuilder::from_toml(toml).unwrap();
508        assert_eq!(builder.config().agent.name, "Test Agent");
509        assert_eq!(builder.config().server.http_port, 9000);
510    }
511
512    #[test]
513    fn test_builder_config_modification() {
514        let toml = r#"
515            [agent]
516            name = "Test Agent"
517        "#;
518
519        let builder = AgentBuilder::from_toml(toml)
520            .unwrap()
521            .with_config(|config| {
522                config.server.http_port = 7000;
523            });
524
525        assert_eq!(builder.config().server.http_port, 7000);
526    }
527}