Skip to main content

kontext_dev_sdk/
mcp.rs

1use std::sync::Arc;
2
3use serde::{Deserialize, Serialize};
4use serde_json::{Value, json};
5use tokio::sync::Mutex;
6
7use crate::KontextAuthSession;
8use crate::KontextDevClient;
9use crate::KontextDevConfig;
10use crate::KontextDevError;
11
12pub const DEFAULT_SERVER: &str = "https://api.kontext.dev";
13const MCP_SESSION_HEADER: &str = "Mcp-Session-Id";
14const META_SEARCH_TOOLS: &str = "SEARCH_TOOLS";
15const META_EXECUTE_TOOL: &str = "EXECUTE_TOOL";
16const DEFAULT_MCP_PROTOCOL_VERSION: &str = "2025-06-18";
17const STREAMABLE_HTTP_ACCEPT: &str = "application/json, text/event-stream";
18const STREAM_CONTENT_TYPE: &str = "text/event-stream";
19
20pub fn normalize_kontext_server_url(server: &str) -> String {
21    let mut url = server.trim_end_matches('/').to_string();
22    if let Some(stripped) = url.strip_suffix("/api/v1") {
23        url = stripped.to_string();
24    }
25    if let Some(stripped) = url.strip_suffix("/mcp") {
26        url = stripped.to_string();
27    }
28    url.trim_end_matches('/').to_string()
29}
30
31#[derive(Clone, Debug)]
32pub struct KontextMcpConfig {
33    pub client_session_id: String,
34    pub client_id: String,
35    pub redirect_uri: String,
36    pub url: Option<String>,
37    pub server: Option<String>,
38    pub client_secret: Option<String>,
39    pub scope: Option<String>,
40    pub resource: Option<String>,
41    pub session_key: Option<String>,
42    pub integration_ui_url: Option<String>,
43    pub integration_return_to: Option<String>,
44    pub auth_timeout_seconds: Option<i64>,
45    pub open_connect_page_on_login: Option<bool>,
46    pub token_cache_path: Option<String>,
47}
48
49impl Default for KontextMcpConfig {
50    fn default() -> Self {
51        Self {
52            client_session_id: String::new(),
53            client_id: String::new(),
54            redirect_uri: "http://localhost:3333/callback".to_string(),
55            url: None,
56            server: Some(DEFAULT_SERVER.to_string()),
57            client_secret: None,
58            scope: None,
59            resource: None,
60            session_key: None,
61            integration_ui_url: None,
62            integration_return_to: None,
63            auth_timeout_seconds: None,
64            open_connect_page_on_login: None,
65            token_cache_path: None,
66        }
67    }
68}
69
70#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
71#[serde(rename_all = "snake_case")]
72pub enum RuntimeIntegrationCategory {
73    GatewayRemoteMcp,
74    InternalMcpCredentials,
75}
76
77#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
78#[serde(rename_all = "snake_case")]
79pub enum RuntimeIntegrationConnectType {
80    Oauth,
81    UserToken,
82    Credentials,
83    None,
84}
85
86#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
87#[serde(rename_all = "camelCase")]
88pub struct RuntimeIntegrationRecord {
89    pub id: String,
90    pub name: String,
91    pub url: String,
92    pub category: RuntimeIntegrationCategory,
93    pub connect_type: RuntimeIntegrationConnectType,
94    #[serde(skip_serializing_if = "Option::is_none")]
95    pub auth_mode: Option<String>,
96    #[serde(skip_serializing_if = "Option::is_none")]
97    pub credential_schema: Option<serde_json::Value>,
98    #[serde(skip_serializing_if = "Option::is_none")]
99    pub requires_oauth: Option<bool>,
100    #[serde(skip_serializing_if = "Option::is_none")]
101    pub connection: Option<RuntimeIntegrationConnection>,
102}
103
104#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
105#[serde(rename_all = "camelCase")]
106pub struct RuntimeIntegrationConnection {
107    pub connected: bool,
108    pub status: String,
109    #[serde(skip_serializing_if = "Option::is_none")]
110    pub expires_at: Option<String>,
111    #[serde(skip_serializing_if = "Option::is_none")]
112    pub display_name: Option<String>,
113}
114
115#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)]
116#[serde(rename_all = "camelCase")]
117pub struct KontextTool {
118    pub id: String,
119    pub name: String,
120    #[serde(skip_serializing_if = "Option::is_none")]
121    pub description: Option<String>,
122    #[serde(skip_serializing_if = "Option::is_none")]
123    pub input_schema: Option<serde_json::Value>,
124    #[serde(skip_serializing_if = "Option::is_none")]
125    pub server: Option<KontextToolServer>,
126}
127
128#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
129#[serde(rename_all = "camelCase")]
130pub struct KontextToolServer {
131    pub id: String,
132    #[serde(skip_serializing_if = "Option::is_none")]
133    pub name: Option<String>,
134}
135
136#[derive(Clone, Debug, Default)]
137struct McpSessionState {
138    session_id: Option<String>,
139    access_token: Option<String>,
140}
141
142#[derive(Clone, Debug)]
143pub struct KontextMcp {
144    config: KontextMcpConfig,
145    client: KontextDevClient,
146    http: reqwest::Client,
147    session: Arc<Mutex<McpSessionState>>,
148}
149
150impl KontextMcp {
151    pub fn new(config: KontextMcpConfig) -> Self {
152        let server =
153            normalize_kontext_server_url(config.server.as_deref().unwrap_or(DEFAULT_SERVER));
154        let sdk_config = KontextDevConfig {
155            server,
156            client_id: config.client_id.clone(),
157            client_secret: config.client_secret.clone(),
158            scope: config.scope.clone().unwrap_or_default(),
159            server_name: "kontext-dev".to_string(),
160            resource: config
161                .resource
162                .clone()
163                .unwrap_or_else(|| "mcp-gateway".to_string()),
164            integration_ui_url: config.integration_ui_url.clone(),
165            integration_return_to: config.integration_return_to.clone(),
166            open_connect_page_on_login: config.open_connect_page_on_login.unwrap_or(true),
167            auth_timeout_seconds: config.auth_timeout_seconds.unwrap_or(300),
168            token_cache_path: config.token_cache_path.clone(),
169            redirect_uri: config.redirect_uri.clone(),
170        };
171
172        Self {
173            config,
174            client: KontextDevClient::new(sdk_config),
175            http: reqwest::Client::new(),
176            session: Arc::new(Mutex::new(McpSessionState::default())),
177        }
178    }
179
180    pub fn client(&self) -> &KontextDevClient {
181        &self.client
182    }
183
184    pub async fn authenticate_mcp(&self) -> Result<KontextAuthSession, KontextDevError> {
185        self.client.authenticate_mcp().await
186    }
187
188    pub fn mcp_url(&self) -> Result<String, KontextDevError> {
189        if let Some(url) = &self.config.url {
190            return Ok(url.clone());
191        }
192        self.client.mcp_url()
193    }
194
195    pub async fn clear_cached_session(&self) {
196        self.invalidate_session().await;
197    }
198
199    pub async fn list_integrations(
200        &self,
201    ) -> Result<Vec<RuntimeIntegrationRecord>, KontextDevError> {
202        let session = self.authenticate_mcp().await?;
203        let base = self.client.server_base_url()?;
204        let response = self
205            .http
206            .get(format!("{}/mcp/integrations", base.trim_end_matches('/')))
207            .bearer_auth(session.gateway_token.access_token)
208            .send()
209            .await
210            .map_err(|err| KontextDevError::ConnectSession {
211                message: err.to_string(),
212            })?;
213
214        if !response.status().is_success() {
215            let status = response.status();
216            let body = response.text().await.unwrap_or_default();
217            return Err(KontextDevError::ConnectSession {
218                message: format!("{status}: {body}"),
219            });
220        }
221
222        #[derive(Deserialize)]
223        struct IntegrationsResponse {
224            items: Vec<RuntimeIntegrationRecord>,
225        }
226
227        let payload = response
228            .json::<IntegrationsResponse>()
229            .await
230            .map_err(|err| KontextDevError::ConnectSession {
231                message: err.to_string(),
232            })?;
233
234        Ok(payload.items)
235    }
236
237    pub async fn list_tools(&self) -> Result<Vec<KontextTool>, KontextDevError> {
238        let session = self.authenticate_mcp().await?;
239        self.list_tools_with_access_token(&session.gateway_token.access_token)
240            .await
241    }
242
243    pub async fn list_tools_with_access_token(
244        &self,
245        access_token: &str,
246    ) -> Result<Vec<KontextTool>, KontextDevError> {
247        let result = self
248            .json_rpc_with_session(
249                access_token,
250                "tools/list",
251                json!({}),
252                Some("list-tools"),
253                true,
254            )
255            .await?;
256        parse_tools_list_result(&result)
257    }
258
259    pub async fn call_tool(
260        &self,
261        tool_id: &str,
262        args: Option<serde_json::Map<String, serde_json::Value>>,
263    ) -> Result<serde_json::Value, KontextDevError> {
264        let session = self.authenticate_mcp().await?;
265        self.call_tool_with_access_token(&session.gateway_token.access_token, tool_id, args)
266            .await
267    }
268
269    pub async fn call_tool_with_access_token(
270        &self,
271        access_token: &str,
272        tool_id: &str,
273        args: Option<serde_json::Map<String, serde_json::Value>>,
274    ) -> Result<serde_json::Value, KontextDevError> {
275        self.json_rpc_with_session(
276            access_token,
277            "tools/call",
278            json!({ "name": tool_id, "arguments": args.unwrap_or_default() }),
279            Some("call-tool"),
280            true,
281        )
282        .await
283    }
284
285    async fn json_rpc_with_session(
286        &self,
287        access_token: &str,
288        method: &str,
289        params: Value,
290        id: Option<&str>,
291        allow_session_retry: bool,
292    ) -> Result<Value, KontextDevError> {
293        let max_attempts = if allow_session_retry { 2 } else { 1 };
294        for attempt in 0..max_attempts {
295            let session_id = self.ensure_mcp_session(access_token).await?;
296
297            let response = self
298                .http
299                .post(self.mcp_url()?)
300                .bearer_auth(access_token)
301                .header(reqwest::header::ACCEPT, STREAMABLE_HTTP_ACCEPT)
302                .header(MCP_SESSION_HEADER, &session_id)
303                .json(&json!({
304                    "jsonrpc": "2.0",
305                    "id": id.unwrap_or("1"),
306                    "method": method,
307                    "params": params,
308                }))
309                .send()
310                .await
311                .map_err(|err| KontextDevError::ConnectSession {
312                    message: err.to_string(),
313                })?;
314
315            if !response.status().is_success() {
316                let status = response.status();
317                let body = response.text().await.unwrap_or_default();
318                let retryable =
319                    attempt + 1 < max_attempts && is_invalid_session_response_body(body.as_str());
320                if retryable {
321                    self.invalidate_session().await;
322                    continue;
323                }
324                return Err(KontextDevError::ConnectSession {
325                    message: format!("{status}: {body}"),
326                });
327            }
328
329            let payload = parse_json_or_streamable_response(response).await?;
330
331            if let Some(error) = payload.get("error") {
332                let message = extract_jsonrpc_error_message(error);
333                let retryable =
334                    attempt + 1 < max_attempts && is_invalid_session_jsonrpc_error(error);
335                if retryable {
336                    self.invalidate_session().await;
337                    continue;
338                }
339                return Err(KontextDevError::ConnectSession { message });
340            }
341
342            return Ok(payload.get("result").cloned().unwrap_or(Value::Null));
343        }
344
345        Err(KontextDevError::ConnectSession {
346            message: "MCP request failed after session retry".to_string(),
347        })
348    }
349
350    async fn ensure_mcp_session(&self, access_token: &str) -> Result<String, KontextDevError> {
351        {
352            let guard = self.session.lock().await;
353            if guard.access_token.as_deref() == Some(access_token)
354                && let Some(session_id) = guard.session_id.clone()
355            {
356                return Ok(session_id);
357            }
358        }
359
360        let initialize_response = self
361            .http
362            .post(self.mcp_url()?)
363            .bearer_auth(access_token)
364            .header(reqwest::header::ACCEPT, STREAMABLE_HTTP_ACCEPT)
365            .json(&json!({
366                "jsonrpc": "2.0",
367                "id": "initialize",
368                "method": "initialize",
369                "params": {
370                    "protocolVersion": DEFAULT_MCP_PROTOCOL_VERSION,
371                    "capabilities": {
372                        "tools": {}
373                    },
374                    "clientInfo": {
375                        "name": "kontext-dev-sdk-rs",
376                        "version": env!("CARGO_PKG_VERSION"),
377                        "sessionId": self.config.client_session_id
378                    }
379                }
380            }))
381            .send()
382            .await
383            .map_err(|err| KontextDevError::ConnectSession {
384                message: err.to_string(),
385            })?;
386
387        if !initialize_response.status().is_success() {
388            let status = initialize_response.status();
389            let body = initialize_response.text().await.unwrap_or_default();
390            return Err(KontextDevError::ConnectSession {
391                message: format!("{status}: {body}"),
392            });
393        }
394
395        let session_header = initialize_response
396            .headers()
397            .get(MCP_SESSION_HEADER)
398            .or_else(|| initialize_response.headers().get("mcp-session-id"))
399            .and_then(|value| value.to_str().ok())
400            .map(|value| value.trim().to_string());
401
402        let initialize_payload = parse_json_or_streamable_response(initialize_response).await?;
403
404        if let Some(error) = initialize_payload.get("error") {
405            return Err(KontextDevError::ConnectSession {
406                message: extract_jsonrpc_error_message(error),
407            });
408        }
409
410        let session_id = session_header
411            .or_else(|| {
412                initialize_payload
413                    .get("result")
414                    .and_then(|result| result.get("sessionId"))
415                    .and_then(|value| value.as_str())
416                    .map(|value| value.to_string())
417            })
418            .or_else(|| {
419                initialize_payload
420                    .get("result")
421                    .and_then(|result| result.get("session_id"))
422                    .and_then(|value| value.as_str())
423                    .map(|value| value.to_string())
424            })
425            .ok_or_else(|| KontextDevError::ConnectSession {
426                message: "MCP initialize did not return a session id".to_string(),
427            })?;
428
429        // Best-effort initialized notification. Many servers accept requests
430        // without it, but send it to follow the Streamable HTTP MCP handshake.
431        let _ = self
432            .http
433            .post(self.mcp_url()?)
434            .bearer_auth(access_token)
435            .header(reqwest::header::ACCEPT, STREAMABLE_HTTP_ACCEPT)
436            .header(MCP_SESSION_HEADER, &session_id)
437            .json(&json!({
438                "jsonrpc": "2.0",
439                "method": "notifications/initialized",
440                "params": {}
441            }))
442            .send()
443            .await;
444
445        {
446            let mut guard = self.session.lock().await;
447            guard.session_id = Some(session_id.clone());
448            guard.access_token = Some(access_token.to_string());
449        }
450
451        Ok(session_id)
452    }
453
454    async fn invalidate_session(&self) {
455        let mut guard = self.session.lock().await;
456        guard.session_id = None;
457        guard.access_token = None;
458    }
459}
460
461fn extract_jsonrpc_error_message(error: &Value) -> String {
462    error
463        .get("message")
464        .and_then(|value| value.as_str())
465        .map(ToString::to_string)
466        .or_else(|| {
467            error
468                .get("error_description")
469                .and_then(|value| value.as_str())
470                .map(ToString::to_string)
471        })
472        .unwrap_or_else(|| error.to_string())
473}
474
475fn is_invalid_session_error(message: &str) -> bool {
476    let lower = message.to_ascii_lowercase();
477    lower.contains("no valid session id")
478        || lower.contains("no valid session-id")
479        || lower.contains("invalid session")
480}
481
482fn is_session_not_found_error(message: &str) -> bool {
483    let lower = message.to_ascii_lowercase();
484    lower.contains("session") && lower.contains("not found")
485}
486
487fn is_invalid_session_jsonrpc_error(error: &Value) -> bool {
488    let message = extract_jsonrpc_error_message(error);
489    if is_invalid_session_error(message.as_str()) {
490        return true;
491    }
492
493    let code = error.get("code").and_then(Value::as_i64);
494    code == Some(-32000) && is_session_not_found_error(message.as_str())
495}
496
497fn is_invalid_session_response_body(body: &str) -> bool {
498    if is_invalid_session_error(body) || is_session_not_found_error(body) {
499        return true;
500    }
501
502    if let Ok(payload) = serde_json::from_str::<Value>(body)
503        && let Some(error) = payload.get("error")
504    {
505        return is_invalid_session_jsonrpc_error(error);
506    }
507
508    false
509}
510
511async fn parse_json_or_streamable_response(
512    response: reqwest::Response,
513) -> Result<Value, KontextDevError> {
514    let content_type = response
515        .headers()
516        .get(reqwest::header::CONTENT_TYPE)
517        .and_then(|value| value.to_str().ok())
518        .map(|value| value.to_ascii_lowercase())
519        .unwrap_or_default();
520    let body = response
521        .text()
522        .await
523        .map_err(|err| KontextDevError::ConnectSession {
524            message: err.to_string(),
525        })?;
526
527    parse_json_or_streamable_body(&body, &content_type)
528        .map_err(|message| KontextDevError::ConnectSession { message })
529}
530
531fn parse_json_or_streamable_body(body: &str, content_type: &str) -> Result<Value, String> {
532    let parse_json = || serde_json::from_str::<Value>(body).map_err(|err| err.to_string());
533    let parse_sse = || parse_sse_last_json_event(body);
534
535    if content_type.contains(STREAM_CONTENT_TYPE) {
536        return parse_sse().ok_or_else(|| {
537            "failed to parse streamable MCP response as SSE JSON events".to_string()
538        });
539    }
540
541    parse_json().or_else(|json_err| {
542        parse_sse().ok_or_else(|| format!("failed to decode response body: {json_err}"))
543    })
544}
545
546fn parse_sse_last_json_event(body: &str) -> Option<Value> {
547    let mut current_data = Vec::<String>::new();
548    let mut last_json = None;
549
550    let flush_data = |current_data: &mut Vec<String>, last_json: &mut Option<Value>| {
551        if current_data.is_empty() {
552            return;
553        }
554        let data = current_data.join("\n");
555        current_data.clear();
556        let trimmed = data.trim();
557        if trimmed.is_empty() || trimmed == "[DONE]" {
558            return;
559        }
560        if let Ok(value) = serde_json::from_str::<Value>(trimmed) {
561            *last_json = Some(value);
562        }
563    };
564
565    for line in body.lines() {
566        let line = line.trim_end_matches('\r');
567        if line.is_empty() {
568            flush_data(&mut current_data, &mut last_json);
569            continue;
570        }
571        if let Some(data) = line.strip_prefix("data:") {
572            current_data.push(data.trim_start().to_string());
573            continue;
574        }
575        if let Ok(value) = serde_json::from_str::<Value>(line) {
576            last_json = Some(value);
577        }
578    }
579    flush_data(&mut current_data, &mut last_json);
580
581    last_json
582}
583
584pub(crate) fn has_meta_gateway_tools(tools: &[KontextTool]) -> bool {
585    let mut has_search = false;
586    let mut has_execute = false;
587    for tool in tools {
588        if tool.name == META_SEARCH_TOOLS {
589            has_search = true;
590        } else if tool.name == META_EXECUTE_TOOL {
591            has_execute = true;
592        }
593    }
594    has_search && has_execute
595}
596
597pub(crate) fn extract_json_resource_text(result: &Value) -> Option<String> {
598    let content = result.get("content")?.as_array()?;
599    for item in content {
600        if item.get("type").and_then(Value::as_str) != Some("resource") {
601            continue;
602        }
603        let Some(resource) = item.get("resource") else {
604            continue;
605        };
606        if resource.get("mimeType").and_then(Value::as_str) != Some("application/json") {
607            continue;
608        }
609        if let Some(text) = resource.get("text").and_then(Value::as_str) {
610            return Some(text.to_string());
611        }
612    }
613    None
614}
615
616pub(crate) fn extract_text_content(result: &Value) -> String {
617    let Some(content) = result.get("content").and_then(Value::as_array) else {
618        return result.to_string();
619    };
620
621    let mut text_items = Vec::new();
622    for item in content {
623        if item.get("type").and_then(Value::as_str) == Some("text")
624            && let Some(text) = item.get("text").and_then(Value::as_str)
625        {
626            text_items.push(text.to_string());
627        }
628    }
629    if !text_items.is_empty() {
630        return text_items.join("\n");
631    }
632
633    let mut resource_items = Vec::new();
634    for item in content {
635        if item.get("type").and_then(Value::as_str) != Some("resource") {
636            continue;
637        }
638        let Some(resource_text) = item
639            .get("resource")
640            .and_then(|resource| resource.get("text"))
641            .and_then(Value::as_str)
642        else {
643            continue;
644        };
645
646        let parsed = serde_json::from_str::<Value>(resource_text)
647            .ok()
648            .map(|value| extract_text_content(&value))
649            .unwrap_or_else(|| resource_text.to_string());
650        resource_items.push(parsed);
651    }
652
653    if !resource_items.is_empty() {
654        return resource_items.join("\n");
655    }
656
657    content
658        .iter()
659        .map(Value::to_string)
660        .collect::<Vec<_>>()
661        .join("\n")
662}
663
664#[derive(Clone, Debug)]
665pub(crate) struct GatewayToolsPayload {
666    pub tools: Vec<KontextTool>,
667    pub errors: Vec<GatewayToolError>,
668    pub elicitations: Vec<GatewayElicitation>,
669}
670
671#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
672#[serde(rename_all = "camelCase")]
673pub struct GatewayToolError {
674    pub server_id: String,
675    #[serde(default)]
676    pub server_name: Option<String>,
677    #[serde(default)]
678    pub reason: Option<String>,
679}
680
681#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
682#[serde(rename_all = "camelCase")]
683pub struct GatewayElicitation {
684    pub url: String,
685    #[serde(default)]
686    pub message: Option<String>,
687    #[serde(default)]
688    pub integration_id: Option<String>,
689    #[serde(default)]
690    pub integration_name: Option<String>,
691}
692
693#[derive(Clone, Debug, Deserialize)]
694#[serde(rename_all = "camelCase")]
695struct GatewayToolSummary {
696    id: String,
697    name: String,
698    #[serde(default)]
699    description: Option<String>,
700    #[serde(default)]
701    input_schema: Option<Value>,
702    #[serde(default)]
703    server: Option<GatewayToolServer>,
704}
705
706#[derive(Clone, Debug, Deserialize)]
707#[serde(rename_all = "camelCase")]
708struct GatewayToolServer {
709    #[serde(default)]
710    id: Option<String>,
711    #[serde(default)]
712    name: Option<String>,
713}
714
715#[derive(Debug, Deserialize)]
716#[serde(rename_all = "camelCase")]
717struct RawTool {
718    name: String,
719    #[serde(default)]
720    description: Option<String>,
721    #[serde(default)]
722    input_schema: Option<serde_json::Value>,
723}
724
725fn parse_tools_list_result(result: &Value) -> Result<Vec<KontextTool>, KontextDevError> {
726    let tools = result
727        .get("tools")
728        .and_then(|value| value.as_array())
729        .cloned()
730        .unwrap_or_default();
731
732    tools
733        .into_iter()
734        .map(|tool| {
735            let raw: RawTool =
736                serde_json::from_value(tool).map_err(|err| KontextDevError::ConnectSession {
737                    message: format!("invalid tool payload: {err}"),
738                })?;
739
740            Ok(KontextTool {
741                id: raw.name.clone(),
742                name: raw.name,
743                description: raw.description,
744                input_schema: raw.input_schema,
745                server: None,
746            })
747        })
748        .collect()
749}
750
751pub(crate) fn parse_gateway_tools_payload(
752    raw: &Value,
753) -> Result<GatewayToolsPayload, KontextDevError> {
754    let json_text =
755        extract_json_resource_text(raw).ok_or_else(|| KontextDevError::ConnectSession {
756            message: "SEARCH_TOOLS did not return JSON resource content".to_string(),
757        })?;
758
759    let parsed = serde_json::from_str::<Value>(&json_text).map_err(|err| {
760        KontextDevError::ConnectSession {
761            message: format!("SEARCH_TOOLS returned invalid JSON: {err}"),
762        }
763    })?;
764
765    if let Some(items) = parsed.as_array() {
766        let tools = items
767            .iter()
768            .cloned()
769            .map(serde_json::from_value::<GatewayToolSummary>)
770            .collect::<Result<Vec<_>, _>>()
771            .map_err(|err| KontextDevError::ConnectSession {
772                message: format!("SEARCH_TOOLS returned invalid tool entry: {err}"),
773            })?
774            .into_iter()
775            .map(to_kontext_gateway_tool)
776            .collect();
777        return Ok(GatewayToolsPayload {
778            tools,
779            errors: Vec::new(),
780            elicitations: Vec::new(),
781        });
782    }
783
784    let Some(obj) = parsed.as_object() else {
785        return Err(KontextDevError::ConnectSession {
786            message: "SEARCH_TOOLS response was not a JSON array or object".to_string(),
787        });
788    };
789
790    let tools = obj
791        .get("items")
792        .and_then(Value::as_array)
793        .cloned()
794        .unwrap_or_default()
795        .into_iter()
796        .map(serde_json::from_value::<GatewayToolSummary>)
797        .collect::<Result<Vec<_>, _>>()
798        .map_err(|err| KontextDevError::ConnectSession {
799            message: format!("SEARCH_TOOLS items contained invalid tool data: {err}"),
800        })?
801        .into_iter()
802        .map(to_kontext_gateway_tool)
803        .collect::<Vec<_>>();
804
805    let errors = obj
806        .get("errors")
807        .and_then(Value::as_array)
808        .cloned()
809        .unwrap_or_default()
810        .into_iter()
811        .filter_map(|value| serde_json::from_value::<GatewayToolError>(value).ok())
812        .collect::<Vec<_>>();
813
814    let elicitations = obj
815        .get("elicitations")
816        .and_then(Value::as_array)
817        .cloned()
818        .unwrap_or_default()
819        .into_iter()
820        .filter_map(|value| serde_json::from_value::<GatewayElicitation>(value).ok())
821        .collect::<Vec<_>>();
822
823    Ok(GatewayToolsPayload {
824        tools,
825        errors,
826        elicitations,
827    })
828}
829
830fn to_kontext_gateway_tool(summary: GatewayToolSummary) -> KontextTool {
831    let server = summary.server.and_then(|server| {
832        server.id.map(|id| KontextToolServer {
833            id,
834            name: server.name,
835        })
836    });
837
838    KontextTool {
839        id: summary.id,
840        name: summary.name,
841        description: summary.description,
842        input_schema: summary.input_schema,
843        server,
844    }
845}
846
847#[cfg(test)]
848mod tests {
849    use super::*;
850    use std::sync::Arc;
851    use std::sync::atomic::AtomicUsize;
852    use std::sync::atomic::Ordering;
853    use wiremock::Mock;
854    use wiremock::MockServer;
855    use wiremock::ResponseTemplate;
856    use wiremock::matchers::method;
857    use wiremock::matchers::path;
858
859    #[test]
860    fn parse_json_or_streamable_body_parses_json_payload() {
861        let parsed = parse_json_or_streamable_body(
862            r#"{"jsonrpc":"2.0","result":{"ok":true}}"#,
863            "application/json",
864        )
865        .expect("json should parse");
866        assert_eq!(parsed["result"]["ok"], Value::Bool(true));
867    }
868
869    #[test]
870    fn parse_json_or_streamable_body_parses_sse_payload() {
871        let parsed = parse_json_or_streamable_body(
872            "event: message\ndata: {\"jsonrpc\":\"2.0\",\"result\":{\"sessionId\":\"abc\"}}\n\n",
873            "text/event-stream",
874        )
875        .expect("sse should parse");
876        assert_eq!(
877            parsed["result"]["sessionId"],
878            Value::String("abc".to_string())
879        );
880    }
881
882    #[test]
883    fn parse_json_or_streamable_body_falls_back_to_sse_when_content_type_is_json() {
884        let parsed = parse_json_or_streamable_body(
885            "data: {\"jsonrpc\":\"2.0\",\"result\":{\"tools\":[]}}\n\n",
886            "application/json",
887        )
888        .expect("sse fallback should parse");
889        assert_eq!(parsed["result"]["tools"], Value::Array(Vec::new()));
890    }
891
892    #[test]
893    fn raw_tool_parses_input_schema_from_camel_case_key() {
894        let parsed: RawTool = serde_json::from_value(serde_json::json!({
895            "name": "SEARCH_TOOLS",
896            "description": "Search available tools",
897            "inputSchema": { "type": "object", "properties": { "limit": { "type": "number" } } }
898        }))
899        .expect("raw tool should deserialize");
900
901        assert_eq!(parsed.name, "SEARCH_TOOLS");
902        assert_eq!(
903            parsed
904                .input_schema
905                .as_ref()
906                .and_then(|value| value.get("type"))
907                .and_then(Value::as_str),
908            Some("object")
909        );
910    }
911
912    #[test]
913    fn extract_json_resource_text_skips_resource_items_without_resource_payload() {
914        let payload = serde_json::json!({
915            "content": [
916                { "type": "resource" },
917                {
918                    "type": "resource",
919                    "resource": {
920                        "mimeType": "application/json",
921                        "text": "{\"ok\":true}"
922                    }
923                }
924            ]
925        });
926
927        assert_eq!(
928            extract_json_resource_text(&payload),
929            Some("{\"ok\":true}".to_string())
930        );
931    }
932
933    #[test]
934    fn runtime_integration_record_parses_user_token_connect_type() {
935        let parsed: RuntimeIntegrationRecord = serde_json::from_value(serde_json::json!({
936            "id": "convex-int",
937            "name": "Convex",
938            "url": "https://convex.example.com/mcp",
939            "category": "gateway_remote_mcp",
940            "connectType": "user_token"
941        }))
942        .expect("record should deserialize");
943
944        assert_eq!(
945            parsed.connect_type,
946            RuntimeIntegrationConnectType::UserToken
947        );
948    }
949
950    #[test]
951    fn runtime_integration_record_rejects_unknown_connect_type() {
952        let err = serde_json::from_value::<RuntimeIntegrationRecord>(serde_json::json!({
953            "id": "convex-int",
954            "name": "Convex",
955            "url": "https://convex.example.com/mcp",
956            "category": "gateway_remote_mcp",
957            "connectType": "api_key"
958        }))
959        .expect_err("record should reject unknown connect type");
960
961        assert!(err.to_string().contains("unknown variant"));
962    }
963
964    #[derive(Clone, Copy, Debug)]
965    enum SessionFailureKind {
966        HttpNotFound,
967        JsonRpcNotFound,
968    }
969
970    fn create_test_mcp(server: &MockServer) -> KontextMcp {
971        KontextMcp::new(KontextMcpConfig {
972            client_session_id: "client-session".to_string(),
973            client_id: "client-id".to_string(),
974            redirect_uri: "http://localhost:3333/callback".to_string(),
975            url: Some(format!("{}/mcp", server.uri())),
976            server: Some(server.uri()),
977            client_secret: None,
978            scope: None,
979            resource: None,
980            session_key: None,
981            integration_ui_url: None,
982            integration_return_to: None,
983            auth_timeout_seconds: None,
984            open_connect_page_on_login: None,
985            token_cache_path: None,
986        })
987    }
988
989    async fn mount_retrying_tools_list_server(
990        server: &MockServer,
991        failure_kind: SessionFailureKind,
992        recover_on_retry: bool,
993    ) -> (Arc<AtomicUsize>, Arc<AtomicUsize>) {
994        let initialize_calls = Arc::new(AtomicUsize::new(0));
995        let tools_list_calls = Arc::new(AtomicUsize::new(0));
996
997        let initialize_calls_for_mock = Arc::clone(&initialize_calls);
998        let tools_list_calls_for_mock = Arc::clone(&tools_list_calls);
999        Mock::given(method("POST"))
1000            .and(path("/mcp"))
1001            .respond_with(move |request: &wiremock::Request| {
1002                let payload: Value = serde_json::from_slice(&request.body)
1003                    .expect("MCP requests should be valid JSON payloads");
1004                let method = payload
1005                    .get("method")
1006                    .and_then(Value::as_str)
1007                    .expect("MCP requests should include method");
1008
1009                match method {
1010                    "initialize" => {
1011                        let initialize_call =
1012                            initialize_calls_for_mock.fetch_add(1, Ordering::SeqCst);
1013                        let session_id = if initialize_call == 0 {
1014                            "stale-session"
1015                        } else {
1016                            "fresh-session"
1017                        };
1018
1019                        ResponseTemplate::new(200)
1020                            .append_header("Mcp-Session-Id", session_id)
1021                            .set_body_json(json!({
1022                                "jsonrpc": "2.0",
1023                                "id": "initialize",
1024                                "result": {
1025                                    "sessionId": session_id
1026                                }
1027                            }))
1028                    }
1029                    "notifications/initialized" => {
1030                        ResponseTemplate::new(200).set_body_json(json!({
1031                            "jsonrpc": "2.0",
1032                            "result": {}
1033                        }))
1034                    }
1035                    "tools/list" => {
1036                        let tools_list_call =
1037                            tools_list_calls_for_mock.fetch_add(1, Ordering::SeqCst);
1038                        if tools_list_call == 0 || !recover_on_retry {
1039                            return match failure_kind {
1040                                SessionFailureKind::HttpNotFound => ResponseTemplate::new(400)
1041                                    .set_body_string(
1042                                        "Request rejected: Session stale-session not found",
1043                                    ),
1044                                SessionFailureKind::JsonRpcNotFound => ResponseTemplate::new(200)
1045                                    .set_body_json(json!({
1046                                        "jsonrpc": "2.0",
1047                                        "id": "list-tools",
1048                                        "error": {
1049                                            "code": -32000,
1050                                            "message": "Session stale-session not found"
1051                                        }
1052                                    })),
1053                            };
1054                        }
1055
1056                        ResponseTemplate::new(200).set_body_json(json!({
1057                            "jsonrpc": "2.0",
1058                            "id": "list-tools",
1059                            "result": {
1060                                "tools": [{
1061                                    "name": "github.search",
1062                                    "description": "Search GitHub",
1063                                    "inputSchema": {
1064                                        "type": "object"
1065                                    }
1066                                }]
1067                            }
1068                        }))
1069                    }
1070                    _ => ResponseTemplate::new(500),
1071                }
1072            })
1073            .mount(server)
1074            .await;
1075
1076        (initialize_calls, tools_list_calls)
1077    }
1078
1079    #[tokio::test]
1080    async fn list_tools_recovers_from_http_session_not_found() {
1081        let server = MockServer::start().await;
1082        let (initialize_calls, tools_list_calls) =
1083            mount_retrying_tools_list_server(&server, SessionFailureKind::HttpNotFound, true).await;
1084
1085        let mcp = create_test_mcp(&server);
1086        let tools = mcp
1087            .list_tools_with_access_token("access-token")
1088            .await
1089            .expect("HTTP session-not-found should recover");
1090
1091        assert_eq!(
1092            tools,
1093            vec![KontextTool {
1094                id: "github.search".to_string(),
1095                name: "github.search".to_string(),
1096                description: Some("Search GitHub".to_string()),
1097                input_schema: Some(json!({
1098                    "type": "object"
1099                })),
1100                server: None,
1101            }]
1102        );
1103        assert_eq!(initialize_calls.load(Ordering::SeqCst), 2);
1104        assert_eq!(tools_list_calls.load(Ordering::SeqCst), 2);
1105    }
1106
1107    #[tokio::test]
1108    async fn list_tools_recovers_from_jsonrpc_session_not_found() {
1109        let server = MockServer::start().await;
1110        let (initialize_calls, tools_list_calls) =
1111            mount_retrying_tools_list_server(&server, SessionFailureKind::JsonRpcNotFound, true)
1112                .await;
1113
1114        let mcp = create_test_mcp(&server);
1115        let tools = mcp
1116            .list_tools_with_access_token("access-token")
1117            .await
1118            .expect("JSON-RPC session-not-found should recover");
1119
1120        assert_eq!(tools.len(), 1);
1121        assert_eq!(initialize_calls.load(Ordering::SeqCst), 2);
1122        assert_eq!(tools_list_calls.load(Ordering::SeqCst), 2);
1123    }
1124
1125    #[tokio::test]
1126    async fn list_tools_stale_session_retry_happens_once() {
1127        let server = MockServer::start().await;
1128        let (initialize_calls, tools_list_calls) =
1129            mount_retrying_tools_list_server(&server, SessionFailureKind::HttpNotFound, false)
1130                .await;
1131
1132        let mcp = create_test_mcp(&server);
1133        let err = mcp
1134            .list_tools_with_access_token("access-token")
1135            .await
1136            .expect_err("recovery should fail when stale session persists");
1137
1138        assert!(err.to_string().contains("Session stale-session not found"));
1139        assert_eq!(initialize_calls.load(Ordering::SeqCst), 2);
1140        assert_eq!(tools_list_calls.load(Ordering::SeqCst), 2);
1141    }
1142}