mpl_proxy/
traffic.rs

1//! Traffic recording for schema inference
2//!
3//! Records MCP/A2A traffic samples for schema generation.
4
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7use std::fs;
8use std::path::{Path, PathBuf};
9use std::sync::atomic::{AtomicU64, Ordering};
10use std::sync::{Arc, RwLock};
11use tracing::{debug, warn};
12
13/// Traffic record for a single request/response
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct TrafficRecord {
16    /// Unique ID for this record
17    pub id: String,
18    /// Timestamp of the request
19    pub timestamp: String,
20    /// Inferred or declared SType
21    pub stype: String,
22    /// HTTP method
23    pub method: String,
24    /// Request path
25    pub path: String,
26    /// Request payload
27    pub payload: serde_json::Value,
28    /// Response payload (if captured)
29    #[serde(skip_serializing_if = "Option::is_none")]
30    pub response: Option<serde_json::Value>,
31    /// Response status code
32    #[serde(skip_serializing_if = "Option::is_none")]
33    pub status_code: Option<u16>,
34    /// Request duration in milliseconds
35    #[serde(skip_serializing_if = "Option::is_none")]
36    pub duration_ms: Option<u64>,
37    /// Whether validation passed
38    #[serde(default)]
39    pub validation_passed: bool,
40    /// Validation errors if any
41    #[serde(default, skip_serializing_if = "Vec::is_empty")]
42    pub validation_errors: Vec<String>,
43}
44
45/// Traffic recorder that stores samples for schema inference
46pub struct TrafficRecorder {
47    /// Directory to store traffic samples
48    data_dir: PathBuf,
49    /// In-memory samples by SType (for quick access)
50    samples: Arc<RwLock<HashMap<String, Vec<TrafficRecord>>>>,
51    /// Counter for unique IDs
52    counter: AtomicU64,
53    /// Maximum samples to keep per SType
54    max_samples_per_stype: usize,
55    /// Whether recording is enabled
56    enabled: bool,
57}
58
59impl TrafficRecorder {
60    /// Create a new traffic recorder
61    pub fn new(data_dir: &Path, enabled: bool) -> Self {
62        let traffic_dir = data_dir.join("traffic");
63
64        // Create traffic directory if it doesn't exist
65        if enabled {
66            if let Err(e) = fs::create_dir_all(&traffic_dir) {
67                warn!("Failed to create traffic directory: {}", e);
68            }
69        }
70
71        Self {
72            data_dir: traffic_dir,
73            samples: Arc::new(RwLock::new(HashMap::new())),
74            counter: AtomicU64::new(0),
75            max_samples_per_stype: 1000,
76            enabled,
77        }
78    }
79
80    /// Check if recording is enabled
81    pub fn is_enabled(&self) -> bool {
82        self.enabled
83    }
84
85    /// Record a traffic sample
86    pub fn record(&self, record: TrafficRecord) {
87        if !self.enabled {
88            return;
89        }
90
91        let stype = record.stype.clone();
92
93        // Add to in-memory cache
94        if let Ok(mut samples) = self.samples.write() {
95            let stype_samples = samples.entry(stype.clone()).or_default();
96
97            // Keep max samples
98            if stype_samples.len() >= self.max_samples_per_stype {
99                stype_samples.remove(0);
100            }
101
102            stype_samples.push(record.clone());
103        }
104
105        // Write to disk
106        let filename = format!("{}_{}.json", record.stype.replace('.', "_"), record.id);
107        let filepath = self.data_dir.join(&filename);
108
109        if let Ok(content) = serde_json::to_string_pretty(&record) {
110            if let Err(e) = fs::write(&filepath, content) {
111                warn!("Failed to write traffic record: {}", e);
112            } else {
113                debug!("Recorded traffic sample: {}", filepath.display());
114            }
115        }
116    }
117
118    /// Generate a unique ID for a record
119    pub fn next_id(&self) -> String {
120        let count = self.counter.fetch_add(1, Ordering::SeqCst);
121        format!("{:08x}", count)
122    }
123
124    /// Get sample count for an SType
125    pub fn sample_count(&self, stype: &str) -> usize {
126        self.samples
127            .read()
128            .ok()
129            .and_then(|s| s.get(stype).map(|v| v.len()))
130            .unwrap_or(0)
131    }
132
133    /// Get all STypes with sample counts
134    pub fn get_stats(&self) -> HashMap<String, usize> {
135        self.samples
136            .read()
137            .ok()
138            .map(|s| s.iter().map(|(k, v)| (k.clone(), v.len())).collect())
139            .unwrap_or_default()
140    }
141
142    /// Get samples for an SType
143    pub fn get_samples(&self, stype: &str) -> Vec<TrafficRecord> {
144        self.samples
145            .read()
146            .ok()
147            .and_then(|s| s.get(stype).cloned())
148            .unwrap_or_default()
149    }
150
151    /// Load samples from disk
152    pub fn load_from_disk(&self) -> anyhow::Result<usize> {
153        if !self.data_dir.exists() {
154            return Ok(0);
155        }
156
157        let mut loaded = 0;
158
159        for entry in fs::read_dir(&self.data_dir)? {
160            let entry = entry?;
161            let path = entry.path();
162
163            if path.extension().map(|e| e == "json").unwrap_or(false) {
164                if let Ok(content) = fs::read_to_string(&path) {
165                    if let Ok(record) = serde_json::from_str::<TrafficRecord>(&content) {
166                        if let Ok(mut samples) = self.samples.write() {
167                            samples
168                                .entry(record.stype.clone())
169                                .or_default()
170                                .push(record);
171                            loaded += 1;
172                        }
173                    }
174                }
175            }
176        }
177
178        debug!("Loaded {} traffic samples from disk", loaded);
179        Ok(loaded)
180    }
181}
182
183/// SType inference from payload structure
184pub struct StypeInferrer;
185
186impl StypeInferrer {
187    /// Infer an SType from a JSON payload and request context
188    pub fn infer(
189        path: &str,
190        method: &str,
191        payload: &serde_json::Value,
192    ) -> String {
193        // Check for A2A task patterns first
194        if let Some(stype) = Self::infer_a2a(path, payload) {
195            return stype;
196        }
197
198        // Try to extract from JSON-RPC method
199        if let Some(rpc_method) = payload.get("method").and_then(|m| m.as_str()) {
200            return Self::method_to_stype(rpc_method);
201        }
202
203        // Try to extract from MCP tools/call
204        if let Some(params) = payload.get("params") {
205            if let Some(name) = params.get("name").and_then(|n| n.as_str()) {
206                return format!("inferred.tool.{}.v1", Self::normalize_name(name));
207            }
208        }
209
210        // Infer from path
211        let path_parts: Vec<&str> = path
212            .trim_matches('/')
213            .split('/')
214            .filter(|p| !p.is_empty())
215            .collect();
216
217        if !path_parts.is_empty() {
218            let name = path_parts.last().unwrap_or(&"unknown");
219            return format!(
220                "inferred.{}.{}.v1",
221                method.to_lowercase(),
222                Self::normalize_name(name)
223            );
224        }
225
226        // Fallback: hash-based inference
227        format!("inferred.unknown.payload.v1")
228    }
229
230    /// Infer SType for A2A protocol patterns
231    fn infer_a2a(path: &str, payload: &serde_json::Value) -> Option<String> {
232        // A2A task endpoints: /tasks, /tasks/{id}, /tasks/{id}/send, etc.
233        if path.contains("/tasks") {
234            // Check for task_id in payload or path
235            if let Some(task_id) = payload.get("task_id").or(payload.get("id")) {
236                if task_id.is_string() {
237                    // Determine operation from path
238                    if path.contains("/send") {
239                        return Some("a2a.task.SendMessage.v1".to_string());
240                    } else if path.contains("/cancel") {
241                        return Some("a2a.task.Cancel.v1".to_string());
242                    } else if path.ends_with("/tasks") || path.contains("/tasks/") {
243                        return Some("a2a.task.Task.v1".to_string());
244                    }
245                }
246            }
247
248            // A2A message pattern
249            if payload.get("message").is_some() || payload.get("messages").is_some() {
250                return Some("a2a.task.Message.v1".to_string());
251            }
252
253            return Some("a2a.task.Request.v1".to_string());
254        }
255
256        // A2A agent info endpoint
257        if path.contains("/agent") || path.contains("/.well-known/agent") {
258            return Some("a2a.agent.Info.v1".to_string());
259        }
260
261        // A2A streaming/push notifications
262        if path.contains("/subscribe") || path.contains("/notifications") {
263            return Some("a2a.notification.Subscribe.v1".to_string());
264        }
265
266        None
267    }
268
269    /// Convert JSON-RPC method to SType
270    fn method_to_stype(method: &str) -> String {
271        // e.g., "tools/call" -> "mcp.tools.call.v1"
272        let normalized = method.replace('/', ".");
273        format!("mcp.{}.v1", normalized)
274    }
275
276    /// Normalize a name for use in SType
277    fn normalize_name(name: &str) -> String {
278        name.chars()
279            .map(|c| if c.is_alphanumeric() { c } else { '_' })
280            .collect::<String>()
281            .to_lowercase()
282    }
283}
284
285#[cfg(test)]
286mod tests {
287    use super::*;
288    use serde_json::json;
289
290    #[test]
291    fn test_infer_from_jsonrpc() {
292        let payload = json!({
293            "jsonrpc": "2.0",
294            "method": "tools/call",
295            "params": {"name": "test"}
296        });
297
298        let stype = StypeInferrer::infer("/", "POST", &payload);
299        assert_eq!(stype, "mcp.tools.call.v1");
300    }
301
302    #[test]
303    fn test_infer_from_tool_name() {
304        let payload = json!({
305            "jsonrpc": "2.0",
306            "params": {"name": "calendar_create"}
307        });
308
309        let stype = StypeInferrer::infer("/", "POST", &payload);
310        assert_eq!(stype, "inferred.tool.calendar_create.v1");
311    }
312
313    #[test]
314    fn test_infer_from_path() {
315        let payload = json!({});
316        let stype = StypeInferrer::infer("/api/events", "POST", &payload);
317        assert_eq!(stype, "inferred.post.events.v1");
318    }
319}