Skip to main content

do_memory_mcp/server/tools/
field_projection.rs

1//! Field projection and selection for query result optimization
2//!
3//! This module provides field-level filtering for query results, allowing clients
4//! to request only the fields they need, significantly reducing output token usage.
5//!
6//! ## Features
7//!
8//! - **Generic Field Selection**: Works with any serializable type
9//! - **Nested Field Support**: Select nested fields (e.g., "episode.id", "episode.task_description")
10//! - **Whitelist-Based**: Only returns explicitly requested fields
11//! - **Backward Compatible**: No `fields` parameter = return all fields
12//!
13//! ## Usage
14//!
15//! ```rust
16//! # use anyhow::Result;
17//! # fn main() -> Result<()> {
18//! use do_memory_mcp::server::tools::field_projection::FieldSelector;
19//! use serde_json::json;
20//!
21//! // Create selector from request
22//! let args = json!({"fields": ["episode.id", "episode.task_description"]});
23//! let selector = FieldSelector::from_request(&args);
24//!
25//! // Apply to result
26//! let result = json!({"episode": {"id": "123", "task_description": "test"}});
27//! let filtered = selector.apply(&result)?;
28//! # Ok(())
29//! # }
30//! ```
31//!
32//! ## Token Savings
33//!
34//! - **Before**: Full object returned (~500 tokens)
35//! - **After**: Selected fields only (~200 tokens, 60% reduction)
36
37use anyhow::Result;
38use serde::Serialize;
39use serde_json::Value;
40use std::collections::HashSet;
41use tracing::{debug, trace};
42
43/// Field selector for filtering query results
44#[derive(Debug, Clone)]
45pub struct FieldSelector {
46    /// Set of field paths to select (e.g., "episode.id", "episode.task_description")
47    allowed_fields: Option<HashSet<String>>,
48    /// Whether to return all fields (backward compatibility mode)
49    return_all: bool,
50}
51
52impl FieldSelector {
53    /// Create a new field selector with specific allowed fields
54    ///
55    /// # Arguments
56    ///
57    /// * `fields` - Set of field paths to allow (e.g., "episode.id", "episode.task_description")
58    pub fn new(fields: HashSet<String>) -> Self {
59        if fields.is_empty() {
60            Self {
61                allowed_fields: None,
62                return_all: true,
63            }
64        } else {
65            Self {
66                allowed_fields: Some(fields),
67                return_all: false,
68            }
69        }
70    }
71
72    /// Create a field selector from request arguments
73    ///
74    /// # Arguments
75    ///
76    /// * `args` - Request arguments that may contain a `fields` parameter
77    ///
78    /// # Returns
79    ///
80    /// Returns a FieldSelector. If no `fields` parameter is present,
81    /// returns a selector that allows all fields (backward compatible).
82    pub fn from_request(args: &Value) -> Self {
83        match args.get("fields") {
84            Some(Value::Array(fields)) => {
85                let field_set: HashSet<String> = fields
86                    .iter()
87                    .filter_map(|v| v.as_str().map(|s| s.to_string()))
88                    .collect();
89
90                if field_set.is_empty() {
91                    debug!("Empty fields array, returning all fields");
92                    Self {
93                        allowed_fields: None,
94                        return_all: true,
95                    }
96                } else {
97                    debug!(
98                        "Field selector created with {} fields: {:?}",
99                        field_set.len(),
100                        field_set
101                    );
102                    Self {
103                        allowed_fields: Some(field_set),
104                        return_all: false,
105                    }
106                }
107            }
108            Some(_) => {
109                debug!("Invalid fields parameter type, returning all fields");
110                Self {
111                    allowed_fields: None,
112                    return_all: true,
113                }
114            }
115            None => {
116                trace!("No fields parameter, returning all fields (backward compatible)");
117                Self {
118                    allowed_fields: None,
119                    return_all: true,
120                }
121            }
122        }
123    }
124
125    /// Apply field selection to a serializable value
126    ///
127    /// # Arguments
128    ///
129    /// * `value` - Value to filter (must implement Serialize)
130    ///
131    /// # Returns
132    ///
133    /// Returns filtered JSON value
134    pub fn apply<T: Serialize>(&self, value: &T) -> Result<Value> {
135        let full = serde_json::to_value(value)?;
136
137        if self.return_all {
138            return Ok(full);
139        }
140
141        let allowed = match self.allowed_fields.as_ref() {
142            Some(fields) => fields,
143            None => return Ok(full),
144        };
145        Ok(self.filter_value(&full, allowed, ""))
146    }
147
148    /// Filter a JSON value based on allowed field paths
149    ///
150    /// # Arguments
151    ///
152    /// * `value` - JSON value to filter
153    /// * `allowed` - Set of allowed field paths
154    /// * `prefix` - Current path prefix for nested fields
155    fn filter_value(&self, value: &Value, allowed: &HashSet<String>, prefix: &str) -> Value {
156        match value {
157            Value::Object(map) => self.filter_object(map, allowed, prefix),
158            Value::Array(arr) => self.filter_array(arr, allowed, prefix),
159            _ => value.clone(),
160        }
161    }
162
163    /// Filter an object based on allowed field paths
164    fn filter_object(
165        &self,
166        map: &serde_json::Map<String, Value>,
167        allowed: &HashSet<String>,
168        prefix: &str,
169    ) -> Value {
170        let mut result = serde_json::Map::new();
171
172        for (key, value) in map.iter() {
173            let full_path = if prefix.is_empty() {
174                key.clone()
175            } else {
176                format!("{}.{}", prefix, key)
177            };
178
179            // Check if this field or any of its children are allowed
180            let field_allowed = allowed.contains(&full_path);
181            let child_allowed = allowed
182                .iter()
183                .any(|f| f.starts_with(&format!("{}.", full_path)));
184
185            if field_allowed || child_allowed {
186                if field_allowed {
187                    // Field is explicitly allowed, include it as-is
188                    result.insert(key.clone(), value.clone());
189                } else if child_allowed {
190                    // Child fields are allowed, recurse
191                    let filtered = self.filter_value(value, allowed, &full_path);
192                    if !filtered.is_null() {
193                        result.insert(key.clone(), filtered);
194                    }
195                }
196            }
197        }
198
199        Value::Object(result)
200    }
201
202    /// Filter an array based on allowed field paths
203    fn filter_array(&self, arr: &[Value], allowed: &HashSet<String>, prefix: &str) -> Value {
204        let filtered: Vec<Value> = arr
205            .iter()
206            .map(|item| self.filter_value(item, allowed, prefix))
207            .collect();
208
209        Value::Array(filtered)
210    }
211
212    /// Check if a specific field path is allowed
213    pub fn is_field_allowed(&self, path: &str) -> bool {
214        if self.return_all {
215            return true;
216        }
217
218        match &self.allowed_fields {
219            None => true,
220            Some(allowed) => {
221                // Exact match or parent path match
222                allowed.contains(path)
223                    || allowed.iter().any(|f| path.starts_with(&format!("{}.", f)))
224            }
225        }
226    }
227}
228
229#[cfg(test)]
230mod tests {
231    use super::*;
232    use serde_json::json;
233
234    #[test]
235    fn test_no_fields_returns_all() {
236        let args = json!({});
237        let selector = FieldSelector::from_request(&args);
238
239        let result = json!({"id": "123", "name": "test", "nested": {"value": 42}});
240        let filtered = selector.apply(&result).unwrap();
241
242        assert_eq!(filtered, result);
243    }
244
245    #[test]
246    fn test_simple_field_selection() {
247        let args = json!({"fields": ["id", "name"]});
248        let selector = FieldSelector::from_request(&args);
249
250        let result = json!({"id": "123", "name": "test", "extra": "ignored"});
251        let filtered = selector.apply(&result).unwrap();
252
253        assert_eq!(filtered["id"], "123");
254        assert_eq!(filtered["name"], "test");
255        assert!(!filtered.as_object().unwrap().contains_key("extra"));
256    }
257
258    #[test]
259    fn test_nested_field_selection() {
260        let args = json!({"fields": ["episode.id", "episode.task_description"]});
261        let selector = FieldSelector::from_request(&args);
262
263        let result = json!({
264            "episode": {
265                "id": "123",
266                "task_description": "test task",
267                "steps": ["step1", "step2"],
268                "outcome": {"type": "success"}
269            }
270        });
271        let filtered = selector.apply(&result).unwrap();
272
273        assert_eq!(filtered["episode"]["id"], "123");
274        assert_eq!(filtered["episode"]["task_description"], "test task");
275        assert!(
276            !filtered["episode"]
277                .as_object()
278                .unwrap()
279                .contains_key("steps")
280        );
281    }
282
283    #[test]
284    fn test_array_field_selection() {
285        let args = json!({"fields": ["episodes.id", "episodes.task_description"]});
286        let selector = FieldSelector::from_request(&args);
287
288        let result = json!({
289            "episodes": [
290                {"id": "1", "task_description": "task1", "extra": "data1"},
291                {"id": "2", "task_description": "task2", "extra": "data2"}
292            ]
293        });
294        let filtered = selector.apply(&result).unwrap();
295
296        assert_eq!(filtered["episodes"].as_array().unwrap().len(), 2);
297        assert_eq!(filtered["episodes"][0]["id"], "1");
298        assert_eq!(filtered["episodes"][0]["task_description"], "task1");
299        assert!(
300            !filtered["episodes"][0]
301                .as_object()
302                .unwrap()
303                .contains_key("extra")
304        );
305    }
306
307    #[test]
308    fn test_empty_fields_array_returns_all() {
309        let args = json!({"fields": []});
310        let selector = FieldSelector::from_request(&args);
311
312        assert!(selector.return_all);
313    }
314
315    #[test]
316    fn test_is_field_allowed() {
317        let selector = FieldSelector::new(
318            vec![
319                "episode.id".to_string(),
320                "episode.task_description".to_string(),
321            ]
322            .into_iter()
323            .collect(),
324        );
325
326        assert!(selector.is_field_allowed("episode.id"));
327        assert!(selector.is_field_allowed("episode.task_description"));
328        assert!(!selector.is_field_allowed("episode.steps"));
329    }
330
331    #[test]
332    fn test_complex_nested_structure() {
333        let args = json!({"fields": [
334            "episodes.id",
335            "episodes.task_description",
336            "episodes.outcome.type",
337            "patterns.success_rate"
338        ]});
339        let selector = FieldSelector::from_request(&args);
340
341        let result = json!({
342            "episodes": [
343                {
344                    "id": "1",
345                    "task_description": "task1",
346                    "steps": ["s1", "s2"],
347                    "outcome": {"type": "success", "verdict": "good"}
348                }
349            ],
350            "patterns": [
351                {"success_rate": 0.9, "description": "pattern1"}
352            ],
353            "insights": {"total": 10}
354        });
355        let filtered = selector.apply(&result).unwrap();
356
357        // Verify episodes filtered correctly
358        assert_eq!(filtered["episodes"][0]["id"], "1");
359        assert_eq!(filtered["episodes"][0]["task_description"], "task1");
360        assert_eq!(filtered["episodes"][0]["outcome"]["type"], "success");
361        assert!(
362            !filtered["episodes"][0]
363                .as_object()
364                .unwrap()
365                .contains_key("steps")
366        );
367        assert!(
368            !filtered["episodes"][0]["outcome"]
369                .as_object()
370                .unwrap()
371                .contains_key("verdict")
372        );
373
374        // Verify patterns filtered correctly
375        assert_eq!(filtered["patterns"][0]["success_rate"], 0.9);
376        assert!(
377            !filtered["patterns"][0]
378                .as_object()
379                .unwrap()
380                .contains_key("description")
381        );
382
383        // Verify insights not included
384        assert!(!filtered.as_object().unwrap().contains_key("insights"));
385    }
386}