1use std::sync::Arc;
10
11use chrono::{DateTime, Utc};
12use dashmap::DashMap;
13use serde_json::Value;
14
15use super::store::SchemaStore;
16use super::version::{SchemaVersion, SchemaVersionId, hash_payload};
17use crate::protocol::schema::{PageStatus, detect_page_status, merge_pages};
18
19#[derive(Default)]
24struct MethodState {
25 page_buffer: Vec<Value>,
26 current_hash: Option<String>,
27 next_version_number: u32,
28 stale: bool,
29 stale_since: Option<DateTime<Utc>>,
30}
31
32pub struct SchemaManager<S: SchemaStore> {
38 upstream_id: String,
39 store: S,
40 state: Arc<DashMap<String, MethodState>>,
41}
42
43impl<S: SchemaStore> SchemaManager<S> {
44 pub fn new(upstream_id: impl Into<String>, store: S) -> Self {
45 Self {
46 upstream_id: upstream_id.into(),
47 store,
48 state: Arc::new(DashMap::new()),
49 }
50 }
51
52 pub fn upstream_id(&self) -> &str {
53 &self.upstream_id
54 }
55
56 pub async fn warm(&self, method: &str) {
62 let latest = self
63 .store
64 .latest_version_for_method(&self.upstream_id, method)
65 .await;
66 if let Some(latest) = latest {
67 let mut entry = self.state.entry(method.to_string()).or_default();
68 if entry.current_hash.is_none() {
69 entry.current_hash = Some(latest.content_hash.clone());
70 entry.next_version_number = latest.version + 1;
71 }
72 }
73 }
74
75 pub async fn ingest(
85 &self,
86 method: &str,
87 request_body: &Value,
88 response_body: &Value,
89 ) -> Option<SchemaVersion> {
90 let result = response_body.get("result")?;
91 let status = detect_page_status(request_body, response_body);
92
93 let merged = {
94 let mut entry = self.state.entry(method.to_string()).or_default();
95 entry.page_buffer.push(result.clone());
96 match status {
97 PageStatus::Complete | PageStatus::LastPage => {
98 let pages = std::mem::take(&mut entry.page_buffer);
99 merge_pages(method, &pages)
100 .unwrap_or_else(|| pages.into_iter().next().unwrap_or(Value::Null))
101 }
102 PageStatus::FirstPage | PageStatus::MiddlePage => return None,
103 }
104 };
105
106 let hash = hash_payload(&merged);
107
108 let needs_warm = self
109 .state
110 .get(method)
111 .map(|e| e.current_hash.is_none() && e.next_version_number == 0)
112 .unwrap_or(true);
113 if needs_warm {
114 self.warm(method).await;
115 }
116
117 let (same, version_number) = {
118 let mut entry = self.state.entry(method.to_string()).or_default();
119 if entry.current_hash.as_deref() == Some(hash.as_str()) {
120 (true, 0)
121 } else {
122 let num = entry.next_version_number.max(1);
123 entry.current_hash = Some(hash.clone());
124 entry.next_version_number = num.saturating_add(1);
125 entry.stale = false;
126 entry.stale_since = None;
127 (false, num)
128 }
129 };
130
131 if same {
132 return None;
133 }
134
135 let id = SchemaVersionId(hash.chars().take(16).collect());
136 let version = SchemaVersion {
137 id,
138 upstream_id: self.upstream_id.clone(),
139 method: method.to_string(),
140 version: version_number,
141 payload: Arc::new(merged),
142 content_hash: hash,
143 captured_at: Utc::now(),
144 };
145 Some(self.store.put_version(version).await)
146 }
147
148 pub async fn latest(&self, method: &str) -> Option<SchemaVersion> {
151 self.store
152 .latest_version_for_method(&self.upstream_id, method)
153 .await
154 }
155
156 pub async fn list_tools(&self) -> Vec<Value> {
157 self.list_items("tools/list", "tools").await
158 }
159
160 pub async fn list_resources(&self) -> Vec<Value> {
161 self.list_items("resources/list", "resources").await
162 }
163
164 pub async fn list_resource_templates(&self) -> Vec<Value> {
165 self.list_items("resources/templates/list", "resourceTemplates")
166 .await
167 }
168
169 pub async fn list_prompts(&self) -> Vec<Value> {
170 self.list_items("prompts/list", "prompts").await
171 }
172
173 pub async fn get_tool(&self, name: &str) -> Option<Value> {
174 self.find_item_by_field("tools/list", "tools", "name", name)
175 .await
176 }
177
178 pub async fn get_resource(&self, uri: &str) -> Option<Value> {
179 self.find_item_by_field("resources/list", "resources", "uri", uri)
180 .await
181 }
182
183 pub async fn get_prompt(&self, name: &str) -> Option<Value> {
184 self.find_item_by_field("prompts/list", "prompts", "name", name)
185 .await
186 }
187
188 pub fn mark_stale(&self, method: &str) {
194 let mut entry = self.state.entry(method.to_string()).or_default();
195 if !entry.stale {
196 entry.stale = true;
197 entry.stale_since = Some(Utc::now());
198 }
199 }
200
201 pub fn is_stale(&self, method: &str) -> bool {
202 self.state.get(method).map(|e| e.stale).unwrap_or(false)
203 }
204
205 pub fn stale_since(&self, method: &str) -> Option<DateTime<Utc>> {
206 self.state.get(method).and_then(|e| e.stale_since)
207 }
208
209 async fn list_items(&self, method: &str, array_key: &str) -> Vec<Value> {
212 let Some(latest) = self.latest(method).await else {
213 return Vec::new();
214 };
215 latest
216 .payload
217 .get(array_key)
218 .and_then(|v| v.as_array())
219 .cloned()
220 .unwrap_or_default()
221 }
222
223 async fn find_item_by_field(
224 &self,
225 method: &str,
226 array_key: &str,
227 field: &str,
228 needle: &str,
229 ) -> Option<Value> {
230 let latest = self.latest(method).await?;
231 let arr = latest.payload.get(array_key).and_then(|v| v.as_array())?;
232 arr.iter()
233 .find(|item| item.get(field).and_then(|v| v.as_str()) == Some(needle))
234 .cloned()
235 }
236}
237
238#[cfg(test)]
239#[allow(non_snake_case)]
240mod tests {
241 use super::*;
242 use crate::protocol::schema_manager::store::MemorySchemaStore;
243 use serde_json::json;
244
245 fn manager() -> SchemaManager<MemorySchemaStore> {
246 SchemaManager::new("proxy-1", MemorySchemaStore::new())
247 }
248
249 fn tools_list_req(cursor: Option<&str>) -> Value {
250 match cursor {
251 Some(c) => {
252 json!({"jsonrpc": "2.0", "id": 1, "method": "tools/list", "params": {"cursor": c}})
253 }
254 None => json!({"jsonrpc": "2.0", "id": 1, "method": "tools/list"}),
255 }
256 }
257
258 fn tools_list_resp(tools: Value, next_cursor: Option<&str>) -> Value {
259 let mut result = json!({"tools": tools});
260 if let Some(c) = next_cursor {
261 result["nextCursor"] = json!(c);
262 }
263 json!({"jsonrpc": "2.0", "id": 1, "result": result})
264 }
265
266 #[tokio::test]
267 async fn ingest__complete_page_creates_version_one() {
268 let m = manager();
269 let req = tools_list_req(None);
270 let resp = tools_list_resp(json!([{"name": "search"}]), None);
271 let v = m.ingest("tools/list", &req, &resp).await.unwrap();
272 assert_eq!(v.version, 1);
273 assert_eq!(v.method, "tools/list");
274 assert_eq!(v.upstream_id, "proxy-1");
275 }
276
277 #[tokio::test]
278 async fn ingest__first_page_buffers_returns_none() {
279 let m = manager();
280 let req = tools_list_req(None);
281 let resp = tools_list_resp(json!([{"name": "a"}]), Some("cur1"));
282 assert!(m.ingest("tools/list", &req, &resp).await.is_none());
283 }
284
285 #[tokio::test]
286 async fn ingest__first_middle_last_chain_merges_once() {
287 let m = manager();
288
289 let r1 = tools_list_resp(json!([{"name": "a"}]), Some("c1"));
290 assert!(
291 m.ingest("tools/list", &tools_list_req(None), &r1)
292 .await
293 .is_none()
294 );
295
296 let r2 = tools_list_resp(json!([{"name": "b"}]), Some("c2"));
297 assert!(
298 m.ingest("tools/list", &tools_list_req(Some("c1")), &r2)
299 .await
300 .is_none()
301 );
302
303 let r3 = tools_list_resp(json!([{"name": "c"}]), None);
304 let v = m
305 .ingest("tools/list", &tools_list_req(Some("c2")), &r3)
306 .await
307 .unwrap();
308
309 let names: Vec<&str> = v.payload["tools"]
310 .as_array()
311 .unwrap()
312 .iter()
313 .map(|t| t["name"].as_str().unwrap())
314 .collect();
315 assert_eq!(names, vec!["a", "b", "c"]);
316 assert_eq!(v.version, 1);
317 }
318
319 #[tokio::test]
320 async fn ingest__unchanged_payload_returns_none() {
321 let m = manager();
322 let req = tools_list_req(None);
323 let resp = tools_list_resp(json!([{"name": "a"}]), None);
324 m.ingest("tools/list", &req, &resp).await.unwrap();
325 assert!(m.ingest("tools/list", &req, &resp).await.is_none());
326 }
327
328 #[tokio::test]
329 async fn ingest__changed_payload_increments_version() {
330 let m = manager();
331 let req = tools_list_req(None);
332 let r1 = tools_list_resp(json!([{"name": "a"}]), None);
333 let v1 = m.ingest("tools/list", &req, &r1).await.unwrap();
334 assert_eq!(v1.version, 1);
335
336 let r2 = tools_list_resp(json!([{"name": "a"}, {"name": "b"}]), None);
337 let v2 = m.ingest("tools/list", &req, &r2).await.unwrap();
338 assert_eq!(v2.version, 2);
339 }
340
341 #[tokio::test]
342 async fn ingest__clears_stale_on_new_version() {
343 let m = manager();
344 let req = tools_list_req(None);
345 let r1 = tools_list_resp(json!([{"name": "a"}]), None);
346 m.ingest("tools/list", &req, &r1).await.unwrap();
347
348 m.mark_stale("tools/list");
349 assert!(m.is_stale("tools/list"));
350
351 let r2 = tools_list_resp(json!([{"name": "a"}, {"name": "b"}]), None);
352 m.ingest("tools/list", &req, &r2).await.unwrap();
353 assert!(!m.is_stale("tools/list"));
354 }
355
356 #[tokio::test]
357 async fn ingest__no_result_returns_none() {
358 let m = manager();
359 let req = tools_list_req(None);
360 let err_resp =
361 json!({"jsonrpc": "2.0", "id": 1, "error": {"code": -32603, "message": "x"}});
362 assert!(m.ingest("tools/list", &req, &err_resp).await.is_none());
363 }
364
365 #[tokio::test]
366 async fn mark_stale__and_is_stale_idempotent() {
367 let m = manager();
368 assert!(!m.is_stale("tools/list"));
369 m.mark_stale("tools/list");
370 let first = m.stale_since("tools/list");
371 m.mark_stale("tools/list");
372 let second = m.stale_since("tools/list");
373 assert!(m.is_stale("tools/list"));
374 assert_eq!(first, second);
375 }
376
377 #[tokio::test]
378 async fn list_tools__empty_when_no_version() {
379 let m = manager();
380 assert!(m.list_tools().await.is_empty());
381 }
382
383 #[tokio::test]
384 async fn list_tools__returns_items_from_latest() {
385 let m = manager();
386 let req = tools_list_req(None);
387 let resp = tools_list_resp(json!([{"name": "a"}, {"name": "b"}]), None);
388 m.ingest("tools/list", &req, &resp).await.unwrap();
389
390 let tools = m.list_tools().await;
391 assert_eq!(tools.len(), 2);
392 assert_eq!(tools[0]["name"], "a");
393 assert_eq!(tools[1]["name"], "b");
394 }
395
396 #[tokio::test]
397 async fn get_tool__by_name_hit_and_miss() {
398 let m = manager();
399 let req = tools_list_req(None);
400 let resp = tools_list_resp(json!([{"name": "search", "description": "find"}]), None);
401 m.ingest("tools/list", &req, &resp).await.unwrap();
402
403 let hit = m.get_tool("search").await.unwrap();
404 assert_eq!(hit["description"], "find");
405 assert!(m.get_tool("missing").await.is_none());
406 }
407
408 #[tokio::test]
409 async fn get_resource__by_uri() {
410 let m = manager();
411 let req = json!({"jsonrpc": "2.0", "id": 1, "method": "resources/list"});
412 let resp = json!({
413 "jsonrpc": "2.0", "id": 1,
414 "result": {"resources": [{"uri": "file://a", "name": "A"}]}
415 });
416 m.ingest("resources/list", &req, &resp).await.unwrap();
417 let r = m.get_resource("file://a").await.unwrap();
418 assert_eq!(r["name"], "A");
419 }
420
421 #[tokio::test]
422 async fn warm__seeds_counter_from_store() {
423 let store = MemorySchemaStore::new();
424 let pre = SchemaVersion {
425 id: SchemaVersionId("abc".to_string()),
426 upstream_id: "proxy-1".to_string(),
427 method: "tools/list".to_string(),
428 version: 5,
429 payload: Arc::new(json!({"tools": [{"name": "x"}]})),
430 content_hash: "prior-hash".to_string(),
431 captured_at: Utc::now(),
432 };
433 store.put_version(pre).await;
434
435 let m = SchemaManager::new("proxy-1", store);
436 let req = tools_list_req(None);
437 let resp = tools_list_resp(json!([{"name": "y"}]), None);
438 let v = m.ingest("tools/list", &req, &resp).await.unwrap();
439 assert_eq!(v.version, 6);
440 }
441
442 #[tokio::test]
443 async fn latest__returns_current_version() {
444 let m = manager();
445 let req = tools_list_req(None);
446 let resp = tools_list_resp(json!([{"name": "a"}]), None);
447 m.ingest("tools/list", &req, &resp).await.unwrap();
448 let latest = m.latest("tools/list").await.unwrap();
449 assert_eq!(latest.version, 1);
450 }
451
452 #[tokio::test]
453 async fn list_resource_templates__walks_template_key() {
454 let m = manager();
455 let req = json!({"jsonrpc": "2.0", "id": 1, "method": "resources/templates/list"});
456 let resp = json!({
457 "jsonrpc": "2.0", "id": 1,
458 "result": {"resourceTemplates": [{"uriTemplate": "file://{id}", "name": "f"}]}
459 });
460 m.ingest("resources/templates/list", &req, &resp)
461 .await
462 .unwrap();
463 let items = m.list_resource_templates().await;
464 assert_eq!(items.len(), 1);
465 assert_eq!(items[0]["name"], "f");
466 }
467
468 #[tokio::test]
469 async fn upstream_id__accessor() {
470 let m = manager();
471 assert_eq!(m.upstream_id(), "proxy-1");
472 }
473}