1use std::sync::Arc;
2
3use serde_json::{Map, Value, json};
4use tokio::sync::RwLock;
5
6use crate::KontextDevError;
7use crate::mcp::{
8 GatewayElicitation, GatewayToolError, GatewayToolsPayload, KontextMcp, KontextMcpConfig,
9 KontextTool, extract_text_content, has_meta_gateway_tools, parse_gateway_tools_payload,
10};
11
12const META_SEARCH_TOOLS: &str = "SEARCH_TOOLS";
13const META_EXECUTE_TOOL: &str = "EXECUTE_TOOL";
14const META_REQUEST_CAPABILITY: &str = "REQUEST_CAPABILITY";
15
16#[derive(Clone, Copy, Debug, PartialEq, Eq)]
17pub enum ClientState {
18 Idle,
19 Connecting,
20 Ready,
21 NeedsAuth,
22 Failed,
23}
24
25#[derive(Clone, Debug)]
26pub struct KontextClientConfig {
27 pub client_session_id: String,
28 pub client_id: String,
29 pub redirect_uri: String,
30 pub url: Option<String>,
31 pub server_url: Option<String>,
32 pub client_secret: Option<String>,
33 pub scope: Option<String>,
34 pub resource: Option<String>,
35 pub integration_ui_url: Option<String>,
36 pub integration_return_to: Option<String>,
37 pub auth_timeout_seconds: Option<i64>,
38 pub token_cache_path: Option<String>,
39}
40
41#[derive(Clone, Debug)]
42pub struct IntegrationInfo {
43 pub id: String,
44 pub name: String,
45 pub connected: bool,
46 pub connect_url: Option<String>,
47 pub reason: Option<String>,
48}
49
50#[derive(Clone, Debug)]
51pub struct ToolResult {
52 pub content: String,
53 pub raw: serde_json::Value,
54}
55
56#[derive(Clone, Debug)]
57pub struct ConnectSessionResult {
58 pub connect_url: String,
59 pub session_id: String,
60 pub expires_at: String,
61}
62
63#[derive(Clone)]
64pub struct KontextClient {
65 state: Arc<RwLock<ClientState>>,
66 mcp: KontextMcp,
67 meta_tool_mode: Arc<RwLock<Option<bool>>>,
68}
69
70impl KontextClient {
71 pub fn new(config: KontextClientConfig) -> Self {
72 let mcp = KontextMcp::new(KontextMcpConfig {
73 client_session_id: config.client_session_id,
74 client_id: config.client_id,
75 redirect_uri: config.redirect_uri,
76 url: config.url,
77 server: config.server_url,
78 client_secret: config.client_secret,
79 scope: config.scope,
80 resource: config.resource,
81 session_key: None,
82 integration_ui_url: config.integration_ui_url,
83 integration_return_to: config.integration_return_to,
84 auth_timeout_seconds: config.auth_timeout_seconds,
85 open_connect_page_on_login: Some(true),
86 token_cache_path: config.token_cache_path,
87 });
88
89 Self {
90 state: Arc::new(RwLock::new(ClientState::Idle)),
91 mcp,
92 meta_tool_mode: Arc::new(RwLock::new(None)),
93 }
94 }
95
96 pub async fn state(&self) -> ClientState {
97 *self.state.read().await
98 }
99
100 pub fn mcp(&self) -> &KontextMcp {
101 &self.mcp
102 }
103
104 pub async fn connect(&self) -> Result<(), KontextDevError> {
105 {
106 let mut state = self.state.write().await;
107 *state = ClientState::Connecting;
108 }
109
110 match self.mcp.list_tools().await {
111 Ok(tools) => {
112 let mut mode = self.meta_tool_mode.write().await;
113 *mode = Some(has_meta_gateway_tools(&tools));
114 let mut state = self.state.write().await;
115 *state = ClientState::Ready;
116 Ok(())
117 }
118 Err(err) => {
119 let mut state = self.state.write().await;
120 *state = if is_auth_error(&err) {
121 ClientState::NeedsAuth
122 } else {
123 ClientState::Failed
124 };
125 Err(err)
126 }
127 }
128 }
129
130 pub async fn disconnect(&self) {
131 let mut mode = self.meta_tool_mode.write().await;
132 *mode = None;
133 let mut state = self.state.write().await;
134 *state = ClientState::Idle;
135 }
136
137 pub async fn get_connect_page_url(&self) -> Result<ConnectSessionResult, KontextDevError> {
138 let session = self.mcp.authenticate_mcp().await?;
139 build_connect_session_result(self.mcp.client(), &session.gateway_token.access_token).await
140 }
141
142 pub async fn sign_in(&self) -> Result<(), KontextDevError> {
143 self.connect().await
144 }
145
146 pub async fn sign_out(&self) -> Result<(), KontextDevError> {
147 self.mcp.client().clear_token_cache()?;
148 self.disconnect().await;
149 Ok(())
150 }
151
152 pub async fn integrations_list(&self) -> Result<Vec<IntegrationInfo>, KontextDevError> {
153 self.ensure_connected().await?;
154
155 if self.is_meta_tool_mode().await? {
156 let payload = self.fetch_gateway_tools(Some(100)).await?;
157 return Ok(parse_integration_status(&payload));
158 }
159
160 let records = self.mcp.list_integrations().await?;
161 Ok(records
162 .into_iter()
163 .map(|record| IntegrationInfo {
164 id: record.id,
165 name: record.name,
166 connected: record
167 .connection
168 .as_ref()
169 .map(|c| c.connected)
170 .unwrap_or(false),
171 connect_url: None,
172 reason: None,
173 })
174 .collect())
175 }
176
177 pub async fn tools_list(&self) -> Result<Vec<KontextTool>, KontextDevError> {
178 self.ensure_connected().await?;
179
180 let mcp_tools = self.mcp.list_tools().await?;
181 let non_meta = mcp_tools
182 .iter()
183 .filter(|tool| !is_gateway_meta_tool(tool.name.as_str()))
184 .cloned()
185 .collect::<Vec<_>>();
186
187 if !non_meta.is_empty() || !has_meta_gateway_tools(&mcp_tools) {
188 let mut mode = self.meta_tool_mode.write().await;
189 *mode = Some(false);
190 return Ok(non_meta);
191 }
192
193 let mut mode = self.meta_tool_mode.write().await;
194 *mode = Some(true);
195 drop(mode);
196
197 let payload = self.fetch_gateway_tools(Some(100)).await?;
198 let mut tools = payload.tools;
199 append_request_capability_tool(&mut tools, mcp_tools.as_slice());
200 Ok(tools)
201 }
202
203 pub async fn tools_execute(
204 &self,
205 tool_id: &str,
206 args: Option<serde_json::Map<String, serde_json::Value>>,
207 ) -> Result<ToolResult, KontextDevError> {
208 self.ensure_connected().await?;
209
210 let raw = if self.is_meta_tool_mode().await? {
211 if tool_id == META_REQUEST_CAPABILITY {
212 self.mcp.call_tool(META_REQUEST_CAPABILITY, args).await?
213 } else {
214 let mut execute_args = Map::new();
215 execute_args.insert("tool_id".to_string(), Value::String(tool_id.to_string()));
216 execute_args.insert(
217 "tool_arguments".to_string(),
218 Value::Object(args.unwrap_or_default()),
219 );
220 self.mcp
221 .call_tool(META_EXECUTE_TOOL, Some(execute_args))
222 .await?
223 }
224 } else {
225 self.mcp.call_tool(tool_id, args).await?
226 };
227
228 Ok(ToolResult {
229 content: extract_text_content(&raw),
230 raw,
231 })
232 }
233
234 async fn ensure_connected(&self) -> Result<(), KontextDevError> {
235 let state = self.state().await;
236 if state == ClientState::Ready {
237 return Ok(());
238 }
239 self.connect().await
240 }
241
242 async fn is_meta_tool_mode(&self) -> Result<bool, KontextDevError> {
243 if let Some(mode) = *self.meta_tool_mode.read().await {
244 return Ok(mode);
245 }
246
247 let tools = self.mcp.list_tools().await?;
248 let mode = has_meta_gateway_tools(&tools);
249 let mut lock = self.meta_tool_mode.write().await;
250 *lock = Some(mode);
251 Ok(mode)
252 }
253
254 async fn fetch_gateway_tools(
255 &self,
256 limit: Option<u32>,
257 ) -> Result<GatewayToolsPayload, KontextDevError> {
258 let result = self
259 .mcp
260 .call_tool(
261 META_SEARCH_TOOLS,
262 Some({
263 let mut args = Map::new();
264 if let Some(limit) = limit {
265 args.insert("limit".to_string(), json!(limit));
266 }
267 args
268 }),
269 )
270 .await?;
271 parse_gateway_tools_payload(&result)
272 }
273}
274
275fn is_gateway_meta_tool(tool_name: &str) -> bool {
276 matches!(
277 tool_name,
278 META_SEARCH_TOOLS | META_EXECUTE_TOOL | META_REQUEST_CAPABILITY
279 )
280}
281
282fn append_request_capability_tool(tools: &mut Vec<KontextTool>, mcp_tools: &[KontextTool]) {
283 if tools
284 .iter()
285 .any(|tool| tool.name == META_REQUEST_CAPABILITY)
286 {
287 return;
288 }
289 let Some(capability_tool) = mcp_tools
290 .iter()
291 .find(|tool| tool.name == META_REQUEST_CAPABILITY)
292 else {
293 return;
294 };
295
296 tools.push(capability_tool.clone());
297}
298
299pub fn create_kontext_client(config: KontextClientConfig) -> KontextClient {
300 KontextClient::new(config)
301}
302
303async fn build_connect_session_result(
304 client: &crate::KontextDevClient,
305 gateway_access_token: &str,
306) -> Result<ConnectSessionResult, KontextDevError> {
307 let connect_session = client.create_connect_session(gateway_access_token).await?;
308 let connect_url = client.integration_connect_url(&connect_session.session_id)?;
309
310 Ok(ConnectSessionResult {
311 connect_url,
312 session_id: connect_session.session_id,
313 expires_at: connect_session.expires_at,
314 })
315}
316
317fn parse_integration_status(payload: &GatewayToolsPayload) -> Vec<IntegrationInfo> {
318 let mut seen = std::collections::HashSet::<String>::new();
319 let mut out = Vec::new();
320
321 for tool in &payload.tools {
322 let Some(server) = tool.server.as_ref() else {
323 continue;
324 };
325 if !seen.insert(server.id.clone()) {
326 continue;
327 }
328 out.push(IntegrationInfo {
329 id: server.id.clone(),
330 name: server.name.clone().unwrap_or_else(|| server.id.clone()),
331 connected: true,
332 connect_url: None,
333 reason: None,
334 });
335 }
336
337 for GatewayToolError {
338 server_id,
339 server_name,
340 reason,
341 } in &payload.errors
342 {
343 if !seen.insert(server_id.clone()) {
344 continue;
345 }
346 let connect_url = payload.elicitations.iter().find_map(
347 |GatewayElicitation {
348 url,
349 integration_id,
350 ..
351 }| {
352 if integration_id.as_deref() == Some(server_id.as_str()) {
353 Some(url.clone())
354 } else {
355 None
356 }
357 },
358 );
359 out.push(IntegrationInfo {
360 id: server_id.clone(),
361 name: server_name.clone().unwrap_or_else(|| server_id.clone()),
362 connected: false,
363 connect_url,
364 reason: reason.clone(),
365 });
366 }
367
368 out
369}
370
371fn is_auth_error(err: &KontextDevError) -> bool {
372 matches!(
373 err,
374 KontextDevError::OAuthCallbackTimeout { .. }
375 | KontextDevError::OAuthCallbackCancelled
376 | KontextDevError::MissingAuthorizationCode
377 | KontextDevError::OAuthCallbackError { .. }
378 | KontextDevError::InvalidOAuthState
379 | KontextDevError::TokenRequest { .. }
380 | KontextDevError::TokenExchange { .. }
381 )
382}
383
384#[cfg(test)]
385mod tests {
386 use super::*;
387 use wiremock::Mock;
388 use wiremock::MockServer;
389 use wiremock::ResponseTemplate;
390 use wiremock::matchers::header;
391 use wiremock::matchers::method;
392 use wiremock::matchers::path;
393
394 #[tokio::test]
395 async fn build_connect_session_result_uses_one_session_and_matching_connect_url() {
396 let server = MockServer::start().await;
397 let session_id = "session-123";
398 let expires_at = "2030-01-01T00:00:00Z";
399 let access_token = "test-gateway-token";
400
401 Mock::given(method("POST"))
402 .and(path("/mcp/connect-session"))
403 .and(header("authorization", format!("Bearer {access_token}")))
404 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
405 "sessionId": session_id,
406 "expiresAt": expires_at
407 })))
408 .expect(1)
409 .mount(&server)
410 .await;
411
412 let client = crate::KontextDevClient::new(crate::KontextDevConfig {
413 server: server.uri(),
414 client_id: "client-id".to_string(),
415 client_secret: None,
416 scope: "".to_string(),
417 server_name: "kontext-dev".to_string(),
418 resource: "mcp-gateway".to_string(),
419 integration_ui_url: Some("https://app.kontext.dev".to_string()),
420 integration_return_to: None,
421 open_connect_page_on_login: true,
422 auth_timeout_seconds: 300,
423 token_cache_path: None,
424 redirect_uri: "http://localhost:3333/callback".to_string(),
425 });
426
427 let result = build_connect_session_result(&client, access_token)
428 .await
429 .expect("connect session result should be built");
430
431 assert_eq!(result.session_id, session_id);
432 assert_eq!(result.expires_at, expires_at);
433 assert_eq!(
434 result.connect_url,
435 format!("https://app.kontext.dev/oauth/connect?session={session_id}")
436 );
437 }
438}