Skip to main content

kontext_dev_sdk/
client.rs

1use std::sync::Arc;
2
3use serde_json::{Map, Value, json};
4use tokio::sync::RwLock;
5
6use crate::KontextDevError;
7use crate::mcp::{
8    GatewayElicitation, GatewayToolError, GatewayToolsPayload, KontextMcp, KontextMcpConfig,
9    KontextTool, extract_text_content, has_meta_gateway_tools, parse_gateway_tools_payload,
10};
11
12const META_SEARCH_TOOLS: &str = "SEARCH_TOOLS";
13const META_EXECUTE_TOOL: &str = "EXECUTE_TOOL";
14const META_REQUEST_CAPABILITY: &str = "REQUEST_CAPABILITY";
15
16#[derive(Clone, Copy, Debug, PartialEq, Eq)]
17pub enum ClientState {
18    Idle,
19    Connecting,
20    Ready,
21    NeedsAuth,
22    Failed,
23}
24
25#[derive(Clone, Debug)]
26pub struct KontextClientConfig {
27    pub client_session_id: String,
28    pub client_id: String,
29    pub redirect_uri: String,
30    pub url: Option<String>,
31    pub server_url: Option<String>,
32    pub client_secret: Option<String>,
33    pub scope: Option<String>,
34    pub resource: Option<String>,
35    pub integration_ui_url: Option<String>,
36    pub integration_return_to: Option<String>,
37    pub auth_timeout_seconds: Option<i64>,
38    pub token_cache_path: Option<String>,
39}
40
41#[derive(Clone, Debug)]
42pub struct IntegrationInfo {
43    pub id: String,
44    pub name: String,
45    pub connected: bool,
46    pub connect_url: Option<String>,
47    pub reason: Option<String>,
48}
49
50#[derive(Clone, Debug)]
51pub struct ToolResult {
52    pub content: String,
53    pub raw: serde_json::Value,
54}
55
56#[derive(Clone, Debug)]
57pub struct ConnectSessionResult {
58    pub connect_url: String,
59    pub session_id: String,
60    pub expires_at: String,
61}
62
63#[derive(Clone)]
64pub struct KontextClient {
65    state: Arc<RwLock<ClientState>>,
66    mcp: KontextMcp,
67    meta_tool_mode: Arc<RwLock<Option<bool>>>,
68}
69
70impl KontextClient {
71    pub fn new(config: KontextClientConfig) -> Self {
72        let mcp = KontextMcp::new(KontextMcpConfig {
73            client_session_id: config.client_session_id,
74            client_id: config.client_id,
75            redirect_uri: config.redirect_uri,
76            url: config.url,
77            server: config.server_url,
78            client_secret: config.client_secret,
79            scope: config.scope,
80            resource: config.resource,
81            session_key: None,
82            integration_ui_url: config.integration_ui_url,
83            integration_return_to: config.integration_return_to,
84            auth_timeout_seconds: config.auth_timeout_seconds,
85            open_connect_page_on_login: Some(true),
86            token_cache_path: config.token_cache_path,
87        });
88
89        Self {
90            state: Arc::new(RwLock::new(ClientState::Idle)),
91            mcp,
92            meta_tool_mode: Arc::new(RwLock::new(None)),
93        }
94    }
95
96    pub async fn state(&self) -> ClientState {
97        *self.state.read().await
98    }
99
100    pub fn mcp(&self) -> &KontextMcp {
101        &self.mcp
102    }
103
104    pub async fn connect(&self) -> Result<(), KontextDevError> {
105        {
106            let mut state = self.state.write().await;
107            *state = ClientState::Connecting;
108        }
109
110        match self.mcp.list_tools().await {
111            Ok(tools) => {
112                let mut mode = self.meta_tool_mode.write().await;
113                *mode = Some(has_meta_gateway_tools(&tools));
114                let mut state = self.state.write().await;
115                *state = ClientState::Ready;
116                Ok(())
117            }
118            Err(err) => {
119                let mut state = self.state.write().await;
120                *state = if is_auth_error(&err) {
121                    ClientState::NeedsAuth
122                } else {
123                    ClientState::Failed
124                };
125                Err(err)
126            }
127        }
128    }
129
130    pub async fn disconnect(&self) {
131        let mut mode = self.meta_tool_mode.write().await;
132        *mode = None;
133        let mut state = self.state.write().await;
134        *state = ClientState::Idle;
135    }
136
137    pub async fn get_connect_page_url(&self) -> Result<ConnectSessionResult, KontextDevError> {
138        let session = self.mcp.authenticate_mcp().await?;
139        build_connect_session_result(self.mcp.client(), &session.gateway_token.access_token).await
140    }
141
142    pub async fn sign_in(&self) -> Result<(), KontextDevError> {
143        self.connect().await
144    }
145
146    pub async fn sign_out(&self) -> Result<(), KontextDevError> {
147        self.mcp.client().clear_token_cache()?;
148        self.disconnect().await;
149        Ok(())
150    }
151
152    pub async fn integrations_list(&self) -> Result<Vec<IntegrationInfo>, KontextDevError> {
153        self.ensure_connected().await?;
154
155        if self.is_meta_tool_mode().await? {
156            let payload = self.fetch_gateway_tools(Some(100)).await?;
157            return Ok(parse_integration_status(&payload));
158        }
159
160        let records = self.mcp.list_integrations().await?;
161        Ok(records
162            .into_iter()
163            .map(|record| IntegrationInfo {
164                id: record.id,
165                name: record.name,
166                connected: record
167                    .connection
168                    .as_ref()
169                    .map(|c| c.connected)
170                    .unwrap_or(false),
171                connect_url: None,
172                reason: None,
173            })
174            .collect())
175    }
176
177    pub async fn tools_list(&self) -> Result<Vec<KontextTool>, KontextDevError> {
178        self.ensure_connected().await?;
179
180        let mcp_tools = self.mcp.list_tools().await?;
181        let non_meta = mcp_tools
182            .iter()
183            .filter(|tool| !is_gateway_meta_tool(tool.name.as_str()))
184            .cloned()
185            .collect::<Vec<_>>();
186
187        if !non_meta.is_empty() || !has_meta_gateway_tools(&mcp_tools) {
188            let mut mode = self.meta_tool_mode.write().await;
189            *mode = Some(false);
190            return Ok(non_meta);
191        }
192
193        let mut mode = self.meta_tool_mode.write().await;
194        *mode = Some(true);
195        drop(mode);
196
197        let payload = self.fetch_gateway_tools(Some(100)).await?;
198        let mut tools = payload.tools;
199        append_request_capability_tool(&mut tools, mcp_tools.as_slice());
200        Ok(tools)
201    }
202
203    pub async fn tools_execute(
204        &self,
205        tool_id: &str,
206        args: Option<serde_json::Map<String, serde_json::Value>>,
207    ) -> Result<ToolResult, KontextDevError> {
208        self.ensure_connected().await?;
209
210        let raw = if self.is_meta_tool_mode().await? {
211            if tool_id == META_REQUEST_CAPABILITY {
212                self.mcp.call_tool(META_REQUEST_CAPABILITY, args).await?
213            } else {
214                let mut execute_args = Map::new();
215                execute_args.insert("tool_id".to_string(), Value::String(tool_id.to_string()));
216                execute_args.insert(
217                    "tool_arguments".to_string(),
218                    Value::Object(args.unwrap_or_default()),
219                );
220                self.mcp
221                    .call_tool(META_EXECUTE_TOOL, Some(execute_args))
222                    .await?
223            }
224        } else {
225            self.mcp.call_tool(tool_id, args).await?
226        };
227
228        Ok(ToolResult {
229            content: extract_text_content(&raw),
230            raw,
231        })
232    }
233
234    async fn ensure_connected(&self) -> Result<(), KontextDevError> {
235        let state = self.state().await;
236        if state == ClientState::Ready {
237            return Ok(());
238        }
239        self.connect().await
240    }
241
242    async fn is_meta_tool_mode(&self) -> Result<bool, KontextDevError> {
243        if let Some(mode) = *self.meta_tool_mode.read().await {
244            return Ok(mode);
245        }
246
247        let tools = self.mcp.list_tools().await?;
248        let mode = has_meta_gateway_tools(&tools);
249        let mut lock = self.meta_tool_mode.write().await;
250        *lock = Some(mode);
251        Ok(mode)
252    }
253
254    async fn fetch_gateway_tools(
255        &self,
256        limit: Option<u32>,
257    ) -> Result<GatewayToolsPayload, KontextDevError> {
258        let result = self
259            .mcp
260            .call_tool(
261                META_SEARCH_TOOLS,
262                Some({
263                    let mut args = Map::new();
264                    if let Some(limit) = limit {
265                        args.insert("limit".to_string(), json!(limit));
266                    }
267                    args
268                }),
269            )
270            .await?;
271        parse_gateway_tools_payload(&result)
272    }
273}
274
275fn is_gateway_meta_tool(tool_name: &str) -> bool {
276    matches!(
277        tool_name,
278        META_SEARCH_TOOLS | META_EXECUTE_TOOL | META_REQUEST_CAPABILITY
279    )
280}
281
282fn append_request_capability_tool(tools: &mut Vec<KontextTool>, mcp_tools: &[KontextTool]) {
283    if tools
284        .iter()
285        .any(|tool| tool.name == META_REQUEST_CAPABILITY)
286    {
287        return;
288    }
289    let Some(capability_tool) = mcp_tools
290        .iter()
291        .find(|tool| tool.name == META_REQUEST_CAPABILITY)
292    else {
293        return;
294    };
295
296    tools.push(capability_tool.clone());
297}
298
299pub fn create_kontext_client(config: KontextClientConfig) -> KontextClient {
300    KontextClient::new(config)
301}
302
303async fn build_connect_session_result(
304    client: &crate::KontextDevClient,
305    gateway_access_token: &str,
306) -> Result<ConnectSessionResult, KontextDevError> {
307    let connect_session = client.create_connect_session(gateway_access_token).await?;
308    let connect_url = client.integration_connect_url(&connect_session.session_id)?;
309
310    Ok(ConnectSessionResult {
311        connect_url,
312        session_id: connect_session.session_id,
313        expires_at: connect_session.expires_at,
314    })
315}
316
317fn parse_integration_status(payload: &GatewayToolsPayload) -> Vec<IntegrationInfo> {
318    let mut seen = std::collections::HashSet::<String>::new();
319    let mut out = Vec::new();
320
321    for tool in &payload.tools {
322        let Some(server) = tool.server.as_ref() else {
323            continue;
324        };
325        if !seen.insert(server.id.clone()) {
326            continue;
327        }
328        out.push(IntegrationInfo {
329            id: server.id.clone(),
330            name: server.name.clone().unwrap_or_else(|| server.id.clone()),
331            connected: true,
332            connect_url: None,
333            reason: None,
334        });
335    }
336
337    for GatewayToolError {
338        server_id,
339        server_name,
340        reason,
341    } in &payload.errors
342    {
343        if !seen.insert(server_id.clone()) {
344            continue;
345        }
346        let connect_url = payload.elicitations.iter().find_map(
347            |GatewayElicitation {
348                 url,
349                 integration_id,
350                 ..
351             }| {
352                if integration_id.as_deref() == Some(server_id.as_str()) {
353                    Some(url.clone())
354                } else {
355                    None
356                }
357            },
358        );
359        out.push(IntegrationInfo {
360            id: server_id.clone(),
361            name: server_name.clone().unwrap_or_else(|| server_id.clone()),
362            connected: false,
363            connect_url,
364            reason: reason.clone(),
365        });
366    }
367
368    out
369}
370
371fn is_auth_error(err: &KontextDevError) -> bool {
372    matches!(
373        err,
374        KontextDevError::OAuthCallbackTimeout { .. }
375            | KontextDevError::OAuthCallbackCancelled
376            | KontextDevError::MissingAuthorizationCode
377            | KontextDevError::OAuthCallbackError { .. }
378            | KontextDevError::InvalidOAuthState
379            | KontextDevError::TokenRequest { .. }
380            | KontextDevError::TokenExchange { .. }
381    )
382}
383
384#[cfg(test)]
385mod tests {
386    use super::*;
387    use wiremock::Mock;
388    use wiremock::MockServer;
389    use wiremock::ResponseTemplate;
390    use wiremock::matchers::header;
391    use wiremock::matchers::method;
392    use wiremock::matchers::path;
393
394    #[tokio::test]
395    async fn build_connect_session_result_uses_one_session_and_matching_connect_url() {
396        let server = MockServer::start().await;
397        let session_id = "session-123";
398        let expires_at = "2030-01-01T00:00:00Z";
399        let access_token = "test-gateway-token";
400
401        Mock::given(method("POST"))
402            .and(path("/mcp/connect-session"))
403            .and(header("authorization", format!("Bearer {access_token}")))
404            .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
405                "sessionId": session_id,
406                "expiresAt": expires_at
407            })))
408            .expect(1)
409            .mount(&server)
410            .await;
411
412        let client = crate::KontextDevClient::new(crate::KontextDevConfig {
413            server: server.uri(),
414            client_id: "client-id".to_string(),
415            client_secret: None,
416            scope: "".to_string(),
417            server_name: "kontext-dev".to_string(),
418            resource: "mcp-gateway".to_string(),
419            integration_ui_url: Some("https://app.kontext.dev".to_string()),
420            integration_return_to: None,
421            open_connect_page_on_login: true,
422            auth_timeout_seconds: 300,
423            token_cache_path: None,
424            redirect_uri: "http://localhost:3333/callback".to_string(),
425        });
426
427        let result = build_connect_session_result(&client, access_token)
428            .await
429            .expect("connect session result should be built");
430
431        assert_eq!(result.session_id, session_id);
432        assert_eq!(result.expires_at, expires_at);
433        assert_eq!(
434            result.connect_url,
435            format!("https://app.kontext.dev/oauth/connect?session={session_id}")
436        );
437    }
438}