1use std::sync::Arc;
2
3use serde_json::{Map, Value, json};
4use tokio::sync::RwLock;
5
6use crate::KontextDevError;
7use crate::mcp::{
8 GatewayElicitation, GatewayToolError, GatewayToolsPayload, KontextMcp, KontextMcpConfig,
9 KontextTool, extract_text_content, has_meta_gateway_tools, parse_gateway_tools_payload,
10};
11
12const META_SEARCH_TOOLS: &str = "SEARCH_TOOLS";
13const META_EXECUTE_TOOL: &str = "EXECUTE_TOOL";
14const META_REQUEST_CAPABILITY: &str = "REQUEST_CAPABILITY";
15
16#[derive(Clone, Copy, Debug, PartialEq, Eq)]
17pub enum ClientState {
18 Idle,
19 Connecting,
20 Ready,
21 NeedsAuth,
22 Failed,
23}
24
25#[derive(Clone, Debug)]
26pub struct KontextClientConfig {
27 pub client_id: String,
28 pub redirect_uri: String,
29 pub url: Option<String>,
30 pub server_url: Option<String>,
31 pub client_secret: Option<String>,
32 pub scope: Option<String>,
33 pub resource: Option<String>,
34 pub integration_ui_url: Option<String>,
35 pub integration_return_to: Option<String>,
36 pub auth_timeout_seconds: Option<i64>,
37 pub token_cache_path: Option<String>,
38}
39
40#[derive(Clone, Debug)]
41pub struct IntegrationInfo {
42 pub id: String,
43 pub name: String,
44 pub connected: bool,
45 pub connect_url: Option<String>,
46 pub reason: Option<String>,
47}
48
49#[derive(Clone, Debug)]
50pub struct ToolResult {
51 pub content: String,
52 pub raw: serde_json::Value,
53}
54
55#[derive(Clone, Debug)]
56pub struct ConnectSessionResult {
57 pub connect_url: String,
58 pub session_id: String,
59 pub expires_at: String,
60}
61
62#[derive(Clone)]
63pub struct KontextClient {
64 state: Arc<RwLock<ClientState>>,
65 mcp: KontextMcp,
66 meta_tool_mode: Arc<RwLock<Option<bool>>>,
67}
68
69impl KontextClient {
70 pub fn new(config: KontextClientConfig) -> Self {
71 let mcp = KontextMcp::new(KontextMcpConfig {
72 client_id: config.client_id,
73 redirect_uri: config.redirect_uri,
74 url: config.url,
75 server: config.server_url,
76 client_secret: config.client_secret,
77 scope: config.scope,
78 resource: config.resource,
79 session_key: None,
80 integration_ui_url: config.integration_ui_url,
81 integration_return_to: config.integration_return_to,
82 auth_timeout_seconds: config.auth_timeout_seconds,
83 open_connect_page_on_login: Some(true),
84 token_cache_path: config.token_cache_path,
85 });
86
87 Self {
88 state: Arc::new(RwLock::new(ClientState::Idle)),
89 mcp,
90 meta_tool_mode: Arc::new(RwLock::new(None)),
91 }
92 }
93
94 pub async fn state(&self) -> ClientState {
95 *self.state.read().await
96 }
97
98 pub fn mcp(&self) -> &KontextMcp {
99 &self.mcp
100 }
101
102 pub async fn connect(&self) -> Result<(), KontextDevError> {
103 {
104 let mut state = self.state.write().await;
105 *state = ClientState::Connecting;
106 }
107
108 match self.mcp.list_tools().await {
109 Ok(tools) => {
110 let mut mode = self.meta_tool_mode.write().await;
111 *mode = Some(has_meta_gateway_tools(&tools));
112 let mut state = self.state.write().await;
113 *state = ClientState::Ready;
114 Ok(())
115 }
116 Err(err) => {
117 let mut state = self.state.write().await;
118 *state = if is_auth_error(&err) {
119 ClientState::NeedsAuth
120 } else {
121 ClientState::Failed
122 };
123 Err(err)
124 }
125 }
126 }
127
128 pub async fn disconnect(&self) {
129 let mut mode = self.meta_tool_mode.write().await;
130 *mode = None;
131 let mut state = self.state.write().await;
132 *state = ClientState::Idle;
133 }
134
135 pub async fn get_connect_page_url(&self) -> Result<ConnectSessionResult, KontextDevError> {
136 let session = self.mcp.authenticate_mcp().await?;
137 build_connect_session_result(self.mcp.client(), &session.gateway_token.access_token).await
138 }
139
140 pub async fn sign_in(&self) -> Result<(), KontextDevError> {
141 self.connect().await
142 }
143
144 pub async fn sign_out(&self) -> Result<(), KontextDevError> {
145 self.mcp.client().clear_token_cache()?;
146 self.disconnect().await;
147 Ok(())
148 }
149
150 pub async fn integrations_list(&self) -> Result<Vec<IntegrationInfo>, KontextDevError> {
151 self.ensure_connected().await?;
152
153 if self.is_meta_tool_mode().await? {
154 let payload = self.fetch_gateway_tools(Some(100)).await?;
155 return Ok(parse_integration_status(&payload));
156 }
157
158 let records = self.mcp.list_integrations().await?;
159 Ok(records
160 .into_iter()
161 .map(|record| IntegrationInfo {
162 id: record.id,
163 name: record.name,
164 connected: record
165 .connection
166 .as_ref()
167 .map(|c| c.connected)
168 .unwrap_or(false),
169 connect_url: None,
170 reason: None,
171 })
172 .collect())
173 }
174
175 pub async fn tools_list(&self) -> Result<Vec<KontextTool>, KontextDevError> {
176 self.ensure_connected().await?;
177
178 let mcp_tools = self.mcp.list_tools().await?;
179 let non_meta = mcp_tools
180 .iter()
181 .filter(|tool| !is_gateway_meta_tool(tool.name.as_str()))
182 .cloned()
183 .collect::<Vec<_>>();
184
185 if !non_meta.is_empty() || !has_meta_gateway_tools(&mcp_tools) {
186 let mut mode = self.meta_tool_mode.write().await;
187 *mode = Some(false);
188 return Ok(non_meta);
189 }
190
191 let mut mode = self.meta_tool_mode.write().await;
192 *mode = Some(true);
193 drop(mode);
194
195 let payload = self.fetch_gateway_tools(Some(100)).await?;
196 let mut tools = payload.tools;
197 append_request_capability_tool(&mut tools, mcp_tools.as_slice());
198 Ok(tools)
199 }
200
201 pub async fn tools_execute(
202 &self,
203 tool_id: &str,
204 args: Option<serde_json::Map<String, serde_json::Value>>,
205 ) -> Result<ToolResult, KontextDevError> {
206 self.ensure_connected().await?;
207
208 let raw = if self.is_meta_tool_mode().await? {
209 if tool_id == META_REQUEST_CAPABILITY {
210 self.mcp.call_tool(META_REQUEST_CAPABILITY, args).await?
211 } else {
212 let mut execute_args = Map::new();
213 execute_args.insert("tool_id".to_string(), Value::String(tool_id.to_string()));
214 execute_args.insert(
215 "tool_arguments".to_string(),
216 Value::Object(args.unwrap_or_default()),
217 );
218 self.mcp
219 .call_tool(META_EXECUTE_TOOL, Some(execute_args))
220 .await?
221 }
222 } else {
223 self.mcp.call_tool(tool_id, args).await?
224 };
225
226 Ok(ToolResult {
227 content: extract_text_content(&raw),
228 raw,
229 })
230 }
231
232 async fn ensure_connected(&self) -> Result<(), KontextDevError> {
233 let state = self.state().await;
234 if state == ClientState::Ready {
235 return Ok(());
236 }
237 self.connect().await
238 }
239
240 async fn is_meta_tool_mode(&self) -> Result<bool, KontextDevError> {
241 if let Some(mode) = *self.meta_tool_mode.read().await {
242 return Ok(mode);
243 }
244
245 let tools = self.mcp.list_tools().await?;
246 let mode = has_meta_gateway_tools(&tools);
247 let mut lock = self.meta_tool_mode.write().await;
248 *lock = Some(mode);
249 Ok(mode)
250 }
251
252 async fn fetch_gateway_tools(
253 &self,
254 limit: Option<u32>,
255 ) -> Result<GatewayToolsPayload, KontextDevError> {
256 let result = self
257 .mcp
258 .call_tool(
259 META_SEARCH_TOOLS,
260 Some({
261 let mut args = Map::new();
262 if let Some(limit) = limit {
263 args.insert("limit".to_string(), json!(limit));
264 }
265 args
266 }),
267 )
268 .await?;
269 parse_gateway_tools_payload(&result)
270 }
271}
272
273fn is_gateway_meta_tool(tool_name: &str) -> bool {
274 matches!(
275 tool_name,
276 META_SEARCH_TOOLS | META_EXECUTE_TOOL | META_REQUEST_CAPABILITY
277 )
278}
279
280fn append_request_capability_tool(tools: &mut Vec<KontextTool>, mcp_tools: &[KontextTool]) {
281 if tools
282 .iter()
283 .any(|tool| tool.name == META_REQUEST_CAPABILITY)
284 {
285 return;
286 }
287 let Some(capability_tool) = mcp_tools
288 .iter()
289 .find(|tool| tool.name == META_REQUEST_CAPABILITY)
290 else {
291 return;
292 };
293
294 tools.push(capability_tool.clone());
295}
296
297pub fn create_kontext_client(config: KontextClientConfig) -> KontextClient {
298 KontextClient::new(config)
299}
300
301async fn build_connect_session_result(
302 client: &crate::KontextDevClient,
303 gateway_access_token: &str,
304) -> Result<ConnectSessionResult, KontextDevError> {
305 let connect_session = client.create_connect_session(gateway_access_token).await?;
306 let connect_url = client.integration_connect_url(&connect_session.session_id)?;
307
308 Ok(ConnectSessionResult {
309 connect_url,
310 session_id: connect_session.session_id,
311 expires_at: connect_session.expires_at,
312 })
313}
314
315fn parse_integration_status(payload: &GatewayToolsPayload) -> Vec<IntegrationInfo> {
316 let mut seen = std::collections::HashSet::<String>::new();
317 let mut out = Vec::new();
318
319 for tool in &payload.tools {
320 let Some(server) = tool.server.as_ref() else {
321 continue;
322 };
323 if !seen.insert(server.id.clone()) {
324 continue;
325 }
326 out.push(IntegrationInfo {
327 id: server.id.clone(),
328 name: server.name.clone().unwrap_or_else(|| server.id.clone()),
329 connected: true,
330 connect_url: None,
331 reason: None,
332 });
333 }
334
335 for GatewayToolError {
336 server_id,
337 server_name,
338 reason,
339 } in &payload.errors
340 {
341 if !seen.insert(server_id.clone()) {
342 continue;
343 }
344 let connect_url = payload.elicitations.iter().find_map(
345 |GatewayElicitation {
346 url,
347 integration_id,
348 ..
349 }| {
350 if integration_id.as_deref() == Some(server_id.as_str()) {
351 Some(url.clone())
352 } else {
353 None
354 }
355 },
356 );
357 out.push(IntegrationInfo {
358 id: server_id.clone(),
359 name: server_name.clone().unwrap_or_else(|| server_id.clone()),
360 connected: false,
361 connect_url,
362 reason: reason.clone(),
363 });
364 }
365
366 out
367}
368
369fn is_auth_error(err: &KontextDevError) -> bool {
370 matches!(
371 err,
372 KontextDevError::OAuthCallbackTimeout { .. }
373 | KontextDevError::OAuthCallbackCancelled
374 | KontextDevError::MissingAuthorizationCode
375 | KontextDevError::OAuthCallbackError { .. }
376 | KontextDevError::InvalidOAuthState
377 | KontextDevError::TokenRequest { .. }
378 | KontextDevError::TokenExchange { .. }
379 )
380}
381
382#[cfg(test)]
383mod tests {
384 use super::*;
385 use wiremock::Mock;
386 use wiremock::MockServer;
387 use wiremock::ResponseTemplate;
388 use wiremock::matchers::header;
389 use wiremock::matchers::method;
390 use wiremock::matchers::path;
391
392 #[tokio::test]
393 async fn build_connect_session_result_uses_one_session_and_matching_connect_url() {
394 let server = MockServer::start().await;
395 let session_id = "session-123";
396 let expires_at = "2030-01-01T00:00:00Z";
397 let access_token = "test-gateway-token";
398
399 Mock::given(method("POST"))
400 .and(path("/mcp/connect-session"))
401 .and(header("authorization", format!("Bearer {access_token}")))
402 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
403 "sessionId": session_id,
404 "expiresAt": expires_at
405 })))
406 .expect(1)
407 .mount(&server)
408 .await;
409
410 let client = crate::KontextDevClient::new(crate::KontextDevConfig {
411 server: server.uri(),
412 client_id: "client-id".to_string(),
413 client_secret: None,
414 scope: "".to_string(),
415 server_name: "kontext-dev".to_string(),
416 resource: "mcp-gateway".to_string(),
417 integration_ui_url: Some("https://app.kontext.dev".to_string()),
418 integration_return_to: None,
419 open_connect_page_on_login: true,
420 auth_timeout_seconds: 300,
421 token_cache_path: None,
422 redirect_uri: "http://localhost:3333/callback".to_string(),
423 });
424
425 let result = build_connect_session_result(&client, access_token)
426 .await
427 .expect("connect session result should be built");
428
429 assert_eq!(result.session_id, session_id);
430 assert_eq!(result.expires_at, expires_at);
431 assert_eq!(
432 result.connect_url,
433 format!("https://app.kontext.dev/oauth/connect?session={session_id}")
434 );
435 }
436}