1use std::sync::Arc;
10use std::sync::atomic::{AtomicUsize, Ordering};
11
12use chrono::{DateTime, Utc};
13use dashmap::DashMap;
14use serde_json::Value;
15use tokio::sync::Notify;
16
17use super::store::SchemaStore;
18use super::version::{SchemaVersion, SchemaVersionId, hash_payload};
19use crate::protocol::schema::{PageStatus, detect_page_status, merge_pages};
20
21#[derive(Default)]
24struct PendingTracker {
25 count: AtomicUsize,
26 notify: Notify,
27}
28
29impl PendingTracker {
30 fn begin(&self) {
31 self.count.fetch_add(1, Ordering::SeqCst);
32 }
33 fn end(&self) {
34 if self.count.fetch_sub(1, Ordering::SeqCst) == 1 {
35 self.notify.notify_waiters();
36 }
37 }
38 async fn wait_idle(&self) {
39 while self.count.load(Ordering::SeqCst) > 0 {
40 let notified = self.notify.notified();
41 if self.count.load(Ordering::SeqCst) == 0 {
42 return;
43 }
44 notified.await;
45 }
46 }
47}
48
49#[derive(Default)]
54struct MethodState {
55 page_buffer: Vec<Value>,
56 current_hash: Option<String>,
57 next_version_number: u32,
58 stale: bool,
59 stale_since: Option<DateTime<Utc>>,
60}
61
62pub struct SchemaManager<S: SchemaStore> {
68 upstream_id: String,
69 store: S,
70 state: Arc<DashMap<String, MethodState>>,
71 pending: Arc<PendingTracker>,
72}
73
74impl<S: SchemaStore> SchemaManager<S> {
75 pub fn new(upstream_id: impl Into<String>, store: S) -> Self {
76 Self {
77 upstream_id: upstream_id.into(),
78 store,
79 state: Arc::new(DashMap::new()),
80 pending: Arc::new(PendingTracker::default()),
81 }
82 }
83
84 pub async fn wait_idle(&self) {
89 self.pending.wait_idle().await;
90 }
91
92 pub fn upstream_id(&self) -> &str {
93 &self.upstream_id
94 }
95
96 pub async fn warm(&self, method: &str) {
102 let latest = self
103 .store
104 .latest_version_for_method(&self.upstream_id, method)
105 .await;
106 if let Some(latest) = latest {
107 let mut entry = self.state.entry(method.to_string()).or_default();
108 if entry.current_hash.is_none() {
109 entry.current_hash = Some(latest.content_hash.clone());
110 entry.next_version_number = latest.version + 1;
111 }
112 }
113 }
114
115 pub async fn preload(&self, version: SchemaVersion) {
128 {
129 let mut entry = self.state.entry(version.method.clone()).or_default();
130 if entry.current_hash.is_some() {
131 return;
132 }
133 entry.current_hash = Some(version.content_hash.clone());
134 entry.next_version_number = version.version.saturating_add(1);
135 }
136 self.store.put_version(version).await;
137 }
138
139 pub async fn ingest(
149 &self,
150 method: &str,
151 request_body: &Value,
152 response_body: &Value,
153 ) -> Option<SchemaVersion> {
154 let result = response_body.get("result")?;
155 let status = detect_page_status(request_body, response_body);
156
157 let merged = {
158 let mut entry = self.state.entry(method.to_string()).or_default();
159 entry.page_buffer.push(result.clone());
160 match status {
161 PageStatus::Complete | PageStatus::LastPage => {
162 let pages = std::mem::take(&mut entry.page_buffer);
163 merge_pages(method, &pages)
164 .unwrap_or_else(|| pages.into_iter().next().unwrap_or(Value::Null))
165 }
166 PageStatus::FirstPage | PageStatus::MiddlePage => return None,
167 }
168 };
169
170 let hash = hash_payload(&merged);
171
172 let needs_warm = self
173 .state
174 .get(method)
175 .map(|e| e.current_hash.is_none() && e.next_version_number == 0)
176 .unwrap_or(true);
177 if needs_warm {
178 self.warm(method).await;
179 }
180
181 let (same, version_number) = {
182 let mut entry = self.state.entry(method.to_string()).or_default();
183 if entry.current_hash.as_deref() == Some(hash.as_str()) {
184 (true, 0)
185 } else {
186 let num = entry.next_version_number.max(1);
187 entry.current_hash = Some(hash.clone());
188 entry.next_version_number = num.saturating_add(1);
189 entry.stale = false;
190 entry.stale_since = None;
191 (false, num)
192 }
193 };
194
195 if same {
196 return None;
197 }
198
199 let id = SchemaVersionId(hash.chars().take(16).collect());
200 let version = SchemaVersion {
201 id,
202 upstream_id: self.upstream_id.clone(),
203 method: method.to_string(),
204 version: version_number,
205 payload: Arc::new(merged),
206 content_hash: hash,
207 captured_at: Utc::now(),
208 };
209 Some(self.store.put_version(version).await)
210 }
211
212 pub fn spawn_ingest<F>(
222 self: &Arc<Self>,
223 method: String,
224 request_body: Value,
225 response_body: Value,
226 on_version: F,
227 ) where
228 F: FnOnce(&SchemaVersion) + Send + 'static,
229 {
230 self.pending.begin();
231 let manager = Arc::clone(self);
232 tokio::spawn(async move {
233 let result = manager.ingest(&method, &request_body, &response_body).await;
234 if let Some(version) = result.as_ref() {
235 on_version(version);
236 }
237 manager.pending.end();
238 });
239 }
240
241 pub async fn latest(&self, method: &str) -> Option<SchemaVersion> {
244 self.store
245 .latest_version_for_method(&self.upstream_id, method)
246 .await
247 }
248
249 pub async fn list_tools(&self) -> Vec<Value> {
250 self.list_items("tools/list", "tools").await
251 }
252
253 pub async fn list_resources(&self) -> Vec<Value> {
254 self.list_items("resources/list", "resources").await
255 }
256
257 pub async fn list_resource_templates(&self) -> Vec<Value> {
258 self.list_items("resources/templates/list", "resourceTemplates")
259 .await
260 }
261
262 pub async fn list_prompts(&self) -> Vec<Value> {
263 self.list_items("prompts/list", "prompts").await
264 }
265
266 pub async fn get_tool(&self, name: &str) -> Option<Value> {
267 self.find_item_by_field("tools/list", "tools", "name", name)
268 .await
269 }
270
271 pub async fn get_resource(&self, uri: &str) -> Option<Value> {
272 self.find_item_by_field("resources/list", "resources", "uri", uri)
273 .await
274 }
275
276 pub async fn get_prompt(&self, name: &str) -> Option<Value> {
277 self.find_item_by_field("prompts/list", "prompts", "name", name)
278 .await
279 }
280
281 pub fn mark_stale(&self, method: &str) {
287 let mut entry = self.state.entry(method.to_string()).or_default();
288 if !entry.stale {
289 entry.stale = true;
290 entry.stale_since = Some(Utc::now());
291 }
292 }
293
294 pub fn is_stale(&self, method: &str) -> bool {
295 self.state.get(method).map(|e| e.stale).unwrap_or(false)
296 }
297
298 pub fn stale_since(&self, method: &str) -> Option<DateTime<Utc>> {
299 self.state.get(method).and_then(|e| e.stale_since)
300 }
301
302 async fn list_items(&self, method: &str, array_key: &str) -> Vec<Value> {
305 let Some(latest) = self.latest(method).await else {
306 return Vec::new();
307 };
308 latest
309 .payload
310 .get(array_key)
311 .and_then(|v| v.as_array())
312 .cloned()
313 .unwrap_or_default()
314 }
315
316 async fn find_item_by_field(
317 &self,
318 method: &str,
319 array_key: &str,
320 field: &str,
321 needle: &str,
322 ) -> Option<Value> {
323 let latest = self.latest(method).await?;
324 let arr = latest.payload.get(array_key).and_then(|v| v.as_array())?;
325 arr.iter()
326 .find(|item| item.get(field).and_then(|v| v.as_str()) == Some(needle))
327 .cloned()
328 }
329}
330
331#[cfg(test)]
332#[allow(non_snake_case)]
333mod tests {
334 use super::*;
335 use crate::protocol::schema_manager::store::MemorySchemaStore;
336 use serde_json::json;
337
338 fn manager() -> SchemaManager<MemorySchemaStore> {
339 SchemaManager::new("proxy-1", MemorySchemaStore::new())
340 }
341
342 fn tools_list_req(cursor: Option<&str>) -> Value {
343 match cursor {
344 Some(c) => {
345 json!({"jsonrpc": "2.0", "id": 1, "method": "tools/list", "params": {"cursor": c}})
346 }
347 None => json!({"jsonrpc": "2.0", "id": 1, "method": "tools/list"}),
348 }
349 }
350
351 fn tools_list_resp(tools: Value, next_cursor: Option<&str>) -> Value {
352 let mut result = json!({"tools": tools});
353 if let Some(c) = next_cursor {
354 result["nextCursor"] = json!(c);
355 }
356 json!({"jsonrpc": "2.0", "id": 1, "result": result})
357 }
358
359 #[tokio::test]
360 async fn ingest__complete_page_creates_version_one() {
361 let m = manager();
362 let req = tools_list_req(None);
363 let resp = tools_list_resp(json!([{"name": "search"}]), None);
364 let v = m.ingest("tools/list", &req, &resp).await.unwrap();
365 assert_eq!(v.version, 1);
366 assert_eq!(v.method, "tools/list");
367 assert_eq!(v.upstream_id, "proxy-1");
368 }
369
370 #[tokio::test]
371 async fn ingest__first_page_buffers_returns_none() {
372 let m = manager();
373 let req = tools_list_req(None);
374 let resp = tools_list_resp(json!([{"name": "a"}]), Some("cur1"));
375 assert!(m.ingest("tools/list", &req, &resp).await.is_none());
376 }
377
378 #[tokio::test]
379 async fn ingest__first_middle_last_chain_merges_once() {
380 let m = manager();
381
382 let r1 = tools_list_resp(json!([{"name": "a"}]), Some("c1"));
383 assert!(
384 m.ingest("tools/list", &tools_list_req(None), &r1)
385 .await
386 .is_none()
387 );
388
389 let r2 = tools_list_resp(json!([{"name": "b"}]), Some("c2"));
390 assert!(
391 m.ingest("tools/list", &tools_list_req(Some("c1")), &r2)
392 .await
393 .is_none()
394 );
395
396 let r3 = tools_list_resp(json!([{"name": "c"}]), None);
397 let v = m
398 .ingest("tools/list", &tools_list_req(Some("c2")), &r3)
399 .await
400 .unwrap();
401
402 let names: Vec<&str> = v.payload["tools"]
403 .as_array()
404 .unwrap()
405 .iter()
406 .map(|t| t["name"].as_str().unwrap())
407 .collect();
408 assert_eq!(names, vec!["a", "b", "c"]);
409 assert_eq!(v.version, 1);
410 }
411
412 #[tokio::test]
413 async fn ingest__unchanged_payload_returns_none() {
414 let m = manager();
415 let req = tools_list_req(None);
416 let resp = tools_list_resp(json!([{"name": "a"}]), None);
417 m.ingest("tools/list", &req, &resp).await.unwrap();
418 assert!(m.ingest("tools/list", &req, &resp).await.is_none());
419 }
420
421 #[tokio::test]
422 async fn preload__seeds_hash_and_version_counter() {
423 let m = manager();
428 let req = tools_list_req(None);
429 let stored = json!({"tools": [{"name": "a"}]});
430 let version = SchemaVersion {
431 id: SchemaVersionId("preload-seed-123".to_string()),
432 upstream_id: "proxy-1".to_string(),
433 method: "tools/list".to_string(),
434 version: 3,
435 payload: Arc::new(stored.clone()),
436 content_hash: hash_payload(&stored),
437 captured_at: Utc::now(),
438 };
439 m.preload(version).await;
440
441 let same = tools_list_resp(json!([{"name": "a"}]), None);
443 assert!(m.ingest("tools/list", &req, &same).await.is_none());
444
445 let changed = tools_list_resp(json!([{"name": "a"}, {"name": "b"}]), None);
447 let v4 = m.ingest("tools/list", &req, &changed).await.unwrap();
448 assert_eq!(v4.version, 4);
449 }
450
451 #[tokio::test]
452 async fn preload__idempotent_second_call_noop() {
453 let m = manager();
454 let stored = json!({"tools": [{"name": "a"}]});
455 let mk = |v: u32, tag: &str| SchemaVersion {
456 id: SchemaVersionId(format!("id-{tag}")),
457 upstream_id: "proxy-1".to_string(),
458 method: "tools/list".to_string(),
459 version: v,
460 payload: Arc::new(stored.clone()),
461 content_hash: format!("hash-{tag}"),
462 captured_at: Utc::now(),
463 };
464
465 m.preload(mk(3, "first")).await;
466 m.preload(mk(99, "second")).await;
467
468 let req = tools_list_req(None);
471 let changed = tools_list_resp(json!([{"name": "b"}]), None);
472 let v = m.ingest("tools/list", &req, &changed).await.unwrap();
473 assert_eq!(v.version, 4);
474 }
475
476 #[tokio::test]
477 async fn preload__makes_list_tools_visible_without_ingest() {
478 let m = manager();
479 let stored = json!({"tools": [{"name": "a"}, {"name": "b"}]});
480 let version = SchemaVersion {
481 id: SchemaVersionId("preload-list".to_string()),
482 upstream_id: "proxy-1".to_string(),
483 method: "tools/list".to_string(),
484 version: 1,
485 payload: Arc::new(stored.clone()),
486 content_hash: hash_payload(&stored),
487 captured_at: Utc::now(),
488 };
489 m.preload(version).await;
490
491 let tools = m.list_tools().await;
492 assert_eq!(tools.len(), 2);
493 assert_eq!(tools[0]["name"], "a");
494 }
495
496 #[tokio::test]
497 async fn ingest__volatile_meta_does_not_create_new_version() {
498 let m = manager();
502 let req = tools_list_req(None);
503
504 let r1 = json!({
505 "jsonrpc": "2.0", "id": 1,
506 "result": {
507 "tools": [{"name": "a"}],
508 "_meta": {"requestId": "uuid-1"}
509 }
510 });
511 let r2 = json!({
512 "jsonrpc": "2.0", "id": 1,
513 "result": {
514 "tools": [{"name": "a"}],
515 "_meta": {"requestId": "uuid-2"}
516 }
517 });
518
519 let v1 = m.ingest("tools/list", &req, &r1).await.unwrap();
520 assert_eq!(v1.version, 1);
521 assert!(
522 m.ingest("tools/list", &req, &r2).await.is_none(),
523 "different _meta with identical tools must not mint a new version"
524 );
525 }
526
527 #[tokio::test]
528 async fn ingest__changed_payload_increments_version() {
529 let m = manager();
530 let req = tools_list_req(None);
531 let r1 = tools_list_resp(json!([{"name": "a"}]), None);
532 let v1 = m.ingest("tools/list", &req, &r1).await.unwrap();
533 assert_eq!(v1.version, 1);
534
535 let r2 = tools_list_resp(json!([{"name": "a"}, {"name": "b"}]), None);
536 let v2 = m.ingest("tools/list", &req, &r2).await.unwrap();
537 assert_eq!(v2.version, 2);
538 }
539
540 #[tokio::test]
541 async fn ingest__clears_stale_on_new_version() {
542 let m = manager();
543 let req = tools_list_req(None);
544 let r1 = tools_list_resp(json!([{"name": "a"}]), None);
545 m.ingest("tools/list", &req, &r1).await.unwrap();
546
547 m.mark_stale("tools/list");
548 assert!(m.is_stale("tools/list"));
549
550 let r2 = tools_list_resp(json!([{"name": "a"}, {"name": "b"}]), None);
551 m.ingest("tools/list", &req, &r2).await.unwrap();
552 assert!(!m.is_stale("tools/list"));
553 }
554
555 #[tokio::test]
556 async fn ingest__no_result_returns_none() {
557 let m = manager();
558 let req = tools_list_req(None);
559 let err_resp =
560 json!({"jsonrpc": "2.0", "id": 1, "error": {"code": -32603, "message": "x"}});
561 assert!(m.ingest("tools/list", &req, &err_resp).await.is_none());
562 }
563
564 #[tokio::test]
565 async fn mark_stale__and_is_stale_idempotent() {
566 let m = manager();
567 assert!(!m.is_stale("tools/list"));
568 m.mark_stale("tools/list");
569 let first = m.stale_since("tools/list");
570 m.mark_stale("tools/list");
571 let second = m.stale_since("tools/list");
572 assert!(m.is_stale("tools/list"));
573 assert_eq!(first, second);
574 }
575
576 #[tokio::test]
577 async fn list_tools__empty_when_no_version() {
578 let m = manager();
579 assert!(m.list_tools().await.is_empty());
580 }
581
582 #[tokio::test]
583 async fn list_tools__returns_items_from_latest() {
584 let m = manager();
585 let req = tools_list_req(None);
586 let resp = tools_list_resp(json!([{"name": "a"}, {"name": "b"}]), None);
587 m.ingest("tools/list", &req, &resp).await.unwrap();
588
589 let tools = m.list_tools().await;
590 assert_eq!(tools.len(), 2);
591 assert_eq!(tools[0]["name"], "a");
592 assert_eq!(tools[1]["name"], "b");
593 }
594
595 #[tokio::test]
596 async fn get_tool__by_name_hit_and_miss() {
597 let m = manager();
598 let req = tools_list_req(None);
599 let resp = tools_list_resp(json!([{"name": "search", "description": "find"}]), None);
600 m.ingest("tools/list", &req, &resp).await.unwrap();
601
602 let hit = m.get_tool("search").await.unwrap();
603 assert_eq!(hit["description"], "find");
604 assert!(m.get_tool("missing").await.is_none());
605 }
606
607 #[tokio::test]
608 async fn get_resource__by_uri() {
609 let m = manager();
610 let req = json!({"jsonrpc": "2.0", "id": 1, "method": "resources/list"});
611 let resp = json!({
612 "jsonrpc": "2.0", "id": 1,
613 "result": {"resources": [{"uri": "file://a", "name": "A"}]}
614 });
615 m.ingest("resources/list", &req, &resp).await.unwrap();
616 let r = m.get_resource("file://a").await.unwrap();
617 assert_eq!(r["name"], "A");
618 }
619
620 #[tokio::test]
621 async fn warm__seeds_counter_from_store() {
622 let store = MemorySchemaStore::new();
623 let pre = SchemaVersion {
624 id: SchemaVersionId("abc".to_string()),
625 upstream_id: "proxy-1".to_string(),
626 method: "tools/list".to_string(),
627 version: 5,
628 payload: Arc::new(json!({"tools": [{"name": "x"}]})),
629 content_hash: "prior-hash".to_string(),
630 captured_at: Utc::now(),
631 };
632 store.put_version(pre).await;
633
634 let m = SchemaManager::new("proxy-1", store);
635 let req = tools_list_req(None);
636 let resp = tools_list_resp(json!([{"name": "y"}]), None);
637 let v = m.ingest("tools/list", &req, &resp).await.unwrap();
638 assert_eq!(v.version, 6);
639 }
640
641 #[tokio::test]
642 async fn latest__returns_current_version() {
643 let m = manager();
644 let req = tools_list_req(None);
645 let resp = tools_list_resp(json!([{"name": "a"}]), None);
646 m.ingest("tools/list", &req, &resp).await.unwrap();
647 let latest = m.latest("tools/list").await.unwrap();
648 assert_eq!(latest.version, 1);
649 }
650
651 #[tokio::test]
652 async fn list_resource_templates__walks_template_key() {
653 let m = manager();
654 let req = json!({"jsonrpc": "2.0", "id": 1, "method": "resources/templates/list"});
655 let resp = json!({
656 "jsonrpc": "2.0", "id": 1,
657 "result": {"resourceTemplates": [{"uriTemplate": "file://{id}", "name": "f"}]}
658 });
659 m.ingest("resources/templates/list", &req, &resp)
660 .await
661 .unwrap();
662 let items = m.list_resource_templates().await;
663 assert_eq!(items.len(), 1);
664 assert_eq!(items[0]["name"], "f");
665 }
666
667 #[tokio::test]
668 async fn upstream_id__accessor() {
669 let m = manager();
670 assert_eq!(m.upstream_id(), "proxy-1");
671 }
672}