1use base64::Engine;
2use reqwest::Client;
3use reqwest::StatusCode;
4use serde::Deserialize;
5use serde::Serialize;
6use serde::de::DeserializeOwned;
7use sha2::Digest;
8use std::fs;
9use std::path::PathBuf;
10use std::sync::Arc;
11use std::time::Duration;
12use std::time::SystemTime;
13use std::time::UNIX_EPOCH;
14use tiny_http::Response;
15use tiny_http::Server;
16use tokio::sync::oneshot;
17use tokio::time::timeout;
18use url::Url;
19
20pub mod client;
21pub mod errors;
22mod http_client;
23pub mod management;
24pub mod mcp;
25pub mod oauth;
26pub mod orchestrator;
27pub mod prompt_guidance;
28pub mod server;
29pub mod verify;
30
31pub use client::ClientState;
32pub use client::ConnectSessionResult;
33pub use client::IntegrationInfo;
34pub use client::KontextClient as SdkKontextClient;
35pub use client::KontextClientConfig;
36pub use client::ToolResult;
37pub use client::create_kontext_client;
38pub use errors::*;
39pub use management::KontextManagementClient;
40pub use management::KontextManagementClientConfig;
41pub use management::*;
42pub use mcp::KontextMcp;
43pub use mcp::KontextMcpConfig;
44pub use mcp::RuntimeIntegrationCategory;
45pub use mcp::RuntimeIntegrationConnectType;
46pub use mcp::RuntimeIntegrationRecord;
47pub use oauth::KontextOAuthProvider;
48pub use oauth::KontextOAuthProviderConfig;
49pub use oauth::ParsedOAuthCallback;
50pub use oauth::TokenExchangeConfig;
51pub use oauth::exchange_token;
52pub use oauth::parse_oauth_callback;
53pub use orchestrator::KontextOrchestrator;
54pub use orchestrator::KontextOrchestratorConfig;
55pub use orchestrator::KontextOrchestratorState;
56pub use orchestrator::create_kontext_orchestrator;
57pub use prompt_guidance::KontextPromptGuidance;
58pub use prompt_guidance::build_kontext_prompt_guidance;
59pub use server::IntegrationCredential;
60pub use server::IntegrationName;
61pub use server::IntegrationResolvedCredentials;
62pub use server::KnownIntegration;
63pub use server::Kontext;
64pub use server::KontextOptions;
65pub use server::MiddlewareOptions;
66pub use verify::JwksClient;
67pub use verify::KontextTokenVerifier;
68pub use verify::KontextTokenVerifierConfig;
69pub use verify::TokenVerificationError;
70pub use verify::TokenVerificationErrorCode;
71pub use verify::VerifiedTokenClaims;
72pub use verify::VerifyResult;
73
74pub use kontext_dev_sdk_core::AccessToken;
75pub use kontext_dev_sdk_core::DEFAULT_AUTH_TIMEOUT_SECONDS;
76pub use kontext_dev_sdk_core::DEFAULT_RESOURCE;
77pub use kontext_dev_sdk_core::DEFAULT_SCOPE;
78pub use kontext_dev_sdk_core::DEFAULT_SERVER;
79pub use kontext_dev_sdk_core::DEFAULT_SERVER_NAME;
80pub use kontext_dev_sdk_core::KontextDevConfig;
81pub use kontext_dev_sdk_core::KontextDevCoreError;
82pub use kontext_dev_sdk_core::TokenExchangeToken;
83pub use kontext_dev_sdk_core::normalize_server_url;
84pub use kontext_dev_sdk_core::resolve_authorize_url;
85pub use kontext_dev_sdk_core::resolve_connect_session_url;
86pub use kontext_dev_sdk_core::resolve_integration_connection_url;
87pub use kontext_dev_sdk_core::resolve_integration_oauth_init_url;
88pub use kontext_dev_sdk_core::resolve_mcp_url;
89pub use kontext_dev_sdk_core::resolve_server_base_url;
90pub use kontext_dev_sdk_core::resolve_token_url;
91
92const TOKEN_EXCHANGE_GRANT_TYPE: &str = "urn:ietf:params:oauth:grant-type:token-exchange";
93const TOKEN_TYPE_ACCESS_TOKEN: &str = "urn:ietf:params:oauth:token-type:access_token";
94const TOKEN_EXPIRY_BUFFER_SECONDS: u64 = 60;
95
96#[derive(Debug, thiserror::Error)]
97pub enum KontextDevError {
98 #[error(transparent)]
99 Core(#[from] kontext_dev_sdk_core::KontextDevCoreError),
100 #[error("failed to parse URL `{url}`")]
101 InvalidUrl {
102 url: String,
103 source: url::ParseError,
104 },
105 #[error("failed to open browser for OAuth authorization")]
106 BrowserOpenFailed,
107 #[error("OAuth callback timed out after {timeout_seconds}s")]
108 OAuthCallbackTimeout { timeout_seconds: i64 },
109 #[error("OAuth callback channel was unexpectedly cancelled")]
110 OAuthCallbackCancelled,
111 #[error("OAuth callback is missing the authorization code")]
112 MissingAuthorizationCode,
113 #[error("OAuth callback returned an error: {error}")]
114 OAuthCallbackError { error: String },
115 #[error("OAuth state mismatch")]
116 InvalidOAuthState,
117 #[error("Kontext-Dev token request failed for {token_url}: {message}")]
118 TokenRequest { token_url: String, message: String },
119 #[error("Kontext-Dev token exchange failed for resource `{resource}`: {message}")]
120 TokenExchange { resource: String, message: String },
121 #[error("Kontext-Dev connect session request failed: {message}")]
122 ConnectSession { message: String },
123 #[error("Kontext-Dev integration OAuth init failed: {message}")]
124 IntegrationOAuthInit { message: String },
125 #[error("failed to read token cache at `{path}`: {source}")]
126 TokenCacheRead {
127 path: String,
128 source: std::io::Error,
129 },
130 #[error("failed to write token cache at `{path}`: {source}")]
131 TokenCacheWrite {
132 path: String,
133 source: std::io::Error,
134 },
135 #[error("failed to deserialize token cache at `{path}`: {source}")]
136 TokenCacheDeserialize {
137 path: String,
138 source: serde_json::Error,
139 },
140 #[error("failed to serialize token cache: {source}")]
141 TokenCacheSerialize { source: serde_json::Error },
142 #[error("Kontext-Dev access token is empty")]
143 EmptyAccessToken,
144 #[error("missing integration UI URL; set `integration_ui_url` in config")]
145 MissingIntegrationUiUrl,
146}
147
148#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
149pub struct ConnectSession {
150 #[serde(rename = "sessionId")]
151 pub session_id: String,
152 #[serde(rename = "expiresAt")]
153 pub expires_at: String,
154}
155
156#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
157pub struct IntegrationOAuthInitResponse {
158 #[serde(rename = "authorizationUrl")]
159 pub authorization_url: String,
160 #[serde(default)]
161 pub state: Option<String>,
162}
163
164#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
165pub struct IntegrationConnectionStatus {
166 pub connected: bool,
167 #[serde(default)]
168 pub expires_at: Option<String>,
169}
170
171#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
172pub struct KontextAuthSession {
173 pub identity_token: AccessToken,
174 pub gateway_token: TokenExchangeToken,
175 pub browser_auth_performed: bool,
176}
177
178#[derive(Clone, Debug, Deserialize, Serialize)]
179struct CachedAccessToken {
180 access_token: String,
181 token_type: String,
182 refresh_token: Option<String>,
183 scope: Option<String>,
184 expires_at_unix_ms: Option<u64>,
185}
186
187#[derive(Clone, Debug, Deserialize, Serialize)]
188struct TokenCacheFile {
189 client_id: String,
190 resource: String,
191 identity: CachedAccessToken,
192 gateway: CachedAccessToken,
193}
194
195impl CachedAccessToken {
196 fn from_access_token(token: &AccessToken) -> Result<Self, KontextDevError> {
197 if token.access_token.is_empty() {
198 return Err(KontextDevError::EmptyAccessToken);
199 }
200
201 Ok(Self {
202 access_token: token.access_token.clone(),
203 token_type: token.token_type.clone(),
204 refresh_token: token.refresh_token.clone(),
205 scope: token.scope.clone(),
206 expires_at_unix_ms: compute_expires_at_unix_ms(token.expires_in, &token.access_token),
207 })
208 }
209
210 fn from_token_exchange(token: &TokenExchangeToken) -> Result<Self, KontextDevError> {
211 if token.access_token.is_empty() {
212 return Err(KontextDevError::EmptyAccessToken);
213 }
214
215 Ok(Self {
216 access_token: token.access_token.clone(),
217 token_type: token.token_type.clone(),
218 refresh_token: token.refresh_token.clone(),
219 scope: token.scope.clone(),
220 expires_at_unix_ms: compute_expires_at_unix_ms(token.expires_in, &token.access_token),
221 })
222 }
223
224 fn is_valid(&self) -> bool {
225 match self.expires_at_unix_ms {
226 Some(expires_at) => {
227 let buffer_ms = TOKEN_EXPIRY_BUFFER_SECONDS * 1000;
228 now_unix_ms().saturating_add(buffer_ms) < expires_at
229 }
230 None => true,
231 }
232 }
233
234 fn to_access_token(&self) -> AccessToken {
235 AccessToken {
236 access_token: self.access_token.clone(),
237 token_type: self.token_type.clone(),
238 expires_in: self
239 .expires_at_unix_ms
240 .and_then(unix_ms_to_relative_seconds),
241 refresh_token: self.refresh_token.clone(),
242 scope: self.scope.clone(),
243 }
244 }
245
246 fn to_token_exchange_token(&self) -> TokenExchangeToken {
247 TokenExchangeToken {
248 access_token: self.access_token.clone(),
249 issued_token_type: TOKEN_TYPE_ACCESS_TOKEN.to_string(),
250 token_type: self.token_type.clone(),
251 expires_in: self
252 .expires_at_unix_ms
253 .and_then(unix_ms_to_relative_seconds),
254 scope: self.scope.clone(),
255 refresh_token: self.refresh_token.clone(),
256 }
257 }
258}
259
260fn now_unix_ms() -> u64 {
261 SystemTime::now()
262 .duration_since(UNIX_EPOCH)
263 .unwrap_or_else(|_| Duration::from_secs(0))
264 .as_millis() as u64
265}
266
267fn unix_ms_to_relative_seconds(unix_ms: u64) -> Option<i64> {
268 if unix_ms <= now_unix_ms() {
269 return Some(0);
270 }
271
272 let delta_ms = unix_ms - now_unix_ms();
273 let secs = delta_ms / 1000;
274 i64::try_from(secs).ok()
275}
276
277fn compute_expires_at_unix_ms(expires_in: Option<i64>, access_token: &str) -> Option<u64> {
278 if let Some(expires_in) = expires_in
279 && expires_in > 0
280 {
281 let expires_in_u64 = u64::try_from(expires_in).ok()?;
282 return Some(now_unix_ms().saturating_add(expires_in_u64.saturating_mul(1000)));
283 }
284
285 decode_jwt_exp(access_token).map(|exp| exp.saturating_mul(1000))
286}
287
288fn decode_jwt_exp(token: &str) -> Option<u64> {
289 let mut parts = token.split('.');
290 let _header = parts.next()?;
291 let payload = parts.next()?;
292
293 let decoded = base64::engine::general_purpose::URL_SAFE_NO_PAD
294 .decode(payload)
295 .ok()?;
296 let value: serde_json::Value = serde_json::from_slice(&decoded).ok()?;
297 value.get("exp").and_then(|exp| {
298 exp.as_u64()
299 .or_else(|| exp.as_i64().and_then(|v| u64::try_from(v).ok()))
300 })
301}
302
303#[derive(Clone, Debug, Deserialize)]
304struct OAuthErrorBody {
305 error: Option<String>,
306 error_description: Option<String>,
307 message: Option<String>,
308}
309
310#[derive(Clone, Debug, Deserialize)]
311struct OAuthAuthorizationServerMetadata {
312 authorization_endpoint: Option<String>,
313}
314
315#[derive(Clone, Debug, Deserialize)]
316struct OAuthCallbackPayload {
317 code: Option<String>,
318 state: Option<String>,
319 error: Option<String>,
320 error_description: Option<String>,
321}
322
323#[derive(Clone, Debug)]
324struct PkcePair {
325 verifier: String,
326 challenge: String,
327}
328
329fn generate_pkce_pair() -> PkcePair {
330 let mut raw = [0u8; 64];
331 rand::RngCore::fill_bytes(&mut rand::rng(), &mut raw);
332
333 let verifier = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(raw);
334 let digest = sha2::Sha256::digest(verifier.as_bytes());
335 let challenge = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(digest);
336
337 PkcePair {
338 verifier,
339 challenge,
340 }
341}
342
343fn generate_state() -> String {
344 let mut bytes = [0u8; 16];
345 rand::RngCore::fill_bytes(&mut rand::rng(), &mut bytes);
346 bytes.iter().map(|b| format!("{b:02x}")).collect()
347}
348
349fn normalized_scope(scope: &str) -> Option<&str> {
350 let trimmed = scope.trim();
351 if trimmed.is_empty() {
352 None
353 } else {
354 Some(trimmed)
355 }
356}
357
358struct CallbackServerGuard {
359 server: Arc<Server>,
360}
361
362impl Drop for CallbackServerGuard {
363 fn drop(&mut self) {
364 self.server.unblock();
365 }
366}
367
368fn parse_callback(url_path: &str, callback_path: &str) -> Option<OAuthCallbackPayload> {
369 let full = format!("http://localhost{url_path}");
370 let parsed = Url::parse(&full).ok()?;
371
372 if parsed.path() != callback_path {
373 return None;
374 }
375
376 let mut payload = OAuthCallbackPayload {
377 code: None,
378 state: None,
379 error: None,
380 error_description: None,
381 };
382
383 for (key, value) in parsed.query_pairs() {
384 match key.as_ref() {
385 "code" => payload.code = Some(value.to_string()),
386 "state" => payload.state = Some(value.to_string()),
387 "error" => payload.error = Some(value.to_string()),
388 "error_description" => payload.error_description = Some(value.to_string()),
389 _ => {}
390 }
391 }
392
393 Some(payload)
394}
395
396fn spawn_callback_server(
397 server: Arc<Server>,
398 callback_path: String,
399 tx: oneshot::Sender<OAuthCallbackPayload>,
400) -> tokio::task::JoinHandle<()> {
401 tokio::task::spawn_blocking(move || {
402 while let Ok(request) = server.recv() {
403 let path = request.url().to_string();
404 if let Some(payload) = parse_callback(&path, &callback_path) {
405 let response = Response::from_string(
406 "Authentication complete. You can return to your terminal.",
407 );
408 let _ = request.respond(response);
409 let _ = tx.send(payload);
410 break;
411 }
412
413 let response = Response::from_string("Invalid callback").with_status_code(400);
414 let _ = request.respond(response);
415 }
416 })
417}
418
419#[derive(Clone, Debug)]
420pub struct KontextDevClient {
421 config: KontextDevConfig,
422 http: Client,
423}
424
425impl KontextDevClient {
426 pub fn new(config: KontextDevConfig) -> Self {
427 Self {
428 config,
429 http: Client::new(),
430 }
431 }
432
433 pub fn config(&self) -> &KontextDevConfig {
434 &self.config
435 }
436
437 pub fn server_base_url(&self) -> Result<String, KontextDevError> {
438 resolve_server_base_url(&self.config).map_err(KontextDevError::from)
439 }
440
441 pub fn mcp_url(&self) -> Result<String, KontextDevError> {
442 resolve_mcp_url(&self.config).map_err(KontextDevError::from)
443 }
444
445 pub fn token_url(&self) -> Result<String, KontextDevError> {
446 resolve_token_url(&self.config).map_err(KontextDevError::from)
447 }
448
449 pub fn authorize_url(&self) -> Result<String, KontextDevError> {
450 resolve_authorize_url(&self.config).map_err(KontextDevError::from)
451 }
452
453 pub fn connect_session_url(&self) -> Result<String, KontextDevError> {
454 resolve_connect_session_url(&self.config).map_err(KontextDevError::from)
455 }
456
457 pub fn clear_token_cache(&self) -> Result<(), KontextDevError> {
458 let Some(path) = self.token_cache_path() else {
459 return Ok(());
460 };
461
462 if !path.exists() {
463 return Ok(());
464 }
465
466 fs::remove_file(&path).map_err(|source| KontextDevError::TokenCacheWrite {
467 path: path.display().to_string(),
468 source,
469 })
470 }
471
472 pub async fn authenticate_mcp(&self) -> Result<KontextAuthSession, KontextDevError> {
475 let resource = self.config.resource.clone();
476
477 if let Some(cache) = self.read_cache()?
478 && cache.client_id == self.config.client_id
479 && cache.resource == resource
480 && cache.gateway.is_valid()
481 && cache.identity.is_valid()
482 {
483 return Ok(KontextAuthSession {
484 identity_token: cache.identity.to_access_token(),
485 gateway_token: cache.gateway.to_token_exchange_token(),
486 browser_auth_performed: false,
487 });
488 }
489
490 if let Some(cache) = self.read_cache()?
491 && cache.client_id == self.config.client_id
492 && cache.resource == resource
493 {
494 let maybe_refreshed = if cache.identity.is_valid() {
495 Some(cache.identity.to_access_token())
496 } else if let Some(refresh_token) = &cache.identity.refresh_token {
497 Some(self.refresh_identity_token(refresh_token).await?)
498 } else {
499 None
500 };
501
502 if let Some(identity_token) = maybe_refreshed {
503 let gateway_token = self
504 .exchange_for_resource(&identity_token.access_token, &resource, None)
505 .await?;
506 self.write_cache(&identity_token, &gateway_token)?;
507 return Ok(KontextAuthSession {
508 identity_token,
509 gateway_token,
510 browser_auth_performed: false,
511 });
512 }
513 }
514
515 let identity_token = self.authorize_with_browser_pkce().await?;
516 let gateway_token = self
517 .exchange_for_resource(&identity_token.access_token, &resource, None)
518 .await?;
519
520 self.write_cache(&identity_token, &gateway_token)?;
521
522 Ok(KontextAuthSession {
523 identity_token,
524 gateway_token,
525 browser_auth_performed: true,
526 })
527 }
528
529 pub async fn create_integration_connect_url(
532 &self,
533 gateway_access_token: &str,
534 ) -> Result<String, KontextDevError> {
535 let session = self.create_connect_session(gateway_access_token).await?;
536 self.integration_connect_url(&session.session_id)
537 }
538
539 pub fn integration_connect_url(&self, session_id: &str) -> Result<String, KontextDevError> {
540 if session_id.trim().is_empty() {
541 return Err(KontextDevError::ConnectSession {
542 message: "connect session id is empty".to_string(),
543 });
544 }
545
546 let base = if let Some(explicit) = &self.config.integration_ui_url {
547 explicit.trim_end_matches('/').to_string()
548 } else {
549 let server = resolve_server_base_url(&self.config)?;
550 if server.contains("api.kontext.dev") {
551 "https://app.kontext.dev".to_string()
552 } else {
553 server
554 }
555 };
556
557 let mut url = Url::parse(&base).map_err(|source| KontextDevError::InvalidUrl {
558 url: base.clone(),
559 source,
560 })?;
561 url.set_path("/oauth/connect");
562 url.query_pairs_mut().append_pair("session", session_id);
563 Ok(url.to_string())
564 }
565
566 pub async fn open_integration_connect_page(
568 &self,
569 gateway_access_token: &str,
570 ) -> Result<String, KontextDevError> {
571 let url = self
572 .create_integration_connect_url(gateway_access_token)
573 .await?;
574
575 if webbrowser::open(&url).is_err() {
576 return Err(KontextDevError::BrowserOpenFailed);
577 }
578
579 Ok(url)
580 }
581
582 pub async fn create_connect_session(
583 &self,
584 gateway_access_token: &str,
585 ) -> Result<ConnectSession, KontextDevError> {
586 let url = self.connect_session_url()?;
587 let response = self
588 .http
589 .post(&url)
590 .header("Authorization", format!("Bearer {gateway_access_token}"))
591 .json(&serde_json::json!({}))
592 .send()
593 .await
594 .map_err(|err| KontextDevError::ConnectSession {
595 message: err.to_string(),
596 })?;
597
598 if !response.status().is_success() {
599 let message = build_error_message(response).await;
600 return Err(KontextDevError::ConnectSession { message });
601 }
602
603 response
604 .json::<ConnectSession>()
605 .await
606 .map_err(|err| KontextDevError::ConnectSession {
607 message: err.to_string(),
608 })
609 }
610
611 pub async fn initiate_integration_oauth(
612 &self,
613 gateway_access_token: &str,
614 integration_id: &str,
615 return_to: Option<&str>,
616 ) -> Result<IntegrationOAuthInitResponse, KontextDevError> {
617 let url = resolve_integration_oauth_init_url(&self.config, integration_id)?;
618
619 let payload = return_to
620 .map(|value| serde_json::json!({ "returnTo": value }))
621 .unwrap_or_else(|| serde_json::json!({}));
622
623 let request = self
624 .http
625 .post(&url)
626 .header("Authorization", format!("Bearer {gateway_access_token}"))
627 .json(&payload);
628
629 let response =
630 request
631 .send()
632 .await
633 .map_err(|err| KontextDevError::IntegrationOAuthInit {
634 message: err.to_string(),
635 })?;
636
637 if !response.status().is_success() {
638 let message = build_error_message(response).await;
639 return Err(KontextDevError::IntegrationOAuthInit { message });
640 }
641
642 response
643 .json::<IntegrationOAuthInitResponse>()
644 .await
645 .map_err(|err| KontextDevError::IntegrationOAuthInit {
646 message: err.to_string(),
647 })
648 }
649
650 pub async fn integration_connection_status(
651 &self,
652 gateway_access_token: &str,
653 integration_id: &str,
654 ) -> Result<IntegrationConnectionStatus, KontextDevError> {
655 let url = resolve_integration_connection_url(&self.config, integration_id)?;
656 let response = self
657 .http
658 .get(url)
659 .header("Authorization", format!("Bearer {gateway_access_token}"))
660 .send()
661 .await
662 .map_err(|err| KontextDevError::IntegrationOAuthInit {
663 message: err.to_string(),
664 })?;
665
666 if !response.status().is_success() {
667 let message = build_error_message(response).await;
668 return Err(KontextDevError::IntegrationOAuthInit { message });
669 }
670
671 response
672 .json::<IntegrationConnectionStatus>()
673 .await
674 .map_err(|err| KontextDevError::IntegrationOAuthInit {
675 message: err.to_string(),
676 })
677 }
678
679 pub async fn wait_for_integration_connection(
680 &self,
681 gateway_access_token: &str,
682 integration_id: &str,
683 timeout_ms: u64,
684 interval_ms: u64,
685 ) -> Result<bool, KontextDevError> {
686 let started = now_unix_ms();
687
688 loop {
689 let status = self
690 .integration_connection_status(gateway_access_token, integration_id)
691 .await?;
692 if status.connected {
693 return Ok(true);
694 }
695
696 if now_unix_ms().saturating_sub(started) >= timeout_ms {
697 return Ok(false);
698 }
699
700 tokio::time::sleep(Duration::from_millis(interval_ms)).await;
701 }
702 }
703
704 async fn authorize_with_browser_pkce(&self) -> Result<AccessToken, KontextDevError> {
705 let auth_url = self.resolve_authorization_url().await?;
706 let token_url = self.token_url()?;
707 let pkce = generate_pkce_pair();
708 let state = generate_state();
709
710 let (callback_url, callback_payload) = self.listen_for_callback().await?;
711
712 let mut url = Url::parse(&auth_url).map_err(|source| KontextDevError::InvalidUrl {
713 url: auth_url.clone(),
714 source,
715 })?;
716
717 {
718 let mut query = url.query_pairs_mut();
719 query
720 .append_pair("client_id", &self.config.client_id)
721 .append_pair("response_type", "code")
722 .append_pair("redirect_uri", &callback_url)
723 .append_pair("state", &state);
724
725 if let Some(scope) = normalized_scope(&self.config.scope) {
726 query.append_pair("scope", scope);
727 }
728
729 query
730 .append_pair("code_challenge_method", "S256")
731 .append_pair("code_challenge", &pkce.challenge);
732 }
733
734 if webbrowser::open(url.as_str()).is_err() {
735 return Err(KontextDevError::BrowserOpenFailed);
736 }
737
738 let payload = callback_payload.await?;
739
740 if let Some(error) = payload.error {
741 let with_details = payload
742 .error_description
743 .map(|description| format!("{error}: {description}"))
744 .unwrap_or(error);
745 return Err(KontextDevError::OAuthCallbackError {
746 error: with_details,
747 });
748 }
749
750 if payload.state.as_deref() != Some(state.as_str()) {
751 return Err(KontextDevError::InvalidOAuthState);
752 }
753
754 let code = payload
755 .code
756 .ok_or(KontextDevError::MissingAuthorizationCode)?;
757
758 let mut body = vec![
759 ("grant_type", "authorization_code".to_string()),
760 ("code", code),
761 ("redirect_uri", callback_url),
762 ("client_id", self.config.client_id.clone()),
763 ("code_verifier", pkce.verifier),
764 ];
765
766 if let Some(client_secret) = &self.config.client_secret {
767 body.push(("client_secret", client_secret.clone()));
768 }
769
770 post_token(&self.http, &token_url, None, &body).await
771 }
772
773 async fn resolve_authorization_url(&self) -> Result<String, KontextDevError> {
774 if let Some(discovered) = self.discover_authorization_endpoint().await {
775 return Ok(discovered);
776 }
777
778 let authorize_url = self.authorize_url()?;
779 if !self.endpoint_is_missing(&authorize_url).await {
780 return Ok(authorize_url);
781 }
782
783 let server_base = resolve_server_base_url(&self.config)?;
784 Ok(format!("{}/oauth2/auth", server_base.trim_end_matches('/')))
785 }
786
787 async fn discover_authorization_endpoint(&self) -> Option<String> {
788 let base = resolve_server_base_url(&self.config).ok()?;
789 let base = base.trim_end_matches('/');
790
791 let candidates = [
792 format!("{base}/.well-known/oauth-authorization-server/mcp"),
793 format!("{base}/.well-known/oauth-authorization-server"),
794 ];
795
796 for url in candidates {
797 let response = match self.http.get(&url).send().await {
798 Ok(response) => response,
799 Err(_) => continue,
800 };
801
802 if !response.status().is_success() {
803 continue;
804 }
805
806 let metadata = match response.json::<OAuthAuthorizationServerMetadata>().await {
807 Ok(metadata) => metadata,
808 Err(_) => continue,
809 };
810
811 let Some(endpoint) = metadata.authorization_endpoint else {
812 continue;
813 };
814
815 if Url::parse(&endpoint).is_ok() {
816 return Some(endpoint);
817 }
818 }
819
820 None
821 }
822
823 async fn endpoint_is_missing(&self, url: &str) -> bool {
824 let probe_client = match reqwest::Client::builder()
825 .redirect(reqwest::redirect::Policy::none())
826 .build()
827 {
828 Ok(client) => client,
829 Err(_) => return false,
830 };
831
832 match probe_client.get(url).send().await {
833 Ok(response) => response.status() == StatusCode::NOT_FOUND,
834 Err(_) => false,
835 }
836 }
837
838 async fn refresh_identity_token(
839 &self,
840 refresh_token: &str,
841 ) -> Result<AccessToken, KontextDevError> {
842 let token_url = self.token_url()?;
843 let mut body = vec![
844 ("grant_type", "refresh_token".to_string()),
845 ("refresh_token", refresh_token.to_string()),
846 ("client_id", self.config.client_id.clone()),
847 ];
848
849 if let Some(client_secret) = &self.config.client_secret {
850 body.push(("client_secret", client_secret.clone()));
851 }
852
853 post_token(&self.http, &token_url, None, &body).await
854 }
855
856 pub async fn exchange_for_resource(
857 &self,
858 subject_token: &str,
859 resource: &str,
860 scope: Option<&str>,
861 ) -> Result<TokenExchangeToken, KontextDevError> {
862 let token_url = self.token_url()?;
863
864 let mut body = vec![
865 ("grant_type", TOKEN_EXCHANGE_GRANT_TYPE.to_string()),
866 ("subject_token", subject_token.to_string()),
867 ("subject_token_type", TOKEN_TYPE_ACCESS_TOKEN.to_string()),
868 ("resource", resource.to_string()),
869 ];
870
871 if let Some(scope) = scope {
872 body.push(("scope", scope.to_string()));
873 }
874
875 let auth_header = self.config.client_secret.as_ref().map(|secret| {
876 let raw = format!("{}:{}", self.config.client_id, secret);
877 format!(
878 "Basic {}",
879 base64::engine::general_purpose::STANDARD.encode(raw.as_bytes())
880 )
881 });
882
883 if self.config.client_secret.is_none() {
884 body.push(("client_id", self.config.client_id.clone()));
885 }
886
887 let response = post_form_with_optional_auth::<TokenExchangeToken>(
888 &self.http,
889 &token_url,
890 auth_header,
891 &body,
892 )
893 .await
894 .map_err(|message| KontextDevError::TokenExchange {
895 resource: resource.to_string(),
896 message,
897 })?;
898
899 if response.access_token.is_empty() {
900 return Err(KontextDevError::EmptyAccessToken);
901 }
902
903 Ok(response)
904 }
905
906 async fn listen_for_callback(
907 &self,
908 ) -> Result<
909 (
910 String,
911 impl std::future::Future<Output = Result<OAuthCallbackPayload, KontextDevError>>,
912 ),
913 KontextDevError,
914 > {
915 let redirect_uri = self.config.redirect_uri.trim().to_string();
916 let parsed = Url::parse(&redirect_uri).map_err(|source| KontextDevError::InvalidUrl {
917 url: redirect_uri.clone(),
918 source,
919 })?;
920
921 if parsed.scheme() != "http" {
922 return Err(KontextDevError::OAuthCallbackError {
923 error: "redirect_uri must use http".to_string(),
924 });
925 }
926
927 if parsed.query().is_some() || parsed.fragment().is_some() {
928 return Err(KontextDevError::OAuthCallbackError {
929 error: "redirect_uri must not include query parameters or fragments".to_string(),
930 });
931 }
932
933 let host = parsed
934 .host_str()
935 .ok_or_else(|| KontextDevError::OAuthCallbackError {
936 error: "redirect_uri host is missing".to_string(),
937 })?;
938
939 let port = parsed
940 .port()
941 .ok_or_else(|| KontextDevError::OAuthCallbackError {
942 error: "redirect_uri must include an explicit port".to_string(),
943 })?;
944
945 let callback_path = parsed.path().to_string();
946 let bind_addr = if host.contains(':') {
947 format!("[{host}]:{port}")
948 } else {
949 format!("{host}:{port}")
950 };
951
952 let server = Arc::new(Server::http(&bind_addr).map_err(|err| {
953 KontextDevError::OAuthCallbackError {
954 error: format!("failed to start callback server at {bind_addr}: {err}"),
955 }
956 })?);
957
958 let (tx, rx) = oneshot::channel();
959 let _join = spawn_callback_server(server.clone(), callback_path, tx);
960 let _guard = CallbackServerGuard { server };
961 let timeout_seconds = self.config.auth_timeout_seconds.max(1);
962
963 let fut = async move {
964 let payload = timeout(Duration::from_secs(timeout_seconds as u64), rx)
965 .await
966 .map_err(|_| KontextDevError::OAuthCallbackTimeout { timeout_seconds })?
967 .map_err(|_| KontextDevError::OAuthCallbackCancelled)?;
968 drop(_guard);
969 Ok(payload)
970 };
971
972 Ok((redirect_uri, fut))
973 }
974
975 fn token_cache_path(&self) -> Option<PathBuf> {
976 if let Some(explicit) = &self.config.token_cache_path {
977 return Some(PathBuf::from(explicit));
978 }
979
980 let home = dirs::home_dir()?;
981 let mut path = home;
982 path.push(".kontext-dev");
983 path.push("tokens");
984
985 let sanitized_client_id: String = self
986 .config
987 .client_id
988 .chars()
989 .map(|ch| {
990 if ch.is_ascii_alphanumeric() || ch == '-' || ch == '_' {
991 ch
992 } else {
993 '_'
994 }
995 })
996 .collect();
997
998 path.push(format!("{sanitized_client_id}.json"));
999 Some(path)
1000 }
1001
1002 fn read_cache(&self) -> Result<Option<TokenCacheFile>, KontextDevError> {
1003 let Some(path) = self.token_cache_path() else {
1004 return Ok(None);
1005 };
1006
1007 let raw = match fs::read_to_string(&path) {
1008 Ok(raw) => raw,
1009 Err(err) if err.kind() == std::io::ErrorKind::NotFound => return Ok(None),
1010 Err(source) => {
1011 return Err(KontextDevError::TokenCacheRead {
1012 path: path.display().to_string(),
1013 source,
1014 });
1015 }
1016 };
1017
1018 serde_json::from_str(&raw).map(Some).map_err(|source| {
1019 KontextDevError::TokenCacheDeserialize {
1020 path: path.display().to_string(),
1021 source,
1022 }
1023 })
1024 }
1025
1026 fn write_cache(
1027 &self,
1028 identity: &AccessToken,
1029 gateway: &TokenExchangeToken,
1030 ) -> Result<(), KontextDevError> {
1031 let Some(path) = self.token_cache_path() else {
1032 return Ok(());
1033 };
1034
1035 if let Some(parent) = path.parent() {
1036 fs::create_dir_all(parent).map_err(|source| KontextDevError::TokenCacheWrite {
1037 path: parent.display().to_string(),
1038 source,
1039 })?;
1040 }
1041
1042 let payload = TokenCacheFile {
1043 client_id: self.config.client_id.clone(),
1044 resource: self.config.resource.clone(),
1045 identity: CachedAccessToken::from_access_token(identity)?,
1046 gateway: CachedAccessToken::from_token_exchange(gateway)?,
1047 };
1048
1049 let serialized = serde_json::to_string_pretty(&payload)
1050 .map_err(|source| KontextDevError::TokenCacheSerialize { source })?;
1051
1052 fs::write(&path, serialized).map_err(|source| KontextDevError::TokenCacheWrite {
1053 path: path.display().to_string(),
1054 source,
1055 })
1056 }
1057}
1058
1059pub async fn request_access_token(
1063 config: &KontextDevConfig,
1064) -> Result<AccessToken, KontextDevError> {
1065 let token_url = resolve_token_url(config)?;
1066
1067 let mut body = vec![("grant_type", "client_credentials".to_string())];
1068
1069 if let Some(scope) = normalized_scope(&config.scope) {
1070 body.push(("scope", scope.to_string()));
1071 }
1072
1073 let auth_header = if let Some(secret) = &config.client_secret {
1074 let raw = format!("{}:{}", config.client_id, secret);
1075 Some(format!(
1076 "Basic {}",
1077 base64::engine::general_purpose::STANDARD.encode(raw.as_bytes())
1078 ))
1079 } else {
1080 body.push(("client_id", config.client_id.clone()));
1081 None
1082 };
1083
1084 post_token(&Client::new(), &token_url, auth_header.as_deref(), &body).await
1085}
1086
1087async fn post_token(
1088 http: &Client,
1089 token_url: &str,
1090 authorization: Option<&str>,
1091 body: &[(impl AsRef<str>, String)],
1092) -> Result<AccessToken, KontextDevError> {
1093 let body_vec: Vec<(&str, String)> = body.iter().map(|(k, v)| (k.as_ref(), v.clone())).collect();
1094
1095 let response = post_form_with_optional_auth::<AccessToken>(
1096 http,
1097 token_url,
1098 authorization.map(ToString::to_string),
1099 &body_vec,
1100 )
1101 .await
1102 .map_err(|message| KontextDevError::TokenRequest {
1103 token_url: token_url.to_string(),
1104 message,
1105 })?;
1106
1107 Ok(response)
1108}
1109
1110async fn post_form_with_optional_auth<T>(
1111 http: &Client,
1112 url: &str,
1113 authorization: Option<String>,
1114 body: &[(&str, String)],
1115) -> Result<T, String>
1116where
1117 T: DeserializeOwned,
1118{
1119 let mut request = http
1120 .post(url)
1121 .header("Content-Type", "application/x-www-form-urlencoded");
1122
1123 if let Some(header) = authorization {
1124 request = request.header("Authorization", header);
1125 }
1126
1127 let form = body
1128 .iter()
1129 .map(|(k, v)| (k.to_string(), v.to_string()))
1130 .collect::<Vec<(String, String)>>();
1131
1132 let response = request
1133 .form(&form)
1134 .send()
1135 .await
1136 .map_err(|err| err.to_string())?;
1137
1138 if !response.status().is_success() {
1139 return Err(build_error_message(response).await);
1140 }
1141
1142 response.json::<T>().await.map_err(|err| err.to_string())
1143}
1144
1145async fn build_error_message(response: reqwest::Response) -> String {
1146 let status = response.status();
1147 let fallback = format!(
1148 "{} {}",
1149 status.as_u16(),
1150 status.canonical_reason().unwrap_or("")
1151 );
1152
1153 let body = response.text().await.unwrap_or_default();
1154 if body.is_empty() {
1155 return fallback.trim().to_string();
1156 }
1157
1158 if let Ok(parsed) = serde_json::from_str::<OAuthErrorBody>(&body) {
1159 if let Some(description) = parsed.error_description {
1160 return description;
1161 }
1162 if let Some(message) = parsed.message {
1163 return message;
1164 }
1165 if let Some(error) = parsed.error {
1166 return error;
1167 }
1168 }
1169
1170 format!("{fallback}: {body}")
1171}
1172
1173#[cfg(test)]
1174mod tests {
1175 use super::*;
1176
1177 fn config() -> KontextDevConfig {
1178 KontextDevConfig {
1179 server: "http://localhost:4000".to_string(),
1180 client_id: "client_123".to_string(),
1181 client_secret: None,
1182 scope: DEFAULT_SCOPE.to_string(),
1183 server_name: DEFAULT_SERVER_NAME.to_string(),
1184 resource: DEFAULT_RESOURCE.to_string(),
1185 integration_ui_url: Some("https://app.kontext.dev".to_string()),
1186 integration_return_to: None,
1187 open_connect_page_on_login: true,
1188 auth_timeout_seconds: DEFAULT_AUTH_TIMEOUT_SECONDS,
1189 token_cache_path: None,
1190 redirect_uri: "http://localhost:3333/callback".to_string(),
1191 }
1192 }
1193
1194 #[test]
1195 fn create_connect_url_uses_hosted_ui() {
1196 let client = KontextDevClient::new(config());
1197 let url = client
1198 .integration_connect_url("session-123")
1199 .expect("url should be built");
1200 assert_eq!(
1201 url,
1202 "https://app.kontext.dev/oauth/connect?session=session-123"
1203 );
1204 }
1205
1206 #[test]
1207 fn jwt_exp_decode_reads_exp() {
1208 let header = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(r#"{"alg":"none"}"#);
1209 let payload =
1210 base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(r#"{"exp":4070908800}"#);
1211 let token = format!("{header}.{payload}.sig");
1212 assert_eq!(decode_jwt_exp(&token), Some(4_070_908_800));
1213 }
1214
1215 #[test]
1216 fn oauth_metadata_parses_authorization_endpoint() {
1217 let metadata = serde_json::from_str::<OAuthAuthorizationServerMetadata>(
1218 r#"{
1219 "issuer": "https://issuer.example.com",
1220 "authorization_endpoint": "https://issuer.example.com/oauth2/auth"
1221 }"#,
1222 )
1223 .expect("metadata should parse");
1224
1225 assert_eq!(
1226 metadata.authorization_endpoint.as_deref(),
1227 Some("https://issuer.example.com/oauth2/auth")
1228 );
1229 }
1230
1231 #[test]
1232 fn normalized_scope_omits_blank_values() {
1233 assert_eq!(normalized_scope(""), None);
1234 assert_eq!(normalized_scope(" "), None);
1235 }
1236
1237 #[test]
1238 fn normalized_scope_preserves_non_empty_values() {
1239 assert_eq!(normalized_scope("mcp:invoke"), Some("mcp:invoke"));
1240 assert_eq!(
1241 normalized_scope(" mcp:invoke openid "),
1242 Some("mcp:invoke openid")
1243 );
1244 }
1245}