1use std::sync::Arc;
2
3use serde::{Deserialize, Serialize};
4use serde_json::{Value, json};
5use tokio::sync::Mutex;
6
7use crate::KontextAuthSession;
8use crate::KontextDevClient;
9use crate::KontextDevConfig;
10use crate::KontextDevError;
11
12pub const DEFAULT_SERVER: &str = "https://api.kontext.dev";
13const MCP_SESSION_HEADER: &str = "Mcp-Session-Id";
14const META_SEARCH_TOOLS: &str = "SEARCH_TOOLS";
15const META_EXECUTE_TOOL: &str = "EXECUTE_TOOL";
16const DEFAULT_MCP_PROTOCOL_VERSION: &str = "2025-06-18";
17const STREAMABLE_HTTP_ACCEPT: &str = "application/json, text/event-stream";
18const STREAM_CONTENT_TYPE: &str = "text/event-stream";
19
20pub fn normalize_kontext_server_url(server: &str) -> String {
21 let mut url = server.trim_end_matches('/').to_string();
22 if let Some(stripped) = url.strip_suffix("/api/v1") {
23 url = stripped.to_string();
24 }
25 if let Some(stripped) = url.strip_suffix("/mcp") {
26 url = stripped.to_string();
27 }
28 url.trim_end_matches('/').to_string()
29}
30
31#[derive(Clone, Debug)]
32pub struct KontextMcpConfig {
33 pub client_session_id: String,
34 pub client_id: String,
35 pub redirect_uri: String,
36 pub url: Option<String>,
37 pub server: Option<String>,
38 pub client_secret: Option<String>,
39 pub scope: Option<String>,
40 pub resource: Option<String>,
41 pub session_key: Option<String>,
42 pub integration_ui_url: Option<String>,
43 pub integration_return_to: Option<String>,
44 pub auth_timeout_seconds: Option<i64>,
45 pub open_connect_page_on_login: Option<bool>,
46 pub token_cache_path: Option<String>,
47}
48
49impl Default for KontextMcpConfig {
50 fn default() -> Self {
51 Self {
52 client_session_id: String::new(),
53 client_id: String::new(),
54 redirect_uri: "http://localhost:3333/callback".to_string(),
55 url: None,
56 server: Some(DEFAULT_SERVER.to_string()),
57 client_secret: None,
58 scope: None,
59 resource: None,
60 session_key: None,
61 integration_ui_url: None,
62 integration_return_to: None,
63 auth_timeout_seconds: None,
64 open_connect_page_on_login: None,
65 token_cache_path: None,
66 }
67 }
68}
69
70#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
71#[serde(rename_all = "snake_case")]
72pub enum RuntimeIntegrationCategory {
73 GatewayRemoteMcp,
74 InternalMcpCredentials,
75}
76
77#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
78#[serde(rename_all = "snake_case")]
79pub enum RuntimeIntegrationConnectType {
80 Oauth,
81 UserToken,
82 Credentials,
83 None,
84}
85
86#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
87#[serde(rename_all = "camelCase")]
88pub struct RuntimeIntegrationRecord {
89 pub id: String,
90 pub name: String,
91 pub url: String,
92 pub category: RuntimeIntegrationCategory,
93 pub connect_type: RuntimeIntegrationConnectType,
94 #[serde(skip_serializing_if = "Option::is_none")]
95 pub auth_mode: Option<String>,
96 #[serde(skip_serializing_if = "Option::is_none")]
97 pub credential_schema: Option<serde_json::Value>,
98 #[serde(skip_serializing_if = "Option::is_none")]
99 pub requires_oauth: Option<bool>,
100 #[serde(skip_serializing_if = "Option::is_none")]
101 pub connection: Option<RuntimeIntegrationConnection>,
102}
103
104#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
105#[serde(rename_all = "camelCase")]
106pub struct RuntimeIntegrationConnection {
107 pub connected: bool,
108 pub status: String,
109 #[serde(skip_serializing_if = "Option::is_none")]
110 pub expires_at: Option<String>,
111 #[serde(skip_serializing_if = "Option::is_none")]
112 pub display_name: Option<String>,
113}
114
115#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
116#[serde(rename_all = "camelCase")]
117pub struct KontextTool {
118 pub id: String,
119 pub name: String,
120 #[serde(skip_serializing_if = "Option::is_none")]
121 pub description: Option<String>,
122 #[serde(skip_serializing_if = "Option::is_none")]
123 pub input_schema: Option<serde_json::Value>,
124 #[serde(skip_serializing_if = "Option::is_none")]
125 pub server: Option<KontextToolServer>,
126}
127
128#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
129#[serde(rename_all = "camelCase")]
130pub struct KontextToolServer {
131 pub id: String,
132 #[serde(skip_serializing_if = "Option::is_none")]
133 pub name: Option<String>,
134}
135
136#[derive(Clone, Debug, Default)]
137struct McpSessionState {
138 session_id: Option<String>,
139 access_token: Option<String>,
140}
141
142#[derive(Clone, Debug)]
143pub struct KontextMcp {
144 config: KontextMcpConfig,
145 client: KontextDevClient,
146 http: reqwest::Client,
147 session: Arc<Mutex<McpSessionState>>,
148}
149
150impl KontextMcp {
151 pub fn new(config: KontextMcpConfig) -> Self {
152 let server =
153 normalize_kontext_server_url(config.server.as_deref().unwrap_or(DEFAULT_SERVER));
154 let sdk_config = KontextDevConfig {
155 server,
156 client_id: config.client_id.clone(),
157 client_secret: config.client_secret.clone(),
158 scope: config.scope.clone().unwrap_or_default(),
159 server_name: "kontext-dev".to_string(),
160 resource: config
161 .resource
162 .clone()
163 .unwrap_or_else(|| "mcp-gateway".to_string()),
164 integration_ui_url: config.integration_ui_url.clone(),
165 integration_return_to: config.integration_return_to.clone(),
166 open_connect_page_on_login: config.open_connect_page_on_login.unwrap_or(true),
167 auth_timeout_seconds: config.auth_timeout_seconds.unwrap_or(300),
168 token_cache_path: config.token_cache_path.clone(),
169 redirect_uri: config.redirect_uri.clone(),
170 };
171
172 Self {
173 config,
174 client: KontextDevClient::new(sdk_config),
175 http: reqwest::Client::new(),
176 session: Arc::new(Mutex::new(McpSessionState::default())),
177 }
178 }
179
180 pub fn client(&self) -> &KontextDevClient {
181 &self.client
182 }
183
184 pub async fn authenticate_mcp(&self) -> Result<KontextAuthSession, KontextDevError> {
185 self.client.authenticate_mcp().await
186 }
187
188 pub fn mcp_url(&self) -> Result<String, KontextDevError> {
189 if let Some(url) = &self.config.url {
190 return Ok(url.clone());
191 }
192 self.client.mcp_url()
193 }
194
195 pub async fn clear_cached_session(&self) {
196 self.invalidate_session().await;
197 }
198
199 pub async fn list_integrations(
200 &self,
201 ) -> Result<Vec<RuntimeIntegrationRecord>, KontextDevError> {
202 let session = self.authenticate_mcp().await?;
203 let base = self.client.server_base_url()?;
204 let response = self
205 .http
206 .get(format!("{}/mcp/integrations", base.trim_end_matches('/')))
207 .bearer_auth(session.gateway_token.access_token)
208 .send()
209 .await
210 .map_err(|err| KontextDevError::ConnectSession {
211 message: err.to_string(),
212 })?;
213
214 if !response.status().is_success() {
215 let status = response.status();
216 let body = response.text().await.unwrap_or_default();
217 return Err(KontextDevError::ConnectSession {
218 message: format!("{status}: {body}"),
219 });
220 }
221
222 #[derive(Deserialize)]
223 struct IntegrationsResponse {
224 items: Vec<RuntimeIntegrationRecord>,
225 }
226
227 let payload = response
228 .json::<IntegrationsResponse>()
229 .await
230 .map_err(|err| KontextDevError::ConnectSession {
231 message: err.to_string(),
232 })?;
233
234 Ok(payload.items)
235 }
236
237 pub async fn list_tools(&self) -> Result<Vec<KontextTool>, KontextDevError> {
238 let session = self.authenticate_mcp().await?;
239 self.list_tools_with_access_token(&session.gateway_token.access_token)
240 .await
241 }
242
243 pub async fn list_tools_with_access_token(
244 &self,
245 access_token: &str,
246 ) -> Result<Vec<KontextTool>, KontextDevError> {
247 let result = self
248 .json_rpc_with_session(
249 access_token,
250 "tools/list",
251 json!({}),
252 Some("list-tools"),
253 true,
254 )
255 .await?;
256 parse_tools_list_result(&result)
257 }
258
259 pub async fn call_tool(
260 &self,
261 tool_id: &str,
262 args: Option<serde_json::Map<String, serde_json::Value>>,
263 ) -> Result<serde_json::Value, KontextDevError> {
264 let session = self.authenticate_mcp().await?;
265 self.call_tool_with_access_token(&session.gateway_token.access_token, tool_id, args)
266 .await
267 }
268
269 pub async fn call_tool_with_access_token(
270 &self,
271 access_token: &str,
272 tool_id: &str,
273 args: Option<serde_json::Map<String, serde_json::Value>>,
274 ) -> Result<serde_json::Value, KontextDevError> {
275 self.json_rpc_with_session(
276 access_token,
277 "tools/call",
278 json!({ "name": tool_id, "arguments": args.unwrap_or_default() }),
279 Some("call-tool"),
280 true,
281 )
282 .await
283 }
284
285 async fn json_rpc_with_session(
286 &self,
287 access_token: &str,
288 method: &str,
289 params: Value,
290 id: Option<&str>,
291 allow_session_retry: bool,
292 ) -> Result<Value, KontextDevError> {
293 let max_attempts = if allow_session_retry { 2 } else { 1 };
294 for attempt in 0..max_attempts {
295 let session_id = self.ensure_mcp_session(access_token).await?;
296
297 let response = self
298 .http
299 .post(self.mcp_url()?)
300 .bearer_auth(access_token)
301 .header(reqwest::header::ACCEPT, STREAMABLE_HTTP_ACCEPT)
302 .header(MCP_SESSION_HEADER, &session_id)
303 .json(&json!({
304 "jsonrpc": "2.0",
305 "id": id.unwrap_or("1"),
306 "method": method,
307 "params": params,
308 }))
309 .send()
310 .await
311 .map_err(|err| KontextDevError::ConnectSession {
312 message: err.to_string(),
313 })?;
314
315 if !response.status().is_success() {
316 let status = response.status();
317 let body = response.text().await.unwrap_or_default();
318 let retryable =
319 attempt + 1 < max_attempts && is_invalid_session_response_body(body.as_str());
320 if retryable {
321 self.invalidate_session().await;
322 continue;
323 }
324 return Err(KontextDevError::ConnectSession {
325 message: format!("{status}: {body}"),
326 });
327 }
328
329 let payload = parse_json_or_streamable_response(response).await?;
330
331 if let Some(error) = payload.get("error") {
332 let message = extract_jsonrpc_error_message(error);
333 let retryable =
334 attempt + 1 < max_attempts && is_invalid_session_jsonrpc_error(error);
335 if retryable {
336 self.invalidate_session().await;
337 continue;
338 }
339 return Err(KontextDevError::ConnectSession { message });
340 }
341
342 return Ok(payload.get("result").cloned().unwrap_or(Value::Null));
343 }
344
345 Err(KontextDevError::ConnectSession {
346 message: "MCP request failed after session retry".to_string(),
347 })
348 }
349
350 async fn ensure_mcp_session(&self, access_token: &str) -> Result<String, KontextDevError> {
351 {
352 let guard = self.session.lock().await;
353 if guard.access_token.as_deref() == Some(access_token)
354 && let Some(session_id) = guard.session_id.clone()
355 {
356 return Ok(session_id);
357 }
358 }
359
360 let initialize_response = self
361 .http
362 .post(self.mcp_url()?)
363 .bearer_auth(access_token)
364 .header(reqwest::header::ACCEPT, STREAMABLE_HTTP_ACCEPT)
365 .json(&json!({
366 "jsonrpc": "2.0",
367 "id": "initialize",
368 "method": "initialize",
369 "params": {
370 "protocolVersion": DEFAULT_MCP_PROTOCOL_VERSION,
371 "capabilities": {
372 "tools": {}
373 },
374 "clientInfo": {
375 "name": "kontext-dev-sdk-rs",
376 "version": env!("CARGO_PKG_VERSION"),
377 "sessionId": self.config.client_session_id
378 }
379 }
380 }))
381 .send()
382 .await
383 .map_err(|err| KontextDevError::ConnectSession {
384 message: err.to_string(),
385 })?;
386
387 if !initialize_response.status().is_success() {
388 let status = initialize_response.status();
389 let body = initialize_response.text().await.unwrap_or_default();
390 return Err(KontextDevError::ConnectSession {
391 message: format!("{status}: {body}"),
392 });
393 }
394
395 let session_header = initialize_response
396 .headers()
397 .get(MCP_SESSION_HEADER)
398 .or_else(|| initialize_response.headers().get("mcp-session-id"))
399 .and_then(|value| value.to_str().ok())
400 .map(|value| value.trim().to_string());
401
402 let initialize_payload = parse_json_or_streamable_response(initialize_response).await?;
403
404 if let Some(error) = initialize_payload.get("error") {
405 return Err(KontextDevError::ConnectSession {
406 message: extract_jsonrpc_error_message(error),
407 });
408 }
409
410 let session_id = session_header
411 .or_else(|| {
412 initialize_payload
413 .get("result")
414 .and_then(|result| result.get("sessionId"))
415 .and_then(|value| value.as_str())
416 .map(|value| value.to_string())
417 })
418 .or_else(|| {
419 initialize_payload
420 .get("result")
421 .and_then(|result| result.get("session_id"))
422 .and_then(|value| value.as_str())
423 .map(|value| value.to_string())
424 })
425 .ok_or_else(|| KontextDevError::ConnectSession {
426 message: "MCP initialize did not return a session id".to_string(),
427 })?;
428
429 let _ = self
432 .http
433 .post(self.mcp_url()?)
434 .bearer_auth(access_token)
435 .header(reqwest::header::ACCEPT, STREAMABLE_HTTP_ACCEPT)
436 .header(MCP_SESSION_HEADER, &session_id)
437 .json(&json!({
438 "jsonrpc": "2.0",
439 "method": "notifications/initialized",
440 "params": {}
441 }))
442 .send()
443 .await;
444
445 {
446 let mut guard = self.session.lock().await;
447 guard.session_id = Some(session_id.clone());
448 guard.access_token = Some(access_token.to_string());
449 }
450
451 Ok(session_id)
452 }
453
454 async fn invalidate_session(&self) {
455 let mut guard = self.session.lock().await;
456 guard.session_id = None;
457 guard.access_token = None;
458 }
459}
460
461fn extract_jsonrpc_error_message(error: &Value) -> String {
462 error
463 .get("message")
464 .and_then(|value| value.as_str())
465 .map(ToString::to_string)
466 .or_else(|| {
467 error
468 .get("error_description")
469 .and_then(|value| value.as_str())
470 .map(ToString::to_string)
471 })
472 .unwrap_or_else(|| error.to_string())
473}
474
475fn is_invalid_session_error(message: &str) -> bool {
476 let lower = message.to_ascii_lowercase();
477 lower.contains("no valid session id")
478 || lower.contains("no valid session-id")
479 || lower.contains("invalid session")
480}
481
482fn is_session_not_found_error(message: &str) -> bool {
483 let lower = message.to_ascii_lowercase();
484 lower.contains("session") && lower.contains("not found")
485}
486
487fn is_invalid_session_jsonrpc_error(error: &Value) -> bool {
488 let message = extract_jsonrpc_error_message(error);
489 if is_invalid_session_error(message.as_str()) {
490 return true;
491 }
492
493 let code = error.get("code").and_then(Value::as_i64);
494 code == Some(-32000) && is_session_not_found_error(message.as_str())
495}
496
497fn is_invalid_session_response_body(body: &str) -> bool {
498 if is_invalid_session_error(body) || is_session_not_found_error(body) {
499 return true;
500 }
501
502 if let Ok(payload) = serde_json::from_str::<Value>(body)
503 && let Some(error) = payload.get("error")
504 {
505 return is_invalid_session_jsonrpc_error(error);
506 }
507
508 false
509}
510
511async fn parse_json_or_streamable_response(
512 response: reqwest::Response,
513) -> Result<Value, KontextDevError> {
514 let content_type = response
515 .headers()
516 .get(reqwest::header::CONTENT_TYPE)
517 .and_then(|value| value.to_str().ok())
518 .map(|value| value.to_ascii_lowercase())
519 .unwrap_or_default();
520 let body = response
521 .text()
522 .await
523 .map_err(|err| KontextDevError::ConnectSession {
524 message: err.to_string(),
525 })?;
526
527 parse_json_or_streamable_body(&body, &content_type)
528 .map_err(|message| KontextDevError::ConnectSession { message })
529}
530
531fn parse_json_or_streamable_body(body: &str, content_type: &str) -> Result<Value, String> {
532 let parse_json = || serde_json::from_str::<Value>(body).map_err(|err| err.to_string());
533 let parse_sse = || parse_sse_last_json_event(body);
534
535 if content_type.contains(STREAM_CONTENT_TYPE) {
536 return parse_sse().ok_or_else(|| {
537 "failed to parse streamable MCP response as SSE JSON events".to_string()
538 });
539 }
540
541 parse_json().or_else(|json_err| {
542 parse_sse().ok_or_else(|| format!("failed to decode response body: {json_err}"))
543 })
544}
545
546fn parse_sse_last_json_event(body: &str) -> Option<Value> {
547 let mut current_data = Vec::<String>::new();
548 let mut last_json = None;
549
550 let flush_data = |current_data: &mut Vec<String>, last_json: &mut Option<Value>| {
551 if current_data.is_empty() {
552 return;
553 }
554 let data = current_data.join("\n");
555 current_data.clear();
556 let trimmed = data.trim();
557 if trimmed.is_empty() || trimmed == "[DONE]" {
558 return;
559 }
560 if let Ok(value) = serde_json::from_str::<Value>(trimmed) {
561 *last_json = Some(value);
562 }
563 };
564
565 for line in body.lines() {
566 let line = line.trim_end_matches('\r');
567 if line.is_empty() {
568 flush_data(&mut current_data, &mut last_json);
569 continue;
570 }
571 if let Some(data) = line.strip_prefix("data:") {
572 current_data.push(data.trim_start().to_string());
573 continue;
574 }
575 if let Ok(value) = serde_json::from_str::<Value>(line) {
576 last_json = Some(value);
577 }
578 }
579 flush_data(&mut current_data, &mut last_json);
580
581 last_json
582}
583
584pub(crate) fn has_meta_gateway_tools(tools: &[KontextTool]) -> bool {
585 let mut has_search = false;
586 let mut has_execute = false;
587 for tool in tools {
588 if tool.name == META_SEARCH_TOOLS {
589 has_search = true;
590 } else if tool.name == META_EXECUTE_TOOL {
591 has_execute = true;
592 }
593 }
594 has_search && has_execute
595}
596
597pub(crate) fn extract_json_resource_text(result: &Value) -> Option<String> {
598 let content = result.get("content")?.as_array()?;
599 for item in content {
600 if item.get("type").and_then(Value::as_str) != Some("resource") {
601 continue;
602 }
603 let Some(resource) = item.get("resource") else {
604 continue;
605 };
606 if resource.get("mimeType").and_then(Value::as_str) != Some("application/json") {
607 continue;
608 }
609 if let Some(text) = resource.get("text").and_then(Value::as_str) {
610 return Some(text.to_string());
611 }
612 }
613 None
614}
615
616pub(crate) fn extract_text_content(result: &Value) -> String {
617 let Some(content) = result.get("content").and_then(Value::as_array) else {
618 return result.to_string();
619 };
620
621 let mut text_items = Vec::new();
622 for item in content {
623 if item.get("type").and_then(Value::as_str) == Some("text")
624 && let Some(text) = item.get("text").and_then(Value::as_str)
625 {
626 text_items.push(text.to_string());
627 }
628 }
629 if !text_items.is_empty() {
630 return text_items.join("\n");
631 }
632
633 let mut resource_items = Vec::new();
634 for item in content {
635 if item.get("type").and_then(Value::as_str) != Some("resource") {
636 continue;
637 }
638 let Some(resource_text) = item
639 .get("resource")
640 .and_then(|resource| resource.get("text"))
641 .and_then(Value::as_str)
642 else {
643 continue;
644 };
645
646 let parsed = serde_json::from_str::<Value>(resource_text)
647 .ok()
648 .map(|value| extract_text_content(&value))
649 .unwrap_or_else(|| resource_text.to_string());
650 resource_items.push(parsed);
651 }
652
653 if !resource_items.is_empty() {
654 return resource_items.join("\n");
655 }
656
657 content
658 .iter()
659 .map(Value::to_string)
660 .collect::<Vec<_>>()
661 .join("\n")
662}
663
664#[derive(Clone, Debug)]
665pub(crate) struct GatewayToolsPayload {
666 pub tools: Vec<KontextTool>,
667 pub errors: Vec<GatewayToolError>,
668 pub elicitations: Vec<GatewayElicitation>,
669}
670
671#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
672#[serde(rename_all = "camelCase")]
673pub struct GatewayToolError {
674 pub server_id: String,
675 #[serde(default)]
676 pub server_name: Option<String>,
677 #[serde(default)]
678 pub reason: Option<String>,
679}
680
681#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
682#[serde(rename_all = "camelCase")]
683pub struct GatewayElicitation {
684 pub url: String,
685 #[serde(default)]
686 pub message: Option<String>,
687 #[serde(default)]
688 pub integration_id: Option<String>,
689 #[serde(default)]
690 pub integration_name: Option<String>,
691}
692
693#[derive(Clone, Debug, Deserialize)]
694#[serde(rename_all = "camelCase")]
695struct GatewayToolSummary {
696 id: String,
697 name: String,
698 #[serde(default)]
699 description: Option<String>,
700 #[serde(default)]
701 input_schema: Option<Value>,
702 #[serde(default)]
703 server: Option<GatewayToolServer>,
704}
705
706#[derive(Clone, Debug, Deserialize)]
707#[serde(rename_all = "camelCase")]
708struct GatewayToolServer {
709 #[serde(default)]
710 id: Option<String>,
711 #[serde(default)]
712 name: Option<String>,
713}
714
715#[derive(Debug, Deserialize)]
716#[serde(rename_all = "camelCase")]
717struct RawTool {
718 name: String,
719 #[serde(default)]
720 description: Option<String>,
721 #[serde(default)]
722 input_schema: Option<serde_json::Value>,
723}
724
725fn parse_tools_list_result(result: &Value) -> Result<Vec<KontextTool>, KontextDevError> {
726 let tools = result
727 .get("tools")
728 .and_then(|value| value.as_array())
729 .cloned()
730 .unwrap_or_default();
731
732 tools
733 .into_iter()
734 .map(|tool| {
735 let raw: RawTool =
736 serde_json::from_value(tool).map_err(|err| KontextDevError::ConnectSession {
737 message: format!("invalid tool payload: {err}"),
738 })?;
739
740 Ok(KontextTool {
741 id: raw.name.clone(),
742 name: raw.name,
743 description: raw.description,
744 input_schema: raw.input_schema,
745 server: None,
746 })
747 })
748 .collect()
749}
750
751pub(crate) fn parse_gateway_tools_payload(
752 raw: &Value,
753) -> Result<GatewayToolsPayload, KontextDevError> {
754 let json_text =
755 extract_json_resource_text(raw).ok_or_else(|| KontextDevError::ConnectSession {
756 message: "SEARCH_TOOLS did not return JSON resource content".to_string(),
757 })?;
758
759 let parsed = serde_json::from_str::<Value>(&json_text).map_err(|err| {
760 KontextDevError::ConnectSession {
761 message: format!("SEARCH_TOOLS returned invalid JSON: {err}"),
762 }
763 })?;
764
765 if let Some(items) = parsed.as_array() {
766 let tools = items
767 .iter()
768 .cloned()
769 .map(serde_json::from_value::<GatewayToolSummary>)
770 .collect::<Result<Vec<_>, _>>()
771 .map_err(|err| KontextDevError::ConnectSession {
772 message: format!("SEARCH_TOOLS returned invalid tool entry: {err}"),
773 })?
774 .into_iter()
775 .map(to_kontext_gateway_tool)
776 .collect();
777 return Ok(GatewayToolsPayload {
778 tools,
779 errors: Vec::new(),
780 elicitations: Vec::new(),
781 });
782 }
783
784 let Some(obj) = parsed.as_object() else {
785 return Err(KontextDevError::ConnectSession {
786 message: "SEARCH_TOOLS response was not a JSON array or object".to_string(),
787 });
788 };
789
790 let tools = obj
791 .get("items")
792 .and_then(Value::as_array)
793 .cloned()
794 .unwrap_or_default()
795 .into_iter()
796 .map(serde_json::from_value::<GatewayToolSummary>)
797 .collect::<Result<Vec<_>, _>>()
798 .map_err(|err| KontextDevError::ConnectSession {
799 message: format!("SEARCH_TOOLS items contained invalid tool data: {err}"),
800 })?
801 .into_iter()
802 .map(to_kontext_gateway_tool)
803 .collect::<Vec<_>>();
804
805 let errors = obj
806 .get("errors")
807 .and_then(Value::as_array)
808 .cloned()
809 .unwrap_or_default()
810 .into_iter()
811 .filter_map(|value| serde_json::from_value::<GatewayToolError>(value).ok())
812 .collect::<Vec<_>>();
813
814 let elicitations = obj
815 .get("elicitations")
816 .and_then(Value::as_array)
817 .cloned()
818 .unwrap_or_default()
819 .into_iter()
820 .filter_map(|value| serde_json::from_value::<GatewayElicitation>(value).ok())
821 .collect::<Vec<_>>();
822
823 Ok(GatewayToolsPayload {
824 tools,
825 errors,
826 elicitations,
827 })
828}
829
830fn to_kontext_gateway_tool(summary: GatewayToolSummary) -> KontextTool {
831 let server = summary.server.and_then(|server| {
832 server.id.map(|id| KontextToolServer {
833 id,
834 name: server.name,
835 })
836 });
837
838 KontextTool {
839 id: summary.id,
840 name: summary.name,
841 description: summary.description,
842 input_schema: summary.input_schema,
843 server,
844 }
845}
846
847#[cfg(test)]
848mod tests {
849 use super::*;
850 use std::sync::Arc;
851 use std::sync::atomic::AtomicUsize;
852 use std::sync::atomic::Ordering;
853 use wiremock::Mock;
854 use wiremock::MockServer;
855 use wiremock::ResponseTemplate;
856 use wiremock::matchers::method;
857 use wiremock::matchers::path;
858
859 #[test]
860 fn parse_json_or_streamable_body_parses_json_payload() {
861 let parsed = parse_json_or_streamable_body(
862 r#"{"jsonrpc":"2.0","result":{"ok":true}}"#,
863 "application/json",
864 )
865 .expect("json should parse");
866 assert_eq!(parsed["result"]["ok"], Value::Bool(true));
867 }
868
869 #[test]
870 fn parse_json_or_streamable_body_parses_sse_payload() {
871 let parsed = parse_json_or_streamable_body(
872 "event: message\ndata: {\"jsonrpc\":\"2.0\",\"result\":{\"sessionId\":\"abc\"}}\n\n",
873 "text/event-stream",
874 )
875 .expect("sse should parse");
876 assert_eq!(
877 parsed["result"]["sessionId"],
878 Value::String("abc".to_string())
879 );
880 }
881
882 #[test]
883 fn parse_json_or_streamable_body_falls_back_to_sse_when_content_type_is_json() {
884 let parsed = parse_json_or_streamable_body(
885 "data: {\"jsonrpc\":\"2.0\",\"result\":{\"tools\":[]}}\n\n",
886 "application/json",
887 )
888 .expect("sse fallback should parse");
889 assert_eq!(parsed["result"]["tools"], Value::Array(Vec::new()));
890 }
891
892 #[test]
893 fn raw_tool_parses_input_schema_from_camel_case_key() {
894 let parsed: RawTool = serde_json::from_value(serde_json::json!({
895 "name": "SEARCH_TOOLS",
896 "description": "Search available tools",
897 "inputSchema": { "type": "object", "properties": { "limit": { "type": "number" } } }
898 }))
899 .expect("raw tool should deserialize");
900
901 assert_eq!(parsed.name, "SEARCH_TOOLS");
902 assert_eq!(
903 parsed
904 .input_schema
905 .as_ref()
906 .and_then(|value| value.get("type"))
907 .and_then(Value::as_str),
908 Some("object")
909 );
910 }
911
912 #[test]
913 fn extract_json_resource_text_skips_resource_items_without_resource_payload() {
914 let payload = serde_json::json!({
915 "content": [
916 { "type": "resource" },
917 {
918 "type": "resource",
919 "resource": {
920 "mimeType": "application/json",
921 "text": "{\"ok\":true}"
922 }
923 }
924 ]
925 });
926
927 assert_eq!(
928 extract_json_resource_text(&payload),
929 Some("{\"ok\":true}".to_string())
930 );
931 }
932
933 #[test]
934 fn runtime_integration_record_parses_user_token_connect_type() {
935 let parsed: RuntimeIntegrationRecord = serde_json::from_value(serde_json::json!({
936 "id": "convex-int",
937 "name": "Convex",
938 "url": "https://convex.example.com/mcp",
939 "category": "gateway_remote_mcp",
940 "connectType": "user_token"
941 }))
942 .expect("record should deserialize");
943
944 assert_eq!(
945 parsed.connect_type,
946 RuntimeIntegrationConnectType::UserToken
947 );
948 }
949
950 #[test]
951 fn runtime_integration_record_rejects_unknown_connect_type() {
952 let err = serde_json::from_value::<RuntimeIntegrationRecord>(serde_json::json!({
953 "id": "convex-int",
954 "name": "Convex",
955 "url": "https://convex.example.com/mcp",
956 "category": "gateway_remote_mcp",
957 "connectType": "api_key"
958 }))
959 .expect_err("record should reject unknown connect type");
960
961 assert!(err.to_string().contains("unknown variant"));
962 }
963
964 #[derive(Clone, Copy, Debug)]
965 enum SessionFailureKind {
966 HttpNotFound,
967 JsonRpcNotFound,
968 }
969
970 fn create_test_mcp(server: &MockServer) -> KontextMcp {
971 KontextMcp::new(KontextMcpConfig {
972 client_session_id: "client-session".to_string(),
973 client_id: "client-id".to_string(),
974 redirect_uri: "http://localhost:3333/callback".to_string(),
975 url: Some(format!("{}/mcp", server.uri())),
976 server: Some(server.uri()),
977 client_secret: None,
978 scope: None,
979 resource: None,
980 session_key: None,
981 integration_ui_url: None,
982 integration_return_to: None,
983 auth_timeout_seconds: None,
984 open_connect_page_on_login: None,
985 token_cache_path: None,
986 })
987 }
988
989 async fn mount_retrying_tools_list_server(
990 server: &MockServer,
991 failure_kind: SessionFailureKind,
992 recover_on_retry: bool,
993 ) -> (Arc<AtomicUsize>, Arc<AtomicUsize>) {
994 let initialize_calls = Arc::new(AtomicUsize::new(0));
995 let tools_list_calls = Arc::new(AtomicUsize::new(0));
996
997 let initialize_calls_for_mock = Arc::clone(&initialize_calls);
998 let tools_list_calls_for_mock = Arc::clone(&tools_list_calls);
999 Mock::given(method("POST"))
1000 .and(path("/mcp"))
1001 .respond_with(move |request: &wiremock::Request| {
1002 let payload: Value = serde_json::from_slice(&request.body)
1003 .expect("MCP requests should be valid JSON payloads");
1004 let method = payload
1005 .get("method")
1006 .and_then(Value::as_str)
1007 .expect("MCP requests should include method");
1008
1009 match method {
1010 "initialize" => {
1011 let initialize_call =
1012 initialize_calls_for_mock.fetch_add(1, Ordering::SeqCst);
1013 let session_id = if initialize_call == 0 {
1014 "stale-session"
1015 } else {
1016 "fresh-session"
1017 };
1018
1019 ResponseTemplate::new(200)
1020 .append_header("Mcp-Session-Id", session_id)
1021 .set_body_json(json!({
1022 "jsonrpc": "2.0",
1023 "id": "initialize",
1024 "result": {
1025 "sessionId": session_id
1026 }
1027 }))
1028 }
1029 "notifications/initialized" => {
1030 ResponseTemplate::new(200).set_body_json(json!({
1031 "jsonrpc": "2.0",
1032 "result": {}
1033 }))
1034 }
1035 "tools/list" => {
1036 let tools_list_call =
1037 tools_list_calls_for_mock.fetch_add(1, Ordering::SeqCst);
1038 if tools_list_call == 0 || !recover_on_retry {
1039 return match failure_kind {
1040 SessionFailureKind::HttpNotFound => ResponseTemplate::new(400)
1041 .set_body_string(
1042 "Request rejected: Session stale-session not found",
1043 ),
1044 SessionFailureKind::JsonRpcNotFound => ResponseTemplate::new(200)
1045 .set_body_json(json!({
1046 "jsonrpc": "2.0",
1047 "id": "list-tools",
1048 "error": {
1049 "code": -32000,
1050 "message": "Session stale-session not found"
1051 }
1052 })),
1053 };
1054 }
1055
1056 ResponseTemplate::new(200).set_body_json(json!({
1057 "jsonrpc": "2.0",
1058 "id": "list-tools",
1059 "result": {
1060 "tools": [{
1061 "name": "github.search",
1062 "description": "Search GitHub",
1063 "inputSchema": {
1064 "type": "object"
1065 }
1066 }]
1067 }
1068 }))
1069 }
1070 _ => ResponseTemplate::new(500),
1071 }
1072 })
1073 .mount(server)
1074 .await;
1075
1076 (initialize_calls, tools_list_calls)
1077 }
1078
1079 #[tokio::test]
1080 async fn list_tools_recovers_from_http_session_not_found() {
1081 let server = MockServer::start().await;
1082 let (initialize_calls, tools_list_calls) =
1083 mount_retrying_tools_list_server(&server, SessionFailureKind::HttpNotFound, true).await;
1084
1085 let mcp = create_test_mcp(&server);
1086 let tools = mcp
1087 .list_tools_with_access_token("access-token")
1088 .await
1089 .expect("HTTP session-not-found should recover");
1090
1091 assert_eq!(
1092 tools,
1093 vec![KontextTool {
1094 id: "github.search".to_string(),
1095 name: "github.search".to_string(),
1096 description: Some("Search GitHub".to_string()),
1097 input_schema: Some(json!({
1098 "type": "object"
1099 })),
1100 server: None,
1101 }]
1102 );
1103 assert_eq!(initialize_calls.load(Ordering::SeqCst), 2);
1104 assert_eq!(tools_list_calls.load(Ordering::SeqCst), 2);
1105 }
1106
1107 #[tokio::test]
1108 async fn list_tools_recovers_from_jsonrpc_session_not_found() {
1109 let server = MockServer::start().await;
1110 let (initialize_calls, tools_list_calls) =
1111 mount_retrying_tools_list_server(&server, SessionFailureKind::JsonRpcNotFound, true)
1112 .await;
1113
1114 let mcp = create_test_mcp(&server);
1115 let tools = mcp
1116 .list_tools_with_access_token("access-token")
1117 .await
1118 .expect("JSON-RPC session-not-found should recover");
1119
1120 assert_eq!(tools.len(), 1);
1121 assert_eq!(initialize_calls.load(Ordering::SeqCst), 2);
1122 assert_eq!(tools_list_calls.load(Ordering::SeqCst), 2);
1123 }
1124
1125 #[tokio::test]
1126 async fn list_tools_stale_session_retry_happens_once() {
1127 let server = MockServer::start().await;
1128 let (initialize_calls, tools_list_calls) =
1129 mount_retrying_tools_list_server(&server, SessionFailureKind::HttpNotFound, false)
1130 .await;
1131
1132 let mcp = create_test_mcp(&server);
1133 let err = mcp
1134 .list_tools_with_access_token("access-token")
1135 .await
1136 .expect_err("recovery should fail when stale session persists");
1137
1138 assert!(err.to_string().contains("Session stale-session not found"));
1139 assert_eq!(initialize_calls.load(Ordering::SeqCst), 2);
1140 assert_eq!(tools_list_calls.load(Ordering::SeqCst), 2);
1141 }
1142}