1#[cfg(feature = "mcp-client")]
7use crate::core::McpClientManager;
8use crate::core::config::{AgentConfig, AuthConfig, StorageConfig};
9use a2a_rs::adapter::{
10 BearerTokenAuthenticator, DefaultRequestProcessor, HttpServer, SimpleAgentInfo, WebSocketServer,
11};
12use a2a_rs::port::{
13 AsyncMessageHandler, AsyncNotificationManager, AsyncStreamingHandler, AsyncTaskManager,
14};
15use std::sync::Arc;
16use tracing::{info, warn};
17
18#[cfg(feature = "auth")]
19use a2a_rs::adapter::{JwtAuthenticator, OAuth2Authenticator};
20#[cfg(feature = "auth")]
21use oauth2::{AuthUrl, ClientId, ClientSecret, RedirectUrl, TokenUrl};
22#[cfg(feature = "auth")]
23use std::collections::HashMap;
24
25pub struct AgentRuntime<H, S> {
27 config: AgentConfig,
28 handler: Arc<H>,
29 storage: Arc<S>,
30 #[cfg(feature = "mcp-client")]
31 mcp_client: Option<McpClientManager>,
32}
33
34impl<H, S> AgentRuntime<H, S>
35where
36 H: AsyncMessageHandler + Clone + Send + Sync + 'static,
37 S: AsyncTaskManager + AsyncNotificationManager + Clone + Send + Sync + 'static,
38{
39 pub fn new(config: AgentConfig, handler: Arc<H>, storage: Arc<S>) -> Self {
41 Self {
42 config,
43 handler,
44 storage,
45 #[cfg(feature = "mcp-client")]
46 mcp_client: None,
47 }
48 }
49
50 #[cfg(feature = "mcp-client")]
52 pub fn with_mcp_client(
53 config: AgentConfig,
54 handler: Arc<H>,
55 storage: Arc<S>,
56 mcp_client: McpClientManager,
57 ) -> Self {
58 Self {
59 config,
60 handler,
61 storage,
62 mcp_client: Some(mcp_client),
63 }
64 }
65
66 #[cfg(feature = "mcp-client")]
68 pub fn mcp_client(&self) -> Option<&McpClientManager> {
69 self.mcp_client.as_ref()
70 }
71
72 fn build_agent_info(&self, base_url: String) -> SimpleAgentInfo {
74 let mut agent_info = SimpleAgentInfo::new(self.config.agent.name.clone(), base_url);
75
76 if let Some(ref description) = self.config.agent.description {
77 agent_info = agent_info.with_description(description.clone());
78 }
79
80 if let Some(ref provider) = self.config.agent.provider {
81 agent_info = agent_info.with_provider(provider.name.clone(), provider.url.clone());
82 }
83
84 if let Some(ref doc_url) = self.config.agent.documentation_url {
85 agent_info = agent_info.with_documentation_url(doc_url.clone());
86 }
87
88 if self.config.features.streaming {
90 agent_info = agent_info.with_streaming();
91 }
92
93 if self.config.features.push_notifications {
94 agent_info = agent_info.with_push_notifications();
95 }
96
97 if self.config.features.state_history {
98 agent_info = agent_info.with_state_transition_history();
99 }
100
101 if self.config.features.authenticated_card {
102 agent_info = agent_info.with_authenticated_extended_card();
103 }
104
105 if let Some(ref ap2_config) = self.config.features.extensions.ap2 {
107 let roles_json: Vec<serde_json::Value> = ap2_config
108 .roles
109 .iter()
110 .map(|r| serde_json::Value::String(r.clone()))
111 .collect();
112
113 let mut params = std::collections::HashMap::new();
114 params.insert("roles".to_string(), serde_json::Value::Array(roles_json));
115
116 let ext = a2a_rs::domain::AgentExtension {
117 uri: "https://github.com/google-agentic-commerce/ap2/tree/v0.1".to_string(),
118 description: Some("Agent Payments Protocol (AP2) v0.1".to_string()),
119 required: Some(ap2_config.required),
120 params: Some(params),
121 };
122
123 agent_info = agent_info.add_extension(ext);
124 info!("💳 AP2 extension enabled (roles: {:?})", ap2_config.roles);
125 }
126
127 for skill in &self.config.skills {
129 agent_info = agent_info.add_comprehensive_skill(
130 skill.id.clone(),
131 skill.name.clone(),
132 skill.description.clone(),
133 if skill.keywords.is_empty() {
134 None
135 } else {
136 Some(skill.keywords.clone())
137 },
138 if skill.examples.is_empty() {
139 None
140 } else {
141 Some(skill.examples.clone())
142 },
143 Some(skill.input_formats.clone()),
144 Some(skill.output_formats.clone()),
145 );
146 }
147
148 agent_info
149 }
150
151 pub async fn start_http(&self) -> Result<(), RuntimeError> {
153 if self.config.server.http_port == 0 {
154 return Err(RuntimeError::ServerNotConfigured(
155 "HTTP port is 0".to_string(),
156 ));
157 }
158
159 let base_url = format!(
160 "http://{}:{}",
161 self.config.server.host, self.config.server.http_port
162 );
163 let agent_info = self.build_agent_info(base_url);
164
165 let processor = DefaultRequestProcessor::new(
166 (*self.handler).clone(),
167 (*self.storage).clone(),
168 (*self.storage).clone(),
169 agent_info.clone(),
170 );
171
172 let bind_address = format!(
173 "{}:{}",
174 self.config.server.host, self.config.server.http_port
175 );
176
177 info!("🌐 Starting HTTP server on {}", bind_address);
178 self.print_agent_info("HTTP", &self.config.server.http_port.to_string());
179
180 match &self.config.server.auth {
181 AuthConfig::None => {
182 let server = HttpServer::new(processor, agent_info, bind_address);
183 server
184 .start()
185 .await
186 .map_err(|e| RuntimeError::ServerError(e.to_string()))
187 }
188 AuthConfig::Bearer { tokens, format } => {
189 info!(
190 "🔐 Authentication: Bearer token ({} token(s){})",
191 tokens.len(),
192 format
193 .as_ref()
194 .map(|f| format!(", format: {}", f))
195 .unwrap_or_default()
196 );
197 let authenticator = BearerTokenAuthenticator::new(tokens.clone());
198 let server =
199 HttpServer::with_auth(processor, agent_info, bind_address, authenticator);
200 server
201 .start()
202 .await
203 .map_err(|e| RuntimeError::ServerError(e.to_string()))
204 }
205 AuthConfig::ApiKey {
206 keys,
207 location,
208 name,
209 } => {
210 warn!(
211 "🔐 API key authentication configured ({} {}, {} key(s)) but not yet supported, using no auth",
212 location,
213 name,
214 keys.len()
215 );
216 let server = HttpServer::new(processor, agent_info, bind_address);
217 server
218 .start()
219 .await
220 .map_err(|e| RuntimeError::ServerError(e.to_string()))
221 }
222 #[cfg(feature = "auth")]
223 AuthConfig::Jwt {
224 secret,
225 rsa_pem_path,
226 algorithm,
227 issuer,
228 audience,
229 } => {
230 info!("🔐 Authentication: JWT (algorithm: {})", algorithm);
231
232 let mut authenticator = if let Some(secret) = secret {
233 JwtAuthenticator::new_with_secret(secret.as_bytes())
234 } else if let Some(pem_path) = rsa_pem_path {
235 let pem_data = std::fs::read(pem_path).map_err(|e| {
236 RuntimeError::ServerError(format!("Failed to read RSA PEM file: {}", e))
237 })?;
238 JwtAuthenticator::new_with_rsa_pem(&pem_data).map_err(|e| {
239 RuntimeError::ServerError(format!(
240 "Failed to create JWT authenticator: {}",
241 e
242 ))
243 })?
244 } else {
245 return Err(RuntimeError::ServerError(
246 "JWT authentication requires either 'secret' or 'rsa_pem_path'".to_string(),
247 ));
248 };
249
250 if let Some(iss) = issuer {
251 authenticator = authenticator.with_issuer(iss.clone());
252 info!(" Issuer: {}", iss);
253 }
254 if let Some(aud) = audience {
255 authenticator = authenticator.with_audience(aud.clone());
256 info!(" Audience: {}", aud);
257 }
258
259 let server =
260 HttpServer::with_auth(processor, agent_info, bind_address, authenticator);
261 server
262 .start()
263 .await
264 .map_err(|e| RuntimeError::ServerError(e.to_string()))
265 }
266 #[cfg(not(feature = "auth"))]
267 AuthConfig::Jwt { .. } => Err(RuntimeError::ServerError(
268 "JWT authentication requires the 'auth' feature to be enabled".to_string(),
269 )),
270 #[cfg(feature = "auth")]
271 AuthConfig::OAuth2 {
272 client_id,
273 client_secret,
274 authorization_url,
275 token_url,
276 redirect_url,
277 flow,
278 scopes,
279 } => {
280 info!("🔐 Authentication: OAuth2 (flow: {})", flow);
281 info!(" Authorization URL: {}", authorization_url);
282 info!(" Token URL: {}", token_url);
283
284 let client_id = ClientId::new(client_id.clone());
285 let client_secret = ClientSecret::new(client_secret.clone());
286 let auth_url = AuthUrl::new(authorization_url.clone()).map_err(|e| {
287 RuntimeError::ServerError(format!("Invalid authorization URL: {}", e))
288 })?;
289 let token_url = TokenUrl::new(token_url.clone())
290 .map_err(|e| RuntimeError::ServerError(format!("Invalid token URL: {}", e)))?;
291
292 let scopes_map: HashMap<String, String> =
293 scopes.iter().map(|s| (s.clone(), s.clone())).collect();
294
295 let authenticator = if flow == "client_credentials" {
296 OAuth2Authenticator::new_client_credentials(
297 client_id,
298 client_secret,
299 token_url,
300 scopes_map,
301 )
302 } else {
303 let redirect_url = RedirectUrl::new(
305 redirect_url
306 .clone()
307 .unwrap_or_else(|| "http://localhost:8080/callback".to_string()),
308 )
309 .map_err(|e| {
310 RuntimeError::ServerError(format!("Invalid redirect URL: {}", e))
311 })?;
312
313 info!(" Redirect URL: {}", redirect_url.as_str());
314
315 OAuth2Authenticator::new_authorization_code(
316 client_id,
317 Some(client_secret),
318 auth_url,
319 token_url,
320 redirect_url,
321 scopes_map,
322 )
323 };
324
325 let server =
326 HttpServer::with_auth(processor, agent_info, bind_address, authenticator);
327 server
328 .start()
329 .await
330 .map_err(|e| RuntimeError::ServerError(e.to_string()))
331 }
332 #[cfg(not(feature = "auth"))]
333 AuthConfig::OAuth2 { .. } => Err(RuntimeError::ServerError(
334 "OAuth2 authentication requires the 'auth' feature to be enabled".to_string(),
335 )),
336 }
337 }
338
339 pub async fn start_websocket(&self) -> Result<(), RuntimeError>
341 where
342 S: AsyncStreamingHandler,
343 {
344 if self.config.server.ws_port == 0 {
345 return Err(RuntimeError::ServerNotConfigured(
346 "WebSocket port is 0".to_string(),
347 ));
348 }
349
350 let base_url = format!(
351 "ws://{}:{}",
352 self.config.server.host, self.config.server.ws_port
353 );
354 let agent_info = self.build_agent_info(base_url);
355
356 let processor = DefaultRequestProcessor::new(
357 (*self.handler).clone(),
358 (*self.storage).clone(),
359 (*self.storage).clone(),
360 agent_info.clone(),
361 );
362
363 let bind_address = format!("{}:{}", self.config.server.host, self.config.server.ws_port);
364
365 info!("🔌 Starting WebSocket server on {}", bind_address);
366 self.print_agent_info("WebSocket", &self.config.server.ws_port.to_string());
367
368 match &self.config.server.auth {
369 AuthConfig::None => {
370 let server = WebSocketServer::new(
371 processor,
372 agent_info,
373 (*self.storage).clone(),
374 bind_address,
375 );
376 server
377 .start()
378 .await
379 .map_err(|e| RuntimeError::ServerError(e.to_string()))
380 }
381 AuthConfig::Bearer { tokens, format } => {
382 info!(
383 "🔐 Authentication: Bearer token ({} token(s){})",
384 tokens.len(),
385 format
386 .as_ref()
387 .map(|f| format!(", format: {}", f))
388 .unwrap_or_default()
389 );
390 let authenticator = BearerTokenAuthenticator::new(tokens.clone());
391 let server = WebSocketServer::with_auth(
392 processor,
393 agent_info,
394 (*self.storage).clone(),
395 bind_address,
396 authenticator,
397 );
398 server
399 .start()
400 .await
401 .map_err(|e| RuntimeError::ServerError(e.to_string()))
402 }
403 AuthConfig::ApiKey {
404 keys,
405 location,
406 name,
407 } => {
408 warn!(
409 "🔐 API key authentication configured ({} {}, {} key(s)) but not yet supported, using no auth",
410 location,
411 name,
412 keys.len()
413 );
414 let server = WebSocketServer::new(
415 processor,
416 agent_info,
417 (*self.storage).clone(),
418 bind_address,
419 );
420 server
421 .start()
422 .await
423 .map_err(|e| RuntimeError::ServerError(e.to_string()))
424 }
425 #[cfg(feature = "auth")]
426 AuthConfig::Jwt {
427 secret,
428 rsa_pem_path,
429 algorithm,
430 issuer,
431 audience,
432 } => {
433 info!("🔐 Authentication: JWT (algorithm: {})", algorithm);
434
435 let mut authenticator = if let Some(secret) = secret {
436 JwtAuthenticator::new_with_secret(secret.as_bytes())
437 } else if let Some(pem_path) = rsa_pem_path {
438 let pem_data = std::fs::read(pem_path).map_err(|e| {
439 RuntimeError::ServerError(format!("Failed to read RSA PEM file: {}", e))
440 })?;
441 JwtAuthenticator::new_with_rsa_pem(&pem_data).map_err(|e| {
442 RuntimeError::ServerError(format!(
443 "Failed to create JWT authenticator: {}",
444 e
445 ))
446 })?
447 } else {
448 return Err(RuntimeError::ServerError(
449 "JWT authentication requires either 'secret' or 'rsa_pem_path'".to_string(),
450 ));
451 };
452
453 if let Some(iss) = issuer {
454 authenticator = authenticator.with_issuer(iss.clone());
455 info!(" Issuer: {}", iss);
456 }
457 if let Some(aud) = audience {
458 authenticator = authenticator.with_audience(aud.clone());
459 info!(" Audience: {}", aud);
460 }
461
462 let server = WebSocketServer::with_auth(
463 processor,
464 agent_info,
465 (*self.storage).clone(),
466 bind_address,
467 authenticator,
468 );
469 server
470 .start()
471 .await
472 .map_err(|e| RuntimeError::ServerError(e.to_string()))
473 }
474 #[cfg(not(feature = "auth"))]
475 AuthConfig::Jwt { .. } => Err(RuntimeError::ServerError(
476 "JWT authentication requires the 'auth' feature to be enabled".to_string(),
477 )),
478 #[cfg(feature = "auth")]
479 AuthConfig::OAuth2 {
480 client_id,
481 client_secret,
482 authorization_url,
483 token_url,
484 redirect_url,
485 flow,
486 scopes,
487 } => {
488 info!("🔐 Authentication: OAuth2 (flow: {})", flow);
489
490 let client_id = ClientId::new(client_id.clone());
491 let client_secret = ClientSecret::new(client_secret.clone());
492 let auth_url = AuthUrl::new(authorization_url.clone()).map_err(|e| {
493 RuntimeError::ServerError(format!("Invalid authorization URL: {}", e))
494 })?;
495 let token_url = TokenUrl::new(token_url.clone())
496 .map_err(|e| RuntimeError::ServerError(format!("Invalid token URL: {}", e)))?;
497
498 let scopes_map: HashMap<String, String> =
499 scopes.iter().map(|s| (s.clone(), s.clone())).collect();
500
501 let authenticator = if flow == "client_credentials" {
502 OAuth2Authenticator::new_client_credentials(
503 client_id,
504 client_secret,
505 token_url,
506 scopes_map,
507 )
508 } else {
509 let redirect_url = RedirectUrl::new(
510 redirect_url
511 .clone()
512 .unwrap_or_else(|| "http://localhost:8080/callback".to_string()),
513 )
514 .map_err(|e| {
515 RuntimeError::ServerError(format!("Invalid redirect URL: {}", e))
516 })?;
517
518 OAuth2Authenticator::new_authorization_code(
519 client_id,
520 Some(client_secret),
521 auth_url,
522 token_url,
523 redirect_url,
524 scopes_map,
525 )
526 };
527
528 let server = WebSocketServer::with_auth(
529 processor,
530 agent_info,
531 (*self.storage).clone(),
532 bind_address,
533 authenticator,
534 );
535 server
536 .start()
537 .await
538 .map_err(|e| RuntimeError::ServerError(e.to_string()))
539 }
540 #[cfg(not(feature = "auth"))]
541 AuthConfig::OAuth2 { .. } => Err(RuntimeError::ServerError(
542 "OAuth2 authentication requires the 'auth' feature to be enabled".to_string(),
543 )),
544 }
545 }
546
547 pub async fn start_all(&self) -> Result<(), RuntimeError>
549 where
550 S: AsyncStreamingHandler,
551 {
552 info!("🚀 Starting {} agent...", self.config.agent.name);
553 info!("🔄 Starting both HTTP and WebSocket servers");
554
555 if self.config.server.http_port == 0 && self.config.server.ws_port == 0 {
556 return Err(RuntimeError::ServerNotConfigured(
557 "Both HTTP and WebSocket ports are 0".to_string(),
558 ));
559 }
560
561 let http_runtime = Self {
563 config: self.config.clone(),
564 handler: Arc::clone(&self.handler),
565 storage: Arc::clone(&self.storage),
566 #[cfg(feature = "mcp-client")]
567 mcp_client: self.mcp_client.clone(),
568 };
569
570 let ws_runtime = Self {
571 config: self.config.clone(),
572 handler: Arc::clone(&self.handler),
573 storage: Arc::clone(&self.storage),
574 #[cfg(feature = "mcp-client")]
575 mcp_client: self.mcp_client.clone(),
576 };
577
578 let http_handle = if self.config.server.http_port > 0 {
580 Some(tokio::spawn(async move {
581 if let Err(e) = http_runtime.start_http().await {
582 tracing::error!("❌ HTTP server error: {}", e);
583 }
584 }))
585 } else {
586 None
587 };
588
589 let ws_handle = if self.config.server.ws_port > 0 {
590 Some(tokio::spawn(async move {
591 if let Err(e) = ws_runtime.start_websocket().await {
592 tracing::error!("❌ WebSocket server error: {}", e);
593 }
594 }))
595 } else {
596 None
597 };
598
599 match (http_handle, ws_handle) {
601 (Some(http), Some(ws)) => {
602 tokio::select! {
603 _ = http => info!("HTTP server stopped"),
604 _ = ws => info!("WebSocket server stopped"),
605 }
606 }
607 (Some(http), None) => {
608 let _ = http.await;
609 info!("HTTP server stopped");
610 }
611 (None, Some(ws)) => {
612 let _ = ws.await;
613 info!("WebSocket server stopped");
614 }
615 (None, None) => {
616 return Err(RuntimeError::ServerNotConfigured(
617 "No servers configured".to_string(),
618 ));
619 }
620 }
621
622 Ok(())
623 }
624
625 pub async fn run(self) -> Result<(), RuntimeError>
627 where
628 S: AsyncStreamingHandler,
629 {
630 if self.config.features.mcp_server.enabled {
632 return self.run_as_mcp_server().await;
633 }
634
635 if self.config.server.http_port > 0 && self.config.server.ws_port > 0 {
637 self.start_all().await
638 } else if self.config.server.http_port > 0 {
639 self.start_http().await
640 } else if self.config.server.ws_port > 0 {
641 self.start_websocket().await
642 } else {
643 Err(RuntimeError::ServerNotConfigured(
644 "No servers configured".to_string(),
645 ))
646 }
647 }
648
649 async fn run_as_mcp_server(self) -> Result<(), RuntimeError> {
651 use crate::core::mcp;
652 use a2a_rs::services::AgentInfoProvider;
653
654 info!("🔌 Running agent in MCP server mode");
655
656 let base_url = format!(
658 "http://{}:{}",
659 self.config.server.host, self.config.server.http_port
660 );
661 let agent_info = self.build_agent_info(base_url.clone());
662 let agent_card = agent_info
663 .get_agent_card()
664 .await
665 .map_err(|e| RuntimeError::ServerError(format!("Failed to get agent card: {}", e)))?;
666
667 mcp::run_mcp_server(&self.config.features.mcp_server, agent_card, base_url)
669 .await
670 .map_err(|e| RuntimeError::ServerError(format!("MCP server error: {}", e)))
671 }
672
673 fn print_agent_info(&self, server_type: &str, port: &str) {
675 info!("📋 Agent: {}", self.config.agent.name);
676 if let Some(ref desc) = self.config.agent.description {
677 info!(" Description: {}", desc);
678 }
679 info!(" {} port: {}", server_type, port);
680
681 match &self.config.server.storage {
682 StorageConfig::InMemory => info!("💾 Storage: In-memory (non-persistent)"),
683 StorageConfig::Sqlx { url, .. } => info!("💾 Storage: SQLx ({})", url),
684 }
685
686 if !self.config.skills.is_empty() {
687 info!("🛠️ Skills: {}", self.config.skills.len());
688 for skill in &self.config.skills {
689 info!(" - {} ({})", skill.name, skill.id);
690 }
691 }
692 }
693}
694
695#[derive(Debug, thiserror::Error)]
697pub enum RuntimeError {
698 #[error("Server not configured: {0}")]
699 ServerNotConfigured(String),
700
701 #[error("Server error: {0}")]
702 ServerError(String),
703
704 #[error("Storage error: {0}")]
705 StorageError(String),
706}