Skip to main content

mcpr_core/protocol/schema_manager/
manager.rs

1//! `SchemaManager` — top-level per-upstream view of an MCP server's schema.
2//!
3//! Callers feed schema-method responses in via [`SchemaManager::ingest`].
4//! The manager handles pagination buffering, change detection (by content
5//! hash), and version assignment, persisting new versions to a
6//! [`SchemaStore`]. Query methods read back the latest merged payload
7//! without re-hitting the store per item.
8
9use std::sync::Arc;
10
11use chrono::{DateTime, Utc};
12use dashmap::DashMap;
13use serde_json::Value;
14
15use super::store::SchemaStore;
16use super::version::{SchemaVersion, SchemaVersionId, hash_payload};
17use crate::protocol::schema::{PageStatus, detect_page_status, merge_pages};
18
19/// Per-method runtime state held in memory. Separate from the
20/// `SchemaStore` because these fields serve the hot path — change
21/// detection, pagination buffering, stale flag — and would be expensive
22/// to read-through on every ingest.
23#[derive(Default)]
24struct MethodState {
25    page_buffer: Vec<Value>,
26    current_hash: Option<String>,
27    next_version_number: u32,
28    stale: bool,
29    stale_since: Option<DateTime<Utc>>,
30}
31
32/// Top-level handle for one upstream MCP server's schema view.
33///
34/// Generic over the store backend; downstream typically uses
35/// `SchemaManager<MemorySchemaStore>` for the OSS proxy and swaps in a
36/// database-backed store for cloud deployments.
37pub struct SchemaManager<S: SchemaStore> {
38    upstream_id: String,
39    store: S,
40    state: Arc<DashMap<String, MethodState>>,
41}
42
43impl<S: SchemaStore> SchemaManager<S> {
44    pub fn new(upstream_id: impl Into<String>, store: S) -> Self {
45        Self {
46            upstream_id: upstream_id.into(),
47            store,
48            state: Arc::new(DashMap::new()),
49        }
50    }
51
52    pub fn upstream_id(&self) -> &str {
53        &self.upstream_id
54    }
55
56    /// Seed the in-memory state for `method` from the store.
57    ///
58    /// Callers normally don't need to invoke this directly — `ingest`
59    /// lazy-warms on the first call for a method. Exposed for explicit
60    /// startup warm-up when desired.
61    pub async fn warm(&self, method: &str) {
62        let latest = self
63            .store
64            .latest_version_for_method(&self.upstream_id, method)
65            .await;
66        if let Some(latest) = latest {
67            let mut entry = self.state.entry(method.to_string()).or_default();
68            if entry.current_hash.is_none() {
69                entry.current_hash = Some(latest.content_hash.clone());
70                entry.next_version_number = latest.version + 1;
71            }
72        }
73    }
74
75    /// Feed a schema-method response through the manager.
76    ///
77    /// Returns `Some(version)` when a new `SchemaVersion` was created
78    /// (pagination complete AND content differs from the current
79    /// version). Returns `None` when:
80    ///
81    /// - The response is not a complete page (still buffering).
82    /// - The content hash matches the current version.
83    /// - The response has no `result` field.
84    pub async fn ingest(
85        &self,
86        method: &str,
87        request_body: &Value,
88        response_body: &Value,
89    ) -> Option<SchemaVersion> {
90        let result = response_body.get("result")?;
91        let status = detect_page_status(request_body, response_body);
92
93        let merged = {
94            let mut entry = self.state.entry(method.to_string()).or_default();
95            entry.page_buffer.push(result.clone());
96            match status {
97                PageStatus::Complete | PageStatus::LastPage => {
98                    let pages = std::mem::take(&mut entry.page_buffer);
99                    merge_pages(method, &pages)
100                        .unwrap_or_else(|| pages.into_iter().next().unwrap_or(Value::Null))
101                }
102                PageStatus::FirstPage | PageStatus::MiddlePage => return None,
103            }
104        };
105
106        let hash = hash_payload(&merged);
107
108        let needs_warm = self
109            .state
110            .get(method)
111            .map(|e| e.current_hash.is_none() && e.next_version_number == 0)
112            .unwrap_or(true);
113        if needs_warm {
114            self.warm(method).await;
115        }
116
117        let (same, version_number) = {
118            let mut entry = self.state.entry(method.to_string()).or_default();
119            if entry.current_hash.as_deref() == Some(hash.as_str()) {
120                (true, 0)
121            } else {
122                let num = entry.next_version_number.max(1);
123                entry.current_hash = Some(hash.clone());
124                entry.next_version_number = num.saturating_add(1);
125                entry.stale = false;
126                entry.stale_since = None;
127                (false, num)
128            }
129        };
130
131        if same {
132            return None;
133        }
134
135        let id = SchemaVersionId(hash.chars().take(16).collect());
136        let version = SchemaVersion {
137            id,
138            upstream_id: self.upstream_id.clone(),
139            method: method.to_string(),
140            version: version_number,
141            payload: Arc::new(merged),
142            content_hash: hash,
143            captured_at: Utc::now(),
144        };
145        Some(self.store.put_version(version).await)
146    }
147
148    /// Latest stored version for `method`, or `None` if nothing has
149    /// been ingested yet.
150    pub async fn latest(&self, method: &str) -> Option<SchemaVersion> {
151        self.store
152            .latest_version_for_method(&self.upstream_id, method)
153            .await
154    }
155
156    pub async fn list_tools(&self) -> Vec<Value> {
157        self.list_items("tools/list", "tools").await
158    }
159
160    pub async fn list_resources(&self) -> Vec<Value> {
161        self.list_items("resources/list", "resources").await
162    }
163
164    pub async fn list_resource_templates(&self) -> Vec<Value> {
165        self.list_items("resources/templates/list", "resourceTemplates")
166            .await
167    }
168
169    pub async fn list_prompts(&self) -> Vec<Value> {
170        self.list_items("prompts/list", "prompts").await
171    }
172
173    pub async fn get_tool(&self, name: &str) -> Option<Value> {
174        self.find_item_by_field("tools/list", "tools", "name", name)
175            .await
176    }
177
178    pub async fn get_resource(&self, uri: &str) -> Option<Value> {
179        self.find_item_by_field("resources/list", "resources", "uri", uri)
180            .await
181    }
182
183    pub async fn get_prompt(&self, name: &str) -> Option<Value> {
184        self.find_item_by_field("prompts/list", "prompts", "name", name)
185            .await
186    }
187
188    /// Mark the current version for `method` as stale. Idempotent.
189    ///
190    /// Sync on purpose — the stale flag is used by the hot request
191    /// path (observing `notifications/tools/list_changed`) where a
192    /// round-trip to async code would be overkill.
193    pub fn mark_stale(&self, method: &str) {
194        let mut entry = self.state.entry(method.to_string()).or_default();
195        if !entry.stale {
196            entry.stale = true;
197            entry.stale_since = Some(Utc::now());
198        }
199    }
200
201    pub fn is_stale(&self, method: &str) -> bool {
202        self.state.get(method).map(|e| e.stale).unwrap_or(false)
203    }
204
205    pub fn stale_since(&self, method: &str) -> Option<DateTime<Utc>> {
206        self.state.get(method).and_then(|e| e.stale_since)
207    }
208
209    // ── internals ──
210
211    async fn list_items(&self, method: &str, array_key: &str) -> Vec<Value> {
212        let Some(latest) = self.latest(method).await else {
213            return Vec::new();
214        };
215        latest
216            .payload
217            .get(array_key)
218            .and_then(|v| v.as_array())
219            .cloned()
220            .unwrap_or_default()
221    }
222
223    async fn find_item_by_field(
224        &self,
225        method: &str,
226        array_key: &str,
227        field: &str,
228        needle: &str,
229    ) -> Option<Value> {
230        let latest = self.latest(method).await?;
231        let arr = latest.payload.get(array_key).and_then(|v| v.as_array())?;
232        arr.iter()
233            .find(|item| item.get(field).and_then(|v| v.as_str()) == Some(needle))
234            .cloned()
235    }
236}
237
238#[cfg(test)]
239#[allow(non_snake_case)]
240mod tests {
241    use super::*;
242    use crate::protocol::schema_manager::store::MemorySchemaStore;
243    use serde_json::json;
244
245    fn manager() -> SchemaManager<MemorySchemaStore> {
246        SchemaManager::new("proxy-1", MemorySchemaStore::new())
247    }
248
249    fn tools_list_req(cursor: Option<&str>) -> Value {
250        match cursor {
251            Some(c) => {
252                json!({"jsonrpc": "2.0", "id": 1, "method": "tools/list", "params": {"cursor": c}})
253            }
254            None => json!({"jsonrpc": "2.0", "id": 1, "method": "tools/list"}),
255        }
256    }
257
258    fn tools_list_resp(tools: Value, next_cursor: Option<&str>) -> Value {
259        let mut result = json!({"tools": tools});
260        if let Some(c) = next_cursor {
261            result["nextCursor"] = json!(c);
262        }
263        json!({"jsonrpc": "2.0", "id": 1, "result": result})
264    }
265
266    #[tokio::test]
267    async fn ingest__complete_page_creates_version_one() {
268        let m = manager();
269        let req = tools_list_req(None);
270        let resp = tools_list_resp(json!([{"name": "search"}]), None);
271        let v = m.ingest("tools/list", &req, &resp).await.unwrap();
272        assert_eq!(v.version, 1);
273        assert_eq!(v.method, "tools/list");
274        assert_eq!(v.upstream_id, "proxy-1");
275    }
276
277    #[tokio::test]
278    async fn ingest__first_page_buffers_returns_none() {
279        let m = manager();
280        let req = tools_list_req(None);
281        let resp = tools_list_resp(json!([{"name": "a"}]), Some("cur1"));
282        assert!(m.ingest("tools/list", &req, &resp).await.is_none());
283    }
284
285    #[tokio::test]
286    async fn ingest__first_middle_last_chain_merges_once() {
287        let m = manager();
288
289        let r1 = tools_list_resp(json!([{"name": "a"}]), Some("c1"));
290        assert!(
291            m.ingest("tools/list", &tools_list_req(None), &r1)
292                .await
293                .is_none()
294        );
295
296        let r2 = tools_list_resp(json!([{"name": "b"}]), Some("c2"));
297        assert!(
298            m.ingest("tools/list", &tools_list_req(Some("c1")), &r2)
299                .await
300                .is_none()
301        );
302
303        let r3 = tools_list_resp(json!([{"name": "c"}]), None);
304        let v = m
305            .ingest("tools/list", &tools_list_req(Some("c2")), &r3)
306            .await
307            .unwrap();
308
309        let names: Vec<&str> = v.payload["tools"]
310            .as_array()
311            .unwrap()
312            .iter()
313            .map(|t| t["name"].as_str().unwrap())
314            .collect();
315        assert_eq!(names, vec!["a", "b", "c"]);
316        assert_eq!(v.version, 1);
317    }
318
319    #[tokio::test]
320    async fn ingest__unchanged_payload_returns_none() {
321        let m = manager();
322        let req = tools_list_req(None);
323        let resp = tools_list_resp(json!([{"name": "a"}]), None);
324        m.ingest("tools/list", &req, &resp).await.unwrap();
325        assert!(m.ingest("tools/list", &req, &resp).await.is_none());
326    }
327
328    #[tokio::test]
329    async fn ingest__changed_payload_increments_version() {
330        let m = manager();
331        let req = tools_list_req(None);
332        let r1 = tools_list_resp(json!([{"name": "a"}]), None);
333        let v1 = m.ingest("tools/list", &req, &r1).await.unwrap();
334        assert_eq!(v1.version, 1);
335
336        let r2 = tools_list_resp(json!([{"name": "a"}, {"name": "b"}]), None);
337        let v2 = m.ingest("tools/list", &req, &r2).await.unwrap();
338        assert_eq!(v2.version, 2);
339    }
340
341    #[tokio::test]
342    async fn ingest__clears_stale_on_new_version() {
343        let m = manager();
344        let req = tools_list_req(None);
345        let r1 = tools_list_resp(json!([{"name": "a"}]), None);
346        m.ingest("tools/list", &req, &r1).await.unwrap();
347
348        m.mark_stale("tools/list");
349        assert!(m.is_stale("tools/list"));
350
351        let r2 = tools_list_resp(json!([{"name": "a"}, {"name": "b"}]), None);
352        m.ingest("tools/list", &req, &r2).await.unwrap();
353        assert!(!m.is_stale("tools/list"));
354    }
355
356    #[tokio::test]
357    async fn ingest__no_result_returns_none() {
358        let m = manager();
359        let req = tools_list_req(None);
360        let err_resp =
361            json!({"jsonrpc": "2.0", "id": 1, "error": {"code": -32603, "message": "x"}});
362        assert!(m.ingest("tools/list", &req, &err_resp).await.is_none());
363    }
364
365    #[tokio::test]
366    async fn mark_stale__and_is_stale_idempotent() {
367        let m = manager();
368        assert!(!m.is_stale("tools/list"));
369        m.mark_stale("tools/list");
370        let first = m.stale_since("tools/list");
371        m.mark_stale("tools/list");
372        let second = m.stale_since("tools/list");
373        assert!(m.is_stale("tools/list"));
374        assert_eq!(first, second);
375    }
376
377    #[tokio::test]
378    async fn list_tools__empty_when_no_version() {
379        let m = manager();
380        assert!(m.list_tools().await.is_empty());
381    }
382
383    #[tokio::test]
384    async fn list_tools__returns_items_from_latest() {
385        let m = manager();
386        let req = tools_list_req(None);
387        let resp = tools_list_resp(json!([{"name": "a"}, {"name": "b"}]), None);
388        m.ingest("tools/list", &req, &resp).await.unwrap();
389
390        let tools = m.list_tools().await;
391        assert_eq!(tools.len(), 2);
392        assert_eq!(tools[0]["name"], "a");
393        assert_eq!(tools[1]["name"], "b");
394    }
395
396    #[tokio::test]
397    async fn get_tool__by_name_hit_and_miss() {
398        let m = manager();
399        let req = tools_list_req(None);
400        let resp = tools_list_resp(json!([{"name": "search", "description": "find"}]), None);
401        m.ingest("tools/list", &req, &resp).await.unwrap();
402
403        let hit = m.get_tool("search").await.unwrap();
404        assert_eq!(hit["description"], "find");
405        assert!(m.get_tool("missing").await.is_none());
406    }
407
408    #[tokio::test]
409    async fn get_resource__by_uri() {
410        let m = manager();
411        let req = json!({"jsonrpc": "2.0", "id": 1, "method": "resources/list"});
412        let resp = json!({
413            "jsonrpc": "2.0", "id": 1,
414            "result": {"resources": [{"uri": "file://a", "name": "A"}]}
415        });
416        m.ingest("resources/list", &req, &resp).await.unwrap();
417        let r = m.get_resource("file://a").await.unwrap();
418        assert_eq!(r["name"], "A");
419    }
420
421    #[tokio::test]
422    async fn warm__seeds_counter_from_store() {
423        let store = MemorySchemaStore::new();
424        let pre = SchemaVersion {
425            id: SchemaVersionId("abc".to_string()),
426            upstream_id: "proxy-1".to_string(),
427            method: "tools/list".to_string(),
428            version: 5,
429            payload: Arc::new(json!({"tools": [{"name": "x"}]})),
430            content_hash: "prior-hash".to_string(),
431            captured_at: Utc::now(),
432        };
433        store.put_version(pre).await;
434
435        let m = SchemaManager::new("proxy-1", store);
436        let req = tools_list_req(None);
437        let resp = tools_list_resp(json!([{"name": "y"}]), None);
438        let v = m.ingest("tools/list", &req, &resp).await.unwrap();
439        assert_eq!(v.version, 6);
440    }
441
442    #[tokio::test]
443    async fn latest__returns_current_version() {
444        let m = manager();
445        let req = tools_list_req(None);
446        let resp = tools_list_resp(json!([{"name": "a"}]), None);
447        m.ingest("tools/list", &req, &resp).await.unwrap();
448        let latest = m.latest("tools/list").await.unwrap();
449        assert_eq!(latest.version, 1);
450    }
451
452    #[tokio::test]
453    async fn list_resource_templates__walks_template_key() {
454        let m = manager();
455        let req = json!({"jsonrpc": "2.0", "id": 1, "method": "resources/templates/list"});
456        let resp = json!({
457            "jsonrpc": "2.0", "id": 1,
458            "result": {"resourceTemplates": [{"uriTemplate": "file://{id}", "name": "f"}]}
459        });
460        m.ingest("resources/templates/list", &req, &resp)
461            .await
462            .unwrap();
463        let items = m.list_resource_templates().await;
464        assert_eq!(items.len(), 1);
465        assert_eq!(items[0]["name"], "f");
466    }
467
468    #[tokio::test]
469    async fn upstream_id__accessor() {
470        let m = manager();
471        assert_eq!(m.upstream_id(), "proxy-1");
472    }
473}