Skip to main content

kontext_dev_sdk/
client.rs

1use std::sync::Arc;
2
3use serde_json::{Map, Value, json};
4use tokio::sync::RwLock;
5
6use crate::KontextDevError;
7use crate::mcp::{
8    GatewayElicitation, GatewayToolError, GatewayToolsPayload, KontextMcp, KontextMcpConfig,
9    KontextTool, extract_text_content, has_meta_gateway_tools, parse_gateway_tools_payload,
10};
11use crate::prompt_guidance::{KontextPromptGuidance, build_kontext_prompt_guidance};
12
13const META_SEARCH_TOOLS: &str = "SEARCH_TOOLS";
14const META_EXECUTE_TOOL: &str = "EXECUTE_TOOL";
15const META_REQUEST_CAPABILITY: &str = "REQUEST_CAPABILITY";
16
17#[derive(Clone, Copy, Debug, PartialEq, Eq)]
18pub enum ClientState {
19    Idle,
20    Connecting,
21    Ready,
22    NeedsAuth,
23    Failed,
24}
25
26#[derive(Clone, Debug)]
27pub struct KontextClientConfig {
28    pub client_session_id: String,
29    pub client_id: String,
30    pub redirect_uri: String,
31    pub url: Option<String>,
32    pub server_url: Option<String>,
33    pub client_secret: Option<String>,
34    pub scope: Option<String>,
35    pub resource: Option<String>,
36    pub integration_ui_url: Option<String>,
37    pub integration_return_to: Option<String>,
38    pub auth_timeout_seconds: Option<i64>,
39    pub token_cache_path: Option<String>,
40}
41
42#[derive(Clone, Debug)]
43pub struct IntegrationInfo {
44    pub id: String,
45    pub name: String,
46    pub connected: bool,
47    pub connect_url: Option<String>,
48    pub reason: Option<String>,
49}
50
51#[derive(Clone, Debug)]
52pub struct ToolResult {
53    pub content: String,
54    pub raw: serde_json::Value,
55}
56
57#[derive(Clone, Debug)]
58pub struct ConnectSessionResult {
59    pub connect_url: String,
60    pub session_id: String,
61    pub expires_at: String,
62}
63
64#[derive(Clone)]
65pub struct KontextClient {
66    state: Arc<RwLock<ClientState>>,
67    mcp: KontextMcp,
68    meta_tool_mode: Arc<RwLock<Option<bool>>>,
69}
70
71impl KontextClient {
72    pub fn new(config: KontextClientConfig) -> Self {
73        let mcp = KontextMcp::new(KontextMcpConfig {
74            client_session_id: config.client_session_id,
75            client_id: config.client_id,
76            redirect_uri: config.redirect_uri,
77            url: config.url,
78            server: config.server_url,
79            client_secret: config.client_secret,
80            scope: config.scope,
81            resource: config.resource,
82            session_key: None,
83            integration_ui_url: config.integration_ui_url,
84            integration_return_to: config.integration_return_to,
85            auth_timeout_seconds: config.auth_timeout_seconds,
86            open_connect_page_on_login: Some(true),
87            token_cache_path: config.token_cache_path,
88        });
89
90        Self {
91            state: Arc::new(RwLock::new(ClientState::Idle)),
92            mcp,
93            meta_tool_mode: Arc::new(RwLock::new(None)),
94        }
95    }
96
97    pub async fn state(&self) -> ClientState {
98        *self.state.read().await
99    }
100
101    pub fn mcp(&self) -> &KontextMcp {
102        &self.mcp
103    }
104
105    pub async fn connect(&self) -> Result<(), KontextDevError> {
106        {
107            let mut state = self.state.write().await;
108            *state = ClientState::Connecting;
109        }
110
111        match self.mcp.list_tools().await {
112            Ok(tools) => {
113                let mut mode = self.meta_tool_mode.write().await;
114                *mode = Some(has_meta_gateway_tools(&tools));
115                let mut state = self.state.write().await;
116                *state = ClientState::Ready;
117                Ok(())
118            }
119            Err(err) => {
120                let mut state = self.state.write().await;
121                *state = if is_auth_error(&err) {
122                    ClientState::NeedsAuth
123                } else {
124                    ClientState::Failed
125                };
126                Err(err)
127            }
128        }
129    }
130
131    pub async fn disconnect(&self) {
132        self.mcp.clear_cached_session().await;
133        let mut mode = self.meta_tool_mode.write().await;
134        *mode = None;
135        let mut state = self.state.write().await;
136        *state = ClientState::Idle;
137    }
138
139    pub async fn get_connect_page_url(&self) -> Result<ConnectSessionResult, KontextDevError> {
140        let session = self.mcp.authenticate_mcp().await?;
141        build_connect_session_result(self.mcp.client(), &session.gateway_token.access_token).await
142    }
143
144    pub async fn sign_in(&self) -> Result<(), KontextDevError> {
145        self.connect().await
146    }
147
148    pub async fn sign_out(&self) -> Result<(), KontextDevError> {
149        self.mcp.client().clear_token_cache()?;
150        self.disconnect().await;
151        Ok(())
152    }
153
154    pub async fn integrations_list(&self) -> Result<Vec<IntegrationInfo>, KontextDevError> {
155        self.ensure_connected().await?;
156
157        if self.is_meta_tool_mode().await? {
158            let payload = self.fetch_gateway_tools(Some(100)).await?;
159            return Ok(parse_integration_status(&payload));
160        }
161
162        let records = self.mcp.list_integrations().await?;
163        Ok(records
164            .into_iter()
165            .map(|record| IntegrationInfo {
166                id: record.id,
167                name: record.name,
168                connected: record
169                    .connection
170                    .as_ref()
171                    .map(|c| c.connected)
172                    .unwrap_or(false),
173                connect_url: None,
174                reason: None,
175            })
176            .collect())
177    }
178
179    pub async fn tools_list(&self) -> Result<Vec<KontextTool>, KontextDevError> {
180        self.ensure_connected().await?;
181
182        let mcp_tools = self.mcp.list_tools().await?;
183        let non_meta = mcp_tools
184            .iter()
185            .filter(|tool| !is_gateway_meta_tool(tool.name.as_str()))
186            .cloned()
187            .collect::<Vec<_>>();
188
189        if !non_meta.is_empty() || !has_meta_gateway_tools(&mcp_tools) {
190            let mut mode = self.meta_tool_mode.write().await;
191            *mode = Some(false);
192            return Ok(non_meta);
193        }
194
195        let mut mode = self.meta_tool_mode.write().await;
196        *mode = Some(true);
197        drop(mode);
198
199        let payload = self.fetch_gateway_tools(Some(100)).await?;
200        let mut tools = payload.tools;
201        append_request_capability_tool(&mut tools, mcp_tools.as_slice());
202        Ok(tools)
203    }
204
205    pub async fn tools_execute(
206        &self,
207        tool_id: &str,
208        args: Option<serde_json::Map<String, serde_json::Value>>,
209    ) -> Result<ToolResult, KontextDevError> {
210        self.ensure_connected().await?;
211
212        let raw = if self.is_meta_tool_mode().await? {
213            if tool_id == META_REQUEST_CAPABILITY {
214                self.mcp.call_tool(META_REQUEST_CAPABILITY, args).await?
215            } else {
216                let mut execute_args = Map::new();
217                execute_args.insert("tool_id".to_string(), Value::String(tool_id.to_string()));
218                execute_args.insert(
219                    "tool_arguments".to_string(),
220                    Value::Object(args.unwrap_or_default()),
221                );
222                self.mcp
223                    .call_tool(META_EXECUTE_TOOL, Some(execute_args))
224                    .await?
225            }
226        } else {
227            self.mcp.call_tool(tool_id, args).await?
228        };
229
230        Ok(ToolResult {
231            content: extract_text_content(&raw),
232            raw,
233        })
234    }
235
236    pub async fn prompt_guidance(&self) -> Result<KontextPromptGuidance, KontextDevError> {
237        let tools = self.tools_list().await?;
238        let integrations = self.integrations_list().await?;
239        let tool_names = tools
240            .into_iter()
241            .map(|tool| tool.name)
242            .collect::<Vec<String>>();
243
244        Ok(build_kontext_prompt_guidance(
245            tool_names.as_slice(),
246            integrations.as_slice(),
247        ))
248    }
249
250    async fn ensure_connected(&self) -> Result<(), KontextDevError> {
251        let state = self.state().await;
252        if state == ClientState::Ready {
253            return Ok(());
254        }
255        self.connect().await
256    }
257
258    async fn is_meta_tool_mode(&self) -> Result<bool, KontextDevError> {
259        if let Some(mode) = *self.meta_tool_mode.read().await {
260            return Ok(mode);
261        }
262
263        let tools = self.mcp.list_tools().await?;
264        let mode = has_meta_gateway_tools(&tools);
265        let mut lock = self.meta_tool_mode.write().await;
266        *lock = Some(mode);
267        Ok(mode)
268    }
269
270    async fn fetch_gateway_tools(
271        &self,
272        limit: Option<u32>,
273    ) -> Result<GatewayToolsPayload, KontextDevError> {
274        let result = self
275            .mcp
276            .call_tool(
277                META_SEARCH_TOOLS,
278                Some({
279                    let mut args = Map::new();
280                    if let Some(limit) = limit {
281                        args.insert("limit".to_string(), json!(limit));
282                    }
283                    args
284                }),
285            )
286            .await?;
287        parse_gateway_tools_payload(&result)
288    }
289}
290
291fn is_gateway_meta_tool(tool_name: &str) -> bool {
292    matches!(
293        tool_name,
294        META_SEARCH_TOOLS | META_EXECUTE_TOOL | META_REQUEST_CAPABILITY
295    )
296}
297
298fn append_request_capability_tool(tools: &mut Vec<KontextTool>, mcp_tools: &[KontextTool]) {
299    if tools
300        .iter()
301        .any(|tool| tool.name == META_REQUEST_CAPABILITY)
302    {
303        return;
304    }
305    let Some(capability_tool) = mcp_tools
306        .iter()
307        .find(|tool| tool.name == META_REQUEST_CAPABILITY)
308    else {
309        return;
310    };
311
312    tools.push(capability_tool.clone());
313}
314
315pub fn create_kontext_client(config: KontextClientConfig) -> KontextClient {
316    KontextClient::new(config)
317}
318
319async fn build_connect_session_result(
320    client: &crate::KontextDevClient,
321    gateway_access_token: &str,
322) -> Result<ConnectSessionResult, KontextDevError> {
323    let connect_session = client.create_connect_session(gateway_access_token).await?;
324    let connect_url = client.integration_connect_url(&connect_session.session_id)?;
325
326    Ok(ConnectSessionResult {
327        connect_url,
328        session_id: connect_session.session_id,
329        expires_at: connect_session.expires_at,
330    })
331}
332
333fn parse_integration_status(payload: &GatewayToolsPayload) -> Vec<IntegrationInfo> {
334    let mut seen = std::collections::HashSet::<String>::new();
335    let mut out = Vec::new();
336
337    for tool in &payload.tools {
338        let Some(server) = tool.server.as_ref() else {
339            continue;
340        };
341        if !seen.insert(server.id.clone()) {
342            continue;
343        }
344        out.push(IntegrationInfo {
345            id: server.id.clone(),
346            name: server.name.clone().unwrap_or_else(|| server.id.clone()),
347            connected: true,
348            connect_url: None,
349            reason: None,
350        });
351    }
352
353    for GatewayToolError {
354        server_id,
355        server_name,
356        reason,
357    } in &payload.errors
358    {
359        if !seen.insert(server_id.clone()) {
360            continue;
361        }
362        let connect_url = payload.elicitations.iter().find_map(
363            |GatewayElicitation {
364                 url,
365                 integration_id,
366                 ..
367             }| {
368                if integration_id.as_deref() == Some(server_id.as_str()) {
369                    Some(url.clone())
370                } else {
371                    None
372                }
373            },
374        );
375        out.push(IntegrationInfo {
376            id: server_id.clone(),
377            name: server_name.clone().unwrap_or_else(|| server_id.clone()),
378            connected: false,
379            connect_url,
380            reason: reason.clone(),
381        });
382    }
383
384    out
385}
386
387fn is_auth_error(err: &KontextDevError) -> bool {
388    matches!(
389        err,
390        KontextDevError::OAuthCallbackTimeout { .. }
391            | KontextDevError::OAuthCallbackCancelled
392            | KontextDevError::MissingAuthorizationCode
393            | KontextDevError::OAuthCallbackError { .. }
394            | KontextDevError::InvalidOAuthState
395            | KontextDevError::TokenRequest { .. }
396            | KontextDevError::TokenExchange { .. }
397    )
398}
399
400#[cfg(test)]
401mod tests {
402    use super::*;
403    use wiremock::Mock;
404    use wiremock::MockServer;
405    use wiremock::ResponseTemplate;
406    use wiremock::matchers::body_partial_json;
407    use wiremock::matchers::header;
408    use wiremock::matchers::method;
409    use wiremock::matchers::path;
410
411    #[tokio::test]
412    async fn build_connect_session_result_uses_one_session_and_matching_connect_url() {
413        let server = MockServer::start().await;
414        let session_id = "session-123";
415        let expires_at = "2030-01-01T00:00:00Z";
416        let access_token = "test-gateway-token";
417
418        Mock::given(method("POST"))
419            .and(path("/mcp/connect-session"))
420            .and(header("authorization", format!("Bearer {access_token}")))
421            .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
422                "sessionId": session_id,
423                "expiresAt": expires_at
424            })))
425            .expect(1)
426            .mount(&server)
427            .await;
428
429        let client = crate::KontextDevClient::new(crate::KontextDevConfig {
430            server: server.uri(),
431            client_id: "client-id".to_string(),
432            client_secret: None,
433            scope: "".to_string(),
434            server_name: "kontext-dev".to_string(),
435            resource: "mcp-gateway".to_string(),
436            integration_ui_url: Some("https://app.kontext.dev".to_string()),
437            integration_return_to: None,
438            open_connect_page_on_login: true,
439            auth_timeout_seconds: 300,
440            token_cache_path: None,
441            redirect_uri: "http://localhost:3333/callback".to_string(),
442        });
443
444        let result = build_connect_session_result(&client, access_token)
445            .await
446            .expect("connect session result should be built");
447
448        assert_eq!(result.session_id, session_id);
449        assert_eq!(result.expires_at, expires_at);
450        assert_eq!(
451            result.connect_url,
452            format!("https://app.kontext.dev/oauth/connect?session={session_id}")
453        );
454    }
455
456    #[tokio::test]
457    async fn disconnect_clears_cached_mcp_session_state() {
458        let server = MockServer::start().await;
459        let access_token = "test-gateway-token";
460
461        Mock::given(method("POST"))
462            .and(path("/mcp"))
463            .and(body_partial_json(serde_json::json!({
464                "method": "initialize"
465            })))
466            .respond_with(
467                ResponseTemplate::new(200)
468                    .append_header("Mcp-Session-Id", "session-123")
469                    .set_body_json(serde_json::json!({
470                        "jsonrpc": "2.0",
471                        "id": "initialize",
472                        "result": {
473                            "sessionId": "session-123"
474                        }
475                    })),
476            )
477            .expect(2)
478            .mount(&server)
479            .await;
480
481        Mock::given(method("POST"))
482            .and(path("/mcp"))
483            .and(body_partial_json(serde_json::json!({
484                "method": "notifications/initialized"
485            })))
486            .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
487                "jsonrpc": "2.0",
488                "result": {}
489            })))
490            .expect(2)
491            .mount(&server)
492            .await;
493
494        Mock::given(method("POST"))
495            .and(path("/mcp"))
496            .and(body_partial_json(serde_json::json!({
497                "method": "tools/list"
498            })))
499            .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
500                "jsonrpc": "2.0",
501                "id": "list-tools",
502                "result": {
503                    "tools": []
504                }
505            })))
506            .expect(2)
507            .mount(&server)
508            .await;
509
510        let client = KontextClient::new(KontextClientConfig {
511            client_session_id: "client-session".to_string(),
512            client_id: "client-id".to_string(),
513            redirect_uri: "http://localhost:3333/callback".to_string(),
514            url: Some(format!("{}/mcp", server.uri())),
515            server_url: Some(server.uri()),
516            client_secret: None,
517            scope: None,
518            resource: None,
519            integration_ui_url: None,
520            integration_return_to: None,
521            auth_timeout_seconds: None,
522            token_cache_path: None,
523        });
524
525        client
526            .mcp()
527            .list_tools_with_access_token(access_token)
528            .await
529            .expect("first tools/list should initialize and succeed");
530        client.disconnect().await;
531        client
532            .mcp()
533            .list_tools_with_access_token(access_token)
534            .await
535            .expect("second tools/list should re-initialize after disconnect");
536    }
537}