1use std::sync::Arc;
2
3use serde_json::{Map, Value, json};
4use tokio::sync::RwLock;
5
6use crate::KontextDevError;
7use crate::mcp::{
8 GatewayElicitation, GatewayToolError, GatewayToolsPayload, KontextMcp, KontextMcpConfig,
9 KontextTool, extract_text_content, has_meta_gateway_tools, parse_gateway_tools_payload,
10};
11
12const META_SEARCH_TOOLS: &str = "SEARCH_TOOLS";
13const META_EXECUTE_TOOL: &str = "EXECUTE_TOOL";
14
15#[derive(Clone, Copy, Debug, PartialEq, Eq)]
16pub enum ClientState {
17 Idle,
18 Connecting,
19 Ready,
20 NeedsAuth,
21 Failed,
22}
23
24#[derive(Clone, Debug)]
25pub struct KontextClientConfig {
26 pub client_id: String,
27 pub redirect_uri: String,
28 pub url: Option<String>,
29 pub server_url: Option<String>,
30 pub client_secret: Option<String>,
31 pub scope: Option<String>,
32 pub resource: Option<String>,
33 pub integration_ui_url: Option<String>,
34 pub integration_return_to: Option<String>,
35 pub auth_timeout_seconds: Option<i64>,
36 pub token_cache_path: Option<String>,
37}
38
39#[derive(Clone, Debug)]
40pub struct IntegrationInfo {
41 pub id: String,
42 pub name: String,
43 pub connected: bool,
44 pub connect_url: Option<String>,
45 pub reason: Option<String>,
46}
47
48#[derive(Clone, Debug)]
49pub struct ToolResult {
50 pub content: String,
51 pub raw: serde_json::Value,
52}
53
54#[derive(Clone, Debug)]
55pub struct ConnectSessionResult {
56 pub connect_url: String,
57 pub session_id: String,
58 pub expires_at: String,
59}
60
61#[derive(Clone)]
62pub struct KontextClient {
63 state: Arc<RwLock<ClientState>>,
64 mcp: KontextMcp,
65 meta_tool_mode: Arc<RwLock<Option<bool>>>,
66}
67
68impl KontextClient {
69 pub fn new(config: KontextClientConfig) -> Self {
70 let mcp = KontextMcp::new(KontextMcpConfig {
71 client_id: config.client_id,
72 redirect_uri: config.redirect_uri,
73 url: config.url,
74 server: config.server_url,
75 client_secret: config.client_secret,
76 scope: config.scope,
77 resource: config.resource,
78 session_key: None,
79 integration_ui_url: config.integration_ui_url,
80 integration_return_to: config.integration_return_to,
81 auth_timeout_seconds: config.auth_timeout_seconds,
82 open_connect_page_on_login: Some(true),
83 token_cache_path: config.token_cache_path,
84 });
85
86 Self {
87 state: Arc::new(RwLock::new(ClientState::Idle)),
88 mcp,
89 meta_tool_mode: Arc::new(RwLock::new(None)),
90 }
91 }
92
93 pub async fn state(&self) -> ClientState {
94 *self.state.read().await
95 }
96
97 pub fn mcp(&self) -> &KontextMcp {
98 &self.mcp
99 }
100
101 pub async fn connect(&self) -> Result<(), KontextDevError> {
102 {
103 let mut state = self.state.write().await;
104 *state = ClientState::Connecting;
105 }
106
107 match self.mcp.list_tools().await {
108 Ok(tools) => {
109 let mut mode = self.meta_tool_mode.write().await;
110 *mode = Some(has_meta_gateway_tools(&tools));
111 let mut state = self.state.write().await;
112 *state = ClientState::Ready;
113 Ok(())
114 }
115 Err(err) => {
116 let mut state = self.state.write().await;
117 *state = if is_auth_error(&err) {
118 ClientState::NeedsAuth
119 } else {
120 ClientState::Failed
121 };
122 Err(err)
123 }
124 }
125 }
126
127 pub async fn disconnect(&self) {
128 let mut mode = self.meta_tool_mode.write().await;
129 *mode = None;
130 let mut state = self.state.write().await;
131 *state = ClientState::Idle;
132 }
133
134 pub async fn get_connect_page_url(&self) -> Result<ConnectSessionResult, KontextDevError> {
135 let session = self.mcp.authenticate_mcp().await?;
136 build_connect_session_result(self.mcp.client(), &session.gateway_token.access_token).await
137 }
138
139 pub async fn sign_in(&self) -> Result<(), KontextDevError> {
140 self.connect().await
141 }
142
143 pub async fn sign_out(&self) -> Result<(), KontextDevError> {
144 self.mcp.client().clear_token_cache()?;
145 self.disconnect().await;
146 Ok(())
147 }
148
149 pub async fn integrations_list(&self) -> Result<Vec<IntegrationInfo>, KontextDevError> {
150 self.ensure_connected().await?;
151
152 if self.is_meta_tool_mode().await? {
153 let payload = self.fetch_gateway_tools(Some(100)).await?;
154 return Ok(parse_integration_status(&payload));
155 }
156
157 let records = self.mcp.list_integrations().await?;
158 Ok(records
159 .into_iter()
160 .map(|record| IntegrationInfo {
161 id: record.id,
162 name: record.name,
163 connected: record
164 .connection
165 .as_ref()
166 .map(|c| c.connected)
167 .unwrap_or(false),
168 connect_url: None,
169 reason: None,
170 })
171 .collect())
172 }
173
174 pub async fn tools_list(&self) -> Result<Vec<KontextTool>, KontextDevError> {
175 self.ensure_connected().await?;
176
177 let mcp_tools = self.mcp.list_tools().await?;
178 let non_meta = mcp_tools
179 .iter()
180 .filter(|tool| tool.name != META_SEARCH_TOOLS && tool.name != META_EXECUTE_TOOL)
181 .cloned()
182 .collect::<Vec<_>>();
183
184 if !non_meta.is_empty() || !has_meta_gateway_tools(&mcp_tools) {
185 let mut mode = self.meta_tool_mode.write().await;
186 *mode = Some(false);
187 return Ok(non_meta);
188 }
189
190 let mut mode = self.meta_tool_mode.write().await;
191 *mode = Some(true);
192 drop(mode);
193
194 let payload = self.fetch_gateway_tools(Some(100)).await?;
195 Ok(payload.tools)
196 }
197
198 pub async fn tools_execute(
199 &self,
200 tool_id: &str,
201 args: Option<serde_json::Map<String, serde_json::Value>>,
202 ) -> Result<ToolResult, KontextDevError> {
203 self.ensure_connected().await?;
204
205 let raw = if self.is_meta_tool_mode().await? {
206 let mut execute_args = Map::new();
207 execute_args.insert("tool_id".to_string(), Value::String(tool_id.to_string()));
208 execute_args.insert(
209 "tool_arguments".to_string(),
210 Value::Object(args.unwrap_or_default()),
211 );
212 self.mcp
213 .call_tool(META_EXECUTE_TOOL, Some(execute_args))
214 .await?
215 } else {
216 self.mcp.call_tool(tool_id, args).await?
217 };
218
219 Ok(ToolResult {
220 content: extract_text_content(&raw),
221 raw,
222 })
223 }
224
225 async fn ensure_connected(&self) -> Result<(), KontextDevError> {
226 let state = self.state().await;
227 if state == ClientState::Ready {
228 return Ok(());
229 }
230 self.connect().await
231 }
232
233 async fn is_meta_tool_mode(&self) -> Result<bool, KontextDevError> {
234 if let Some(mode) = *self.meta_tool_mode.read().await {
235 return Ok(mode);
236 }
237
238 let tools = self.mcp.list_tools().await?;
239 let mode = has_meta_gateway_tools(&tools);
240 let mut lock = self.meta_tool_mode.write().await;
241 *lock = Some(mode);
242 Ok(mode)
243 }
244
245 async fn fetch_gateway_tools(
246 &self,
247 limit: Option<u32>,
248 ) -> Result<GatewayToolsPayload, KontextDevError> {
249 let result = self
250 .mcp
251 .call_tool(
252 META_SEARCH_TOOLS,
253 Some({
254 let mut args = Map::new();
255 if let Some(limit) = limit {
256 args.insert("limit".to_string(), json!(limit));
257 }
258 args
259 }),
260 )
261 .await?;
262 parse_gateway_tools_payload(&result)
263 }
264}
265
266pub fn create_kontext_client(config: KontextClientConfig) -> KontextClient {
267 KontextClient::new(config)
268}
269
270async fn build_connect_session_result(
271 client: &crate::KontextDevClient,
272 gateway_access_token: &str,
273) -> Result<ConnectSessionResult, KontextDevError> {
274 let connect_session = client.create_connect_session(gateway_access_token).await?;
275 let connect_url = client.integration_connect_url(&connect_session.session_id)?;
276
277 Ok(ConnectSessionResult {
278 connect_url,
279 session_id: connect_session.session_id,
280 expires_at: connect_session.expires_at,
281 })
282}
283
284fn parse_integration_status(payload: &GatewayToolsPayload) -> Vec<IntegrationInfo> {
285 let mut seen = std::collections::HashSet::<String>::new();
286 let mut out = Vec::new();
287
288 for tool in &payload.tools {
289 let Some(server) = tool.server.as_ref() else {
290 continue;
291 };
292 if !seen.insert(server.id.clone()) {
293 continue;
294 }
295 out.push(IntegrationInfo {
296 id: server.id.clone(),
297 name: server.name.clone().unwrap_or_else(|| server.id.clone()),
298 connected: true,
299 connect_url: None,
300 reason: None,
301 });
302 }
303
304 for GatewayToolError {
305 server_id,
306 server_name,
307 reason,
308 } in &payload.errors
309 {
310 if !seen.insert(server_id.clone()) {
311 continue;
312 }
313 let connect_url = payload.elicitations.iter().find_map(
314 |GatewayElicitation {
315 url,
316 integration_id,
317 ..
318 }| {
319 if integration_id.as_deref() == Some(server_id.as_str()) {
320 Some(url.clone())
321 } else {
322 None
323 }
324 },
325 );
326 out.push(IntegrationInfo {
327 id: server_id.clone(),
328 name: server_name.clone().unwrap_or_else(|| server_id.clone()),
329 connected: false,
330 connect_url,
331 reason: reason.clone(),
332 });
333 }
334
335 out
336}
337
338fn is_auth_error(err: &KontextDevError) -> bool {
339 matches!(
340 err,
341 KontextDevError::OAuthCallbackTimeout { .. }
342 | KontextDevError::OAuthCallbackCancelled
343 | KontextDevError::MissingAuthorizationCode
344 | KontextDevError::OAuthCallbackError { .. }
345 | KontextDevError::InvalidOAuthState
346 | KontextDevError::TokenRequest { .. }
347 | KontextDevError::TokenExchange { .. }
348 )
349}
350
351#[cfg(test)]
352mod tests {
353 use super::*;
354 use wiremock::Mock;
355 use wiremock::MockServer;
356 use wiremock::ResponseTemplate;
357 use wiremock::matchers::header;
358 use wiremock::matchers::method;
359 use wiremock::matchers::path;
360
361 #[tokio::test]
362 async fn build_connect_session_result_uses_one_session_and_matching_connect_url() {
363 let server = MockServer::start().await;
364 let session_id = "session-123";
365 let expires_at = "2030-01-01T00:00:00Z";
366 let access_token = "test-gateway-token";
367
368 Mock::given(method("POST"))
369 .and(path("/mcp/connect-session"))
370 .and(header("authorization", format!("Bearer {access_token}")))
371 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
372 "sessionId": session_id,
373 "expiresAt": expires_at
374 })))
375 .expect(1)
376 .mount(&server)
377 .await;
378
379 let client = crate::KontextDevClient::new(crate::KontextDevConfig {
380 server: server.uri(),
381 client_id: "client-id".to_string(),
382 client_secret: None,
383 scope: "".to_string(),
384 server_name: "kontext-dev".to_string(),
385 resource: "mcp-gateway".to_string(),
386 integration_ui_url: Some("https://app.kontext.dev".to_string()),
387 integration_return_to: None,
388 open_connect_page_on_login: true,
389 auth_timeout_seconds: 300,
390 token_cache_path: None,
391 redirect_uri: "http://localhost:3333/callback".to_string(),
392 });
393
394 let result = build_connect_session_result(&client, access_token)
395 .await
396 .expect("connect session result should be built");
397
398 assert_eq!(result.session_id, session_id);
399 assert_eq!(result.expires_at, expires_at);
400 assert_eq!(
401 result.connect_url,
402 format!("https://app.kontext.dev/oauth/connect?session={session_id}")
403 );
404 }
405}