1use std::sync::Arc;
2
3use serde_json::{Map, Value, json};
4use tokio::sync::RwLock;
5
6use crate::KontextDevError;
7use crate::mcp::{
8 GatewayElicitation, GatewayToolError, GatewayToolsPayload, KontextMcp, KontextMcpConfig,
9 KontextTool, extract_text_content, has_meta_gateway_tools, parse_gateway_tools_payload,
10};
11use crate::prompt_guidance::{KontextPromptGuidance, build_kontext_prompt_guidance};
12
13const META_SEARCH_TOOLS: &str = "SEARCH_TOOLS";
14const META_EXECUTE_TOOL: &str = "EXECUTE_TOOL";
15const META_REQUEST_CAPABILITY: &str = "REQUEST_CAPABILITY";
16
17#[derive(Clone, Copy, Debug, PartialEq, Eq)]
18pub enum ClientState {
19 Idle,
20 Connecting,
21 Ready,
22 NeedsAuth,
23 Failed,
24}
25
26#[derive(Clone, Debug)]
27pub struct KontextClientConfig {
28 pub client_session_id: String,
29 pub client_id: String,
30 pub redirect_uri: String,
31 pub url: Option<String>,
32 pub server_url: Option<String>,
33 pub client_secret: Option<String>,
34 pub scope: Option<String>,
35 pub resource: Option<String>,
36 pub integration_ui_url: Option<String>,
37 pub integration_return_to: Option<String>,
38 pub auth_timeout_seconds: Option<i64>,
39 pub token_cache_path: Option<String>,
40}
41
42#[derive(Clone, Debug)]
43pub struct IntegrationInfo {
44 pub id: String,
45 pub name: String,
46 pub connected: bool,
47 pub connect_url: Option<String>,
48 pub reason: Option<String>,
49}
50
51#[derive(Clone, Debug)]
52pub struct ToolResult {
53 pub content: String,
54 pub raw: serde_json::Value,
55}
56
57#[derive(Clone, Debug)]
58pub struct ConnectSessionResult {
59 pub connect_url: String,
60 pub session_id: String,
61 pub expires_at: String,
62}
63
64#[derive(Clone)]
65pub struct KontextClient {
66 state: Arc<RwLock<ClientState>>,
67 mcp: KontextMcp,
68 meta_tool_mode: Arc<RwLock<Option<bool>>>,
69}
70
71impl KontextClient {
72 pub fn new(config: KontextClientConfig) -> Self {
73 let mcp = KontextMcp::new(KontextMcpConfig {
74 client_session_id: config.client_session_id,
75 client_id: config.client_id,
76 redirect_uri: config.redirect_uri,
77 url: config.url,
78 server: config.server_url,
79 client_secret: config.client_secret,
80 scope: config.scope,
81 resource: config.resource,
82 session_key: None,
83 integration_ui_url: config.integration_ui_url,
84 integration_return_to: config.integration_return_to,
85 auth_timeout_seconds: config.auth_timeout_seconds,
86 open_connect_page_on_login: Some(true),
87 token_cache_path: config.token_cache_path,
88 });
89
90 Self {
91 state: Arc::new(RwLock::new(ClientState::Idle)),
92 mcp,
93 meta_tool_mode: Arc::new(RwLock::new(None)),
94 }
95 }
96
97 pub async fn state(&self) -> ClientState {
98 *self.state.read().await
99 }
100
101 pub fn mcp(&self) -> &KontextMcp {
102 &self.mcp
103 }
104
105 pub async fn connect(&self) -> Result<(), KontextDevError> {
106 {
107 let mut state = self.state.write().await;
108 *state = ClientState::Connecting;
109 }
110
111 match self.mcp.list_tools().await {
112 Ok(tools) => {
113 let mut mode = self.meta_tool_mode.write().await;
114 *mode = Some(has_meta_gateway_tools(&tools));
115 let mut state = self.state.write().await;
116 *state = ClientState::Ready;
117 Ok(())
118 }
119 Err(err) => {
120 let mut state = self.state.write().await;
121 *state = if is_auth_error(&err) {
122 ClientState::NeedsAuth
123 } else {
124 ClientState::Failed
125 };
126 Err(err)
127 }
128 }
129 }
130
131 pub async fn disconnect(&self) {
132 let mut mode = self.meta_tool_mode.write().await;
133 *mode = None;
134 let mut state = self.state.write().await;
135 *state = ClientState::Idle;
136 }
137
138 pub async fn get_connect_page_url(&self) -> Result<ConnectSessionResult, KontextDevError> {
139 let session = self.mcp.authenticate_mcp().await?;
140 build_connect_session_result(self.mcp.client(), &session.gateway_token.access_token).await
141 }
142
143 pub async fn sign_in(&self) -> Result<(), KontextDevError> {
144 self.connect().await
145 }
146
147 pub async fn sign_out(&self) -> Result<(), KontextDevError> {
148 self.mcp.client().clear_token_cache()?;
149 self.disconnect().await;
150 Ok(())
151 }
152
153 pub async fn integrations_list(&self) -> Result<Vec<IntegrationInfo>, KontextDevError> {
154 self.ensure_connected().await?;
155
156 if self.is_meta_tool_mode().await? {
157 let payload = self.fetch_gateway_tools(Some(100)).await?;
158 return Ok(parse_integration_status(&payload));
159 }
160
161 let records = self.mcp.list_integrations().await?;
162 Ok(records
163 .into_iter()
164 .map(|record| IntegrationInfo {
165 id: record.id,
166 name: record.name,
167 connected: record
168 .connection
169 .as_ref()
170 .map(|c| c.connected)
171 .unwrap_or(false),
172 connect_url: None,
173 reason: None,
174 })
175 .collect())
176 }
177
178 pub async fn tools_list(&self) -> Result<Vec<KontextTool>, KontextDevError> {
179 self.ensure_connected().await?;
180
181 let mcp_tools = self.mcp.list_tools().await?;
182 let non_meta = mcp_tools
183 .iter()
184 .filter(|tool| !is_gateway_meta_tool(tool.name.as_str()))
185 .cloned()
186 .collect::<Vec<_>>();
187
188 if !non_meta.is_empty() || !has_meta_gateway_tools(&mcp_tools) {
189 let mut mode = self.meta_tool_mode.write().await;
190 *mode = Some(false);
191 return Ok(non_meta);
192 }
193
194 let mut mode = self.meta_tool_mode.write().await;
195 *mode = Some(true);
196 drop(mode);
197
198 let payload = self.fetch_gateway_tools(Some(100)).await?;
199 let mut tools = payload.tools;
200 append_request_capability_tool(&mut tools, mcp_tools.as_slice());
201 Ok(tools)
202 }
203
204 pub async fn tools_execute(
205 &self,
206 tool_id: &str,
207 args: Option<serde_json::Map<String, serde_json::Value>>,
208 ) -> Result<ToolResult, KontextDevError> {
209 self.ensure_connected().await?;
210
211 let raw = if self.is_meta_tool_mode().await? {
212 if tool_id == META_REQUEST_CAPABILITY {
213 self.mcp.call_tool(META_REQUEST_CAPABILITY, args).await?
214 } else {
215 let mut execute_args = Map::new();
216 execute_args.insert("tool_id".to_string(), Value::String(tool_id.to_string()));
217 execute_args.insert(
218 "tool_arguments".to_string(),
219 Value::Object(args.unwrap_or_default()),
220 );
221 self.mcp
222 .call_tool(META_EXECUTE_TOOL, Some(execute_args))
223 .await?
224 }
225 } else {
226 self.mcp.call_tool(tool_id, args).await?
227 };
228
229 Ok(ToolResult {
230 content: extract_text_content(&raw),
231 raw,
232 })
233 }
234
235 pub async fn prompt_guidance(&self) -> Result<KontextPromptGuidance, KontextDevError> {
236 let tools = self.tools_list().await?;
237 let integrations = self.integrations_list().await?;
238 let tool_names = tools
239 .into_iter()
240 .map(|tool| tool.name)
241 .collect::<Vec<String>>();
242
243 Ok(build_kontext_prompt_guidance(
244 tool_names.as_slice(),
245 integrations.as_slice(),
246 ))
247 }
248
249 async fn ensure_connected(&self) -> Result<(), KontextDevError> {
250 let state = self.state().await;
251 if state == ClientState::Ready {
252 return Ok(());
253 }
254 self.connect().await
255 }
256
257 async fn is_meta_tool_mode(&self) -> Result<bool, KontextDevError> {
258 if let Some(mode) = *self.meta_tool_mode.read().await {
259 return Ok(mode);
260 }
261
262 let tools = self.mcp.list_tools().await?;
263 let mode = has_meta_gateway_tools(&tools);
264 let mut lock = self.meta_tool_mode.write().await;
265 *lock = Some(mode);
266 Ok(mode)
267 }
268
269 async fn fetch_gateway_tools(
270 &self,
271 limit: Option<u32>,
272 ) -> Result<GatewayToolsPayload, KontextDevError> {
273 let result = self
274 .mcp
275 .call_tool(
276 META_SEARCH_TOOLS,
277 Some({
278 let mut args = Map::new();
279 if let Some(limit) = limit {
280 args.insert("limit".to_string(), json!(limit));
281 }
282 args
283 }),
284 )
285 .await?;
286 parse_gateway_tools_payload(&result)
287 }
288}
289
290fn is_gateway_meta_tool(tool_name: &str) -> bool {
291 matches!(
292 tool_name,
293 META_SEARCH_TOOLS | META_EXECUTE_TOOL | META_REQUEST_CAPABILITY
294 )
295}
296
297fn append_request_capability_tool(tools: &mut Vec<KontextTool>, mcp_tools: &[KontextTool]) {
298 if tools
299 .iter()
300 .any(|tool| tool.name == META_REQUEST_CAPABILITY)
301 {
302 return;
303 }
304 let Some(capability_tool) = mcp_tools
305 .iter()
306 .find(|tool| tool.name == META_REQUEST_CAPABILITY)
307 else {
308 return;
309 };
310
311 tools.push(capability_tool.clone());
312}
313
314pub fn create_kontext_client(config: KontextClientConfig) -> KontextClient {
315 KontextClient::new(config)
316}
317
318async fn build_connect_session_result(
319 client: &crate::KontextDevClient,
320 gateway_access_token: &str,
321) -> Result<ConnectSessionResult, KontextDevError> {
322 let connect_session = client.create_connect_session(gateway_access_token).await?;
323 let connect_url = client.integration_connect_url(&connect_session.session_id)?;
324
325 Ok(ConnectSessionResult {
326 connect_url,
327 session_id: connect_session.session_id,
328 expires_at: connect_session.expires_at,
329 })
330}
331
332fn parse_integration_status(payload: &GatewayToolsPayload) -> Vec<IntegrationInfo> {
333 let mut seen = std::collections::HashSet::<String>::new();
334 let mut out = Vec::new();
335
336 for tool in &payload.tools {
337 let Some(server) = tool.server.as_ref() else {
338 continue;
339 };
340 if !seen.insert(server.id.clone()) {
341 continue;
342 }
343 out.push(IntegrationInfo {
344 id: server.id.clone(),
345 name: server.name.clone().unwrap_or_else(|| server.id.clone()),
346 connected: true,
347 connect_url: None,
348 reason: None,
349 });
350 }
351
352 for GatewayToolError {
353 server_id,
354 server_name,
355 reason,
356 } in &payload.errors
357 {
358 if !seen.insert(server_id.clone()) {
359 continue;
360 }
361 let connect_url = payload.elicitations.iter().find_map(
362 |GatewayElicitation {
363 url,
364 integration_id,
365 ..
366 }| {
367 if integration_id.as_deref() == Some(server_id.as_str()) {
368 Some(url.clone())
369 } else {
370 None
371 }
372 },
373 );
374 out.push(IntegrationInfo {
375 id: server_id.clone(),
376 name: server_name.clone().unwrap_or_else(|| server_id.clone()),
377 connected: false,
378 connect_url,
379 reason: reason.clone(),
380 });
381 }
382
383 out
384}
385
386fn is_auth_error(err: &KontextDevError) -> bool {
387 matches!(
388 err,
389 KontextDevError::OAuthCallbackTimeout { .. }
390 | KontextDevError::OAuthCallbackCancelled
391 | KontextDevError::MissingAuthorizationCode
392 | KontextDevError::OAuthCallbackError { .. }
393 | KontextDevError::InvalidOAuthState
394 | KontextDevError::TokenRequest { .. }
395 | KontextDevError::TokenExchange { .. }
396 )
397}
398
399#[cfg(test)]
400mod tests {
401 use super::*;
402 use wiremock::Mock;
403 use wiremock::MockServer;
404 use wiremock::ResponseTemplate;
405 use wiremock::matchers::header;
406 use wiremock::matchers::method;
407 use wiremock::matchers::path;
408
409 #[tokio::test]
410 async fn build_connect_session_result_uses_one_session_and_matching_connect_url() {
411 let server = MockServer::start().await;
412 let session_id = "session-123";
413 let expires_at = "2030-01-01T00:00:00Z";
414 let access_token = "test-gateway-token";
415
416 Mock::given(method("POST"))
417 .and(path("/mcp/connect-session"))
418 .and(header("authorization", format!("Bearer {access_token}")))
419 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
420 "sessionId": session_id,
421 "expiresAt": expires_at
422 })))
423 .expect(1)
424 .mount(&server)
425 .await;
426
427 let client = crate::KontextDevClient::new(crate::KontextDevConfig {
428 server: server.uri(),
429 client_id: "client-id".to_string(),
430 client_secret: None,
431 scope: "".to_string(),
432 server_name: "kontext-dev".to_string(),
433 resource: "mcp-gateway".to_string(),
434 integration_ui_url: Some("https://app.kontext.dev".to_string()),
435 integration_return_to: None,
436 open_connect_page_on_login: true,
437 auth_timeout_seconds: 300,
438 token_cache_path: None,
439 redirect_uri: "http://localhost:3333/callback".to_string(),
440 });
441
442 let result = build_connect_session_result(&client, access_token)
443 .await
444 .expect("connect session result should be built");
445
446 assert_eq!(result.session_id, session_id);
447 assert_eq!(result.expires_at, expires_at);
448 assert_eq!(
449 result.connect_url,
450 format!("https://app.kontext.dev/oauth/connect?session={session_id}")
451 );
452 }
453}