Skip to main content

kontext_dev_sdk/
client.rs

1use std::sync::Arc;
2
3use serde_json::{Map, Value, json};
4use tokio::sync::RwLock;
5
6use crate::KontextDevError;
7use crate::mcp::{
8    GatewayElicitation, GatewayToolError, GatewayToolsPayload, KontextMcp, KontextMcpConfig,
9    KontextTool, extract_text_content, has_meta_gateway_tools, parse_gateway_tools_payload,
10};
11
12const META_SEARCH_TOOLS: &str = "SEARCH_TOOLS";
13const META_EXECUTE_TOOL: &str = "EXECUTE_TOOL";
14
15#[derive(Clone, Copy, Debug, PartialEq, Eq)]
16pub enum ClientState {
17    Idle,
18    Connecting,
19    Ready,
20    NeedsAuth,
21    Failed,
22}
23
24#[derive(Clone, Debug)]
25pub struct KontextClientConfig {
26    pub client_id: String,
27    pub redirect_uri: String,
28    pub url: Option<String>,
29    pub server_url: Option<String>,
30    pub client_secret: Option<String>,
31    pub scope: Option<String>,
32    pub resource: Option<String>,
33    pub integration_ui_url: Option<String>,
34    pub integration_return_to: Option<String>,
35    pub auth_timeout_seconds: Option<i64>,
36    pub token_cache_path: Option<String>,
37}
38
39#[derive(Clone, Debug)]
40pub struct IntegrationInfo {
41    pub id: String,
42    pub name: String,
43    pub connected: bool,
44    pub connect_url: Option<String>,
45    pub reason: Option<String>,
46}
47
48#[derive(Clone, Debug)]
49pub struct ToolResult {
50    pub content: String,
51    pub raw: serde_json::Value,
52}
53
54#[derive(Clone, Debug)]
55pub struct ConnectSessionResult {
56    pub connect_url: String,
57    pub session_id: String,
58    pub expires_at: String,
59}
60
61#[derive(Clone)]
62pub struct KontextClient {
63    state: Arc<RwLock<ClientState>>,
64    mcp: KontextMcp,
65    meta_tool_mode: Arc<RwLock<Option<bool>>>,
66}
67
68impl KontextClient {
69    pub fn new(config: KontextClientConfig) -> Self {
70        let mcp = KontextMcp::new(KontextMcpConfig {
71            client_id: config.client_id,
72            redirect_uri: config.redirect_uri,
73            url: config.url,
74            server: config.server_url,
75            client_secret: config.client_secret,
76            scope: config.scope,
77            resource: config.resource,
78            session_key: None,
79            integration_ui_url: config.integration_ui_url,
80            integration_return_to: config.integration_return_to,
81            auth_timeout_seconds: config.auth_timeout_seconds,
82            open_connect_page_on_login: Some(true),
83            token_cache_path: config.token_cache_path,
84        });
85
86        Self {
87            state: Arc::new(RwLock::new(ClientState::Idle)),
88            mcp,
89            meta_tool_mode: Arc::new(RwLock::new(None)),
90        }
91    }
92
93    pub async fn state(&self) -> ClientState {
94        *self.state.read().await
95    }
96
97    pub fn mcp(&self) -> &KontextMcp {
98        &self.mcp
99    }
100
101    pub async fn connect(&self) -> Result<(), KontextDevError> {
102        {
103            let mut state = self.state.write().await;
104            *state = ClientState::Connecting;
105        }
106
107        match self.mcp.list_tools().await {
108            Ok(tools) => {
109                let mut mode = self.meta_tool_mode.write().await;
110                *mode = Some(has_meta_gateway_tools(&tools));
111                let mut state = self.state.write().await;
112                *state = ClientState::Ready;
113                Ok(())
114            }
115            Err(err) => {
116                let mut state = self.state.write().await;
117                *state = if is_auth_error(&err) {
118                    ClientState::NeedsAuth
119                } else {
120                    ClientState::Failed
121                };
122                Err(err)
123            }
124        }
125    }
126
127    pub async fn disconnect(&self) {
128        let mut mode = self.meta_tool_mode.write().await;
129        *mode = None;
130        let mut state = self.state.write().await;
131        *state = ClientState::Idle;
132    }
133
134    pub async fn get_connect_page_url(&self) -> Result<ConnectSessionResult, KontextDevError> {
135        let session = self.mcp.authenticate_mcp().await?;
136        build_connect_session_result(self.mcp.client(), &session.gateway_token.access_token).await
137    }
138
139    pub async fn sign_in(&self) -> Result<(), KontextDevError> {
140        self.connect().await
141    }
142
143    pub async fn sign_out(&self) -> Result<(), KontextDevError> {
144        self.mcp.client().clear_token_cache()?;
145        self.disconnect().await;
146        Ok(())
147    }
148
149    pub async fn integrations_list(&self) -> Result<Vec<IntegrationInfo>, KontextDevError> {
150        self.ensure_connected().await?;
151
152        if self.is_meta_tool_mode().await? {
153            let payload = self.fetch_gateway_tools(Some(100)).await?;
154            return Ok(parse_integration_status(&payload));
155        }
156
157        let records = self.mcp.list_integrations().await?;
158        Ok(records
159            .into_iter()
160            .map(|record| IntegrationInfo {
161                id: record.id,
162                name: record.name,
163                connected: record
164                    .connection
165                    .as_ref()
166                    .map(|c| c.connected)
167                    .unwrap_or(false),
168                connect_url: None,
169                reason: None,
170            })
171            .collect())
172    }
173
174    pub async fn tools_list(&self) -> Result<Vec<KontextTool>, KontextDevError> {
175        self.ensure_connected().await?;
176
177        let mcp_tools = self.mcp.list_tools().await?;
178        let non_meta = mcp_tools
179            .iter()
180            .filter(|tool| tool.name != META_SEARCH_TOOLS && tool.name != META_EXECUTE_TOOL)
181            .cloned()
182            .collect::<Vec<_>>();
183
184        if !non_meta.is_empty() || !has_meta_gateway_tools(&mcp_tools) {
185            let mut mode = self.meta_tool_mode.write().await;
186            *mode = Some(false);
187            return Ok(non_meta);
188        }
189
190        let mut mode = self.meta_tool_mode.write().await;
191        *mode = Some(true);
192        drop(mode);
193
194        let payload = self.fetch_gateway_tools(Some(100)).await?;
195        Ok(payload.tools)
196    }
197
198    pub async fn tools_execute(
199        &self,
200        tool_id: &str,
201        args: Option<serde_json::Map<String, serde_json::Value>>,
202    ) -> Result<ToolResult, KontextDevError> {
203        self.ensure_connected().await?;
204
205        let raw = if self.is_meta_tool_mode().await? {
206            let mut execute_args = Map::new();
207            execute_args.insert("tool_id".to_string(), Value::String(tool_id.to_string()));
208            execute_args.insert(
209                "tool_arguments".to_string(),
210                Value::Object(args.unwrap_or_default()),
211            );
212            self.mcp
213                .call_tool(META_EXECUTE_TOOL, Some(execute_args))
214                .await?
215        } else {
216            self.mcp.call_tool(tool_id, args).await?
217        };
218
219        Ok(ToolResult {
220            content: extract_text_content(&raw),
221            raw,
222        })
223    }
224
225    async fn ensure_connected(&self) -> Result<(), KontextDevError> {
226        let state = self.state().await;
227        if state == ClientState::Ready {
228            return Ok(());
229        }
230        self.connect().await
231    }
232
233    async fn is_meta_tool_mode(&self) -> Result<bool, KontextDevError> {
234        if let Some(mode) = *self.meta_tool_mode.read().await {
235            return Ok(mode);
236        }
237
238        let tools = self.mcp.list_tools().await?;
239        let mode = has_meta_gateway_tools(&tools);
240        let mut lock = self.meta_tool_mode.write().await;
241        *lock = Some(mode);
242        Ok(mode)
243    }
244
245    async fn fetch_gateway_tools(
246        &self,
247        limit: Option<u32>,
248    ) -> Result<GatewayToolsPayload, KontextDevError> {
249        let result = self
250            .mcp
251            .call_tool(
252                META_SEARCH_TOOLS,
253                Some({
254                    let mut args = Map::new();
255                    if let Some(limit) = limit {
256                        args.insert("limit".to_string(), json!(limit));
257                    }
258                    args
259                }),
260            )
261            .await?;
262        parse_gateway_tools_payload(&result)
263    }
264}
265
266pub fn create_kontext_client(config: KontextClientConfig) -> KontextClient {
267    KontextClient::new(config)
268}
269
270async fn build_connect_session_result(
271    client: &crate::KontextDevClient,
272    gateway_access_token: &str,
273) -> Result<ConnectSessionResult, KontextDevError> {
274    let connect_session = client.create_connect_session(gateway_access_token).await?;
275    let connect_url = client.integration_connect_url(&connect_session.session_id)?;
276
277    Ok(ConnectSessionResult {
278        connect_url,
279        session_id: connect_session.session_id,
280        expires_at: connect_session.expires_at,
281    })
282}
283
284fn parse_integration_status(payload: &GatewayToolsPayload) -> Vec<IntegrationInfo> {
285    let mut seen = std::collections::HashSet::<String>::new();
286    let mut out = Vec::new();
287
288    for tool in &payload.tools {
289        let Some(server) = tool.server.as_ref() else {
290            continue;
291        };
292        if !seen.insert(server.id.clone()) {
293            continue;
294        }
295        out.push(IntegrationInfo {
296            id: server.id.clone(),
297            name: server.name.clone().unwrap_or_else(|| server.id.clone()),
298            connected: true,
299            connect_url: None,
300            reason: None,
301        });
302    }
303
304    for GatewayToolError {
305        server_id,
306        server_name,
307        reason,
308    } in &payload.errors
309    {
310        if !seen.insert(server_id.clone()) {
311            continue;
312        }
313        let connect_url = payload.elicitations.iter().find_map(
314            |GatewayElicitation {
315                 url,
316                 integration_id,
317                 ..
318             }| {
319                if integration_id.as_deref() == Some(server_id.as_str()) {
320                    Some(url.clone())
321                } else {
322                    None
323                }
324            },
325        );
326        out.push(IntegrationInfo {
327            id: server_id.clone(),
328            name: server_name.clone().unwrap_or_else(|| server_id.clone()),
329            connected: false,
330            connect_url,
331            reason: reason.clone(),
332        });
333    }
334
335    out
336}
337
338fn is_auth_error(err: &KontextDevError) -> bool {
339    matches!(
340        err,
341        KontextDevError::OAuthCallbackTimeout { .. }
342            | KontextDevError::OAuthCallbackCancelled
343            | KontextDevError::MissingAuthorizationCode
344            | KontextDevError::OAuthCallbackError { .. }
345            | KontextDevError::InvalidOAuthState
346            | KontextDevError::TokenRequest { .. }
347            | KontextDevError::TokenExchange { .. }
348    )
349}
350
351#[cfg(test)]
352mod tests {
353    use super::*;
354    use wiremock::Mock;
355    use wiremock::MockServer;
356    use wiremock::ResponseTemplate;
357    use wiremock::matchers::header;
358    use wiremock::matchers::method;
359    use wiremock::matchers::path;
360
361    #[tokio::test]
362    async fn build_connect_session_result_uses_one_session_and_matching_connect_url() {
363        let server = MockServer::start().await;
364        let session_id = "session-123";
365        let expires_at = "2030-01-01T00:00:00Z";
366        let access_token = "test-gateway-token";
367
368        Mock::given(method("POST"))
369            .and(path("/mcp/connect-session"))
370            .and(header("authorization", format!("Bearer {access_token}")))
371            .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
372                "sessionId": session_id,
373                "expiresAt": expires_at
374            })))
375            .expect(1)
376            .mount(&server)
377            .await;
378
379        let client = crate::KontextDevClient::new(crate::KontextDevConfig {
380            server: server.uri(),
381            client_id: "client-id".to_string(),
382            client_secret: None,
383            scope: "".to_string(),
384            server_name: "kontext-dev".to_string(),
385            resource: "mcp-gateway".to_string(),
386            integration_ui_url: Some("https://app.kontext.dev".to_string()),
387            integration_return_to: None,
388            open_connect_page_on_login: true,
389            auth_timeout_seconds: 300,
390            token_cache_path: None,
391            redirect_uri: "http://localhost:3333/callback".to_string(),
392        });
393
394        let result = build_connect_session_result(&client, access_token)
395            .await
396            .expect("connect session result should be built");
397
398        assert_eq!(result.session_id, session_id);
399        assert_eq!(result.expires_at, expires_at);
400        assert_eq!(
401            result.connect_url,
402            format!("https://app.kontext.dev/oauth/connect?session={session_id}")
403        );
404    }
405}