do_memory_mcp/server/tools/
field_projection.rs1use anyhow::Result;
38use serde::Serialize;
39use serde_json::Value;
40use std::collections::HashSet;
41use tracing::{debug, trace};
42
43#[derive(Debug, Clone)]
45pub struct FieldSelector {
46 allowed_fields: Option<HashSet<String>>,
48 return_all: bool,
50}
51
52impl FieldSelector {
53 pub fn new(fields: HashSet<String>) -> Self {
59 if fields.is_empty() {
60 Self {
61 allowed_fields: None,
62 return_all: true,
63 }
64 } else {
65 Self {
66 allowed_fields: Some(fields),
67 return_all: false,
68 }
69 }
70 }
71
72 pub fn from_request(args: &Value) -> Self {
83 match args.get("fields") {
84 Some(Value::Array(fields)) => {
85 let field_set: HashSet<String> = fields
86 .iter()
87 .filter_map(|v| v.as_str().map(|s| s.to_string()))
88 .collect();
89
90 if field_set.is_empty() {
91 debug!("Empty fields array, returning all fields");
92 Self {
93 allowed_fields: None,
94 return_all: true,
95 }
96 } else {
97 debug!(
98 "Field selector created with {} fields: {:?}",
99 field_set.len(),
100 field_set
101 );
102 Self {
103 allowed_fields: Some(field_set),
104 return_all: false,
105 }
106 }
107 }
108 Some(_) => {
109 debug!("Invalid fields parameter type, returning all fields");
110 Self {
111 allowed_fields: None,
112 return_all: true,
113 }
114 }
115 None => {
116 trace!("No fields parameter, returning all fields (backward compatible)");
117 Self {
118 allowed_fields: None,
119 return_all: true,
120 }
121 }
122 }
123 }
124
125 pub fn apply<T: Serialize>(&self, value: &T) -> Result<Value> {
135 let full = serde_json::to_value(value)?;
136
137 if self.return_all {
138 return Ok(full);
139 }
140
141 let allowed = match self.allowed_fields.as_ref() {
142 Some(fields) => fields,
143 None => return Ok(full),
144 };
145 Ok(self.filter_value(&full, allowed, ""))
146 }
147
148 fn filter_value(&self, value: &Value, allowed: &HashSet<String>, prefix: &str) -> Value {
156 match value {
157 Value::Object(map) => self.filter_object(map, allowed, prefix),
158 Value::Array(arr) => self.filter_array(arr, allowed, prefix),
159 _ => value.clone(),
160 }
161 }
162
163 fn filter_object(
165 &self,
166 map: &serde_json::Map<String, Value>,
167 allowed: &HashSet<String>,
168 prefix: &str,
169 ) -> Value {
170 let mut result = serde_json::Map::new();
171
172 for (key, value) in map.iter() {
173 let full_path = if prefix.is_empty() {
174 key.clone()
175 } else {
176 format!("{}.{}", prefix, key)
177 };
178
179 let field_allowed = allowed.contains(&full_path);
181 let child_allowed = allowed
182 .iter()
183 .any(|f| f.starts_with(&format!("{}.", full_path)));
184
185 if field_allowed || child_allowed {
186 if field_allowed {
187 result.insert(key.clone(), value.clone());
189 } else if child_allowed {
190 let filtered = self.filter_value(value, allowed, &full_path);
192 if !filtered.is_null() {
193 result.insert(key.clone(), filtered);
194 }
195 }
196 }
197 }
198
199 Value::Object(result)
200 }
201
202 fn filter_array(&self, arr: &[Value], allowed: &HashSet<String>, prefix: &str) -> Value {
204 let filtered: Vec<Value> = arr
205 .iter()
206 .map(|item| self.filter_value(item, allowed, prefix))
207 .collect();
208
209 Value::Array(filtered)
210 }
211
212 pub fn is_field_allowed(&self, path: &str) -> bool {
214 if self.return_all {
215 return true;
216 }
217
218 match &self.allowed_fields {
219 None => true,
220 Some(allowed) => {
221 allowed.contains(path)
223 || allowed.iter().any(|f| path.starts_with(&format!("{}.", f)))
224 }
225 }
226 }
227}
228
229#[cfg(test)]
230mod tests {
231 use super::*;
232 use serde_json::json;
233
234 #[test]
235 fn test_no_fields_returns_all() {
236 let args = json!({});
237 let selector = FieldSelector::from_request(&args);
238
239 let result = json!({"id": "123", "name": "test", "nested": {"value": 42}});
240 let filtered = selector.apply(&result).unwrap();
241
242 assert_eq!(filtered, result);
243 }
244
245 #[test]
246 fn test_simple_field_selection() {
247 let args = json!({"fields": ["id", "name"]});
248 let selector = FieldSelector::from_request(&args);
249
250 let result = json!({"id": "123", "name": "test", "extra": "ignored"});
251 let filtered = selector.apply(&result).unwrap();
252
253 assert_eq!(filtered["id"], "123");
254 assert_eq!(filtered["name"], "test");
255 assert!(!filtered.as_object().unwrap().contains_key("extra"));
256 }
257
258 #[test]
259 fn test_nested_field_selection() {
260 let args = json!({"fields": ["episode.id", "episode.task_description"]});
261 let selector = FieldSelector::from_request(&args);
262
263 let result = json!({
264 "episode": {
265 "id": "123",
266 "task_description": "test task",
267 "steps": ["step1", "step2"],
268 "outcome": {"type": "success"}
269 }
270 });
271 let filtered = selector.apply(&result).unwrap();
272
273 assert_eq!(filtered["episode"]["id"], "123");
274 assert_eq!(filtered["episode"]["task_description"], "test task");
275 assert!(
276 !filtered["episode"]
277 .as_object()
278 .unwrap()
279 .contains_key("steps")
280 );
281 }
282
283 #[test]
284 fn test_array_field_selection() {
285 let args = json!({"fields": ["episodes.id", "episodes.task_description"]});
286 let selector = FieldSelector::from_request(&args);
287
288 let result = json!({
289 "episodes": [
290 {"id": "1", "task_description": "task1", "extra": "data1"},
291 {"id": "2", "task_description": "task2", "extra": "data2"}
292 ]
293 });
294 let filtered = selector.apply(&result).unwrap();
295
296 assert_eq!(filtered["episodes"].as_array().unwrap().len(), 2);
297 assert_eq!(filtered["episodes"][0]["id"], "1");
298 assert_eq!(filtered["episodes"][0]["task_description"], "task1");
299 assert!(
300 !filtered["episodes"][0]
301 .as_object()
302 .unwrap()
303 .contains_key("extra")
304 );
305 }
306
307 #[test]
308 fn test_empty_fields_array_returns_all() {
309 let args = json!({"fields": []});
310 let selector = FieldSelector::from_request(&args);
311
312 assert!(selector.return_all);
313 }
314
315 #[test]
316 fn test_is_field_allowed() {
317 let selector = FieldSelector::new(
318 vec![
319 "episode.id".to_string(),
320 "episode.task_description".to_string(),
321 ]
322 .into_iter()
323 .collect(),
324 );
325
326 assert!(selector.is_field_allowed("episode.id"));
327 assert!(selector.is_field_allowed("episode.task_description"));
328 assert!(!selector.is_field_allowed("episode.steps"));
329 }
330
331 #[test]
332 fn test_complex_nested_structure() {
333 let args = json!({"fields": [
334 "episodes.id",
335 "episodes.task_description",
336 "episodes.outcome.type",
337 "patterns.success_rate"
338 ]});
339 let selector = FieldSelector::from_request(&args);
340
341 let result = json!({
342 "episodes": [
343 {
344 "id": "1",
345 "task_description": "task1",
346 "steps": ["s1", "s2"],
347 "outcome": {"type": "success", "verdict": "good"}
348 }
349 ],
350 "patterns": [
351 {"success_rate": 0.9, "description": "pattern1"}
352 ],
353 "insights": {"total": 10}
354 });
355 let filtered = selector.apply(&result).unwrap();
356
357 assert_eq!(filtered["episodes"][0]["id"], "1");
359 assert_eq!(filtered["episodes"][0]["task_description"], "task1");
360 assert_eq!(filtered["episodes"][0]["outcome"]["type"], "success");
361 assert!(
362 !filtered["episodes"][0]
363 .as_object()
364 .unwrap()
365 .contains_key("steps")
366 );
367 assert!(
368 !filtered["episodes"][0]["outcome"]
369 .as_object()
370 .unwrap()
371 .contains_key("verdict")
372 );
373
374 assert_eq!(filtered["patterns"][0]["success_rate"], 0.9);
376 assert!(
377 !filtered["patterns"][0]
378 .as_object()
379 .unwrap()
380 .contains_key("description")
381 );
382
383 assert!(!filtered.as_object().unwrap().contains_key("insights"));
385 }
386}