Skip to main content

mcp_rtk/filter/
mod.rs

1//! The 8-stage filter pipeline and generic JSON compression functions.
2//!
3//! [`FilterEngine`] is the central entry point: given a tool name and its raw
4//! JSON output, it resolves the merged filter rules from the configuration and
5//! applies the following stages in order:
6//!
7//! 1. **keep_fields** — whitelist of JSON field names to retain.
8//! 2. **strip_fields** — blacklist of JSON field names to remove recursively.
9//! 3. **condense_users** — replace user objects with bare usernames.
10//! 4. **strip_nulls** — remove `null` and empty-string fields.
11//! 5. **flatten** — unwrap single-key wrapper objects.
12//! 6. **truncate_strings** — cap string values at a maximum length.
13//! 7. **collapse_arrays** — limit array sizes with a summary entry.
14//! 8. **custom_transforms** — regex-based string replacements.
15//!
16//! The low-level JSON manipulation functions live in the [`json`] submodule.
17
18pub mod json;
19
20use crate::config::{Config, CustomTransform, MergedRules};
21use regex::Regex;
22use serde_json::Value;
23use std::collections::HashMap;
24use std::sync::Arc;
25
26/// Engine that applies the 8-stage JSON compression pipeline.
27///
28/// The engine holds a reference to the shared [`Config`] and resolves per-tool
29/// filter rules on each call via [`Config::get_tool_rules`].
30///
31/// # Examples
32///
33/// ```no_run
34/// # use std::sync::Arc;
35/// # use mcp_rtk::config::Config;
36/// # use mcp_rtk::filter::FilterEngine;
37/// let config = Arc::new(Config::from_upstream(&["npx", "some-mcp"], None).unwrap());
38/// let engine = FilterEngine::new(config);
39/// let filtered = engine.filter("list_merge_requests", r#"[{"iid":1,"title":"Fix"}]"#);
40/// ```
41pub struct FilterEngine {
42    config: Arc<Config>,
43    /// Pre-compiled regex transforms, keyed by tool name.
44    /// `""` key holds the default transforms.
45    compiled_transforms: HashMap<String, Vec<(Regex, String)>>,
46}
47
48impl FilterEngine {
49    /// Create a new filter engine with the given configuration.
50    ///
51    /// Precompiles all regex transforms from defaults and per-tool rules.
52    pub fn new(config: Arc<Config>) -> Self {
53        let mut compiled_transforms = HashMap::new();
54
55        // Compile merged transforms for each known tool
56        for tool_name in config.filters.tools.keys() {
57            let rules = config.get_tool_rules(tool_name);
58            if !rules.custom_transforms.is_empty() {
59                compiled_transforms.insert(
60                    tool_name.clone(),
61                    compile_transforms(&rules.custom_transforms),
62                );
63            }
64        }
65
66        // Compile default-only transforms for unknown tools (key = "")
67        if !config.filters.default.custom_transforms.is_empty() {
68            compiled_transforms.insert(
69                String::new(),
70                compile_transforms(&config.filters.default.custom_transforms),
71            );
72        }
73
74        Self {
75            config,
76            compiled_transforms,
77        }
78    }
79
80    /// Access the underlying configuration.
81    pub fn config(&self) -> &Config {
82        &self.config
83    }
84
85    /// Maximum raw response size (10 MB) to prevent OOM from malicious upstreams.
86    const MAX_RESPONSE_BYTES: usize = 10 * 1024 * 1024;
87
88    /// Apply the full filter pipeline to a tool's raw output string.
89    ///
90    /// If `raw` is valid JSON, it is parsed and run through all 8 pipeline
91    /// stages. If parsing fails, only plain-text truncation is applied.
92    /// Responses exceeding 10 MB are truncated before parsing.
93    pub fn filter(&self, tool_name: &str, raw: &str) -> String {
94        let rules = self.config.get_tool_rules(tool_name);
95
96        // Guard against oversized upstream responses (OOM protection)
97        if raw.len() > Self::MAX_RESPONSE_BYTES {
98            tracing::warn!(
99                tool = tool_name,
100                size = raw.len(),
101                "Response exceeds {} bytes, applying plain-text truncation only",
102                Self::MAX_RESPONSE_BYTES,
103            );
104            return self.filter_plain_text(raw, &rules);
105        }
106
107        let parsed = serde_json::from_str::<Value>(raw);
108        let mut value = match parsed {
109            Ok(v) => v,
110            Err(_) => {
111                // If not valid JSON, apply string-level truncation only
112                return self.filter_plain_text(raw, &rules);
113            }
114        };
115
116        self.apply_pipeline(tool_name, &mut value, &rules);
117        serde_json::to_string(&value).unwrap_or_else(|_| raw.to_string())
118    }
119
120    fn apply_pipeline(&self, tool_name: &str, value: &mut Value, rules: &MergedRules) {
121        // 1. Keep fields (whitelist) — must come first
122        if !rules.keep_fields.is_empty() {
123            json::keep_fields(value, &rules.keep_fields);
124        }
125
126        // 2. Strip fields (blacklist)
127        if !rules.strip_fields.is_empty() {
128            json::strip_fields(value, &rules.strip_fields);
129        }
130
131        // 3. Condense user objects
132        if rules.condense_users {
133            json::condense_user_objects(value);
134        }
135
136        // 4. Strip nulls
137        if rules.strip_nulls {
138            json::strip_null_fields(value);
139        }
140
141        // 5. Flatten single-key wrappers
142        if rules.flatten {
143            json::flatten_single_key_objects(value);
144        }
145
146        // 6. Truncate strings
147        json::truncate_strings(value, rules.truncate_strings_at);
148
149        // 7. Collapse arrays
150        json::collapse_arrays(value, rules.max_array_items);
151
152        // 8. Custom transforms (pre-compiled at engine creation)
153        if !rules.custom_transforms.is_empty() {
154            if let Some(compiled) = self.compiled_transforms.get(tool_name) {
155                json::apply_custom_transforms(value, compiled);
156            } else if let Some(compiled) = self.compiled_transforms.get("") {
157                // Fall back to default transforms
158                json::apply_custom_transforms(value, compiled);
159            }
160        }
161    }
162
163    fn filter_plain_text(&self, text: &str, rules: &MergedRules) -> String {
164        // Use the configured limit, but never exceed the OOM safety cap
165        let limit = rules.truncate_strings_at.min(Self::MAX_RESPONSE_BYTES);
166        if limit < text.len() {
167            let mut end = limit;
168            while end > 0 && !text.is_char_boundary(end) {
169                end -= 1;
170            }
171            let mut truncated = text[..end].to_string();
172            truncated.push_str("...[truncated]");
173            truncated
174        } else {
175            text.to_string()
176        }
177    }
178}
179
180/// Compile [`CustomTransform`] patterns into regex objects.
181///
182/// Invalid patterns are silently skipped.
183fn compile_transforms(transforms: &[CustomTransform]) -> Vec<(Regex, String)> {
184    transforms
185        .iter()
186        .filter_map(|t| {
187            Regex::new(&t.pattern)
188                .ok()
189                .map(|re| (re, t.replacement.clone()))
190        })
191        .collect()
192}
193
194#[cfg(test)]
195mod tests {
196    use super::*;
197    use crate::config::Config;
198    use serde_json::json;
199
200    fn test_config() -> Config {
201        Config::from_upstream(&["npx", "@nicepkg/gitlab-mcp"], None).unwrap()
202    }
203
204    #[test]
205    fn test_filter_list_merge_requests() {
206        let config = Arc::new(test_config());
207        let engine = FilterEngine::new(config);
208
209        let input = json!([{
210            "iid": 42,
211            "title": "Fix login",
212            "state": "opened",
213            "author": {"id": 1, "name": "John", "username": "john", "avatar_url": "http://..."},
214            "source_branch": "fix-login",
215            "target_branch": "main",
216            "web_url": "https://gitlab.com/mr/42",
217            "description": "A very long description that should not appear",
218            "created_at": "2024-01-01",
219            "updated_at": "2024-01-02",
220            "_links": {"self": "..."},
221            "task_completion_status": {"count": 0},
222            "time_stats": {},
223            "extra_field": true
224        }]);
225
226        let result = engine.filter("list_merge_requests", &input.to_string());
227        let parsed: Value = serde_json::from_str(&result).unwrap();
228
229        // Should keep whitelisted fields
230        assert!(parsed[0].get("iid").is_some());
231        assert!(parsed[0].get("title").is_some());
232        assert!(parsed[0].get("state").is_some());
233        // Author should be condensed to {id, username}
234        assert_eq!(parsed[0]["author"], json!({"id": 1, "username": "john"}));
235        // Should NOT contain stripped/non-whitelisted fields
236        assert!(parsed[0].get("description").is_none());
237        assert!(parsed[0].get("_links").is_none());
238        assert!(parsed[0].get("extra_field").is_none());
239    }
240
241    #[test]
242    fn test_filter_plain_text_truncation() {
243        let config = Arc::new(test_config());
244        let engine = FilterEngine::new(config);
245
246        let long_text = "x".repeat(10000);
247        let result = engine.filter("get_job_log", &long_text);
248        assert!(result.len() < 10000);
249        assert!(result.ends_with("...[truncated]"));
250    }
251}