1use base64::Engine;
2use reqwest::Client;
3use reqwest::StatusCode;
4use serde::Deserialize;
5use serde::Serialize;
6use serde::de::DeserializeOwned;
7use sha2::Digest;
8use std::fs;
9use std::path::PathBuf;
10use std::sync::Arc;
11use std::time::Duration;
12use std::time::SystemTime;
13use std::time::UNIX_EPOCH;
14use tiny_http::Response;
15use tiny_http::Server;
16use tokio::sync::oneshot;
17use tokio::time::timeout;
18use url::Url;
19
20pub mod client;
21pub mod errors;
22mod http_client;
23pub mod management;
24pub mod mcp;
25pub mod oauth;
26pub mod orchestrator;
27pub mod server;
28pub mod verify;
29
30pub use client::ClientState;
31pub use client::ConnectSessionResult;
32pub use client::IntegrationInfo;
33pub use client::KontextClient as SdkKontextClient;
34pub use client::KontextClientConfig;
35pub use client::ToolResult;
36pub use client::create_kontext_client;
37pub use errors::*;
38pub use management::KontextManagementClient;
39pub use management::KontextManagementClientConfig;
40pub use management::*;
41pub use mcp::KontextMcp;
42pub use mcp::KontextMcpConfig;
43pub use mcp::RuntimeIntegrationCategory;
44pub use mcp::RuntimeIntegrationConnectType;
45pub use mcp::RuntimeIntegrationRecord;
46pub use oauth::KontextOAuthProvider;
47pub use oauth::KontextOAuthProviderConfig;
48pub use oauth::ParsedOAuthCallback;
49pub use oauth::TokenExchangeConfig;
50pub use oauth::exchange_token;
51pub use oauth::parse_oauth_callback;
52pub use orchestrator::KontextOrchestrator;
53pub use orchestrator::KontextOrchestratorConfig;
54pub use orchestrator::KontextOrchestratorState;
55pub use orchestrator::create_kontext_orchestrator;
56pub use server::IntegrationCredential;
57pub use server::IntegrationName;
58pub use server::IntegrationResolvedCredentials;
59pub use server::KnownIntegration;
60pub use server::Kontext;
61pub use server::KontextOptions;
62pub use server::MiddlewareOptions;
63pub use verify::JwksClient;
64pub use verify::KontextTokenVerifier;
65pub use verify::KontextTokenVerifierConfig;
66pub use verify::TokenVerificationError;
67pub use verify::TokenVerificationErrorCode;
68pub use verify::VerifiedTokenClaims;
69pub use verify::VerifyResult;
70
71pub use kontext_dev_sdk_core::AccessToken;
72pub use kontext_dev_sdk_core::DEFAULT_AUTH_TIMEOUT_SECONDS;
73pub use kontext_dev_sdk_core::DEFAULT_RESOURCE;
74pub use kontext_dev_sdk_core::DEFAULT_SCOPE;
75pub use kontext_dev_sdk_core::DEFAULT_SERVER;
76pub use kontext_dev_sdk_core::DEFAULT_SERVER_NAME;
77pub use kontext_dev_sdk_core::KontextDevConfig;
78pub use kontext_dev_sdk_core::KontextDevCoreError;
79pub use kontext_dev_sdk_core::TokenExchangeToken;
80pub use kontext_dev_sdk_core::normalize_server_url;
81pub use kontext_dev_sdk_core::resolve_authorize_url;
82pub use kontext_dev_sdk_core::resolve_connect_session_url;
83pub use kontext_dev_sdk_core::resolve_integration_connection_url;
84pub use kontext_dev_sdk_core::resolve_integration_oauth_init_url;
85pub use kontext_dev_sdk_core::resolve_mcp_url;
86pub use kontext_dev_sdk_core::resolve_server_base_url;
87pub use kontext_dev_sdk_core::resolve_token_url;
88
89const TOKEN_EXCHANGE_GRANT_TYPE: &str = "urn:ietf:params:oauth:grant-type:token-exchange";
90const TOKEN_TYPE_ACCESS_TOKEN: &str = "urn:ietf:params:oauth:token-type:access_token";
91const TOKEN_EXPIRY_BUFFER_SECONDS: u64 = 60;
92
93#[derive(Debug, thiserror::Error)]
94pub enum KontextDevError {
95 #[error(transparent)]
96 Core(#[from] kontext_dev_sdk_core::KontextDevCoreError),
97 #[error("failed to parse URL `{url}`")]
98 InvalidUrl {
99 url: String,
100 source: url::ParseError,
101 },
102 #[error("failed to open browser for OAuth authorization")]
103 BrowserOpenFailed,
104 #[error("OAuth callback timed out after {timeout_seconds}s")]
105 OAuthCallbackTimeout { timeout_seconds: i64 },
106 #[error("OAuth callback channel was unexpectedly cancelled")]
107 OAuthCallbackCancelled,
108 #[error("OAuth callback is missing the authorization code")]
109 MissingAuthorizationCode,
110 #[error("OAuth callback returned an error: {error}")]
111 OAuthCallbackError { error: String },
112 #[error("OAuth state mismatch")]
113 InvalidOAuthState,
114 #[error("Kontext-Dev token request failed for {token_url}: {message}")]
115 TokenRequest { token_url: String, message: String },
116 #[error("Kontext-Dev token exchange failed for resource `{resource}`: {message}")]
117 TokenExchange { resource: String, message: String },
118 #[error("Kontext-Dev connect session request failed: {message}")]
119 ConnectSession { message: String },
120 #[error("Kontext-Dev integration OAuth init failed: {message}")]
121 IntegrationOAuthInit { message: String },
122 #[error("failed to read token cache at `{path}`: {source}")]
123 TokenCacheRead {
124 path: String,
125 source: std::io::Error,
126 },
127 #[error("failed to write token cache at `{path}`: {source}")]
128 TokenCacheWrite {
129 path: String,
130 source: std::io::Error,
131 },
132 #[error("failed to deserialize token cache at `{path}`: {source}")]
133 TokenCacheDeserialize {
134 path: String,
135 source: serde_json::Error,
136 },
137 #[error("failed to serialize token cache: {source}")]
138 TokenCacheSerialize { source: serde_json::Error },
139 #[error("Kontext-Dev access token is empty")]
140 EmptyAccessToken,
141 #[error("missing integration UI URL; set `integration_ui_url` in config")]
142 MissingIntegrationUiUrl,
143}
144
145#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
146pub struct ConnectSession {
147 #[serde(rename = "sessionId")]
148 pub session_id: String,
149 #[serde(rename = "expiresAt")]
150 pub expires_at: String,
151}
152
153#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
154pub struct IntegrationOAuthInitResponse {
155 #[serde(rename = "authorizationUrl")]
156 pub authorization_url: String,
157 #[serde(default)]
158 pub state: Option<String>,
159}
160
161#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
162pub struct IntegrationConnectionStatus {
163 pub connected: bool,
164 #[serde(default)]
165 pub expires_at: Option<String>,
166}
167
168#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
169pub struct KontextAuthSession {
170 pub identity_token: AccessToken,
171 pub gateway_token: TokenExchangeToken,
172 pub browser_auth_performed: bool,
173}
174
175#[derive(Clone, Debug, Deserialize, Serialize)]
176struct CachedAccessToken {
177 access_token: String,
178 token_type: String,
179 refresh_token: Option<String>,
180 scope: Option<String>,
181 expires_at_unix_ms: Option<u64>,
182}
183
184#[derive(Clone, Debug, Deserialize, Serialize)]
185struct TokenCacheFile {
186 client_id: String,
187 resource: String,
188 identity: CachedAccessToken,
189 gateway: CachedAccessToken,
190}
191
192impl CachedAccessToken {
193 fn from_access_token(token: &AccessToken) -> Result<Self, KontextDevError> {
194 if token.access_token.is_empty() {
195 return Err(KontextDevError::EmptyAccessToken);
196 }
197
198 Ok(Self {
199 access_token: token.access_token.clone(),
200 token_type: token.token_type.clone(),
201 refresh_token: token.refresh_token.clone(),
202 scope: token.scope.clone(),
203 expires_at_unix_ms: compute_expires_at_unix_ms(token.expires_in, &token.access_token),
204 })
205 }
206
207 fn from_token_exchange(token: &TokenExchangeToken) -> Result<Self, KontextDevError> {
208 if token.access_token.is_empty() {
209 return Err(KontextDevError::EmptyAccessToken);
210 }
211
212 Ok(Self {
213 access_token: token.access_token.clone(),
214 token_type: token.token_type.clone(),
215 refresh_token: token.refresh_token.clone(),
216 scope: token.scope.clone(),
217 expires_at_unix_ms: compute_expires_at_unix_ms(token.expires_in, &token.access_token),
218 })
219 }
220
221 fn is_valid(&self) -> bool {
222 match self.expires_at_unix_ms {
223 Some(expires_at) => {
224 let buffer_ms = TOKEN_EXPIRY_BUFFER_SECONDS * 1000;
225 now_unix_ms().saturating_add(buffer_ms) < expires_at
226 }
227 None => true,
228 }
229 }
230
231 fn to_access_token(&self) -> AccessToken {
232 AccessToken {
233 access_token: self.access_token.clone(),
234 token_type: self.token_type.clone(),
235 expires_in: self
236 .expires_at_unix_ms
237 .and_then(unix_ms_to_relative_seconds),
238 refresh_token: self.refresh_token.clone(),
239 scope: self.scope.clone(),
240 }
241 }
242
243 fn to_token_exchange_token(&self) -> TokenExchangeToken {
244 TokenExchangeToken {
245 access_token: self.access_token.clone(),
246 issued_token_type: TOKEN_TYPE_ACCESS_TOKEN.to_string(),
247 token_type: self.token_type.clone(),
248 expires_in: self
249 .expires_at_unix_ms
250 .and_then(unix_ms_to_relative_seconds),
251 scope: self.scope.clone(),
252 refresh_token: self.refresh_token.clone(),
253 }
254 }
255}
256
257fn now_unix_ms() -> u64 {
258 SystemTime::now()
259 .duration_since(UNIX_EPOCH)
260 .unwrap_or_else(|_| Duration::from_secs(0))
261 .as_millis() as u64
262}
263
264fn unix_ms_to_relative_seconds(unix_ms: u64) -> Option<i64> {
265 if unix_ms <= now_unix_ms() {
266 return Some(0);
267 }
268
269 let delta_ms = unix_ms - now_unix_ms();
270 let secs = delta_ms / 1000;
271 i64::try_from(secs).ok()
272}
273
274fn compute_expires_at_unix_ms(expires_in: Option<i64>, access_token: &str) -> Option<u64> {
275 if let Some(expires_in) = expires_in
276 && expires_in > 0
277 {
278 let expires_in_u64 = u64::try_from(expires_in).ok()?;
279 return Some(now_unix_ms().saturating_add(expires_in_u64.saturating_mul(1000)));
280 }
281
282 decode_jwt_exp(access_token).map(|exp| exp.saturating_mul(1000))
283}
284
285fn decode_jwt_exp(token: &str) -> Option<u64> {
286 let mut parts = token.split('.');
287 let _header = parts.next()?;
288 let payload = parts.next()?;
289
290 let decoded = base64::engine::general_purpose::URL_SAFE_NO_PAD
291 .decode(payload)
292 .ok()?;
293 let value: serde_json::Value = serde_json::from_slice(&decoded).ok()?;
294 value.get("exp").and_then(|exp| {
295 exp.as_u64()
296 .or_else(|| exp.as_i64().and_then(|v| u64::try_from(v).ok()))
297 })
298}
299
300#[derive(Clone, Debug, Deserialize)]
301struct OAuthErrorBody {
302 error: Option<String>,
303 error_description: Option<String>,
304 message: Option<String>,
305}
306
307#[derive(Clone, Debug, Deserialize)]
308struct OAuthAuthorizationServerMetadata {
309 authorization_endpoint: Option<String>,
310}
311
312#[derive(Clone, Debug, Deserialize)]
313struct OAuthCallbackPayload {
314 code: Option<String>,
315 state: Option<String>,
316 error: Option<String>,
317 error_description: Option<String>,
318}
319
320#[derive(Clone, Debug)]
321struct PkcePair {
322 verifier: String,
323 challenge: String,
324}
325
326fn generate_pkce_pair() -> PkcePair {
327 let mut raw = [0u8; 64];
328 rand::RngCore::fill_bytes(&mut rand::rng(), &mut raw);
329
330 let verifier = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(raw);
331 let digest = sha2::Sha256::digest(verifier.as_bytes());
332 let challenge = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(digest);
333
334 PkcePair {
335 verifier,
336 challenge,
337 }
338}
339
340fn generate_state() -> String {
341 let mut bytes = [0u8; 16];
342 rand::RngCore::fill_bytes(&mut rand::rng(), &mut bytes);
343 bytes.iter().map(|b| format!("{b:02x}")).collect()
344}
345
346fn normalized_scope(scope: &str) -> Option<&str> {
347 let trimmed = scope.trim();
348 if trimmed.is_empty() {
349 None
350 } else {
351 Some(trimmed)
352 }
353}
354
355struct CallbackServerGuard {
356 server: Arc<Server>,
357}
358
359impl Drop for CallbackServerGuard {
360 fn drop(&mut self) {
361 self.server.unblock();
362 }
363}
364
365fn parse_callback(url_path: &str, callback_path: &str) -> Option<OAuthCallbackPayload> {
366 let full = format!("http://localhost{url_path}");
367 let parsed = Url::parse(&full).ok()?;
368
369 if parsed.path() != callback_path {
370 return None;
371 }
372
373 let mut payload = OAuthCallbackPayload {
374 code: None,
375 state: None,
376 error: None,
377 error_description: None,
378 };
379
380 for (key, value) in parsed.query_pairs() {
381 match key.as_ref() {
382 "code" => payload.code = Some(value.to_string()),
383 "state" => payload.state = Some(value.to_string()),
384 "error" => payload.error = Some(value.to_string()),
385 "error_description" => payload.error_description = Some(value.to_string()),
386 _ => {}
387 }
388 }
389
390 Some(payload)
391}
392
393fn spawn_callback_server(
394 server: Arc<Server>,
395 callback_path: String,
396 tx: oneshot::Sender<OAuthCallbackPayload>,
397) -> tokio::task::JoinHandle<()> {
398 tokio::task::spawn_blocking(move || {
399 while let Ok(request) = server.recv() {
400 let path = request.url().to_string();
401 if let Some(payload) = parse_callback(&path, &callback_path) {
402 let response = Response::from_string(
403 "Authentication complete. You can return to your terminal.",
404 );
405 let _ = request.respond(response);
406 let _ = tx.send(payload);
407 break;
408 }
409
410 let response = Response::from_string("Invalid callback").with_status_code(400);
411 let _ = request.respond(response);
412 }
413 })
414}
415
416#[derive(Clone, Debug)]
417pub struct KontextDevClient {
418 config: KontextDevConfig,
419 http: Client,
420}
421
422impl KontextDevClient {
423 pub fn new(config: KontextDevConfig) -> Self {
424 Self {
425 config,
426 http: Client::new(),
427 }
428 }
429
430 pub fn config(&self) -> &KontextDevConfig {
431 &self.config
432 }
433
434 pub fn server_base_url(&self) -> Result<String, KontextDevError> {
435 resolve_server_base_url(&self.config).map_err(KontextDevError::from)
436 }
437
438 pub fn mcp_url(&self) -> Result<String, KontextDevError> {
439 resolve_mcp_url(&self.config).map_err(KontextDevError::from)
440 }
441
442 pub fn token_url(&self) -> Result<String, KontextDevError> {
443 resolve_token_url(&self.config).map_err(KontextDevError::from)
444 }
445
446 pub fn authorize_url(&self) -> Result<String, KontextDevError> {
447 resolve_authorize_url(&self.config).map_err(KontextDevError::from)
448 }
449
450 pub fn connect_session_url(&self) -> Result<String, KontextDevError> {
451 resolve_connect_session_url(&self.config).map_err(KontextDevError::from)
452 }
453
454 pub fn clear_token_cache(&self) -> Result<(), KontextDevError> {
455 let Some(path) = self.token_cache_path() else {
456 return Ok(());
457 };
458
459 if !path.exists() {
460 return Ok(());
461 }
462
463 fs::remove_file(&path).map_err(|source| KontextDevError::TokenCacheWrite {
464 path: path.display().to_string(),
465 source,
466 })
467 }
468
469 pub async fn authenticate_mcp(&self) -> Result<KontextAuthSession, KontextDevError> {
472 let resource = self.config.resource.clone();
473
474 if let Some(cache) = self.read_cache()?
475 && cache.client_id == self.config.client_id
476 && cache.resource == resource
477 && cache.gateway.is_valid()
478 && cache.identity.is_valid()
479 {
480 return Ok(KontextAuthSession {
481 identity_token: cache.identity.to_access_token(),
482 gateway_token: cache.gateway.to_token_exchange_token(),
483 browser_auth_performed: false,
484 });
485 }
486
487 if let Some(cache) = self.read_cache()?
488 && cache.client_id == self.config.client_id
489 && cache.resource == resource
490 {
491 let maybe_refreshed = if cache.identity.is_valid() {
492 Some(cache.identity.to_access_token())
493 } else if let Some(refresh_token) = &cache.identity.refresh_token {
494 Some(self.refresh_identity_token(refresh_token).await?)
495 } else {
496 None
497 };
498
499 if let Some(identity_token) = maybe_refreshed {
500 let gateway_token = self
501 .exchange_for_resource(&identity_token.access_token, &resource, None)
502 .await?;
503 self.write_cache(&identity_token, &gateway_token)?;
504 return Ok(KontextAuthSession {
505 identity_token,
506 gateway_token,
507 browser_auth_performed: false,
508 });
509 }
510 }
511
512 let identity_token = self.authorize_with_browser_pkce().await?;
513 let gateway_token = self
514 .exchange_for_resource(&identity_token.access_token, &resource, None)
515 .await?;
516
517 self.write_cache(&identity_token, &gateway_token)?;
518
519 Ok(KontextAuthSession {
520 identity_token,
521 gateway_token,
522 browser_auth_performed: true,
523 })
524 }
525
526 pub async fn create_integration_connect_url(
529 &self,
530 gateway_access_token: &str,
531 ) -> Result<String, KontextDevError> {
532 let session = self.create_connect_session(gateway_access_token).await?;
533 self.integration_connect_url(&session.session_id)
534 }
535
536 pub fn integration_connect_url(&self, session_id: &str) -> Result<String, KontextDevError> {
537 if session_id.trim().is_empty() {
538 return Err(KontextDevError::ConnectSession {
539 message: "connect session id is empty".to_string(),
540 });
541 }
542
543 let base = if let Some(explicit) = &self.config.integration_ui_url {
544 explicit.trim_end_matches('/').to_string()
545 } else {
546 let server = resolve_server_base_url(&self.config)?;
547 if server.contains("api.kontext.dev") {
548 "https://app.kontext.dev".to_string()
549 } else {
550 server
551 }
552 };
553
554 let mut url = Url::parse(&base).map_err(|source| KontextDevError::InvalidUrl {
555 url: base.clone(),
556 source,
557 })?;
558 url.set_path("/oauth/connect");
559 url.query_pairs_mut().append_pair("session", session_id);
560 Ok(url.to_string())
561 }
562
563 pub async fn open_integration_connect_page(
565 &self,
566 gateway_access_token: &str,
567 ) -> Result<String, KontextDevError> {
568 let url = self
569 .create_integration_connect_url(gateway_access_token)
570 .await?;
571
572 if webbrowser::open(&url).is_err() {
573 return Err(KontextDevError::BrowserOpenFailed);
574 }
575
576 Ok(url)
577 }
578
579 pub async fn create_connect_session(
580 &self,
581 gateway_access_token: &str,
582 ) -> Result<ConnectSession, KontextDevError> {
583 let url = self.connect_session_url()?;
584 let response = self
585 .http
586 .post(&url)
587 .header("Authorization", format!("Bearer {gateway_access_token}"))
588 .json(&serde_json::json!({}))
589 .send()
590 .await
591 .map_err(|err| KontextDevError::ConnectSession {
592 message: err.to_string(),
593 })?;
594
595 if !response.status().is_success() {
596 let message = build_error_message(response).await;
597 return Err(KontextDevError::ConnectSession { message });
598 }
599
600 response
601 .json::<ConnectSession>()
602 .await
603 .map_err(|err| KontextDevError::ConnectSession {
604 message: err.to_string(),
605 })
606 }
607
608 pub async fn initiate_integration_oauth(
609 &self,
610 gateway_access_token: &str,
611 integration_id: &str,
612 return_to: Option<&str>,
613 ) -> Result<IntegrationOAuthInitResponse, KontextDevError> {
614 let url = resolve_integration_oauth_init_url(&self.config, integration_id)?;
615
616 let payload = return_to
617 .map(|value| serde_json::json!({ "returnTo": value }))
618 .unwrap_or_else(|| serde_json::json!({}));
619
620 let request = self
621 .http
622 .post(&url)
623 .header("Authorization", format!("Bearer {gateway_access_token}"))
624 .json(&payload);
625
626 let response =
627 request
628 .send()
629 .await
630 .map_err(|err| KontextDevError::IntegrationOAuthInit {
631 message: err.to_string(),
632 })?;
633
634 if !response.status().is_success() {
635 let message = build_error_message(response).await;
636 return Err(KontextDevError::IntegrationOAuthInit { message });
637 }
638
639 response
640 .json::<IntegrationOAuthInitResponse>()
641 .await
642 .map_err(|err| KontextDevError::IntegrationOAuthInit {
643 message: err.to_string(),
644 })
645 }
646
647 pub async fn integration_connection_status(
648 &self,
649 gateway_access_token: &str,
650 integration_id: &str,
651 ) -> Result<IntegrationConnectionStatus, KontextDevError> {
652 let url = resolve_integration_connection_url(&self.config, integration_id)?;
653 let response = self
654 .http
655 .get(url)
656 .header("Authorization", format!("Bearer {gateway_access_token}"))
657 .send()
658 .await
659 .map_err(|err| KontextDevError::IntegrationOAuthInit {
660 message: err.to_string(),
661 })?;
662
663 if !response.status().is_success() {
664 let message = build_error_message(response).await;
665 return Err(KontextDevError::IntegrationOAuthInit { message });
666 }
667
668 response
669 .json::<IntegrationConnectionStatus>()
670 .await
671 .map_err(|err| KontextDevError::IntegrationOAuthInit {
672 message: err.to_string(),
673 })
674 }
675
676 pub async fn wait_for_integration_connection(
677 &self,
678 gateway_access_token: &str,
679 integration_id: &str,
680 timeout_ms: u64,
681 interval_ms: u64,
682 ) -> Result<bool, KontextDevError> {
683 let started = now_unix_ms();
684
685 loop {
686 let status = self
687 .integration_connection_status(gateway_access_token, integration_id)
688 .await?;
689 if status.connected {
690 return Ok(true);
691 }
692
693 if now_unix_ms().saturating_sub(started) >= timeout_ms {
694 return Ok(false);
695 }
696
697 tokio::time::sleep(Duration::from_millis(interval_ms)).await;
698 }
699 }
700
701 async fn authorize_with_browser_pkce(&self) -> Result<AccessToken, KontextDevError> {
702 let auth_url = self.resolve_authorization_url().await?;
703 let token_url = self.token_url()?;
704 let pkce = generate_pkce_pair();
705 let state = generate_state();
706
707 let (callback_url, callback_payload) = self.listen_for_callback().await?;
708
709 let mut url = Url::parse(&auth_url).map_err(|source| KontextDevError::InvalidUrl {
710 url: auth_url.clone(),
711 source,
712 })?;
713
714 {
715 let mut query = url.query_pairs_mut();
716 query
717 .append_pair("client_id", &self.config.client_id)
718 .append_pair("response_type", "code")
719 .append_pair("redirect_uri", &callback_url)
720 .append_pair("state", &state);
721
722 if let Some(scope) = normalized_scope(&self.config.scope) {
723 query.append_pair("scope", scope);
724 }
725
726 query
727 .append_pair("code_challenge_method", "S256")
728 .append_pair("code_challenge", &pkce.challenge);
729 }
730
731 if webbrowser::open(url.as_str()).is_err() {
732 return Err(KontextDevError::BrowserOpenFailed);
733 }
734
735 let payload = callback_payload.await?;
736
737 if let Some(error) = payload.error {
738 let with_details = payload
739 .error_description
740 .map(|description| format!("{error}: {description}"))
741 .unwrap_or(error);
742 return Err(KontextDevError::OAuthCallbackError {
743 error: with_details,
744 });
745 }
746
747 if payload.state.as_deref() != Some(state.as_str()) {
748 return Err(KontextDevError::InvalidOAuthState);
749 }
750
751 let code = payload
752 .code
753 .ok_or(KontextDevError::MissingAuthorizationCode)?;
754
755 let mut body = vec![
756 ("grant_type", "authorization_code".to_string()),
757 ("code", code),
758 ("redirect_uri", callback_url),
759 ("client_id", self.config.client_id.clone()),
760 ("code_verifier", pkce.verifier),
761 ];
762
763 if let Some(client_secret) = &self.config.client_secret {
764 body.push(("client_secret", client_secret.clone()));
765 }
766
767 post_token(&self.http, &token_url, None, &body).await
768 }
769
770 async fn resolve_authorization_url(&self) -> Result<String, KontextDevError> {
771 if let Some(discovered) = self.discover_authorization_endpoint().await {
772 return Ok(discovered);
773 }
774
775 let authorize_url = self.authorize_url()?;
776 if !self.endpoint_is_missing(&authorize_url).await {
777 return Ok(authorize_url);
778 }
779
780 let server_base = resolve_server_base_url(&self.config)?;
781 Ok(format!("{}/oauth2/auth", server_base.trim_end_matches('/')))
782 }
783
784 async fn discover_authorization_endpoint(&self) -> Option<String> {
785 let base = resolve_server_base_url(&self.config).ok()?;
786 let base = base.trim_end_matches('/');
787
788 let candidates = [
789 format!("{base}/.well-known/oauth-authorization-server/mcp"),
790 format!("{base}/.well-known/oauth-authorization-server"),
791 ];
792
793 for url in candidates {
794 let response = match self.http.get(&url).send().await {
795 Ok(response) => response,
796 Err(_) => continue,
797 };
798
799 if !response.status().is_success() {
800 continue;
801 }
802
803 let metadata = match response.json::<OAuthAuthorizationServerMetadata>().await {
804 Ok(metadata) => metadata,
805 Err(_) => continue,
806 };
807
808 let Some(endpoint) = metadata.authorization_endpoint else {
809 continue;
810 };
811
812 if Url::parse(&endpoint).is_ok() {
813 return Some(endpoint);
814 }
815 }
816
817 None
818 }
819
820 async fn endpoint_is_missing(&self, url: &str) -> bool {
821 let probe_client = match reqwest::Client::builder()
822 .redirect(reqwest::redirect::Policy::none())
823 .build()
824 {
825 Ok(client) => client,
826 Err(_) => return false,
827 };
828
829 match probe_client.get(url).send().await {
830 Ok(response) => response.status() == StatusCode::NOT_FOUND,
831 Err(_) => false,
832 }
833 }
834
835 async fn refresh_identity_token(
836 &self,
837 refresh_token: &str,
838 ) -> Result<AccessToken, KontextDevError> {
839 let token_url = self.token_url()?;
840 let mut body = vec![
841 ("grant_type", "refresh_token".to_string()),
842 ("refresh_token", refresh_token.to_string()),
843 ("client_id", self.config.client_id.clone()),
844 ];
845
846 if let Some(client_secret) = &self.config.client_secret {
847 body.push(("client_secret", client_secret.clone()));
848 }
849
850 post_token(&self.http, &token_url, None, &body).await
851 }
852
853 pub async fn exchange_for_resource(
854 &self,
855 subject_token: &str,
856 resource: &str,
857 scope: Option<&str>,
858 ) -> Result<TokenExchangeToken, KontextDevError> {
859 let token_url = self.token_url()?;
860
861 let mut body = vec![
862 ("grant_type", TOKEN_EXCHANGE_GRANT_TYPE.to_string()),
863 ("subject_token", subject_token.to_string()),
864 ("subject_token_type", TOKEN_TYPE_ACCESS_TOKEN.to_string()),
865 ("resource", resource.to_string()),
866 ];
867
868 if let Some(scope) = scope {
869 body.push(("scope", scope.to_string()));
870 }
871
872 let auth_header = self.config.client_secret.as_ref().map(|secret| {
873 let raw = format!("{}:{}", self.config.client_id, secret);
874 format!(
875 "Basic {}",
876 base64::engine::general_purpose::STANDARD.encode(raw.as_bytes())
877 )
878 });
879
880 if self.config.client_secret.is_none() {
881 body.push(("client_id", self.config.client_id.clone()));
882 }
883
884 let response = post_form_with_optional_auth::<TokenExchangeToken>(
885 &self.http,
886 &token_url,
887 auth_header,
888 &body,
889 )
890 .await
891 .map_err(|message| KontextDevError::TokenExchange {
892 resource: resource.to_string(),
893 message,
894 })?;
895
896 if response.access_token.is_empty() {
897 return Err(KontextDevError::EmptyAccessToken);
898 }
899
900 Ok(response)
901 }
902
903 async fn listen_for_callback(
904 &self,
905 ) -> Result<
906 (
907 String,
908 impl std::future::Future<Output = Result<OAuthCallbackPayload, KontextDevError>>,
909 ),
910 KontextDevError,
911 > {
912 let redirect_uri = self.config.redirect_uri.trim().to_string();
913 let parsed = Url::parse(&redirect_uri).map_err(|source| KontextDevError::InvalidUrl {
914 url: redirect_uri.clone(),
915 source,
916 })?;
917
918 if parsed.scheme() != "http" {
919 return Err(KontextDevError::OAuthCallbackError {
920 error: "redirect_uri must use http".to_string(),
921 });
922 }
923
924 if parsed.query().is_some() || parsed.fragment().is_some() {
925 return Err(KontextDevError::OAuthCallbackError {
926 error: "redirect_uri must not include query parameters or fragments".to_string(),
927 });
928 }
929
930 let host = parsed
931 .host_str()
932 .ok_or_else(|| KontextDevError::OAuthCallbackError {
933 error: "redirect_uri host is missing".to_string(),
934 })?;
935
936 let port = parsed
937 .port()
938 .ok_or_else(|| KontextDevError::OAuthCallbackError {
939 error: "redirect_uri must include an explicit port".to_string(),
940 })?;
941
942 let callback_path = parsed.path().to_string();
943 let bind_addr = if host.contains(':') {
944 format!("[{host}]:{port}")
945 } else {
946 format!("{host}:{port}")
947 };
948
949 let server = Arc::new(Server::http(&bind_addr).map_err(|err| {
950 KontextDevError::OAuthCallbackError {
951 error: format!("failed to start callback server at {bind_addr}: {err}"),
952 }
953 })?);
954
955 let (tx, rx) = oneshot::channel();
956 let _join = spawn_callback_server(server.clone(), callback_path, tx);
957 let _guard = CallbackServerGuard { server };
958 let timeout_seconds = self.config.auth_timeout_seconds.max(1);
959
960 let fut = async move {
961 let payload = timeout(Duration::from_secs(timeout_seconds as u64), rx)
962 .await
963 .map_err(|_| KontextDevError::OAuthCallbackTimeout { timeout_seconds })?
964 .map_err(|_| KontextDevError::OAuthCallbackCancelled)?;
965 drop(_guard);
966 Ok(payload)
967 };
968
969 Ok((redirect_uri, fut))
970 }
971
972 fn token_cache_path(&self) -> Option<PathBuf> {
973 if let Some(explicit) = &self.config.token_cache_path {
974 return Some(PathBuf::from(explicit));
975 }
976
977 let home = dirs::home_dir()?;
978 let mut path = home;
979 path.push(".kontext-dev");
980 path.push("tokens");
981
982 let sanitized_client_id: String = self
983 .config
984 .client_id
985 .chars()
986 .map(|ch| {
987 if ch.is_ascii_alphanumeric() || ch == '-' || ch == '_' {
988 ch
989 } else {
990 '_'
991 }
992 })
993 .collect();
994
995 path.push(format!("{sanitized_client_id}.json"));
996 Some(path)
997 }
998
999 fn read_cache(&self) -> Result<Option<TokenCacheFile>, KontextDevError> {
1000 let Some(path) = self.token_cache_path() else {
1001 return Ok(None);
1002 };
1003
1004 let raw = match fs::read_to_string(&path) {
1005 Ok(raw) => raw,
1006 Err(err) if err.kind() == std::io::ErrorKind::NotFound => return Ok(None),
1007 Err(source) => {
1008 return Err(KontextDevError::TokenCacheRead {
1009 path: path.display().to_string(),
1010 source,
1011 });
1012 }
1013 };
1014
1015 serde_json::from_str(&raw).map(Some).map_err(|source| {
1016 KontextDevError::TokenCacheDeserialize {
1017 path: path.display().to_string(),
1018 source,
1019 }
1020 })
1021 }
1022
1023 fn write_cache(
1024 &self,
1025 identity: &AccessToken,
1026 gateway: &TokenExchangeToken,
1027 ) -> Result<(), KontextDevError> {
1028 let Some(path) = self.token_cache_path() else {
1029 return Ok(());
1030 };
1031
1032 if let Some(parent) = path.parent() {
1033 fs::create_dir_all(parent).map_err(|source| KontextDevError::TokenCacheWrite {
1034 path: parent.display().to_string(),
1035 source,
1036 })?;
1037 }
1038
1039 let payload = TokenCacheFile {
1040 client_id: self.config.client_id.clone(),
1041 resource: self.config.resource.clone(),
1042 identity: CachedAccessToken::from_access_token(identity)?,
1043 gateway: CachedAccessToken::from_token_exchange(gateway)?,
1044 };
1045
1046 let serialized = serde_json::to_string_pretty(&payload)
1047 .map_err(|source| KontextDevError::TokenCacheSerialize { source })?;
1048
1049 fs::write(&path, serialized).map_err(|source| KontextDevError::TokenCacheWrite {
1050 path: path.display().to_string(),
1051 source,
1052 })
1053 }
1054}
1055
1056pub async fn request_access_token(
1060 config: &KontextDevConfig,
1061) -> Result<AccessToken, KontextDevError> {
1062 let token_url = resolve_token_url(config)?;
1063
1064 let mut body = vec![("grant_type", "client_credentials".to_string())];
1065
1066 if let Some(scope) = normalized_scope(&config.scope) {
1067 body.push(("scope", scope.to_string()));
1068 }
1069
1070 let auth_header = if let Some(secret) = &config.client_secret {
1071 let raw = format!("{}:{}", config.client_id, secret);
1072 Some(format!(
1073 "Basic {}",
1074 base64::engine::general_purpose::STANDARD.encode(raw.as_bytes())
1075 ))
1076 } else {
1077 body.push(("client_id", config.client_id.clone()));
1078 None
1079 };
1080
1081 post_token(&Client::new(), &token_url, auth_header.as_deref(), &body).await
1082}
1083
1084async fn post_token(
1085 http: &Client,
1086 token_url: &str,
1087 authorization: Option<&str>,
1088 body: &[(impl AsRef<str>, String)],
1089) -> Result<AccessToken, KontextDevError> {
1090 let body_vec: Vec<(&str, String)> = body.iter().map(|(k, v)| (k.as_ref(), v.clone())).collect();
1091
1092 let response = post_form_with_optional_auth::<AccessToken>(
1093 http,
1094 token_url,
1095 authorization.map(ToString::to_string),
1096 &body_vec,
1097 )
1098 .await
1099 .map_err(|message| KontextDevError::TokenRequest {
1100 token_url: token_url.to_string(),
1101 message,
1102 })?;
1103
1104 Ok(response)
1105}
1106
1107async fn post_form_with_optional_auth<T>(
1108 http: &Client,
1109 url: &str,
1110 authorization: Option<String>,
1111 body: &[(&str, String)],
1112) -> Result<T, String>
1113where
1114 T: DeserializeOwned,
1115{
1116 let mut request = http
1117 .post(url)
1118 .header("Content-Type", "application/x-www-form-urlencoded");
1119
1120 if let Some(header) = authorization {
1121 request = request.header("Authorization", header);
1122 }
1123
1124 let form = body
1125 .iter()
1126 .map(|(k, v)| (k.to_string(), v.to_string()))
1127 .collect::<Vec<(String, String)>>();
1128
1129 let response = request
1130 .form(&form)
1131 .send()
1132 .await
1133 .map_err(|err| err.to_string())?;
1134
1135 if !response.status().is_success() {
1136 return Err(build_error_message(response).await);
1137 }
1138
1139 response.json::<T>().await.map_err(|err| err.to_string())
1140}
1141
1142async fn build_error_message(response: reqwest::Response) -> String {
1143 let status = response.status();
1144 let fallback = format!(
1145 "{} {}",
1146 status.as_u16(),
1147 status.canonical_reason().unwrap_or("")
1148 );
1149
1150 let body = response.text().await.unwrap_or_default();
1151 if body.is_empty() {
1152 return fallback.trim().to_string();
1153 }
1154
1155 if let Ok(parsed) = serde_json::from_str::<OAuthErrorBody>(&body) {
1156 if let Some(description) = parsed.error_description {
1157 return description;
1158 }
1159 if let Some(message) = parsed.message {
1160 return message;
1161 }
1162 if let Some(error) = parsed.error {
1163 return error;
1164 }
1165 }
1166
1167 format!("{fallback}: {body}")
1168}
1169
1170#[cfg(test)]
1171mod tests {
1172 use super::*;
1173
1174 fn config() -> KontextDevConfig {
1175 KontextDevConfig {
1176 server: "http://localhost:4000".to_string(),
1177 client_id: "client_123".to_string(),
1178 client_secret: None,
1179 scope: DEFAULT_SCOPE.to_string(),
1180 server_name: DEFAULT_SERVER_NAME.to_string(),
1181 resource: DEFAULT_RESOURCE.to_string(),
1182 integration_ui_url: Some("https://app.kontext.dev".to_string()),
1183 integration_return_to: None,
1184 open_connect_page_on_login: true,
1185 auth_timeout_seconds: DEFAULT_AUTH_TIMEOUT_SECONDS,
1186 token_cache_path: None,
1187 redirect_uri: "http://localhost:3333/callback".to_string(),
1188 }
1189 }
1190
1191 #[test]
1192 fn create_connect_url_uses_hosted_ui() {
1193 let client = KontextDevClient::new(config());
1194 let url = client
1195 .integration_connect_url("session-123")
1196 .expect("url should be built");
1197 assert_eq!(
1198 url,
1199 "https://app.kontext.dev/oauth/connect?session=session-123"
1200 );
1201 }
1202
1203 #[test]
1204 fn jwt_exp_decode_reads_exp() {
1205 let header = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(r#"{"alg":"none"}"#);
1206 let payload =
1207 base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(r#"{"exp":4070908800}"#);
1208 let token = format!("{header}.{payload}.sig");
1209 assert_eq!(decode_jwt_exp(&token), Some(4_070_908_800));
1210 }
1211
1212 #[test]
1213 fn oauth_metadata_parses_authorization_endpoint() {
1214 let metadata = serde_json::from_str::<OAuthAuthorizationServerMetadata>(
1215 r#"{
1216 "issuer": "https://issuer.example.com",
1217 "authorization_endpoint": "https://issuer.example.com/oauth2/auth"
1218 }"#,
1219 )
1220 .expect("metadata should parse");
1221
1222 assert_eq!(
1223 metadata.authorization_endpoint.as_deref(),
1224 Some("https://issuer.example.com/oauth2/auth")
1225 );
1226 }
1227
1228 #[test]
1229 fn normalized_scope_omits_blank_values() {
1230 assert_eq!(normalized_scope(""), None);
1231 assert_eq!(normalized_scope(" "), None);
1232 }
1233
1234 #[test]
1235 fn normalized_scope_preserves_non_empty_values() {
1236 assert_eq!(normalized_scope("mcp:invoke"), Some("mcp:invoke"));
1237 assert_eq!(
1238 normalized_scope(" mcp:invoke openid "),
1239 Some("mcp:invoke openid")
1240 );
1241 }
1242}