custom_mcp_server_example/
custom_mcp_server_example.rs

1//! Custom MCP Server Example with OAuth 2.0 and DynamoDB
2//!
3//! This example demonstrates how to create a custom MCP server implementation
4//! and use it with the microkernel architecture. It shows how users can extend
5//! the framework by implementing their own MCP servers with custom tools and
6//! capabilities while leveraging the existing OAuth and storage infrastructure.
7//!
8//! # Features
9//! - Custom MCP server implementation with specialized tools
10//! - AWS Cognito OAuth 2.0 authentication
11//! - DynamoDB persistent storage for OAuth tokens and client data
12//! - MCP over HTTP (streamable)
13//! - MCP over Server-Sent Events (SSE)
14//! - Microkernel architecture with independent, composable handlers
15//!
16//! # Custom MCP Server Features
17//! The custom server implements:
18//! - File system operations (list, read, write files)
19//! - System information tools (CPU, memory, disk usage)
20//! - Text processing utilities (word count, text search)
21//! - Time and date utilities
22//!
23//! # Required Environment Variables
24//! ## Cognito Configuration
25//! - `COGNITO_CLIENT_ID`: Your Cognito app client ID
26//! - `COGNITO_CLIENT_SECRET`: Your Cognito app client secret (optional for public clients)
27//! - `COGNITO_DOMAIN`: Your Cognito domain (e.g., mydomain.auth.us-east-1.amazoncognito.com)
28//! - `COGNITO_REGION`: AWS region (e.g., us-east-1)
29//! - `COGNITO_USER_POOL_ID`: Your Cognito user pool ID (e.g., us-east-1_XXXXXXXXX)
30//! - `COGNITO_SCOPE`: OAuth scopes (default: 'openid email profile phone')
31//!
32//! ## AWS Configuration (for DynamoDB)
33//! - `AWS_ACCESS_KEY_ID`: Your AWS access key ID
34//! - `AWS_SECRET_ACCESS_KEY`: Your AWS secret access key
35//! - `AWS_REGION`: AWS region (should match COGNITO_REGION)
36//!
37//! ## Server Configuration
38//! - `MCP_HOST`: Server host (default: localhost)
39//! - `MCP_PORT`: Server port (default: 8080)
40//! - `DYNAMODB_TABLE_NAME`: DynamoDB table name (default: oauth-storage)
41//! - `DYNAMODB_CREATE_TABLE`: Whether to auto-create table (default: true)
42//!
43//! # Usage
44//! ```bash
45//! # Set environment variables
46//! export COGNITO_CLIENT_ID="your_client_id"
47//! export COGNITO_CLIENT_SECRET="your_client_secret"
48//! export COGNITO_DOMAIN="mydomain.auth.us-east-1.amazoncognito.com"
49//! export COGNITO_REGION="us-east-1"
50//! export COGNITO_USER_POOL_ID="us-east-1_XXXXXXXXX"
51//! export AWS_ACCESS_KEY_ID="your_aws_access_key"
52//! export AWS_SECRET_ACCESS_KEY="your_aws_secret_key"
53//! export AWS_REGION="us-east-1"
54//!
55//! # Run the server
56//! cargo run --example custom_mcp_server_example
57//! ```
58
59use oauth_provider_rs::storage::create_dynamodb_storage;
60use oauth_provider_rs::OAuthProvider;
61use remote_mcp_kernel::{
62    config::{get_cognito_oauth_provider_config, get_cognito_domain, get_cognito_region, get_cognito_user_pool_id, get_bind_socket_addr, get_logging_level, get_server_host, get_server_port, get_server_version, get_cognito_client_id, get_cognito_client_secret, get_cognito_scope}, 
63    error::AppResult, 
64    handlers::SseHandlerConfig, 
65    microkernel::MicrokernelServer,
66};
67use rmcp::{
68    Error as McpError, ServerHandler,
69    handler::server::router::tool::ToolRouter,
70    handler::server::tool::Parameters,
71    model::{CallToolResult, Content, Implementation, ServerCapabilities, ServerInfo},
72    tool, tool_handler, tool_router,
73};
74use schemars::JsonSchema;
75use serde::{Deserialize, Serialize};
76use std::env;
77use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
78
79// Custom MCP Server Implementation
80// =============================================================================
81
82/// Custom MCP server with specialized tools for file operations and system utilities
83#[derive(Debug, Clone)]
84pub struct CustomMcpServer {
85    /// MCP tool router for handling tool calls
86    tool_router: ToolRouter<Self>,
87    /// Server name for identification
88    name: String,
89}
90
91// Tool parameter definitions
92#[derive(Debug, Deserialize, Serialize, JsonSchema)]
93pub struct ListFilesRequest {
94    #[schemars(description = "Directory path to list files from")]
95    pub path: String,
96    #[schemars(description = "Include hidden files")]
97    pub include_hidden: Option<bool>,
98}
99
100#[derive(Debug, Deserialize, Serialize, JsonSchema)]
101pub struct ReadFileRequest {
102    #[schemars(description = "File path to read")]
103    pub path: String,
104    #[schemars(description = "Maximum number of lines to read")]
105    pub max_lines: Option<usize>,
106}
107
108#[derive(Debug, Deserialize, Serialize, JsonSchema)]
109pub struct WriteFileRequest {
110    #[schemars(description = "File path to write to")]
111    pub path: String,
112    #[schemars(description = "Content to write")]
113    pub content: String,
114    #[schemars(description = "Append to file instead of overwriting")]
115    pub append: Option<bool>,
116}
117
118#[derive(Debug, Deserialize, Serialize, JsonSchema)]
119pub struct WordCountRequest {
120    #[schemars(description = "Text to count words in")]
121    pub text: String,
122}
123
124#[derive(Debug, Deserialize, Serialize, JsonSchema)]
125pub struct TextSearchRequest {
126    #[schemars(description = "Text to search in")]
127    pub text: String,
128    #[schemars(description = "Pattern to search for")]
129    pub pattern: String,
130    #[schemars(description = "Case sensitive search")]
131    pub case_sensitive: Option<bool>,
132}
133
134// Tool implementations using the tool_router macro
135#[tool_router]
136impl CustomMcpServer {
137    /// Create a new custom MCP server
138    pub fn new(name: String) -> Self {
139        Self {
140            tool_router: Self::tool_router(),
141            name,
142        }
143    }
144
145    /// List files in a directory
146    #[tool(description = "List files and directories in the specified path")]
147    async fn list_files(
148        &self,
149        Parameters(req): Parameters<ListFilesRequest>,
150    ) -> Result<CallToolResult, McpError> {
151        let path = std::path::Path::new(&req.path);
152
153        if !path.exists() {
154            return Ok(CallToolResult::error(vec![Content::text(format!(
155                "Path does not exist: {}",
156                req.path
157            ))]));
158        }
159
160        if !path.is_dir() {
161            return Ok(CallToolResult::error(vec![Content::text(format!(
162                "Path is not a directory: {}",
163                req.path
164            ))]));
165        }
166
167        let mut files = Vec::new();
168        let include_hidden = req.include_hidden.unwrap_or(false);
169
170        match std::fs::read_dir(path) {
171            Ok(entries) => {
172                for entry in entries {
173                    match entry {
174                        Ok(entry) => {
175                            let file_name = entry.file_name().to_string_lossy().to_string();
176                            let is_hidden = file_name.starts_with('.');
177
178                            if include_hidden || !is_hidden {
179                                let file_type = if entry.path().is_dir() {
180                                    "directory"
181                                } else {
182                                    "file"
183                                };
184                                files.push(format!("{} ({})", file_name, file_type));
185                            }
186                        }
187                        Err(e) => {
188                            return Ok(CallToolResult::error(vec![Content::text(format!(
189                                "Error reading directory entry: {}",
190                                e
191                            ))]));
192                        }
193                    }
194                }
195            }
196            Err(e) => {
197                return Ok(CallToolResult::error(vec![Content::text(format!(
198                    "Error reading directory: {}",
199                    e
200                ))]));
201            }
202        }
203
204        files.sort();
205        let result = files.join("\n");
206        Ok(CallToolResult::success(vec![Content::text(result)]))
207    }
208
209    /// Read contents of a file
210    #[tool(description = "Read the contents of a file")]
211    async fn read_file(
212        &self,
213        Parameters(req): Parameters<ReadFileRequest>,
214    ) -> Result<CallToolResult, McpError> {
215        let path = std::path::Path::new(&req.path);
216
217        if !path.exists() {
218            return Ok(CallToolResult::error(vec![Content::text(format!(
219                "File does not exist: {}",
220                req.path
221            ))]));
222        }
223
224        if !path.is_file() {
225            return Ok(CallToolResult::error(vec![Content::text(format!(
226                "Path is not a file: {}",
227                req.path
228            ))]));
229        }
230
231        match std::fs::read_to_string(path) {
232            Ok(content) => {
233                let result = if let Some(max_lines) = req.max_lines {
234                    content
235                        .lines()
236                        .take(max_lines)
237                        .collect::<Vec<_>>()
238                        .join("\n")
239                } else {
240                    content
241                };
242                Ok(CallToolResult::success(vec![Content::text(result)]))
243            }
244            Err(e) => Ok(CallToolResult::error(vec![Content::text(format!(
245                "Error reading file: {}",
246                e
247            ))])),
248        }
249    }
250
251    /// Write content to a file
252    #[tool(description = "Write content to a file")]
253    async fn write_file(
254        &self,
255        Parameters(req): Parameters<WriteFileRequest>,
256    ) -> Result<CallToolResult, McpError> {
257        let path = std::path::Path::new(&req.path);
258
259        // Create parent directories if they don't exist
260        if let Some(parent) = path.parent() {
261            if !parent.exists() {
262                if let Err(e) = std::fs::create_dir_all(parent) {
263                    return Ok(CallToolResult::error(vec![Content::text(format!(
264                        "Error creating parent directories: {}",
265                        e
266                    ))]));
267                }
268            }
269        }
270
271        let result = if req.append.unwrap_or(false) {
272            std::fs::write(path, &req.content)
273        } else {
274            std::fs::write(path, &req.content)
275        };
276
277        match result {
278            Ok(()) => Ok(CallToolResult::success(vec![Content::text(format!(
279                "Successfully wrote {} bytes to {}",
280                req.content.len(),
281                req.path
282            ))])),
283            Err(e) => Ok(CallToolResult::error(vec![Content::text(format!(
284                "Error writing file: {}",
285                e
286            ))])),
287        }
288    }
289
290    /// Get system information
291    #[tool(description = "Get system information including CPU, memory, and disk usage")]
292    async fn get_system_info(&self) -> Result<CallToolResult, McpError> {
293        let mut info = Vec::new();
294
295        // Get current timestamp
296        let now = std::time::SystemTime::now()
297            .duration_since(std::time::UNIX_EPOCH)
298            .unwrap()
299            .as_secs();
300        info.push(format!("Timestamp: {}", now));
301
302        // Get current working directory
303        if let Ok(cwd) = std::env::current_dir() {
304            info.push(format!("Working Directory: {}", cwd.display()));
305        }
306
307        // Get environment variables count
308        let env_count = std::env::vars().count();
309        info.push(format!("Environment Variables: {}", env_count));
310
311        // Get OS information
312        info.push(format!("OS: {}", std::env::consts::OS));
313        info.push(format!("Architecture: {}", std::env::consts::ARCH));
314
315        let result = info.join("\n");
316        Ok(CallToolResult::success(vec![Content::text(result)]))
317    }
318
319    /// Count words in text
320    #[tool(description = "Count words, lines, and characters in text")]
321    async fn count_words(
322        &self,
323        Parameters(req): Parameters<WordCountRequest>,
324    ) -> Result<CallToolResult, McpError> {
325        let text = &req.text;
326        let lines = text.lines().count();
327        let words = text.split_whitespace().count();
328        let chars = text.chars().count();
329        let bytes = text.len();
330
331        let result = format!(
332            "Lines: {}\nWords: {}\nCharacters: {}\nBytes: {}",
333            lines, words, chars, bytes
334        );
335
336        Ok(CallToolResult::success(vec![Content::text(result)]))
337    }
338
339    /// Search for text patterns
340    #[tool(description = "Search for patterns in text")]
341    async fn search_text(
342        &self,
343        Parameters(req): Parameters<TextSearchRequest>,
344    ) -> Result<CallToolResult, McpError> {
345        let text = &req.text;
346        let pattern = &req.pattern;
347        let case_sensitive = req.case_sensitive.unwrap_or(false);
348
349        let (search_text, search_pattern) = if case_sensitive {
350            (text.to_string(), pattern.to_string())
351        } else {
352            (text.to_lowercase(), pattern.to_lowercase())
353        };
354
355        let mut matches = Vec::new();
356        for (line_num, line) in search_text.lines().enumerate() {
357            if line.contains(&search_pattern) {
358                matches.push(format!(
359                    "Line {}: {}",
360                    line_num + 1,
361                    text.lines().nth(line_num).unwrap_or("")
362                ));
363            }
364        }
365
366        let result = if matches.is_empty() {
367            format!("No matches found for pattern: {}", pattern)
368        } else {
369            format!("Found {} matches:\n{}", matches.len(), matches.join("\n"))
370        };
371
372        Ok(CallToolResult::success(vec![Content::text(result)]))
373    }
374
375    /// Get current date and time
376    #[tool(description = "Get current date and time in various formats")]
377    async fn get_datetime(&self) -> Result<CallToolResult, McpError> {
378        let now = std::time::SystemTime::now();
379        let unix_timestamp = now.duration_since(std::time::UNIX_EPOCH).unwrap().as_secs();
380
381        let result = format!(
382            "Unix Timestamp: {}\nISO 8601 (approx): {}",
383            unix_timestamp,
384            // Simple ISO 8601 approximation (not perfect but sufficient for demo)
385            chrono::DateTime::from_timestamp(unix_timestamp as i64, 0)
386                .unwrap_or_default()
387                .format("%Y-%m-%dT%H:%M:%SZ")
388        );
389
390        Ok(CallToolResult::success(vec![Content::text(result)]))
391    }
392}
393
394// Implement the ServerHandler trait for MCP protocol support
395#[tool_handler]
396impl ServerHandler for CustomMcpServer {
397    fn get_info(&self) -> ServerInfo {
398        ServerInfo {
399            protocol_version: Default::default(),
400            capabilities: ServerCapabilities::builder()
401                .enable_tools()
402                .build(),
403            server_info: Implementation {
404                name: self.name.clone(),
405                version: "1.0.0".to_string(),
406            },
407            instructions: Some("A custom MCP server with file operations, system utilities, and text processing tools".to_string()),
408        }
409    }
410}
411
412// Main application
413// =============================================================================
414
415#[tokio::main]
416async fn main() -> AppResult<()> {
417    // Load environment variables
418    dotenv::dotenv().ok();
419
420    // Initialize tracing
421    init_tracing()?;
422
423    tracing::info!("Starting Custom MCP Server example with Cognito and DynamoDB storage...");
424
425    // Create Cognito OAuth configuration
426    let cognito_config = get_cognito_oauth_provider_config()?;
427
428    // Get DynamoDB configuration
429    let table_name =
430        env::var("DYNAMODB_TABLE_NAME").unwrap_or_else(|_| "oauth-storage".to_string());
431    let create_table = env::var("DYNAMODB_CREATE_TABLE")
432        .unwrap_or_else(|_| "true".to_string())
433        .parse::<bool>()
434        .unwrap_or(true);
435
436    // Log configuration
437    log_startup_info(&table_name, create_table);
438
439    // Create DynamoDB storage
440    let (storage, client_manager) = create_dynamodb_storage(
441        table_name.clone(),
442        create_table,
443        Some("expires_at".to_string()),
444    )
445    .await
446    .map_err(|e| {
447        remote_mcp_kernel::error::AppError::Internal(format!(
448            "Failed to create DynamoDB storage: {}",
449            e
450        ))
451    })?;
452
453    // Create Cognito OAuth provider with DynamoDB storage
454    let oauth_handler = oauth_provider_rs::CognitoOAuthHandler::new_simple(
455        storage,
456        client_manager,
457        cognito_config,
458        get_cognito_domain()?,
459        get_cognito_region()?,
460        get_cognito_user_pool_id()?,
461    );
462
463    let oauth_provider = OAuthProvider::new(oauth_handler, oauth_provider_rs::http_integration::config::OAuthProviderConfig::default());
464
465    // Create custom MCP server
466    let custom_mcp_server = CustomMcpServer::new("Custom File & System MCP Server".to_string());
467
468    // Build microkernel with custom MCP server using convenience methods
469    let microkernel = MicrokernelServer::new()
470        .with_oauth_provider(oauth_provider)
471        .with_mcp_streamable_handler(custom_mcp_server.clone())
472        .with_mcp_sse_handler(custom_mcp_server, SseHandlerConfig::default());
473
474    // Start the microkernel server
475    let bind_address = get_bind_socket_addr()?;
476    tracing::info!("🚀 Starting microkernel server on {}", bind_address);
477    microkernel.serve(bind_address).await?;
478
479    Ok(())
480}
481
482fn init_tracing() -> AppResult<()> {
483    tracing_subscriber::registry()
484        .with(
485            tracing_subscriber::EnvFilter::try_from_default_env()
486                .unwrap_or_else(|_| get_logging_level().as_str().into()),
487        )
488        .with(tracing_subscriber::fmt::layer())
489        .init();
490
491    Ok(())
492}
493
494fn log_startup_info(table_name: &str, create_table: bool) {
495    println!("🚀 Starting Custom MCP Server example with Cognito and DynamoDB storage...");
496    println!("📋 Configuration:");
497    println!("  - Architecture: Microkernel (independent handlers)");
498    println!("  - MCP Server: Custom implementation with specialized tools");
499    println!("  - OAuth Provider: AWS Cognito");
500    println!("  - Storage Backend: DynamoDB");
501    println!("  - Server: {}:{}", get_server_host(), get_server_port().unwrap_or(8080));
502    println!("  - Version: {}", get_server_version());
503    println!();
504
505    println!("🔧 Custom MCP Server Tools:");
506    println!("  - list_files: List files and directories");
507    println!("  - read_file: Read file contents");
508    println!("  - write_file: Write content to files");
509    println!("  - get_system_info: Get system information");
510    println!("  - count_words: Count words, lines, and characters");
511    println!("  - search_text: Search for patterns in text");
512    println!("  - get_datetime: Get current date and time");
513    println!();
514
515    println!("🔐 AWS Cognito Configuration:");
516    println!(
517        "  - Client ID: {}",
518        if get_cognito_client_id().is_ok() {
519            "Configured"
520        } else {
521            "Not configured"
522        }
523    );
524    println!(
525        "  - Client Secret: {}",
526        match get_cognito_client_secret() {
527            Some(secret) if !secret.is_empty() => "Configured",
528            _ => "Not configured (Public Client)",
529        }
530    );
531    println!(
532        "  - Domain: {}",
533        get_cognito_domain().unwrap_or_else(|_| "Not configured".to_string())
534    );
535    println!(
536        "  - Region: {}",
537        get_cognito_region().unwrap_or_else(|_| "Not configured".to_string())
538    );
539    println!(
540        "  - User Pool ID: {}",
541        get_cognito_user_pool_id().unwrap_or_else(|_| "Not configured".to_string())
542    );
543    println!("  - Scopes: {}", get_cognito_scope());
544    println!();
545
546    println!("🗄️  DynamoDB Storage Configuration:");
547    println!("  - Table Name: {}", table_name);
548    println!("  - Auto-create Table: {}", create_table);
549    println!("  - TTL Attribute: expires_at");
550    println!();
551
552    println!("🔧 Handlers:");
553    println!("  - OAuth Provider (Cognito authentication & authorization)");
554    println!("  - Streamable HTTP Handler (MCP over HTTP with custom server)");
555    println!("  - SSE Handler (MCP over SSE with custom server)");
556    println!();
557
558    println!("🏗️  Microkernel Architecture:");
559    println!("  - Custom MCP server with specialized tools");
560    println!("  - Independent handlers that can operate standalone");
561    println!("  - Runtime composition of services");
562    println!("  - Single responsibility per handler");
563    println!("  - Easy testing and maintenance");
564    println!();
565
566    println!("🌐 MCP Protocol Endpoints:");
567    let host = get_server_host();
568    let port = get_server_port().unwrap_or(8080);
569    println!(
570        "  - HTTP (streamable): http://{}:{}/mcp/http",
571        host, port
572    );
573    println!(
574        "  - SSE: http://{}:{}/mcp/sse",
575        host, port
576    );
577    println!(
578        "  - SSE Messages: http://{}:{}/mcp/message",
579        host, port
580    );
581    println!();
582
583    println!("🔐 OAuth 2.0 Endpoints:");
584    let cognito_domain = get_cognito_domain().unwrap_or_else(|_| "Not configured".to_string());
585    println!(
586        "  - Authorization: https://{}/oauth2/authorize",
587        cognito_domain
588    );
589    println!(
590        "  - Token: https://{}/oauth2/token",
591        cognito_domain
592    );
593    println!(
594        "  - JWKS: https://{}/oauth2/jwks",
595        cognito_domain
596    );
597    println!(
598        "  - UserInfo: https://{}/oauth2/userInfo",
599        cognito_domain
600    );
601    println!();
602}