Skip to main content

mcpr_core/protocol/schema_manager/
manager.rs

1//! `SchemaManager` — top-level per-upstream view of an MCP server's schema.
2//!
3//! Callers feed schema-method responses in via [`SchemaManager::ingest`].
4//! The manager handles pagination buffering, change detection (by content
5//! hash), and version assignment, persisting new versions to a
6//! [`SchemaStore`]. Query methods read back the latest merged payload
7//! without re-hitting the store per item.
8
9use std::sync::Arc;
10use std::sync::atomic::{AtomicUsize, Ordering};
11
12use chrono::{DateTime, Utc};
13use dashmap::DashMap;
14use serde_json::Value;
15use tokio::sync::Notify;
16
17use super::store::SchemaStore;
18use super::version::{SchemaVersion, SchemaVersionId, hash_payload};
19use crate::protocol::schema::{PageStatus, detect_page_status, merge_pages};
20
21/// Tracks in-flight `spawn_ingest` tasks so callers (shutdown handlers,
22/// tests) can wait until the async ingest queue has drained.
23#[derive(Default)]
24struct PendingTracker {
25    count: AtomicUsize,
26    notify: Notify,
27}
28
29impl PendingTracker {
30    fn begin(&self) {
31        self.count.fetch_add(1, Ordering::SeqCst);
32    }
33    fn end(&self) {
34        if self.count.fetch_sub(1, Ordering::SeqCst) == 1 {
35            self.notify.notify_waiters();
36        }
37    }
38    async fn wait_idle(&self) {
39        while self.count.load(Ordering::SeqCst) > 0 {
40            let notified = self.notify.notified();
41            if self.count.load(Ordering::SeqCst) == 0 {
42                return;
43            }
44            notified.await;
45        }
46    }
47}
48
49/// Per-method runtime state held in memory. Separate from the
50/// `SchemaStore` because these fields serve the hot path — change
51/// detection, pagination buffering, stale flag — and would be expensive
52/// to read-through on every ingest.
53#[derive(Default)]
54struct MethodState {
55    page_buffer: Vec<Value>,
56    current_hash: Option<String>,
57    next_version_number: u32,
58    stale: bool,
59    stale_since: Option<DateTime<Utc>>,
60}
61
62/// Top-level handle for one upstream MCP server's schema view.
63///
64/// Generic over the store backend; downstream typically uses
65/// `SchemaManager<MemorySchemaStore>` for the OSS proxy and swaps in a
66/// database-backed store for cloud deployments.
67pub struct SchemaManager<S: SchemaStore> {
68    upstream_id: String,
69    store: S,
70    state: Arc<DashMap<String, MethodState>>,
71    pending: Arc<PendingTracker>,
72}
73
74impl<S: SchemaStore> SchemaManager<S> {
75    pub fn new(upstream_id: impl Into<String>, store: S) -> Self {
76        Self {
77            upstream_id: upstream_id.into(),
78            store,
79            state: Arc::new(DashMap::new()),
80            pending: Arc::new(PendingTracker::default()),
81        }
82    }
83
84    /// Wait until every task spawned via [`spawn_ingest`] has finished.
85    ///
86    /// Used by shutdown/test code so the bus sees every
87    /// `SchemaVersionCreated` event before it drains.
88    pub async fn wait_idle(&self) {
89        self.pending.wait_idle().await;
90    }
91
92    pub fn upstream_id(&self) -> &str {
93        &self.upstream_id
94    }
95
96    /// Seed the in-memory state for `method` from the store.
97    ///
98    /// Callers normally don't need to invoke this directly — `ingest`
99    /// lazy-warms on the first call for a method. Exposed for explicit
100    /// startup warm-up when desired.
101    pub async fn warm(&self, method: &str) {
102        let latest = self
103            .store
104            .latest_version_for_method(&self.upstream_id, method)
105            .await;
106        if let Some(latest) = latest {
107            let mut entry = self.state.entry(method.to_string()).or_default();
108            if entry.current_hash.is_none() {
109                entry.current_hash = Some(latest.content_hash.clone());
110                entry.next_version_number = latest.version + 1;
111            }
112        }
113    }
114
115    /// Bootstrap in-memory state from a pre-existing `SchemaVersion`
116    /// (typically loaded from an external persistent store at startup).
117    ///
118    /// Seeds `current_hash` + `next_version_number` so subsequent
119    /// `ingest` calls with matching content return `None` (no phantom
120    /// new version) and non-matching content increments from
121    /// `version.version + 1`. Also writes the version into the
122    /// manager's in-process store so `latest` / `list_tools` /
123    /// `get_tool` / etc. see it without needing the first live request.
124    ///
125    /// Idempotent per method: if `current_hash` is already set (either
126    /// from a prior preload or a completed ingest), this is a no-op.
127    pub async fn preload(&self, version: SchemaVersion) {
128        {
129            let mut entry = self.state.entry(version.method.clone()).or_default();
130            if entry.current_hash.is_some() {
131                return;
132            }
133            entry.current_hash = Some(version.content_hash.clone());
134            entry.next_version_number = version.version.saturating_add(1);
135        }
136        self.store.put_version(version).await;
137    }
138
139    /// Feed a schema-method response through the manager.
140    ///
141    /// Returns `Some(version)` when a new `SchemaVersion` was created
142    /// (pagination complete AND content differs from the current
143    /// version). Returns `None` when:
144    ///
145    /// - The response is not a complete page (still buffering).
146    /// - The content hash matches the current version.
147    /// - The response has no `result` field.
148    pub async fn ingest(
149        &self,
150        method: &str,
151        request_body: &Value,
152        response_body: &Value,
153    ) -> Option<SchemaVersion> {
154        let result = response_body.get("result")?;
155        let status = detect_page_status(request_body, response_body);
156
157        let merged = {
158            let mut entry = self.state.entry(method.to_string()).or_default();
159            entry.page_buffer.push(result.clone());
160            match status {
161                PageStatus::Complete | PageStatus::LastPage => {
162                    let pages = std::mem::take(&mut entry.page_buffer);
163                    merge_pages(method, &pages)
164                        .unwrap_or_else(|| pages.into_iter().next().unwrap_or(Value::Null))
165                }
166                PageStatus::FirstPage | PageStatus::MiddlePage => return None,
167            }
168        };
169
170        let hash = hash_payload(&merged);
171
172        let needs_warm = self
173            .state
174            .get(method)
175            .map(|e| e.current_hash.is_none() && e.next_version_number == 0)
176            .unwrap_or(true);
177        if needs_warm {
178            self.warm(method).await;
179        }
180
181        let (same, version_number) = {
182            let mut entry = self.state.entry(method.to_string()).or_default();
183            if entry.current_hash.as_deref() == Some(hash.as_str()) {
184                (true, 0)
185            } else {
186                let num = entry.next_version_number.max(1);
187                entry.current_hash = Some(hash.clone());
188                entry.next_version_number = num.saturating_add(1);
189                entry.stale = false;
190                entry.stale_since = None;
191                (false, num)
192            }
193        };
194
195        if same {
196            return None;
197        }
198
199        let id = SchemaVersionId(hash.chars().take(16).collect());
200        let version = SchemaVersion {
201            id,
202            upstream_id: self.upstream_id.clone(),
203            method: method.to_string(),
204            version: version_number,
205            payload: Arc::new(merged),
206            content_hash: hash,
207            captured_at: Utc::now(),
208        };
209        Some(self.store.put_version(version).await)
210    }
211
212    /// Spawn an async ingest task so the caller's hot path does not
213    /// pay for merge/hash/store work.
214    ///
215    /// Returns immediately after spawning. Use [`wait_idle`] to block
216    /// until every spawned task (including this one) has completed.
217    ///
218    /// The caller provides a sink closure that receives the new
219    /// [`SchemaVersion`] when one is produced (for emitting events).
220    /// The closure runs on the spawned task, not on the caller.
221    pub fn spawn_ingest<F>(
222        self: &Arc<Self>,
223        method: String,
224        request_body: Value,
225        response_body: Value,
226        on_version: F,
227    ) where
228        F: FnOnce(&SchemaVersion) + Send + 'static,
229    {
230        self.pending.begin();
231        let manager = Arc::clone(self);
232        tokio::spawn(async move {
233            let result = manager.ingest(&method, &request_body, &response_body).await;
234            if let Some(version) = result.as_ref() {
235                on_version(version);
236            }
237            manager.pending.end();
238        });
239    }
240
241    /// Latest stored version for `method`, or `None` if nothing has
242    /// been ingested yet.
243    pub async fn latest(&self, method: &str) -> Option<SchemaVersion> {
244        self.store
245            .latest_version_for_method(&self.upstream_id, method)
246            .await
247    }
248
249    pub async fn list_tools(&self) -> Vec<Value> {
250        self.list_items("tools/list", "tools").await
251    }
252
253    pub async fn list_resources(&self) -> Vec<Value> {
254        self.list_items("resources/list", "resources").await
255    }
256
257    pub async fn list_resource_templates(&self) -> Vec<Value> {
258        self.list_items("resources/templates/list", "resourceTemplates")
259            .await
260    }
261
262    pub async fn list_prompts(&self) -> Vec<Value> {
263        self.list_items("prompts/list", "prompts").await
264    }
265
266    pub async fn get_tool(&self, name: &str) -> Option<Value> {
267        self.find_item_by_field("tools/list", "tools", "name", name)
268            .await
269    }
270
271    pub async fn get_resource(&self, uri: &str) -> Option<Value> {
272        self.find_item_by_field("resources/list", "resources", "uri", uri)
273            .await
274    }
275
276    pub async fn get_prompt(&self, name: &str) -> Option<Value> {
277        self.find_item_by_field("prompts/list", "prompts", "name", name)
278            .await
279    }
280
281    /// Mark the current version for `method` as stale. Idempotent.
282    ///
283    /// Sync on purpose — the stale flag is used by the hot request
284    /// path (observing `notifications/tools/list_changed`) where a
285    /// round-trip to async code would be overkill.
286    pub fn mark_stale(&self, method: &str) {
287        let mut entry = self.state.entry(method.to_string()).or_default();
288        if !entry.stale {
289            entry.stale = true;
290            entry.stale_since = Some(Utc::now());
291        }
292    }
293
294    pub fn is_stale(&self, method: &str) -> bool {
295        self.state.get(method).map(|e| e.stale).unwrap_or(false)
296    }
297
298    pub fn stale_since(&self, method: &str) -> Option<DateTime<Utc>> {
299        self.state.get(method).and_then(|e| e.stale_since)
300    }
301
302    // ── internals ──
303
304    async fn list_items(&self, method: &str, array_key: &str) -> Vec<Value> {
305        let Some(latest) = self.latest(method).await else {
306            return Vec::new();
307        };
308        latest
309            .payload
310            .get(array_key)
311            .and_then(|v| v.as_array())
312            .cloned()
313            .unwrap_or_default()
314    }
315
316    async fn find_item_by_field(
317        &self,
318        method: &str,
319        array_key: &str,
320        field: &str,
321        needle: &str,
322    ) -> Option<Value> {
323        let latest = self.latest(method).await?;
324        let arr = latest.payload.get(array_key).and_then(|v| v.as_array())?;
325        arr.iter()
326            .find(|item| item.get(field).and_then(|v| v.as_str()) == Some(needle))
327            .cloned()
328    }
329}
330
331#[cfg(test)]
332#[allow(non_snake_case)]
333mod tests {
334    use super::*;
335    use crate::protocol::schema_manager::store::MemorySchemaStore;
336    use serde_json::json;
337
338    fn manager() -> SchemaManager<MemorySchemaStore> {
339        SchemaManager::new("proxy-1", MemorySchemaStore::new())
340    }
341
342    fn tools_list_req(cursor: Option<&str>) -> Value {
343        match cursor {
344            Some(c) => {
345                json!({"jsonrpc": "2.0", "id": 1, "method": "tools/list", "params": {"cursor": c}})
346            }
347            None => json!({"jsonrpc": "2.0", "id": 1, "method": "tools/list"}),
348        }
349    }
350
351    fn tools_list_resp(tools: Value, next_cursor: Option<&str>) -> Value {
352        let mut result = json!({"tools": tools});
353        if let Some(c) = next_cursor {
354            result["nextCursor"] = json!(c);
355        }
356        json!({"jsonrpc": "2.0", "id": 1, "result": result})
357    }
358
359    #[tokio::test]
360    async fn ingest__complete_page_creates_version_one() {
361        let m = manager();
362        let req = tools_list_req(None);
363        let resp = tools_list_resp(json!([{"name": "search"}]), None);
364        let v = m.ingest("tools/list", &req, &resp).await.unwrap();
365        assert_eq!(v.version, 1);
366        assert_eq!(v.method, "tools/list");
367        assert_eq!(v.upstream_id, "proxy-1");
368    }
369
370    #[tokio::test]
371    async fn ingest__first_page_buffers_returns_none() {
372        let m = manager();
373        let req = tools_list_req(None);
374        let resp = tools_list_resp(json!([{"name": "a"}]), Some("cur1"));
375        assert!(m.ingest("tools/list", &req, &resp).await.is_none());
376    }
377
378    #[tokio::test]
379    async fn ingest__first_middle_last_chain_merges_once() {
380        let m = manager();
381
382        let r1 = tools_list_resp(json!([{"name": "a"}]), Some("c1"));
383        assert!(
384            m.ingest("tools/list", &tools_list_req(None), &r1)
385                .await
386                .is_none()
387        );
388
389        let r2 = tools_list_resp(json!([{"name": "b"}]), Some("c2"));
390        assert!(
391            m.ingest("tools/list", &tools_list_req(Some("c1")), &r2)
392                .await
393                .is_none()
394        );
395
396        let r3 = tools_list_resp(json!([{"name": "c"}]), None);
397        let v = m
398            .ingest("tools/list", &tools_list_req(Some("c2")), &r3)
399            .await
400            .unwrap();
401
402        let names: Vec<&str> = v.payload["tools"]
403            .as_array()
404            .unwrap()
405            .iter()
406            .map(|t| t["name"].as_str().unwrap())
407            .collect();
408        assert_eq!(names, vec!["a", "b", "c"]);
409        assert_eq!(v.version, 1);
410    }
411
412    #[tokio::test]
413    async fn ingest__unchanged_payload_returns_none() {
414        let m = manager();
415        let req = tools_list_req(None);
416        let resp = tools_list_resp(json!([{"name": "a"}]), None);
417        m.ingest("tools/list", &req, &resp).await.unwrap();
418        assert!(m.ingest("tools/list", &req, &resp).await.is_none());
419    }
420
421    #[tokio::test]
422    async fn preload__seeds_hash_and_version_counter() {
423        // Simulates startup hydration: we hand the manager a v3 version
424        // that was persisted before a restart. The next ingest with the
425        // same content must not mint v4, and the next ingest with
426        // different content must mint v4 (not v1).
427        let m = manager();
428        let req = tools_list_req(None);
429        let stored = json!({"tools": [{"name": "a"}]});
430        let version = SchemaVersion {
431            id: SchemaVersionId("preload-seed-123".to_string()),
432            upstream_id: "proxy-1".to_string(),
433            method: "tools/list".to_string(),
434            version: 3,
435            payload: Arc::new(stored.clone()),
436            content_hash: hash_payload(&stored),
437            captured_at: Utc::now(),
438        };
439        m.preload(version).await;
440
441        // Same-content ingest: no new version.
442        let same = tools_list_resp(json!([{"name": "a"}]), None);
443        assert!(m.ingest("tools/list", &req, &same).await.is_none());
444
445        // Different content: increments to v4, not v1.
446        let changed = tools_list_resp(json!([{"name": "a"}, {"name": "b"}]), None);
447        let v4 = m.ingest("tools/list", &req, &changed).await.unwrap();
448        assert_eq!(v4.version, 4);
449    }
450
451    #[tokio::test]
452    async fn preload__idempotent_second_call_noop() {
453        let m = manager();
454        let stored = json!({"tools": [{"name": "a"}]});
455        let mk = |v: u32, tag: &str| SchemaVersion {
456            id: SchemaVersionId(format!("id-{tag}")),
457            upstream_id: "proxy-1".to_string(),
458            method: "tools/list".to_string(),
459            version: v,
460            payload: Arc::new(stored.clone()),
461            content_hash: format!("hash-{tag}"),
462            captured_at: Utc::now(),
463        };
464
465        m.preload(mk(3, "first")).await;
466        m.preload(mk(99, "second")).await;
467
468        // Second preload was skipped (state already had a hash), so
469        // the counter is 4, not 100.
470        let req = tools_list_req(None);
471        let changed = tools_list_resp(json!([{"name": "b"}]), None);
472        let v = m.ingest("tools/list", &req, &changed).await.unwrap();
473        assert_eq!(v.version, 4);
474    }
475
476    #[tokio::test]
477    async fn preload__makes_list_tools_visible_without_ingest() {
478        let m = manager();
479        let stored = json!({"tools": [{"name": "a"}, {"name": "b"}]});
480        let version = SchemaVersion {
481            id: SchemaVersionId("preload-list".to_string()),
482            upstream_id: "proxy-1".to_string(),
483            method: "tools/list".to_string(),
484            version: 1,
485            payload: Arc::new(stored.clone()),
486            content_hash: hash_payload(&stored),
487            captured_at: Utc::now(),
488        };
489        m.preload(version).await;
490
491        let tools = m.list_tools().await;
492        assert_eq!(tools.len(), 2);
493        assert_eq!(tools[0]["name"], "a");
494    }
495
496    #[tokio::test]
497    async fn ingest__volatile_meta_does_not_create_new_version() {
498        // Regression: dashboards saw 138 versions for a server whose tools
499        // hadn't changed in weeks, because the server regenerated `_meta`
500        // per request. Only the array of items should influence the hash.
501        let m = manager();
502        let req = tools_list_req(None);
503
504        let r1 = json!({
505            "jsonrpc": "2.0", "id": 1,
506            "result": {
507                "tools": [{"name": "a"}],
508                "_meta": {"requestId": "uuid-1"}
509            }
510        });
511        let r2 = json!({
512            "jsonrpc": "2.0", "id": 1,
513            "result": {
514                "tools": [{"name": "a"}],
515                "_meta": {"requestId": "uuid-2"}
516            }
517        });
518
519        let v1 = m.ingest("tools/list", &req, &r1).await.unwrap();
520        assert_eq!(v1.version, 1);
521        assert!(
522            m.ingest("tools/list", &req, &r2).await.is_none(),
523            "different _meta with identical tools must not mint a new version"
524        );
525    }
526
527    #[tokio::test]
528    async fn ingest__changed_payload_increments_version() {
529        let m = manager();
530        let req = tools_list_req(None);
531        let r1 = tools_list_resp(json!([{"name": "a"}]), None);
532        let v1 = m.ingest("tools/list", &req, &r1).await.unwrap();
533        assert_eq!(v1.version, 1);
534
535        let r2 = tools_list_resp(json!([{"name": "a"}, {"name": "b"}]), None);
536        let v2 = m.ingest("tools/list", &req, &r2).await.unwrap();
537        assert_eq!(v2.version, 2);
538    }
539
540    #[tokio::test]
541    async fn ingest__clears_stale_on_new_version() {
542        let m = manager();
543        let req = tools_list_req(None);
544        let r1 = tools_list_resp(json!([{"name": "a"}]), None);
545        m.ingest("tools/list", &req, &r1).await.unwrap();
546
547        m.mark_stale("tools/list");
548        assert!(m.is_stale("tools/list"));
549
550        let r2 = tools_list_resp(json!([{"name": "a"}, {"name": "b"}]), None);
551        m.ingest("tools/list", &req, &r2).await.unwrap();
552        assert!(!m.is_stale("tools/list"));
553    }
554
555    #[tokio::test]
556    async fn ingest__no_result_returns_none() {
557        let m = manager();
558        let req = tools_list_req(None);
559        let err_resp =
560            json!({"jsonrpc": "2.0", "id": 1, "error": {"code": -32603, "message": "x"}});
561        assert!(m.ingest("tools/list", &req, &err_resp).await.is_none());
562    }
563
564    #[tokio::test]
565    async fn mark_stale__and_is_stale_idempotent() {
566        let m = manager();
567        assert!(!m.is_stale("tools/list"));
568        m.mark_stale("tools/list");
569        let first = m.stale_since("tools/list");
570        m.mark_stale("tools/list");
571        let second = m.stale_since("tools/list");
572        assert!(m.is_stale("tools/list"));
573        assert_eq!(first, second);
574    }
575
576    #[tokio::test]
577    async fn list_tools__empty_when_no_version() {
578        let m = manager();
579        assert!(m.list_tools().await.is_empty());
580    }
581
582    #[tokio::test]
583    async fn list_tools__returns_items_from_latest() {
584        let m = manager();
585        let req = tools_list_req(None);
586        let resp = tools_list_resp(json!([{"name": "a"}, {"name": "b"}]), None);
587        m.ingest("tools/list", &req, &resp).await.unwrap();
588
589        let tools = m.list_tools().await;
590        assert_eq!(tools.len(), 2);
591        assert_eq!(tools[0]["name"], "a");
592        assert_eq!(tools[1]["name"], "b");
593    }
594
595    #[tokio::test]
596    async fn get_tool__by_name_hit_and_miss() {
597        let m = manager();
598        let req = tools_list_req(None);
599        let resp = tools_list_resp(json!([{"name": "search", "description": "find"}]), None);
600        m.ingest("tools/list", &req, &resp).await.unwrap();
601
602        let hit = m.get_tool("search").await.unwrap();
603        assert_eq!(hit["description"], "find");
604        assert!(m.get_tool("missing").await.is_none());
605    }
606
607    #[tokio::test]
608    async fn get_resource__by_uri() {
609        let m = manager();
610        let req = json!({"jsonrpc": "2.0", "id": 1, "method": "resources/list"});
611        let resp = json!({
612            "jsonrpc": "2.0", "id": 1,
613            "result": {"resources": [{"uri": "file://a", "name": "A"}]}
614        });
615        m.ingest("resources/list", &req, &resp).await.unwrap();
616        let r = m.get_resource("file://a").await.unwrap();
617        assert_eq!(r["name"], "A");
618    }
619
620    #[tokio::test]
621    async fn warm__seeds_counter_from_store() {
622        let store = MemorySchemaStore::new();
623        let pre = SchemaVersion {
624            id: SchemaVersionId("abc".to_string()),
625            upstream_id: "proxy-1".to_string(),
626            method: "tools/list".to_string(),
627            version: 5,
628            payload: Arc::new(json!({"tools": [{"name": "x"}]})),
629            content_hash: "prior-hash".to_string(),
630            captured_at: Utc::now(),
631        };
632        store.put_version(pre).await;
633
634        let m = SchemaManager::new("proxy-1", store);
635        let req = tools_list_req(None);
636        let resp = tools_list_resp(json!([{"name": "y"}]), None);
637        let v = m.ingest("tools/list", &req, &resp).await.unwrap();
638        assert_eq!(v.version, 6);
639    }
640
641    #[tokio::test]
642    async fn latest__returns_current_version() {
643        let m = manager();
644        let req = tools_list_req(None);
645        let resp = tools_list_resp(json!([{"name": "a"}]), None);
646        m.ingest("tools/list", &req, &resp).await.unwrap();
647        let latest = m.latest("tools/list").await.unwrap();
648        assert_eq!(latest.version, 1);
649    }
650
651    #[tokio::test]
652    async fn list_resource_templates__walks_template_key() {
653        let m = manager();
654        let req = json!({"jsonrpc": "2.0", "id": 1, "method": "resources/templates/list"});
655        let resp = json!({
656            "jsonrpc": "2.0", "id": 1,
657            "result": {"resourceTemplates": [{"uriTemplate": "file://{id}", "name": "f"}]}
658        });
659        m.ingest("resources/templates/list", &req, &resp)
660            .await
661            .unwrap();
662        let items = m.list_resource_templates().await;
663        assert_eq!(items.len(), 1);
664        assert_eq!(items[0]["name"], "f");
665    }
666
667    #[tokio::test]
668    async fn upstream_id__accessor() {
669        let m = manager();
670        assert_eq!(m.upstream_id(), "proxy-1");
671    }
672}