1use std::sync::Arc;
2
3use serde_json::{Map, Value, json};
4use tokio::sync::RwLock;
5
6use crate::KontextDevError;
7use crate::mcp::{
8 GatewayElicitation, GatewayToolError, GatewayToolsPayload, KontextMcp, KontextMcpConfig,
9 KontextTool, extract_text_content, has_meta_gateway_tools, parse_gateway_tools_payload,
10};
11use crate::prompt_guidance::{KontextPromptGuidance, build_kontext_prompt_guidance};
12
13const META_SEARCH_TOOLS: &str = "SEARCH_TOOLS";
14const META_EXECUTE_TOOL: &str = "EXECUTE_TOOL";
15const META_REQUEST_CAPABILITY: &str = "REQUEST_CAPABILITY";
16
17#[derive(Clone, Copy, Debug, PartialEq, Eq)]
18pub enum ClientState {
19 Idle,
20 Connecting,
21 Ready,
22 NeedsAuth,
23 Failed,
24}
25
26#[derive(Clone, Debug)]
27pub struct KontextClientConfig {
28 pub client_session_id: String,
29 pub client_id: String,
30 pub redirect_uri: String,
31 pub url: Option<String>,
32 pub server_url: Option<String>,
33 pub client_secret: Option<String>,
34 pub scope: Option<String>,
35 pub resource: Option<String>,
36 pub integration_ui_url: Option<String>,
37 pub integration_return_to: Option<String>,
38 pub auth_timeout_seconds: Option<i64>,
39 pub token_cache_path: Option<String>,
40}
41
42#[derive(Clone, Debug)]
43pub struct IntegrationInfo {
44 pub id: String,
45 pub name: String,
46 pub connected: bool,
47 pub connect_url: Option<String>,
48 pub reason: Option<String>,
49}
50
51#[derive(Clone, Debug)]
52pub struct ToolResult {
53 pub content: String,
54 pub raw: serde_json::Value,
55}
56
57#[derive(Clone, Debug)]
58pub struct ConnectSessionResult {
59 pub connect_url: String,
60 pub session_id: String,
61 pub expires_at: String,
62}
63
64#[derive(Clone)]
65pub struct KontextClient {
66 state: Arc<RwLock<ClientState>>,
67 mcp: KontextMcp,
68 meta_tool_mode: Arc<RwLock<Option<bool>>>,
69}
70
71impl KontextClient {
72 pub fn new(config: KontextClientConfig) -> Self {
73 let mcp = KontextMcp::new(KontextMcpConfig {
74 client_session_id: config.client_session_id,
75 client_id: config.client_id,
76 redirect_uri: config.redirect_uri,
77 url: config.url,
78 server: config.server_url,
79 client_secret: config.client_secret,
80 scope: config.scope,
81 resource: config.resource,
82 session_key: None,
83 integration_ui_url: config.integration_ui_url,
84 integration_return_to: config.integration_return_to,
85 auth_timeout_seconds: config.auth_timeout_seconds,
86 open_connect_page_on_login: Some(true),
87 token_cache_path: config.token_cache_path,
88 });
89
90 Self {
91 state: Arc::new(RwLock::new(ClientState::Idle)),
92 mcp,
93 meta_tool_mode: Arc::new(RwLock::new(None)),
94 }
95 }
96
97 pub async fn state(&self) -> ClientState {
98 *self.state.read().await
99 }
100
101 pub fn mcp(&self) -> &KontextMcp {
102 &self.mcp
103 }
104
105 pub async fn connect(&self) -> Result<(), KontextDevError> {
106 {
107 let mut state = self.state.write().await;
108 *state = ClientState::Connecting;
109 }
110
111 match self.mcp.list_tools().await {
112 Ok(tools) => {
113 let mut mode = self.meta_tool_mode.write().await;
114 *mode = Some(has_meta_gateway_tools(&tools));
115 let mut state = self.state.write().await;
116 *state = ClientState::Ready;
117 Ok(())
118 }
119 Err(err) => {
120 let mut state = self.state.write().await;
121 *state = if is_auth_error(&err) {
122 ClientState::NeedsAuth
123 } else {
124 ClientState::Failed
125 };
126 Err(err)
127 }
128 }
129 }
130
131 pub async fn disconnect(&self) {
132 self.mcp.clear_cached_session().await;
133 let mut mode = self.meta_tool_mode.write().await;
134 *mode = None;
135 let mut state = self.state.write().await;
136 *state = ClientState::Idle;
137 }
138
139 pub async fn get_connect_page_url(&self) -> Result<ConnectSessionResult, KontextDevError> {
140 let session = self.mcp.authenticate_mcp().await?;
141 build_connect_session_result(self.mcp.client(), &session.gateway_token.access_token).await
142 }
143
144 pub async fn sign_in(&self) -> Result<(), KontextDevError> {
145 self.connect().await
146 }
147
148 pub async fn sign_out(&self) -> Result<(), KontextDevError> {
149 self.mcp.client().clear_token_cache()?;
150 self.disconnect().await;
151 Ok(())
152 }
153
154 pub async fn integrations_list(&self) -> Result<Vec<IntegrationInfo>, KontextDevError> {
155 self.ensure_connected().await?;
156
157 if self.is_meta_tool_mode().await? {
158 let payload = self.fetch_gateway_tools(Some(100)).await?;
159 return Ok(parse_integration_status(&payload));
160 }
161
162 let records = self.mcp.list_integrations().await?;
163 Ok(records
164 .into_iter()
165 .map(|record| IntegrationInfo {
166 id: record.id,
167 name: record.name,
168 connected: record
169 .connection
170 .as_ref()
171 .map(|c| c.connected)
172 .unwrap_or(false),
173 connect_url: None,
174 reason: None,
175 })
176 .collect())
177 }
178
179 pub async fn tools_list(&self) -> Result<Vec<KontextTool>, KontextDevError> {
180 self.ensure_connected().await?;
181
182 let mcp_tools = self.mcp.list_tools().await?;
183 let non_meta = mcp_tools
184 .iter()
185 .filter(|tool| !is_gateway_meta_tool(tool.name.as_str()))
186 .cloned()
187 .collect::<Vec<_>>();
188
189 if !non_meta.is_empty() || !has_meta_gateway_tools(&mcp_tools) {
190 let mut mode = self.meta_tool_mode.write().await;
191 *mode = Some(false);
192 return Ok(non_meta);
193 }
194
195 let mut mode = self.meta_tool_mode.write().await;
196 *mode = Some(true);
197 drop(mode);
198
199 let payload = self.fetch_gateway_tools(Some(100)).await?;
200 let mut tools = payload.tools;
201 append_request_capability_tool(&mut tools, mcp_tools.as_slice());
202 Ok(tools)
203 }
204
205 pub async fn tools_execute(
206 &self,
207 tool_id: &str,
208 args: Option<serde_json::Map<String, serde_json::Value>>,
209 ) -> Result<ToolResult, KontextDevError> {
210 self.ensure_connected().await?;
211
212 let raw = if self.is_meta_tool_mode().await? {
213 if tool_id == META_REQUEST_CAPABILITY {
214 self.mcp.call_tool(META_REQUEST_CAPABILITY, args).await?
215 } else {
216 let mut execute_args = Map::new();
217 execute_args.insert("tool_id".to_string(), Value::String(tool_id.to_string()));
218 execute_args.insert(
219 "tool_arguments".to_string(),
220 Value::Object(args.unwrap_or_default()),
221 );
222 self.mcp
223 .call_tool(META_EXECUTE_TOOL, Some(execute_args))
224 .await?
225 }
226 } else {
227 self.mcp.call_tool(tool_id, args).await?
228 };
229
230 Ok(ToolResult {
231 content: extract_text_content(&raw),
232 raw,
233 })
234 }
235
236 pub async fn prompt_guidance(&self) -> Result<KontextPromptGuidance, KontextDevError> {
237 let tools = self.tools_list().await?;
238 let integrations = self.integrations_list().await?;
239 let tool_names = tools
240 .into_iter()
241 .map(|tool| tool.name)
242 .collect::<Vec<String>>();
243
244 Ok(build_kontext_prompt_guidance(
245 tool_names.as_slice(),
246 integrations.as_slice(),
247 ))
248 }
249
250 async fn ensure_connected(&self) -> Result<(), KontextDevError> {
251 let state = self.state().await;
252 if state == ClientState::Ready {
253 return Ok(());
254 }
255 self.connect().await
256 }
257
258 async fn is_meta_tool_mode(&self) -> Result<bool, KontextDevError> {
259 if let Some(mode) = *self.meta_tool_mode.read().await {
260 return Ok(mode);
261 }
262
263 let tools = self.mcp.list_tools().await?;
264 let mode = has_meta_gateway_tools(&tools);
265 let mut lock = self.meta_tool_mode.write().await;
266 *lock = Some(mode);
267 Ok(mode)
268 }
269
270 async fn fetch_gateway_tools(
271 &self,
272 limit: Option<u32>,
273 ) -> Result<GatewayToolsPayload, KontextDevError> {
274 let result = self
275 .mcp
276 .call_tool(
277 META_SEARCH_TOOLS,
278 Some({
279 let mut args = Map::new();
280 if let Some(limit) = limit {
281 args.insert("limit".to_string(), json!(limit));
282 }
283 args
284 }),
285 )
286 .await?;
287 parse_gateway_tools_payload(&result)
288 }
289}
290
291fn is_gateway_meta_tool(tool_name: &str) -> bool {
292 matches!(
293 tool_name,
294 META_SEARCH_TOOLS | META_EXECUTE_TOOL | META_REQUEST_CAPABILITY
295 )
296}
297
298fn append_request_capability_tool(tools: &mut Vec<KontextTool>, mcp_tools: &[KontextTool]) {
299 if tools
300 .iter()
301 .any(|tool| tool.name == META_REQUEST_CAPABILITY)
302 {
303 return;
304 }
305 let Some(capability_tool) = mcp_tools
306 .iter()
307 .find(|tool| tool.name == META_REQUEST_CAPABILITY)
308 else {
309 return;
310 };
311
312 tools.push(capability_tool.clone());
313}
314
315pub fn create_kontext_client(config: KontextClientConfig) -> KontextClient {
316 KontextClient::new(config)
317}
318
319async fn build_connect_session_result(
320 client: &crate::KontextDevClient,
321 gateway_access_token: &str,
322) -> Result<ConnectSessionResult, KontextDevError> {
323 let connect_session = client.create_connect_session(gateway_access_token).await?;
324 let connect_url = client.integration_connect_url(&connect_session.session_id)?;
325
326 Ok(ConnectSessionResult {
327 connect_url,
328 session_id: connect_session.session_id,
329 expires_at: connect_session.expires_at,
330 })
331}
332
333fn parse_integration_status(payload: &GatewayToolsPayload) -> Vec<IntegrationInfo> {
334 let mut seen = std::collections::HashSet::<String>::new();
335 let mut out = Vec::new();
336
337 for tool in &payload.tools {
338 let Some(server) = tool.server.as_ref() else {
339 continue;
340 };
341 if !seen.insert(server.id.clone()) {
342 continue;
343 }
344 out.push(IntegrationInfo {
345 id: server.id.clone(),
346 name: server.name.clone().unwrap_or_else(|| server.id.clone()),
347 connected: true,
348 connect_url: None,
349 reason: None,
350 });
351 }
352
353 for GatewayToolError {
354 server_id,
355 server_name,
356 reason,
357 } in &payload.errors
358 {
359 if !seen.insert(server_id.clone()) {
360 continue;
361 }
362 let connect_url = payload.elicitations.iter().find_map(
363 |GatewayElicitation {
364 url,
365 integration_id,
366 ..
367 }| {
368 if integration_id.as_deref() == Some(server_id.as_str()) {
369 Some(url.clone())
370 } else {
371 None
372 }
373 },
374 );
375 out.push(IntegrationInfo {
376 id: server_id.clone(),
377 name: server_name.clone().unwrap_or_else(|| server_id.clone()),
378 connected: false,
379 connect_url,
380 reason: reason.clone(),
381 });
382 }
383
384 out
385}
386
387fn is_auth_error(err: &KontextDevError) -> bool {
388 matches!(
389 err,
390 KontextDevError::OAuthCallbackTimeout { .. }
391 | KontextDevError::OAuthCallbackCancelled
392 | KontextDevError::MissingAuthorizationCode
393 | KontextDevError::OAuthCallbackError { .. }
394 | KontextDevError::InvalidOAuthState
395 | KontextDevError::TokenRequest { .. }
396 | KontextDevError::TokenExchange { .. }
397 )
398}
399
400#[cfg(test)]
401mod tests {
402 use super::*;
403 use wiremock::Mock;
404 use wiremock::MockServer;
405 use wiremock::ResponseTemplate;
406 use wiremock::matchers::body_partial_json;
407 use wiremock::matchers::header;
408 use wiremock::matchers::method;
409 use wiremock::matchers::path;
410
411 #[tokio::test]
412 async fn build_connect_session_result_uses_one_session_and_matching_connect_url() {
413 let server = MockServer::start().await;
414 let session_id = "session-123";
415 let expires_at = "2030-01-01T00:00:00Z";
416 let access_token = "test-gateway-token";
417
418 Mock::given(method("POST"))
419 .and(path("/mcp/connect-session"))
420 .and(header("authorization", format!("Bearer {access_token}")))
421 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
422 "sessionId": session_id,
423 "expiresAt": expires_at
424 })))
425 .expect(1)
426 .mount(&server)
427 .await;
428
429 let client = crate::KontextDevClient::new(crate::KontextDevConfig {
430 server: server.uri(),
431 client_id: "client-id".to_string(),
432 client_secret: None,
433 scope: "".to_string(),
434 server_name: "kontext-dev".to_string(),
435 resource: "mcp-gateway".to_string(),
436 integration_ui_url: Some("https://app.kontext.dev".to_string()),
437 integration_return_to: None,
438 open_connect_page_on_login: true,
439 auth_timeout_seconds: 300,
440 token_cache_path: None,
441 redirect_uri: "http://localhost:3333/callback".to_string(),
442 });
443
444 let result = build_connect_session_result(&client, access_token)
445 .await
446 .expect("connect session result should be built");
447
448 assert_eq!(result.session_id, session_id);
449 assert_eq!(result.expires_at, expires_at);
450 assert_eq!(
451 result.connect_url,
452 format!("https://app.kontext.dev/oauth/connect?session={session_id}")
453 );
454 }
455
456 #[tokio::test]
457 async fn disconnect_clears_cached_mcp_session_state() {
458 let server = MockServer::start().await;
459 let access_token = "test-gateway-token";
460
461 Mock::given(method("POST"))
462 .and(path("/mcp"))
463 .and(body_partial_json(serde_json::json!({
464 "method": "initialize"
465 })))
466 .respond_with(
467 ResponseTemplate::new(200)
468 .append_header("Mcp-Session-Id", "session-123")
469 .set_body_json(serde_json::json!({
470 "jsonrpc": "2.0",
471 "id": "initialize",
472 "result": {
473 "sessionId": "session-123"
474 }
475 })),
476 )
477 .expect(2)
478 .mount(&server)
479 .await;
480
481 Mock::given(method("POST"))
482 .and(path("/mcp"))
483 .and(body_partial_json(serde_json::json!({
484 "method": "notifications/initialized"
485 })))
486 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
487 "jsonrpc": "2.0",
488 "result": {}
489 })))
490 .expect(2)
491 .mount(&server)
492 .await;
493
494 Mock::given(method("POST"))
495 .and(path("/mcp"))
496 .and(body_partial_json(serde_json::json!({
497 "method": "tools/list"
498 })))
499 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
500 "jsonrpc": "2.0",
501 "id": "list-tools",
502 "result": {
503 "tools": []
504 }
505 })))
506 .expect(2)
507 .mount(&server)
508 .await;
509
510 let client = KontextClient::new(KontextClientConfig {
511 client_session_id: "client-session".to_string(),
512 client_id: "client-id".to_string(),
513 redirect_uri: "http://localhost:3333/callback".to_string(),
514 url: Some(format!("{}/mcp", server.uri())),
515 server_url: Some(server.uri()),
516 client_secret: None,
517 scope: None,
518 resource: None,
519 integration_ui_url: None,
520 integration_return_to: None,
521 auth_timeout_seconds: None,
522 token_cache_path: None,
523 });
524
525 client
526 .mcp()
527 .list_tools_with_access_token(access_token)
528 .await
529 .expect("first tools/list should initialize and succeed");
530 client.disconnect().await;
531 client
532 .mcp()
533 .list_tools_with_access_token(access_token)
534 .await
535 .expect("second tools/list should re-initialize after disconnect");
536 }
537}