Skip to main content

a2a_agents/core/
runtime.rs

1//! Agent runtime for managing server lifecycle
2//!
3//! The runtime handles starting HTTP/WebSocket servers, wiring components,
4//! and managing the agent lifecycle based on configuration.
5
6#[cfg(feature = "mcp-client")]
7use crate::core::McpClientManager;
8use crate::core::config::{AgentConfig, AuthConfig, StorageConfig};
9use a2a_rs::adapter::{
10    BearerTokenAuthenticator, DefaultRequestProcessor, HttpServer, SimpleAgentInfo, WebSocketServer,
11};
12use a2a_rs::port::{
13    AsyncMessageHandler, AsyncNotificationManager, AsyncStreamingHandler, AsyncTaskManager,
14};
15use std::sync::Arc;
16use tracing::{info, warn};
17
18#[cfg(feature = "auth")]
19use a2a_rs::adapter::{JwtAuthenticator, OAuth2Authenticator};
20#[cfg(feature = "auth")]
21use oauth2::{AuthUrl, ClientId, ClientSecret, RedirectUrl, TokenUrl};
22#[cfg(feature = "auth")]
23use std::collections::HashMap;
24
25/// Agent runtime that manages the server lifecycle
26pub struct AgentRuntime<H, S> {
27    config: AgentConfig,
28    handler: Arc<H>,
29    storage: Arc<S>,
30    #[cfg(feature = "mcp-client")]
31    mcp_client: Option<McpClientManager>,
32}
33
34impl<H, S> AgentRuntime<H, S>
35where
36    H: AsyncMessageHandler + Clone + Send + Sync + 'static,
37    S: AsyncTaskManager + AsyncNotificationManager + Clone + Send + Sync + 'static,
38{
39    /// Create a new runtime
40    pub fn new(config: AgentConfig, handler: Arc<H>, storage: Arc<S>) -> Self {
41        Self {
42            config,
43            handler,
44            storage,
45            #[cfg(feature = "mcp-client")]
46            mcp_client: None,
47        }
48    }
49
50    /// Create a new runtime with MCP client
51    #[cfg(feature = "mcp-client")]
52    pub fn with_mcp_client(
53        config: AgentConfig,
54        handler: Arc<H>,
55        storage: Arc<S>,
56        mcp_client: McpClientManager,
57    ) -> Self {
58        Self {
59            config,
60            handler,
61            storage,
62            mcp_client: Some(mcp_client),
63        }
64    }
65
66    /// Get the MCP client manager (if enabled)
67    #[cfg(feature = "mcp-client")]
68    pub fn mcp_client(&self) -> Option<&McpClientManager> {
69        self.mcp_client.as_ref()
70    }
71
72    /// Build agent info from configuration
73    fn build_agent_info(&self, base_url: String) -> SimpleAgentInfo {
74        let mut agent_info = SimpleAgentInfo::new(self.config.agent.name.clone(), base_url);
75
76        if let Some(ref description) = self.config.agent.description {
77            agent_info = agent_info.with_description(description.clone());
78        }
79
80        if let Some(ref provider) = self.config.agent.provider {
81            agent_info = agent_info.with_provider(provider.name.clone(), provider.url.clone());
82        }
83
84        if let Some(ref doc_url) = self.config.agent.documentation_url {
85            agent_info = agent_info.with_documentation_url(doc_url.clone());
86        }
87
88        // Add features
89        if self.config.features.streaming {
90            agent_info = agent_info.with_streaming();
91        }
92
93        if self.config.features.push_notifications {
94            agent_info = agent_info.with_push_notifications();
95        }
96
97        if self.config.features.state_history {
98            agent_info = agent_info.with_state_transition_history();
99        }
100
101        if self.config.features.authenticated_card {
102            agent_info = agent_info.with_authenticated_extended_card();
103        }
104
105        // Add extensions
106        if let Some(ref ap2_config) = self.config.features.extensions.ap2 {
107            let roles_json: Vec<serde_json::Value> = ap2_config
108                .roles
109                .iter()
110                .map(|r| serde_json::Value::String(r.clone()))
111                .collect();
112
113            let mut params = std::collections::HashMap::new();
114            params.insert("roles".to_string(), serde_json::Value::Array(roles_json));
115
116            let ext = a2a_rs::domain::AgentExtension {
117                uri: "https://github.com/google-agentic-commerce/ap2/tree/v0.1".to_string(),
118                description: Some("Agent Payments Protocol (AP2) v0.1".to_string()),
119                required: Some(ap2_config.required),
120                params: Some(params),
121            };
122
123            agent_info = agent_info.add_extension(ext);
124            info!("💳 AP2 extension enabled (roles: {:?})", ap2_config.roles);
125        }
126
127        // Add skills
128        for skill in &self.config.skills {
129            agent_info = agent_info.add_comprehensive_skill(
130                skill.id.clone(),
131                skill.name.clone(),
132                skill.description.clone(),
133                if skill.keywords.is_empty() {
134                    None
135                } else {
136                    Some(skill.keywords.clone())
137                },
138                if skill.examples.is_empty() {
139                    None
140                } else {
141                    Some(skill.examples.clone())
142                },
143                Some(skill.input_formats.clone()),
144                Some(skill.output_formats.clone()),
145            );
146        }
147
148        agent_info
149    }
150
151    /// Start HTTP server
152    pub async fn start_http(&self) -> Result<(), RuntimeError> {
153        if self.config.server.http_port == 0 {
154            return Err(RuntimeError::ServerNotConfigured(
155                "HTTP port is 0".to_string(),
156            ));
157        }
158
159        let base_url = format!(
160            "http://{}:{}",
161            self.config.server.host, self.config.server.http_port
162        );
163        let agent_info = self.build_agent_info(base_url);
164
165        let processor = DefaultRequestProcessor::new(
166            (*self.handler).clone(),
167            (*self.storage).clone(),
168            (*self.storage).clone(),
169            agent_info.clone(),
170        );
171
172        let bind_address = format!(
173            "{}:{}",
174            self.config.server.host, self.config.server.http_port
175        );
176
177        info!("🌐 Starting HTTP server on {}", bind_address);
178        self.print_agent_info("HTTP", &self.config.server.http_port.to_string());
179
180        match &self.config.server.auth {
181            AuthConfig::None => {
182                let server = HttpServer::new(processor, agent_info, bind_address);
183                server
184                    .start()
185                    .await
186                    .map_err(|e| RuntimeError::ServerError(e.to_string()))
187            }
188            AuthConfig::Bearer { tokens, format } => {
189                info!(
190                    "🔐 Authentication: Bearer token ({} token(s){})",
191                    tokens.len(),
192                    format
193                        .as_ref()
194                        .map(|f| format!(", format: {}", f))
195                        .unwrap_or_default()
196                );
197                let authenticator = BearerTokenAuthenticator::new(tokens.clone());
198                let server =
199                    HttpServer::with_auth(processor, agent_info, bind_address, authenticator);
200                server
201                    .start()
202                    .await
203                    .map_err(|e| RuntimeError::ServerError(e.to_string()))
204            }
205            AuthConfig::ApiKey {
206                keys,
207                location,
208                name,
209            } => {
210                warn!(
211                    "🔐 API key authentication configured ({} {}, {} key(s)) but not yet supported, using no auth",
212                    location,
213                    name,
214                    keys.len()
215                );
216                let server = HttpServer::new(processor, agent_info, bind_address);
217                server
218                    .start()
219                    .await
220                    .map_err(|e| RuntimeError::ServerError(e.to_string()))
221            }
222            #[cfg(feature = "auth")]
223            AuthConfig::Jwt {
224                secret,
225                rsa_pem_path,
226                algorithm,
227                issuer,
228                audience,
229            } => {
230                info!("🔐 Authentication: JWT (algorithm: {})", algorithm);
231
232                let mut authenticator = if let Some(secret) = secret {
233                    JwtAuthenticator::new_with_secret(secret.as_bytes())
234                } else if let Some(pem_path) = rsa_pem_path {
235                    let pem_data = std::fs::read(pem_path).map_err(|e| {
236                        RuntimeError::ServerError(format!("Failed to read RSA PEM file: {}", e))
237                    })?;
238                    JwtAuthenticator::new_with_rsa_pem(&pem_data).map_err(|e| {
239                        RuntimeError::ServerError(format!(
240                            "Failed to create JWT authenticator: {}",
241                            e
242                        ))
243                    })?
244                } else {
245                    return Err(RuntimeError::ServerError(
246                        "JWT authentication requires either 'secret' or 'rsa_pem_path'".to_string(),
247                    ));
248                };
249
250                if let Some(iss) = issuer {
251                    authenticator = authenticator.with_issuer(iss.clone());
252                    info!("   Issuer: {}", iss);
253                }
254                if let Some(aud) = audience {
255                    authenticator = authenticator.with_audience(aud.clone());
256                    info!("   Audience: {}", aud);
257                }
258
259                let server =
260                    HttpServer::with_auth(processor, agent_info, bind_address, authenticator);
261                server
262                    .start()
263                    .await
264                    .map_err(|e| RuntimeError::ServerError(e.to_string()))
265            }
266            #[cfg(not(feature = "auth"))]
267            AuthConfig::Jwt { .. } => Err(RuntimeError::ServerError(
268                "JWT authentication requires the 'auth' feature to be enabled".to_string(),
269            )),
270            #[cfg(feature = "auth")]
271            AuthConfig::OAuth2 {
272                client_id,
273                client_secret,
274                authorization_url,
275                token_url,
276                redirect_url,
277                flow,
278                scopes,
279            } => {
280                info!("🔐 Authentication: OAuth2 (flow: {})", flow);
281                info!("   Authorization URL: {}", authorization_url);
282                info!("   Token URL: {}", token_url);
283
284                let client_id = ClientId::new(client_id.clone());
285                let client_secret = ClientSecret::new(client_secret.clone());
286                let auth_url = AuthUrl::new(authorization_url.clone()).map_err(|e| {
287                    RuntimeError::ServerError(format!("Invalid authorization URL: {}", e))
288                })?;
289                let token_url = TokenUrl::new(token_url.clone())
290                    .map_err(|e| RuntimeError::ServerError(format!("Invalid token URL: {}", e)))?;
291
292                let scopes_map: HashMap<String, String> =
293                    scopes.iter().map(|s| (s.clone(), s.clone())).collect();
294
295                let authenticator = if flow == "client_credentials" {
296                    OAuth2Authenticator::new_client_credentials(
297                        client_id,
298                        client_secret,
299                        token_url,
300                        scopes_map,
301                    )
302                } else {
303                    // Authorization code flow
304                    let redirect_url = RedirectUrl::new(
305                        redirect_url
306                            .clone()
307                            .unwrap_or_else(|| "http://localhost:8080/callback".to_string()),
308                    )
309                    .map_err(|e| {
310                        RuntimeError::ServerError(format!("Invalid redirect URL: {}", e))
311                    })?;
312
313                    info!("   Redirect URL: {}", redirect_url.as_str());
314
315                    OAuth2Authenticator::new_authorization_code(
316                        client_id,
317                        Some(client_secret),
318                        auth_url,
319                        token_url,
320                        redirect_url,
321                        scopes_map,
322                    )
323                };
324
325                let server =
326                    HttpServer::with_auth(processor, agent_info, bind_address, authenticator);
327                server
328                    .start()
329                    .await
330                    .map_err(|e| RuntimeError::ServerError(e.to_string()))
331            }
332            #[cfg(not(feature = "auth"))]
333            AuthConfig::OAuth2 { .. } => Err(RuntimeError::ServerError(
334                "OAuth2 authentication requires the 'auth' feature to be enabled".to_string(),
335            )),
336        }
337    }
338
339    /// Start WebSocket server
340    pub async fn start_websocket(&self) -> Result<(), RuntimeError>
341    where
342        S: AsyncStreamingHandler,
343    {
344        if self.config.server.ws_port == 0 {
345            return Err(RuntimeError::ServerNotConfigured(
346                "WebSocket port is 0".to_string(),
347            ));
348        }
349
350        let base_url = format!(
351            "ws://{}:{}",
352            self.config.server.host, self.config.server.ws_port
353        );
354        let agent_info = self.build_agent_info(base_url);
355
356        let processor = DefaultRequestProcessor::new(
357            (*self.handler).clone(),
358            (*self.storage).clone(),
359            (*self.storage).clone(),
360            agent_info.clone(),
361        );
362
363        let bind_address = format!("{}:{}", self.config.server.host, self.config.server.ws_port);
364
365        info!("🔌 Starting WebSocket server on {}", bind_address);
366        self.print_agent_info("WebSocket", &self.config.server.ws_port.to_string());
367
368        match &self.config.server.auth {
369            AuthConfig::None => {
370                let server = WebSocketServer::new(
371                    processor,
372                    agent_info,
373                    (*self.storage).clone(),
374                    bind_address,
375                );
376                server
377                    .start()
378                    .await
379                    .map_err(|e| RuntimeError::ServerError(e.to_string()))
380            }
381            AuthConfig::Bearer { tokens, format } => {
382                info!(
383                    "🔐 Authentication: Bearer token ({} token(s){})",
384                    tokens.len(),
385                    format
386                        .as_ref()
387                        .map(|f| format!(", format: {}", f))
388                        .unwrap_or_default()
389                );
390                let authenticator = BearerTokenAuthenticator::new(tokens.clone());
391                let server = WebSocketServer::with_auth(
392                    processor,
393                    agent_info,
394                    (*self.storage).clone(),
395                    bind_address,
396                    authenticator,
397                );
398                server
399                    .start()
400                    .await
401                    .map_err(|e| RuntimeError::ServerError(e.to_string()))
402            }
403            AuthConfig::ApiKey {
404                keys,
405                location,
406                name,
407            } => {
408                warn!(
409                    "🔐 API key authentication configured ({} {}, {} key(s)) but not yet supported, using no auth",
410                    location,
411                    name,
412                    keys.len()
413                );
414                let server = WebSocketServer::new(
415                    processor,
416                    agent_info,
417                    (*self.storage).clone(),
418                    bind_address,
419                );
420                server
421                    .start()
422                    .await
423                    .map_err(|e| RuntimeError::ServerError(e.to_string()))
424            }
425            #[cfg(feature = "auth")]
426            AuthConfig::Jwt {
427                secret,
428                rsa_pem_path,
429                algorithm,
430                issuer,
431                audience,
432            } => {
433                info!("🔐 Authentication: JWT (algorithm: {})", algorithm);
434
435                let mut authenticator = if let Some(secret) = secret {
436                    JwtAuthenticator::new_with_secret(secret.as_bytes())
437                } else if let Some(pem_path) = rsa_pem_path {
438                    let pem_data = std::fs::read(pem_path).map_err(|e| {
439                        RuntimeError::ServerError(format!("Failed to read RSA PEM file: {}", e))
440                    })?;
441                    JwtAuthenticator::new_with_rsa_pem(&pem_data).map_err(|e| {
442                        RuntimeError::ServerError(format!(
443                            "Failed to create JWT authenticator: {}",
444                            e
445                        ))
446                    })?
447                } else {
448                    return Err(RuntimeError::ServerError(
449                        "JWT authentication requires either 'secret' or 'rsa_pem_path'".to_string(),
450                    ));
451                };
452
453                if let Some(iss) = issuer {
454                    authenticator = authenticator.with_issuer(iss.clone());
455                    info!("   Issuer: {}", iss);
456                }
457                if let Some(aud) = audience {
458                    authenticator = authenticator.with_audience(aud.clone());
459                    info!("   Audience: {}", aud);
460                }
461
462                let server = WebSocketServer::with_auth(
463                    processor,
464                    agent_info,
465                    (*self.storage).clone(),
466                    bind_address,
467                    authenticator,
468                );
469                server
470                    .start()
471                    .await
472                    .map_err(|e| RuntimeError::ServerError(e.to_string()))
473            }
474            #[cfg(not(feature = "auth"))]
475            AuthConfig::Jwt { .. } => Err(RuntimeError::ServerError(
476                "JWT authentication requires the 'auth' feature to be enabled".to_string(),
477            )),
478            #[cfg(feature = "auth")]
479            AuthConfig::OAuth2 {
480                client_id,
481                client_secret,
482                authorization_url,
483                token_url,
484                redirect_url,
485                flow,
486                scopes,
487            } => {
488                info!("🔐 Authentication: OAuth2 (flow: {})", flow);
489
490                let client_id = ClientId::new(client_id.clone());
491                let client_secret = ClientSecret::new(client_secret.clone());
492                let auth_url = AuthUrl::new(authorization_url.clone()).map_err(|e| {
493                    RuntimeError::ServerError(format!("Invalid authorization URL: {}", e))
494                })?;
495                let token_url = TokenUrl::new(token_url.clone())
496                    .map_err(|e| RuntimeError::ServerError(format!("Invalid token URL: {}", e)))?;
497
498                let scopes_map: HashMap<String, String> =
499                    scopes.iter().map(|s| (s.clone(), s.clone())).collect();
500
501                let authenticator = if flow == "client_credentials" {
502                    OAuth2Authenticator::new_client_credentials(
503                        client_id,
504                        client_secret,
505                        token_url,
506                        scopes_map,
507                    )
508                } else {
509                    let redirect_url = RedirectUrl::new(
510                        redirect_url
511                            .clone()
512                            .unwrap_or_else(|| "http://localhost:8080/callback".to_string()),
513                    )
514                    .map_err(|e| {
515                        RuntimeError::ServerError(format!("Invalid redirect URL: {}", e))
516                    })?;
517
518                    OAuth2Authenticator::new_authorization_code(
519                        client_id,
520                        Some(client_secret),
521                        auth_url,
522                        token_url,
523                        redirect_url,
524                        scopes_map,
525                    )
526                };
527
528                let server = WebSocketServer::with_auth(
529                    processor,
530                    agent_info,
531                    (*self.storage).clone(),
532                    bind_address,
533                    authenticator,
534                );
535                server
536                    .start()
537                    .await
538                    .map_err(|e| RuntimeError::ServerError(e.to_string()))
539            }
540            #[cfg(not(feature = "auth"))]
541            AuthConfig::OAuth2 { .. } => Err(RuntimeError::ServerError(
542                "OAuth2 authentication requires the 'auth' feature to be enabled".to_string(),
543            )),
544        }
545    }
546
547    /// Start both HTTP and WebSocket servers
548    pub async fn start_all(&self) -> Result<(), RuntimeError>
549    where
550        S: AsyncStreamingHandler,
551    {
552        info!("🚀 Starting {} agent...", self.config.agent.name);
553        info!("🔄 Starting both HTTP and WebSocket servers");
554
555        if self.config.server.http_port == 0 && self.config.server.ws_port == 0 {
556            return Err(RuntimeError::ServerNotConfigured(
557                "Both HTTP and WebSocket ports are 0".to_string(),
558            ));
559        }
560
561        // Clone what we need for the tasks
562        let http_runtime = Self {
563            config: self.config.clone(),
564            handler: Arc::clone(&self.handler),
565            storage: Arc::clone(&self.storage),
566            #[cfg(feature = "mcp-client")]
567            mcp_client: self.mcp_client.clone(),
568        };
569
570        let ws_runtime = Self {
571            config: self.config.clone(),
572            handler: Arc::clone(&self.handler),
573            storage: Arc::clone(&self.storage),
574            #[cfg(feature = "mcp-client")]
575            mcp_client: self.mcp_client.clone(),
576        };
577
578        // Start both servers concurrently
579        let http_handle = if self.config.server.http_port > 0 {
580            Some(tokio::spawn(async move {
581                if let Err(e) = http_runtime.start_http().await {
582                    tracing::error!("❌ HTTP server error: {}", e);
583                }
584            }))
585        } else {
586            None
587        };
588
589        let ws_handle = if self.config.server.ws_port > 0 {
590            Some(tokio::spawn(async move {
591                if let Err(e) = ws_runtime.start_websocket().await {
592                    tracing::error!("❌ WebSocket server error: {}", e);
593                }
594            }))
595        } else {
596            None
597        };
598
599        // Wait for either server to complete (they should run indefinitely)
600        match (http_handle, ws_handle) {
601            (Some(http), Some(ws)) => {
602                tokio::select! {
603                    _ = http => info!("HTTP server stopped"),
604                    _ = ws => info!("WebSocket server stopped"),
605                }
606            }
607            (Some(http), None) => {
608                let _ = http.await;
609                info!("HTTP server stopped");
610            }
611            (None, Some(ws)) => {
612                let _ = ws.await;
613                info!("WebSocket server stopped");
614            }
615            (None, None) => {
616                return Err(RuntimeError::ServerNotConfigured(
617                    "No servers configured".to_string(),
618                ));
619            }
620        }
621
622        Ok(())
623    }
624
625    /// Start the appropriate server(s) based on configuration
626    pub async fn run(self) -> Result<(), RuntimeError>
627    where
628        S: AsyncStreamingHandler,
629    {
630        // Check if MCP server mode is enabled
631        if self.config.features.mcp_server.enabled {
632            return self.run_as_mcp_server().await;
633        }
634
635        // Normal A2A server mode
636        if self.config.server.http_port > 0 && self.config.server.ws_port > 0 {
637            self.start_all().await
638        } else if self.config.server.http_port > 0 {
639            self.start_http().await
640        } else if self.config.server.ws_port > 0 {
641            self.start_websocket().await
642        } else {
643            Err(RuntimeError::ServerNotConfigured(
644                "No servers configured".to_string(),
645            ))
646        }
647    }
648
649    /// Run agent as MCP server
650    async fn run_as_mcp_server(self) -> Result<(), RuntimeError> {
651        use crate::core::mcp;
652        use a2a_rs::services::AgentInfoProvider;
653
654        info!("🔌 Running agent in MCP server mode");
655
656        // Build agent card
657        let base_url = format!(
658            "http://{}:{}",
659            self.config.server.host, self.config.server.http_port
660        );
661        let agent_info = self.build_agent_info(base_url.clone());
662        let agent_card = agent_info
663            .get_agent_card()
664            .await
665            .map_err(|e| RuntimeError::ServerError(format!("Failed to get agent card: {}", e)))?;
666
667        // Run MCP server
668        mcp::run_mcp_server(&self.config.features.mcp_server, agent_card, base_url)
669            .await
670            .map_err(|e| RuntimeError::ServerError(format!("MCP server error: {}", e)))
671    }
672
673    /// Print agent information
674    fn print_agent_info(&self, server_type: &str, port: &str) {
675        info!("📋 Agent: {}", self.config.agent.name);
676        if let Some(ref desc) = self.config.agent.description {
677            info!("   Description: {}", desc);
678        }
679        info!("   {} port: {}", server_type, port);
680
681        match &self.config.server.storage {
682            StorageConfig::InMemory => info!("💾 Storage: In-memory (non-persistent)"),
683            StorageConfig::Sqlx { url, .. } => info!("💾 Storage: SQLx ({})", url),
684        }
685
686        if !self.config.skills.is_empty() {
687            info!("🛠️  Skills: {}", self.config.skills.len());
688            for skill in &self.config.skills {
689                info!("   - {} ({})", skill.name, skill.id);
690            }
691        }
692    }
693}
694
695/// Runtime errors
696#[derive(Debug, thiserror::Error)]
697pub enum RuntimeError {
698    #[error("Server not configured: {0}")]
699    ServerNotConfigured(String),
700
701    #[error("Server error: {0}")]
702    ServerError(String),
703
704    #[error("Storage error: {0}")]
705    StorageError(String),
706}