Skip to main content

do_memory_mcp/server/
mod.rs

1//! MCP server for memory integration
2//!
3//! This module provides the MCP (Model Context Protocol) server that integrates
4//! the self-learning memory system with memory queries and pattern analysis.
5//!
6//! ## Features
7//!
8//! - Tool definitions for memory queries and pattern analysis
9//! - Progressive tool disclosure based on usage patterns
10//! - Integration with SelfLearningMemory system
11//! - Execution statistics and monitoring
12//!
13//! ## Example
14//!
15//! ```no_run
16//! use do_memory_mcp::server::MemoryMCPServer;
17//! use do_memory_mcp::types::SandboxConfig;
18//! use do_memory_core::SelfLearningMemory;
19//! use std::sync::Arc;
20//!
21//! #[tokio::main]
22//! async fn main() -> anyhow::Result<()> {
23//!     let memory = Arc::new(SelfLearningMemory::new());
24//!     let server = MemoryMCPServer::new(SandboxConfig::default(), memory).await?;
25//!
26//!     // List available tools
27//!     let tools = server.list_tools().await;
28//!     println!("Available tools: {}", tools.len());
29//!
30//!     Ok(())
31//! }
32//! ```
33
34// Submodules
35pub mod audit;
36pub mod cache_warming;
37pub mod rate_limiter;
38#[cfg(test)]
39mod tests;
40pub mod tool_definitions;
41pub mod tool_definitions_extended;
42pub mod tools;
43
44use crate::cache::QueryCache;
45use crate::monitoring::{MonitoringConfig, MonitoringEndpoints, MonitoringSystem};
46use crate::server::audit::{AuditConfig, AuditLogger};
47use crate::server::rate_limiter::{ClientId, OperationType, RateLimiter};
48use crate::server::tools::registry::ToolRegistry;
49use crate::types::{ExecutionStats, SandboxConfig};
50use anyhow::Result;
51use do_memory_core::SelfLearningMemory;
52use parking_lot::RwLock;
53use serde_json::Value;
54use std::collections::HashMap;
55use std::sync::Arc;
56use tracing::info;
57
58/// MCP server for memory integration
59#[allow(dead_code)] // Server struct instantiated via new(), compiler sees fields as unused
60pub struct MemoryMCPServer {
61    /// Tool registry for lazy-loading tools
62    tool_registry: Arc<ToolRegistry>,
63    /// Execution statistics
64    stats: Arc<RwLock<ExecutionStats>>,
65    /// Tool usage tracking for progressive disclosure (kept for compatibility)
66    tool_usage: Arc<RwLock<HashMap<String, usize>>>,
67    /// Self-learning memory system
68    memory: Arc<SelfLearningMemory>,
69    /// Monitoring system
70    monitoring: Arc<MonitoringSystem>,
71    /// Monitoring endpoints
72    monitoring_endpoints: Arc<MonitoringEndpoints>,
73    /// Query result cache
74    cache: Arc<QueryCache>,
75    /// Audit logger for security events
76    audit_logger: Arc<AuditLogger>,
77    /// Rate limiter for DoS protection
78    rate_limiter: RateLimiter,
79}
80
81impl MemoryMCPServer {
82    /// Create a new MCP server
83    ///
84    /// # Arguments
85    ///
86    /// * `config` - Sandbox configuration (kept for API compatibility)
87    /// * `memory` - Self-learning memory system
88    ///
89    /// # Returns
90    ///
91    /// Returns a new `MemoryMCPServer` instance
92    pub async fn new(_config: SandboxConfig, memory: Arc<SelfLearningMemory>) -> Result<Self> {
93        // Use tool registry for lazy loading
94        let tool_registry = Arc::new(tools::registry::create_default_registry());
95
96        let monitoring = Self::initialize_monitoring();
97        let monitoring_endpoints = Arc::new(MonitoringEndpoints::new(Arc::clone(&monitoring)));
98
99        // Initialize audit logger
100        let audit_config = AuditConfig::from_env();
101        let audit_logger = Arc::new(AuditLogger::new(audit_config).await?);
102
103        let core_count = tool_registry.get_core_tools().len();
104        let total_count = tool_registry.total_tool_count();
105
106        let server = Self {
107            tool_registry,
108            stats: Arc::new(RwLock::new(ExecutionStats::default())),
109            tool_usage: Arc::new(RwLock::new(HashMap::new())),
110            memory,
111            monitoring,
112            monitoring_endpoints,
113            cache: Arc::new(QueryCache::new()),
114            audit_logger,
115            rate_limiter: RateLimiter::from_env(),
116        };
117
118        info!(
119            "MCP server initialized with {} core tools ({} total tools available)",
120            core_count, total_count
121        );
122        info!("Tools loaded on-demand to reduce token usage (lazy loading enabled)");
123        info!(
124            "Monitoring system initialized (enabled: {})",
125            server.monitoring.config().enabled
126        );
127        info!("Audit logging system initialized");
128        info!(
129            "Rate limiter initialized (enabled: {})",
130            server.rate_limiter.is_enabled()
131        );
132
133        // Perform cache warming if enabled
134        if cache_warming::is_cache_warming_enabled() {
135            info!("Starting cache warming process...");
136            if let Err(e) = cache_warming::warm_cache(
137                &server.memory,
138                &cache_warming::CacheWarmingConfig::from_env(),
139            )
140            .await
141            {
142                tracing::warn!(
143                    "Cache warming failed, but continuing with server startup: {}",
144                    e
145                );
146            } else {
147                info!("Cache warming completed successfully");
148            }
149        } else {
150            info!("Cache warming disabled, skipping");
151        }
152
153        Ok(server)
154    }
155
156    fn initialize_monitoring() -> Arc<MonitoringSystem> {
157        let monitoring_config = MonitoringConfig::default();
158        Arc::new(MonitoringSystem::new(monitoring_config))
159    }
160
161    /// Get a reference to the memory system
162    ///
163    /// # Returns
164    ///
165    /// Returns a clone of the `Arc<SelfLearningMemory>`
166    pub fn memory(&self) -> Arc<SelfLearningMemory> {
167        Arc::clone(&self.memory)
168    }
169
170    /// Get a reference to the audit logger
171    ///
172    /// # Returns
173    ///
174    /// Returns a clone of the `Arc<AuditLogger>`
175    pub fn audit_logger(&self) -> Arc<AuditLogger> {
176        Arc::clone(&self.audit_logger)
177    }
178
179    /// Get a reference to the rate limiter
180    ///
181    /// # Returns
182    ///
183    /// Returns a reference to the `RateLimiter`
184    pub fn rate_limiter(&self) -> &RateLimiter {
185        &self.rate_limiter
186    }
187
188    /// Extract client ID from tool arguments
189    ///
190    /// # Arguments
191    ///
192    /// * `args` - Tool arguments JSON value
193    ///
194    /// # Returns
195    ///
196    /// Returns a `ClientId` for rate limiting
197    pub fn client_id_from_args(&self, args: &Value) -> ClientId {
198        args.get("client_id")
199            .and_then(|v| v.as_str())
200            .filter(|s| !s.is_empty())
201            .map(ClientId::from_string)
202            .unwrap_or(ClientId::Unknown)
203    }
204
205    /// Check rate limit for a client
206    ///
207    /// # Arguments
208    ///
209    /// * `client_id` - Client identifier
210    /// * `operation` - Type of operation (read or write)
211    ///
212    /// # Returns
213    ///
214    /// Returns the rate limit check result
215    pub fn check_rate_limit(
216        &self,
217        client_id: &ClientId,
218        operation: OperationType,
219    ) -> crate::server::rate_limiter::RateLimitResult {
220        self.rate_limiter.check_rate_limit(client_id, operation)
221    }
222
223    /// Get rate limit headers for a response
224    ///
225    /// # Arguments
226    ///
227    /// * `result` - Rate limit check result
228    ///
229    /// # Returns
230    ///
231    /// Returns vector of rate limit header tuples
232    pub fn rate_limit_headers(
233        &self,
234        result: &crate::server::rate_limiter::RateLimitResult,
235    ) -> Vec<(String, String)> {
236        self.rate_limiter.get_headers(result)
237    }
238
239    /// Get rate limit headers for a rate-limited response
240    ///
241    /// # Arguments
242    ///
243    /// * `result` - Rate limit check result
244    ///
245    /// # Returns
246    ///
247    /// Returns vector of rate limit header tuples including Retry-After
248    pub fn rate_limited_headers(
249        &self,
250        result: &crate::server::rate_limiter::RateLimitResult,
251    ) -> Vec<(String, String)> {
252        self.rate_limiter.get_rate_limited_headers(result)
253    }
254}