Skip to main content

kontext_dev_sdk/
client.rs

1use std::sync::Arc;
2
3use serde_json::{Map, Value, json};
4use tokio::sync::RwLock;
5
6use crate::KontextDevError;
7use crate::mcp::{
8    GatewayElicitation, GatewayToolError, GatewayToolsPayload, KontextMcp, KontextMcpConfig,
9    KontextTool, extract_text_content, has_meta_gateway_tools, parse_gateway_tools_payload,
10};
11use crate::prompt_guidance::{KontextPromptGuidance, build_kontext_prompt_guidance};
12
13const META_SEARCH_TOOLS: &str = "SEARCH_TOOLS";
14const META_EXECUTE_TOOL: &str = "EXECUTE_TOOL";
15const META_REQUEST_CAPABILITY: &str = "REQUEST_CAPABILITY";
16
17#[derive(Clone, Copy, Debug, PartialEq, Eq)]
18pub enum ClientState {
19    Idle,
20    Connecting,
21    Ready,
22    NeedsAuth,
23    Failed,
24}
25
26#[derive(Clone, Debug)]
27pub struct KontextClientConfig {
28    pub client_session_id: String,
29    pub client_id: String,
30    pub redirect_uri: String,
31    pub url: Option<String>,
32    pub server_url: Option<String>,
33    pub client_secret: Option<String>,
34    pub scope: Option<String>,
35    pub resource: Option<String>,
36    pub integration_ui_url: Option<String>,
37    pub integration_return_to: Option<String>,
38    pub auth_timeout_seconds: Option<i64>,
39    pub token_cache_path: Option<String>,
40}
41
42#[derive(Clone, Debug)]
43pub struct IntegrationInfo {
44    pub id: String,
45    pub name: String,
46    pub connected: bool,
47    pub connect_url: Option<String>,
48    pub reason: Option<String>,
49}
50
51#[derive(Clone, Debug)]
52pub struct ToolResult {
53    pub content: String,
54    pub raw: serde_json::Value,
55}
56
57#[derive(Clone, Debug)]
58pub struct ConnectSessionResult {
59    pub connect_url: String,
60    pub session_id: String,
61    pub expires_at: String,
62}
63
64#[derive(Clone)]
65pub struct KontextClient {
66    state: Arc<RwLock<ClientState>>,
67    mcp: KontextMcp,
68    meta_tool_mode: Arc<RwLock<Option<bool>>>,
69}
70
71impl KontextClient {
72    pub fn new(config: KontextClientConfig) -> Self {
73        let mcp = KontextMcp::new(KontextMcpConfig {
74            client_session_id: config.client_session_id,
75            client_id: config.client_id,
76            redirect_uri: config.redirect_uri,
77            url: config.url,
78            server: config.server_url,
79            client_secret: config.client_secret,
80            scope: config.scope,
81            resource: config.resource,
82            session_key: None,
83            integration_ui_url: config.integration_ui_url,
84            integration_return_to: config.integration_return_to,
85            auth_timeout_seconds: config.auth_timeout_seconds,
86            open_connect_page_on_login: Some(true),
87            token_cache_path: config.token_cache_path,
88        });
89
90        Self {
91            state: Arc::new(RwLock::new(ClientState::Idle)),
92            mcp,
93            meta_tool_mode: Arc::new(RwLock::new(None)),
94        }
95    }
96
97    pub async fn state(&self) -> ClientState {
98        *self.state.read().await
99    }
100
101    pub fn mcp(&self) -> &KontextMcp {
102        &self.mcp
103    }
104
105    pub async fn connect(&self) -> Result<(), KontextDevError> {
106        {
107            let mut state = self.state.write().await;
108            *state = ClientState::Connecting;
109        }
110
111        match self.mcp.list_tools().await {
112            Ok(tools) => {
113                let mut mode = self.meta_tool_mode.write().await;
114                *mode = Some(has_meta_gateway_tools(&tools));
115                let mut state = self.state.write().await;
116                *state = ClientState::Ready;
117                Ok(())
118            }
119            Err(err) => {
120                let mut state = self.state.write().await;
121                *state = if is_auth_error(&err) {
122                    ClientState::NeedsAuth
123                } else {
124                    ClientState::Failed
125                };
126                Err(err)
127            }
128        }
129    }
130
131    pub async fn disconnect(&self) {
132        let mut mode = self.meta_tool_mode.write().await;
133        *mode = None;
134        let mut state = self.state.write().await;
135        *state = ClientState::Idle;
136    }
137
138    pub async fn get_connect_page_url(&self) -> Result<ConnectSessionResult, KontextDevError> {
139        let session = self.mcp.authenticate_mcp().await?;
140        build_connect_session_result(self.mcp.client(), &session.gateway_token.access_token).await
141    }
142
143    pub async fn sign_in(&self) -> Result<(), KontextDevError> {
144        self.connect().await
145    }
146
147    pub async fn sign_out(&self) -> Result<(), KontextDevError> {
148        self.mcp.client().clear_token_cache()?;
149        self.disconnect().await;
150        Ok(())
151    }
152
153    pub async fn integrations_list(&self) -> Result<Vec<IntegrationInfo>, KontextDevError> {
154        self.ensure_connected().await?;
155
156        if self.is_meta_tool_mode().await? {
157            let payload = self.fetch_gateway_tools(Some(100)).await?;
158            return Ok(parse_integration_status(&payload));
159        }
160
161        let records = self.mcp.list_integrations().await?;
162        Ok(records
163            .into_iter()
164            .map(|record| IntegrationInfo {
165                id: record.id,
166                name: record.name,
167                connected: record
168                    .connection
169                    .as_ref()
170                    .map(|c| c.connected)
171                    .unwrap_or(false),
172                connect_url: None,
173                reason: None,
174            })
175            .collect())
176    }
177
178    pub async fn tools_list(&self) -> Result<Vec<KontextTool>, KontextDevError> {
179        self.ensure_connected().await?;
180
181        let mcp_tools = self.mcp.list_tools().await?;
182        let non_meta = mcp_tools
183            .iter()
184            .filter(|tool| !is_gateway_meta_tool(tool.name.as_str()))
185            .cloned()
186            .collect::<Vec<_>>();
187
188        if !non_meta.is_empty() || !has_meta_gateway_tools(&mcp_tools) {
189            let mut mode = self.meta_tool_mode.write().await;
190            *mode = Some(false);
191            return Ok(non_meta);
192        }
193
194        let mut mode = self.meta_tool_mode.write().await;
195        *mode = Some(true);
196        drop(mode);
197
198        let payload = self.fetch_gateway_tools(Some(100)).await?;
199        let mut tools = payload.tools;
200        append_request_capability_tool(&mut tools, mcp_tools.as_slice());
201        Ok(tools)
202    }
203
204    pub async fn tools_execute(
205        &self,
206        tool_id: &str,
207        args: Option<serde_json::Map<String, serde_json::Value>>,
208    ) -> Result<ToolResult, KontextDevError> {
209        self.ensure_connected().await?;
210
211        let raw = if self.is_meta_tool_mode().await? {
212            if tool_id == META_REQUEST_CAPABILITY {
213                self.mcp.call_tool(META_REQUEST_CAPABILITY, args).await?
214            } else {
215                let mut execute_args = Map::new();
216                execute_args.insert("tool_id".to_string(), Value::String(tool_id.to_string()));
217                execute_args.insert(
218                    "tool_arguments".to_string(),
219                    Value::Object(args.unwrap_or_default()),
220                );
221                self.mcp
222                    .call_tool(META_EXECUTE_TOOL, Some(execute_args))
223                    .await?
224            }
225        } else {
226            self.mcp.call_tool(tool_id, args).await?
227        };
228
229        Ok(ToolResult {
230            content: extract_text_content(&raw),
231            raw,
232        })
233    }
234
235    pub async fn prompt_guidance(&self) -> Result<KontextPromptGuidance, KontextDevError> {
236        let tools = self.tools_list().await?;
237        let integrations = self.integrations_list().await?;
238        let tool_names = tools
239            .into_iter()
240            .map(|tool| tool.name)
241            .collect::<Vec<String>>();
242
243        Ok(build_kontext_prompt_guidance(
244            tool_names.as_slice(),
245            integrations.as_slice(),
246        ))
247    }
248
249    async fn ensure_connected(&self) -> Result<(), KontextDevError> {
250        let state = self.state().await;
251        if state == ClientState::Ready {
252            return Ok(());
253        }
254        self.connect().await
255    }
256
257    async fn is_meta_tool_mode(&self) -> Result<bool, KontextDevError> {
258        if let Some(mode) = *self.meta_tool_mode.read().await {
259            return Ok(mode);
260        }
261
262        let tools = self.mcp.list_tools().await?;
263        let mode = has_meta_gateway_tools(&tools);
264        let mut lock = self.meta_tool_mode.write().await;
265        *lock = Some(mode);
266        Ok(mode)
267    }
268
269    async fn fetch_gateway_tools(
270        &self,
271        limit: Option<u32>,
272    ) -> Result<GatewayToolsPayload, KontextDevError> {
273        let result = self
274            .mcp
275            .call_tool(
276                META_SEARCH_TOOLS,
277                Some({
278                    let mut args = Map::new();
279                    if let Some(limit) = limit {
280                        args.insert("limit".to_string(), json!(limit));
281                    }
282                    args
283                }),
284            )
285            .await?;
286        parse_gateway_tools_payload(&result)
287    }
288}
289
290fn is_gateway_meta_tool(tool_name: &str) -> bool {
291    matches!(
292        tool_name,
293        META_SEARCH_TOOLS | META_EXECUTE_TOOL | META_REQUEST_CAPABILITY
294    )
295}
296
297fn append_request_capability_tool(tools: &mut Vec<KontextTool>, mcp_tools: &[KontextTool]) {
298    if tools
299        .iter()
300        .any(|tool| tool.name == META_REQUEST_CAPABILITY)
301    {
302        return;
303    }
304    let Some(capability_tool) = mcp_tools
305        .iter()
306        .find(|tool| tool.name == META_REQUEST_CAPABILITY)
307    else {
308        return;
309    };
310
311    tools.push(capability_tool.clone());
312}
313
314pub fn create_kontext_client(config: KontextClientConfig) -> KontextClient {
315    KontextClient::new(config)
316}
317
318async fn build_connect_session_result(
319    client: &crate::KontextDevClient,
320    gateway_access_token: &str,
321) -> Result<ConnectSessionResult, KontextDevError> {
322    let connect_session = client.create_connect_session(gateway_access_token).await?;
323    let connect_url = client.integration_connect_url(&connect_session.session_id)?;
324
325    Ok(ConnectSessionResult {
326        connect_url,
327        session_id: connect_session.session_id,
328        expires_at: connect_session.expires_at,
329    })
330}
331
332fn parse_integration_status(payload: &GatewayToolsPayload) -> Vec<IntegrationInfo> {
333    let mut seen = std::collections::HashSet::<String>::new();
334    let mut out = Vec::new();
335
336    for tool in &payload.tools {
337        let Some(server) = tool.server.as_ref() else {
338            continue;
339        };
340        if !seen.insert(server.id.clone()) {
341            continue;
342        }
343        out.push(IntegrationInfo {
344            id: server.id.clone(),
345            name: server.name.clone().unwrap_or_else(|| server.id.clone()),
346            connected: true,
347            connect_url: None,
348            reason: None,
349        });
350    }
351
352    for GatewayToolError {
353        server_id,
354        server_name,
355        reason,
356    } in &payload.errors
357    {
358        if !seen.insert(server_id.clone()) {
359            continue;
360        }
361        let connect_url = payload.elicitations.iter().find_map(
362            |GatewayElicitation {
363                 url,
364                 integration_id,
365                 ..
366             }| {
367                if integration_id.as_deref() == Some(server_id.as_str()) {
368                    Some(url.clone())
369                } else {
370                    None
371                }
372            },
373        );
374        out.push(IntegrationInfo {
375            id: server_id.clone(),
376            name: server_name.clone().unwrap_or_else(|| server_id.clone()),
377            connected: false,
378            connect_url,
379            reason: reason.clone(),
380        });
381    }
382
383    out
384}
385
386fn is_auth_error(err: &KontextDevError) -> bool {
387    matches!(
388        err,
389        KontextDevError::OAuthCallbackTimeout { .. }
390            | KontextDevError::OAuthCallbackCancelled
391            | KontextDevError::MissingAuthorizationCode
392            | KontextDevError::OAuthCallbackError { .. }
393            | KontextDevError::InvalidOAuthState
394            | KontextDevError::TokenRequest { .. }
395            | KontextDevError::TokenExchange { .. }
396    )
397}
398
399#[cfg(test)]
400mod tests {
401    use super::*;
402    use wiremock::Mock;
403    use wiremock::MockServer;
404    use wiremock::ResponseTemplate;
405    use wiremock::matchers::header;
406    use wiremock::matchers::method;
407    use wiremock::matchers::path;
408
409    #[tokio::test]
410    async fn build_connect_session_result_uses_one_session_and_matching_connect_url() {
411        let server = MockServer::start().await;
412        let session_id = "session-123";
413        let expires_at = "2030-01-01T00:00:00Z";
414        let access_token = "test-gateway-token";
415
416        Mock::given(method("POST"))
417            .and(path("/mcp/connect-session"))
418            .and(header("authorization", format!("Bearer {access_token}")))
419            .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
420                "sessionId": session_id,
421                "expiresAt": expires_at
422            })))
423            .expect(1)
424            .mount(&server)
425            .await;
426
427        let client = crate::KontextDevClient::new(crate::KontextDevConfig {
428            server: server.uri(),
429            client_id: "client-id".to_string(),
430            client_secret: None,
431            scope: "".to_string(),
432            server_name: "kontext-dev".to_string(),
433            resource: "mcp-gateway".to_string(),
434            integration_ui_url: Some("https://app.kontext.dev".to_string()),
435            integration_return_to: None,
436            open_connect_page_on_login: true,
437            auth_timeout_seconds: 300,
438            token_cache_path: None,
439            redirect_uri: "http://localhost:3333/callback".to_string(),
440        });
441
442        let result = build_connect_session_result(&client, access_token)
443            .await
444            .expect("connect session result should be built");
445
446        assert_eq!(result.session_id, session_id);
447        assert_eq!(result.expires_at, expires_at);
448        assert_eq!(
449            result.connect_url,
450            format!("https://app.kontext.dev/oauth/connect?session={session_id}")
451        );
452    }
453}