Skip to main content

do_memory_mcp/server/
mod.rs

1//! MCP server for memory integration
2//!
3//! This module provides the MCP (Model Context Protocol) server that integrates
4//! the self-learning memory system with code execution capabilities.
5//!
6//! ## Features
7//!
8//! - Tool definitions for memory queries and code execution
9//! - Progressive tool disclosure based on usage patterns
10//! - Integration with SelfLearningMemory system
11//! - Secure code execution via sandbox
12//! - Execution statistics and monitoring
13//!
14//! ## Example
15//!
16//! ```no_run
17//! use do_memory_mcp::server::MemoryMCPServer;
18//! use do_memory_mcp::types::SandboxConfig;
19//! use do_memory_core::SelfLearningMemory;
20//! use std::sync::Arc;
21//!
22//! #[tokio::main]
23//! async fn main() -> anyhow::Result<()> {
24//!     let memory = Arc::new(SelfLearningMemory::new());
25//!     let server = MemoryMCPServer::new(SandboxConfig::default(), memory).await?;
26//!
27//!     // List available tools
28//!     let tools = server.list_tools().await;
29//!     println!("Available tools: {}", tools.len());
30//!
31//!     Ok(())
32//! }
33//! ```
34
35// Submodules
36pub mod audit;
37pub mod cache_warming;
38pub mod rate_limiter;
39pub mod sandbox;
40#[cfg(test)]
41mod tests;
42pub mod tool_definitions;
43pub mod tool_definitions_extended;
44pub mod tools;
45
46use crate::cache::QueryCache;
47use crate::monitoring::{MonitoringConfig, MonitoringEndpoints, MonitoringSystem};
48use crate::server::audit::{AuditConfig, AuditLogger};
49use crate::server::rate_limiter::{ClientId, OperationType, RateLimiter};
50use crate::server::tools::registry::ToolRegistry;
51use crate::types::{ExecutionStats, SandboxConfig};
52use crate::unified_sandbox::UnifiedSandbox;
53use anyhow::Result;
54use do_memory_core::SelfLearningMemory;
55use parking_lot::RwLock;
56use serde_json::Value;
57use std::collections::HashMap;
58use std::sync::Arc;
59use tracing::info;
60
61/// MCP server for memory integration
62pub struct MemoryMCPServer {
63    /// Code execution sandbox
64    sandbox: Arc<UnifiedSandbox>,
65    /// Tool registry for lazy-loading tools
66    tool_registry: Arc<ToolRegistry>,
67    /// Execution statistics
68    stats: Arc<RwLock<ExecutionStats>>,
69    /// Tool usage tracking for progressive disclosure (kept for compatibility)
70    tool_usage: Arc<RwLock<HashMap<String, usize>>>,
71    /// Self-learning memory system
72    memory: Arc<SelfLearningMemory>,
73    /// Monitoring system
74    monitoring: Arc<MonitoringSystem>,
75    /// Monitoring endpoints
76    monitoring_endpoints: Arc<MonitoringEndpoints>,
77    /// Query result cache
78    #[allow(dead_code)]
79    cache: Arc<QueryCache>,
80    /// Audit logger for security events
81    audit_logger: Arc<AuditLogger>,
82    /// Rate limiter for DoS protection
83    rate_limiter: RateLimiter,
84}
85
86impl MemoryMCPServer {
87    /// Create a new MCP server
88    ///
89    /// # Arguments
90    ///
91    /// * `config` - Sandbox configuration for code execution
92    /// * `memory` - Self-learning memory system
93    ///
94    /// # Returns
95    ///
96    /// Returns a new `MemoryMCPServer` instance
97    pub async fn new(config: SandboxConfig, memory: Arc<SelfLearningMemory>) -> Result<Self> {
98        let sandbox_backend = sandbox::determine_sandbox_backend();
99        let sandbox = Arc::new(UnifiedSandbox::new(config.clone(), sandbox_backend).await?);
100
101        // Use tool registry for lazy loading
102        let tool_registry = Arc::new(tools::registry::create_default_registry());
103
104        let monitoring = Self::initialize_monitoring();
105        let monitoring_endpoints = Arc::new(MonitoringEndpoints::new(Arc::clone(&monitoring)));
106
107        // Initialize audit logger
108        let audit_config = AuditConfig::from_env();
109        let audit_logger = Arc::new(AuditLogger::new(audit_config).await?);
110
111        let core_count = tool_registry.get_core_tools().len();
112        let total_count = tool_registry.total_tool_count();
113
114        let server = Self {
115            sandbox,
116            tool_registry,
117            stats: Arc::new(RwLock::new(ExecutionStats::default())),
118            tool_usage: Arc::new(RwLock::new(HashMap::new())),
119            memory,
120            monitoring,
121            monitoring_endpoints,
122            cache: Arc::new(QueryCache::new()),
123            audit_logger,
124            rate_limiter: RateLimiter::from_env(),
125        };
126
127        info!(
128            "MCP server initialized with {} core tools ({} total tools available)",
129            core_count, total_count
130        );
131        info!("Tools loaded on-demand to reduce token usage (lazy loading enabled)");
132        info!(
133            "Monitoring system initialized (enabled: {})",
134            server.monitoring.config().enabled
135        );
136        info!("Audit logging system initialized");
137        info!(
138            "Rate limiter initialized (enabled: {})",
139            server.rate_limiter.is_enabled()
140        );
141
142        // Perform cache warming if enabled
143        if cache_warming::is_cache_warming_enabled() {
144            info!("Starting cache warming process...");
145            if let Err(e) = cache_warming::warm_cache(
146                &server.memory,
147                &cache_warming::CacheWarmingConfig::from_env(),
148            )
149            .await
150            {
151                tracing::warn!(
152                    "Cache warming failed, but continuing with server startup: {}",
153                    e
154                );
155            } else {
156                info!("Cache warming completed successfully");
157            }
158        } else {
159            info!("Cache warming disabled, skipping");
160        }
161
162        Ok(server)
163    }
164
165    fn initialize_monitoring() -> Arc<MonitoringSystem> {
166        let monitoring_config = MonitoringConfig::default();
167        Arc::new(MonitoringSystem::new(monitoring_config))
168    }
169
170    /// Get a reference to the memory system
171    ///
172    /// # Returns
173    ///
174    /// Returns a clone of the `Arc<SelfLearningMemory>`
175    pub fn memory(&self) -> Arc<SelfLearningMemory> {
176        Arc::clone(&self.memory)
177    }
178
179    /// Get a reference to the audit logger
180    ///
181    /// # Returns
182    ///
183    /// Returns a clone of the `Arc<AuditLogger>`
184    pub fn audit_logger(&self) -> Arc<AuditLogger> {
185        Arc::clone(&self.audit_logger)
186    }
187
188    /// Get a reference to the rate limiter
189    ///
190    /// # Returns
191    ///
192    /// Returns a reference to the `RateLimiter`
193    pub fn rate_limiter(&self) -> &RateLimiter {
194        &self.rate_limiter
195    }
196
197    /// Extract client ID from tool arguments
198    ///
199    /// # Arguments
200    ///
201    /// * `args` - Tool arguments JSON value
202    ///
203    /// # Returns
204    ///
205    /// Returns a `ClientId` for rate limiting
206    pub fn client_id_from_args(&self, args: &Value) -> ClientId {
207        args.get("client_id")
208            .and_then(|v| v.as_str())
209            .filter(|s| !s.is_empty())
210            .map(ClientId::from_string)
211            .unwrap_or(ClientId::Unknown)
212    }
213
214    /// Check rate limit for a client
215    ///
216    /// # Arguments
217    ///
218    /// * `client_id` - Client identifier
219    /// * `operation` - Type of operation (read or write)
220    ///
221    /// # Returns
222    ///
223    /// Returns the rate limit check result
224    pub fn check_rate_limit(
225        &self,
226        client_id: &ClientId,
227        operation: OperationType,
228    ) -> crate::server::rate_limiter::RateLimitResult {
229        self.rate_limiter.check_rate_limit(client_id, operation)
230    }
231
232    /// Get rate limit headers for a response
233    ///
234    /// # Arguments
235    ///
236    /// * `result` - Rate limit check result
237    ///
238    /// # Returns
239    ///
240    /// Returns vector of rate limit header tuples
241    pub fn rate_limit_headers(
242        &self,
243        result: &crate::server::rate_limiter::RateLimitResult,
244    ) -> Vec<(String, String)> {
245        self.rate_limiter.get_headers(result)
246    }
247
248    /// Get rate limit headers for a rate-limited response
249    ///
250    /// # Arguments
251    ///
252    /// * `result` - Rate limit check result
253    ///
254    /// # Returns
255    ///
256    /// Returns vector of rate limit header tuples including Retry-After
257    pub fn rate_limited_headers(
258        &self,
259        result: &crate::server::rate_limiter::RateLimitResult,
260    ) -> Vec<(String, String)> {
261        self.rate_limiter.get_rate_limited_headers(result)
262    }
263}