Skip to main content

kontext_dev_sdk/
client.rs

1use std::sync::Arc;
2
3use serde_json::{Map, Value, json};
4use tokio::sync::RwLock;
5
6use crate::KontextDevError;
7use crate::mcp::{
8    GatewayElicitation, GatewayToolError, GatewayToolsPayload, KontextMcp, KontextMcpConfig,
9    KontextTool, extract_text_content, has_meta_gateway_tools, parse_gateway_tools_payload,
10};
11
12const META_SEARCH_TOOLS: &str = "SEARCH_TOOLS";
13const META_EXECUTE_TOOL: &str = "EXECUTE_TOOL";
14const META_REQUEST_CAPABILITY: &str = "REQUEST_CAPABILITY";
15
16#[derive(Clone, Copy, Debug, PartialEq, Eq)]
17pub enum ClientState {
18    Idle,
19    Connecting,
20    Ready,
21    NeedsAuth,
22    Failed,
23}
24
25#[derive(Clone, Debug)]
26pub struct KontextClientConfig {
27    pub client_id: String,
28    pub redirect_uri: String,
29    pub url: Option<String>,
30    pub server_url: Option<String>,
31    pub client_secret: Option<String>,
32    pub scope: Option<String>,
33    pub resource: Option<String>,
34    pub integration_ui_url: Option<String>,
35    pub integration_return_to: Option<String>,
36    pub auth_timeout_seconds: Option<i64>,
37    pub token_cache_path: Option<String>,
38}
39
40#[derive(Clone, Debug)]
41pub struct IntegrationInfo {
42    pub id: String,
43    pub name: String,
44    pub connected: bool,
45    pub connect_url: Option<String>,
46    pub reason: Option<String>,
47}
48
49#[derive(Clone, Debug)]
50pub struct ToolResult {
51    pub content: String,
52    pub raw: serde_json::Value,
53}
54
55#[derive(Clone, Debug)]
56pub struct ConnectSessionResult {
57    pub connect_url: String,
58    pub session_id: String,
59    pub expires_at: String,
60}
61
62#[derive(Clone)]
63pub struct KontextClient {
64    state: Arc<RwLock<ClientState>>,
65    mcp: KontextMcp,
66    meta_tool_mode: Arc<RwLock<Option<bool>>>,
67}
68
69impl KontextClient {
70    pub fn new(config: KontextClientConfig) -> Self {
71        let mcp = KontextMcp::new(KontextMcpConfig {
72            client_id: config.client_id,
73            redirect_uri: config.redirect_uri,
74            url: config.url,
75            server: config.server_url,
76            client_secret: config.client_secret,
77            scope: config.scope,
78            resource: config.resource,
79            session_key: None,
80            integration_ui_url: config.integration_ui_url,
81            integration_return_to: config.integration_return_to,
82            auth_timeout_seconds: config.auth_timeout_seconds,
83            open_connect_page_on_login: Some(true),
84            token_cache_path: config.token_cache_path,
85        });
86
87        Self {
88            state: Arc::new(RwLock::new(ClientState::Idle)),
89            mcp,
90            meta_tool_mode: Arc::new(RwLock::new(None)),
91        }
92    }
93
94    pub async fn state(&self) -> ClientState {
95        *self.state.read().await
96    }
97
98    pub fn mcp(&self) -> &KontextMcp {
99        &self.mcp
100    }
101
102    pub async fn connect(&self) -> Result<(), KontextDevError> {
103        {
104            let mut state = self.state.write().await;
105            *state = ClientState::Connecting;
106        }
107
108        match self.mcp.list_tools().await {
109            Ok(tools) => {
110                let mut mode = self.meta_tool_mode.write().await;
111                *mode = Some(has_meta_gateway_tools(&tools));
112                let mut state = self.state.write().await;
113                *state = ClientState::Ready;
114                Ok(())
115            }
116            Err(err) => {
117                let mut state = self.state.write().await;
118                *state = if is_auth_error(&err) {
119                    ClientState::NeedsAuth
120                } else {
121                    ClientState::Failed
122                };
123                Err(err)
124            }
125        }
126    }
127
128    pub async fn disconnect(&self) {
129        let mut mode = self.meta_tool_mode.write().await;
130        *mode = None;
131        let mut state = self.state.write().await;
132        *state = ClientState::Idle;
133    }
134
135    pub async fn get_connect_page_url(&self) -> Result<ConnectSessionResult, KontextDevError> {
136        let session = self.mcp.authenticate_mcp().await?;
137        build_connect_session_result(self.mcp.client(), &session.gateway_token.access_token).await
138    }
139
140    pub async fn sign_in(&self) -> Result<(), KontextDevError> {
141        self.connect().await
142    }
143
144    pub async fn sign_out(&self) -> Result<(), KontextDevError> {
145        self.mcp.client().clear_token_cache()?;
146        self.disconnect().await;
147        Ok(())
148    }
149
150    pub async fn integrations_list(&self) -> Result<Vec<IntegrationInfo>, KontextDevError> {
151        self.ensure_connected().await?;
152
153        if self.is_meta_tool_mode().await? {
154            let payload = self.fetch_gateway_tools(Some(100)).await?;
155            return Ok(parse_integration_status(&payload));
156        }
157
158        let records = self.mcp.list_integrations().await?;
159        Ok(records
160            .into_iter()
161            .map(|record| IntegrationInfo {
162                id: record.id,
163                name: record.name,
164                connected: record
165                    .connection
166                    .as_ref()
167                    .map(|c| c.connected)
168                    .unwrap_or(false),
169                connect_url: None,
170                reason: None,
171            })
172            .collect())
173    }
174
175    pub async fn tools_list(&self) -> Result<Vec<KontextTool>, KontextDevError> {
176        self.ensure_connected().await?;
177
178        let mcp_tools = self.mcp.list_tools().await?;
179        let non_meta = mcp_tools
180            .iter()
181            .filter(|tool| !is_gateway_meta_tool(tool.name.as_str()))
182            .cloned()
183            .collect::<Vec<_>>();
184
185        if !non_meta.is_empty() || !has_meta_gateway_tools(&mcp_tools) {
186            let mut mode = self.meta_tool_mode.write().await;
187            *mode = Some(false);
188            return Ok(non_meta);
189        }
190
191        let mut mode = self.meta_tool_mode.write().await;
192        *mode = Some(true);
193        drop(mode);
194
195        let payload = self.fetch_gateway_tools(Some(100)).await?;
196        let mut tools = payload.tools;
197        append_request_capability_tool(&mut tools, mcp_tools.as_slice());
198        Ok(tools)
199    }
200
201    pub async fn tools_execute(
202        &self,
203        tool_id: &str,
204        args: Option<serde_json::Map<String, serde_json::Value>>,
205    ) -> Result<ToolResult, KontextDevError> {
206        self.ensure_connected().await?;
207
208        let raw = if self.is_meta_tool_mode().await? {
209            if tool_id == META_REQUEST_CAPABILITY {
210                self.mcp.call_tool(META_REQUEST_CAPABILITY, args).await?
211            } else {
212                let mut execute_args = Map::new();
213                execute_args.insert("tool_id".to_string(), Value::String(tool_id.to_string()));
214                execute_args.insert(
215                    "tool_arguments".to_string(),
216                    Value::Object(args.unwrap_or_default()),
217                );
218                self.mcp
219                    .call_tool(META_EXECUTE_TOOL, Some(execute_args))
220                    .await?
221            }
222        } else {
223            self.mcp.call_tool(tool_id, args).await?
224        };
225
226        Ok(ToolResult {
227            content: extract_text_content(&raw),
228            raw,
229        })
230    }
231
232    async fn ensure_connected(&self) -> Result<(), KontextDevError> {
233        let state = self.state().await;
234        if state == ClientState::Ready {
235            return Ok(());
236        }
237        self.connect().await
238    }
239
240    async fn is_meta_tool_mode(&self) -> Result<bool, KontextDevError> {
241        if let Some(mode) = *self.meta_tool_mode.read().await {
242            return Ok(mode);
243        }
244
245        let tools = self.mcp.list_tools().await?;
246        let mode = has_meta_gateway_tools(&tools);
247        let mut lock = self.meta_tool_mode.write().await;
248        *lock = Some(mode);
249        Ok(mode)
250    }
251
252    async fn fetch_gateway_tools(
253        &self,
254        limit: Option<u32>,
255    ) -> Result<GatewayToolsPayload, KontextDevError> {
256        let result = self
257            .mcp
258            .call_tool(
259                META_SEARCH_TOOLS,
260                Some({
261                    let mut args = Map::new();
262                    if let Some(limit) = limit {
263                        args.insert("limit".to_string(), json!(limit));
264                    }
265                    args
266                }),
267            )
268            .await?;
269        parse_gateway_tools_payload(&result)
270    }
271}
272
273fn is_gateway_meta_tool(tool_name: &str) -> bool {
274    matches!(
275        tool_name,
276        META_SEARCH_TOOLS | META_EXECUTE_TOOL | META_REQUEST_CAPABILITY
277    )
278}
279
280fn append_request_capability_tool(tools: &mut Vec<KontextTool>, mcp_tools: &[KontextTool]) {
281    if tools
282        .iter()
283        .any(|tool| tool.name == META_REQUEST_CAPABILITY)
284    {
285        return;
286    }
287    let Some(capability_tool) = mcp_tools
288        .iter()
289        .find(|tool| tool.name == META_REQUEST_CAPABILITY)
290    else {
291        return;
292    };
293
294    tools.push(capability_tool.clone());
295}
296
297pub fn create_kontext_client(config: KontextClientConfig) -> KontextClient {
298    KontextClient::new(config)
299}
300
301async fn build_connect_session_result(
302    client: &crate::KontextDevClient,
303    gateway_access_token: &str,
304) -> Result<ConnectSessionResult, KontextDevError> {
305    let connect_session = client.create_connect_session(gateway_access_token).await?;
306    let connect_url = client.integration_connect_url(&connect_session.session_id)?;
307
308    Ok(ConnectSessionResult {
309        connect_url,
310        session_id: connect_session.session_id,
311        expires_at: connect_session.expires_at,
312    })
313}
314
315fn parse_integration_status(payload: &GatewayToolsPayload) -> Vec<IntegrationInfo> {
316    let mut seen = std::collections::HashSet::<String>::new();
317    let mut out = Vec::new();
318
319    for tool in &payload.tools {
320        let Some(server) = tool.server.as_ref() else {
321            continue;
322        };
323        if !seen.insert(server.id.clone()) {
324            continue;
325        }
326        out.push(IntegrationInfo {
327            id: server.id.clone(),
328            name: server.name.clone().unwrap_or_else(|| server.id.clone()),
329            connected: true,
330            connect_url: None,
331            reason: None,
332        });
333    }
334
335    for GatewayToolError {
336        server_id,
337        server_name,
338        reason,
339    } in &payload.errors
340    {
341        if !seen.insert(server_id.clone()) {
342            continue;
343        }
344        let connect_url = payload.elicitations.iter().find_map(
345            |GatewayElicitation {
346                 url,
347                 integration_id,
348                 ..
349             }| {
350                if integration_id.as_deref() == Some(server_id.as_str()) {
351                    Some(url.clone())
352                } else {
353                    None
354                }
355            },
356        );
357        out.push(IntegrationInfo {
358            id: server_id.clone(),
359            name: server_name.clone().unwrap_or_else(|| server_id.clone()),
360            connected: false,
361            connect_url,
362            reason: reason.clone(),
363        });
364    }
365
366    out
367}
368
369fn is_auth_error(err: &KontextDevError) -> bool {
370    matches!(
371        err,
372        KontextDevError::OAuthCallbackTimeout { .. }
373            | KontextDevError::OAuthCallbackCancelled
374            | KontextDevError::MissingAuthorizationCode
375            | KontextDevError::OAuthCallbackError { .. }
376            | KontextDevError::InvalidOAuthState
377            | KontextDevError::TokenRequest { .. }
378            | KontextDevError::TokenExchange { .. }
379    )
380}
381
382#[cfg(test)]
383mod tests {
384    use super::*;
385    use wiremock::Mock;
386    use wiremock::MockServer;
387    use wiremock::ResponseTemplate;
388    use wiremock::matchers::header;
389    use wiremock::matchers::method;
390    use wiremock::matchers::path;
391
392    #[tokio::test]
393    async fn build_connect_session_result_uses_one_session_and_matching_connect_url() {
394        let server = MockServer::start().await;
395        let session_id = "session-123";
396        let expires_at = "2030-01-01T00:00:00Z";
397        let access_token = "test-gateway-token";
398
399        Mock::given(method("POST"))
400            .and(path("/mcp/connect-session"))
401            .and(header("authorization", format!("Bearer {access_token}")))
402            .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
403                "sessionId": session_id,
404                "expiresAt": expires_at
405            })))
406            .expect(1)
407            .mount(&server)
408            .await;
409
410        let client = crate::KontextDevClient::new(crate::KontextDevConfig {
411            server: server.uri(),
412            client_id: "client-id".to_string(),
413            client_secret: None,
414            scope: "".to_string(),
415            server_name: "kontext-dev".to_string(),
416            resource: "mcp-gateway".to_string(),
417            integration_ui_url: Some("https://app.kontext.dev".to_string()),
418            integration_return_to: None,
419            open_connect_page_on_login: true,
420            auth_timeout_seconds: 300,
421            token_cache_path: None,
422            redirect_uri: "http://localhost:3333/callback".to_string(),
423        });
424
425        let result = build_connect_session_result(&client, access_token)
426            .await
427            .expect("connect session result should be built");
428
429        assert_eq!(result.session_id, session_id);
430        assert_eq!(result.expires_at, expires_at);
431        assert_eq!(
432            result.connect_url,
433            format!("https://app.kontext.dev/oauth/connect?session={session_id}")
434        );
435    }
436}