1use base64::Engine;
2use reqwest::Client;
3use reqwest::StatusCode;
4use serde::Deserialize;
5use serde::Serialize;
6use serde::de::DeserializeOwned;
7use sha2::Digest;
8use std::fs;
9use std::path::PathBuf;
10use std::sync::Arc;
11use std::time::Duration;
12use std::time::SystemTime;
13use std::time::UNIX_EPOCH;
14use tiny_http::Response;
15use tiny_http::Server;
16use tokio::sync::oneshot;
17use tokio::time::timeout;
18use url::Url;
19
20pub use kontext_dev_core::AccessToken;
21pub use kontext_dev_core::DEFAULT_AUTH_TIMEOUT_SECONDS;
22pub use kontext_dev_core::DEFAULT_RESOURCE;
23pub use kontext_dev_core::DEFAULT_SCOPE;
24pub use kontext_dev_core::DEFAULT_SERVER_NAME;
25pub use kontext_dev_core::KontextDevConfig;
26pub use kontext_dev_core::KontextDevCoreError;
27pub use kontext_dev_core::TokenExchangeToken;
28pub use kontext_dev_core::build_mcp_url;
29pub use kontext_dev_core::normalize_server_url;
30pub use kontext_dev_core::resolve_authorize_url;
31pub use kontext_dev_core::resolve_connect_session_url;
32pub use kontext_dev_core::resolve_integration_connection_url;
33pub use kontext_dev_core::resolve_integration_oauth_init_url;
34pub use kontext_dev_core::resolve_mcp_url;
35pub use kontext_dev_core::resolve_server_base_url;
36pub use kontext_dev_core::resolve_token_url;
37
38const TOKEN_EXCHANGE_GRANT_TYPE: &str = "urn:ietf:params:oauth:grant-type:token-exchange";
39const TOKEN_TYPE_ACCESS_TOKEN: &str = "urn:ietf:params:oauth:token-type:access_token";
40const TOKEN_EXPIRY_BUFFER_SECONDS: u64 = 60;
41
42#[derive(Debug, thiserror::Error)]
43pub enum KontextDevError {
44 #[error(transparent)]
45 Core(#[from] kontext_dev_core::KontextDevCoreError),
46 #[error("failed to parse URL `{url}`")]
47 InvalidUrl {
48 url: String,
49 source: url::ParseError,
50 },
51 #[error("failed to open browser for OAuth authorization")]
52 BrowserOpenFailed,
53 #[error("OAuth callback timed out after {timeout_seconds}s")]
54 OAuthCallbackTimeout { timeout_seconds: i64 },
55 #[error("OAuth callback channel was unexpectedly cancelled")]
56 OAuthCallbackCancelled,
57 #[error("OAuth callback is missing the authorization code")]
58 MissingAuthorizationCode,
59 #[error("OAuth callback returned an error: {error}")]
60 OAuthCallbackError { error: String },
61 #[error("OAuth state mismatch")]
62 InvalidOAuthState,
63 #[error("Kontext-Dev token request failed for {token_url}: {message}")]
64 TokenRequest { token_url: String, message: String },
65 #[error("Kontext-Dev token exchange failed for resource `{resource}`: {message}")]
66 TokenExchange { resource: String, message: String },
67 #[error("Kontext-Dev connect session request failed: {message}")]
68 ConnectSession { message: String },
69 #[error("Kontext-Dev integration OAuth init failed: {message}")]
70 IntegrationOAuthInit { message: String },
71 #[error("failed to read token cache at `{path}`: {source}")]
72 TokenCacheRead {
73 path: String,
74 source: std::io::Error,
75 },
76 #[error("failed to write token cache at `{path}`: {source}")]
77 TokenCacheWrite {
78 path: String,
79 source: std::io::Error,
80 },
81 #[error("failed to deserialize token cache at `{path}`: {source}")]
82 TokenCacheDeserialize {
83 path: String,
84 source: serde_json::Error,
85 },
86 #[error("failed to serialize token cache: {source}")]
87 TokenCacheSerialize { source: serde_json::Error },
88 #[error("Kontext-Dev access token is empty")]
89 EmptyAccessToken,
90 #[error("missing integration UI URL; set `integration_ui_url` in config")]
91 MissingIntegrationUiUrl,
92}
93
94#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
95pub struct ConnectSession {
96 #[serde(rename = "sessionId")]
97 pub session_id: String,
98 #[serde(rename = "expiresAt")]
99 pub expires_at: String,
100}
101
102#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
103pub struct IntegrationOAuthInitResponse {
104 #[serde(rename = "authorizationUrl")]
105 pub authorization_url: String,
106 #[serde(default)]
107 pub state: Option<String>,
108}
109
110#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
111pub struct IntegrationConnectionStatus {
112 pub connected: bool,
113 #[serde(default)]
114 pub expires_at: Option<String>,
115}
116
117#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
118pub struct KontextAuthSession {
119 pub identity_token: AccessToken,
120 pub gateway_token: TokenExchangeToken,
121 pub browser_auth_performed: bool,
122}
123
124#[derive(Clone, Debug, Deserialize, Serialize)]
125struct CachedAccessToken {
126 access_token: String,
127 token_type: String,
128 refresh_token: Option<String>,
129 scope: Option<String>,
130 expires_at_unix_ms: Option<u64>,
131}
132
133#[derive(Clone, Debug, Deserialize, Serialize)]
134struct TokenCacheFile {
135 client_id: String,
136 resource: String,
137 identity: CachedAccessToken,
138 gateway: CachedAccessToken,
139}
140
141impl CachedAccessToken {
142 fn from_access_token(token: &AccessToken) -> Result<Self, KontextDevError> {
143 if token.access_token.is_empty() {
144 return Err(KontextDevError::EmptyAccessToken);
145 }
146
147 Ok(Self {
148 access_token: token.access_token.clone(),
149 token_type: token.token_type.clone(),
150 refresh_token: token.refresh_token.clone(),
151 scope: token.scope.clone(),
152 expires_at_unix_ms: compute_expires_at_unix_ms(token.expires_in, &token.access_token),
153 })
154 }
155
156 fn from_token_exchange(token: &TokenExchangeToken) -> Result<Self, KontextDevError> {
157 if token.access_token.is_empty() {
158 return Err(KontextDevError::EmptyAccessToken);
159 }
160
161 Ok(Self {
162 access_token: token.access_token.clone(),
163 token_type: token.token_type.clone(),
164 refresh_token: token.refresh_token.clone(),
165 scope: token.scope.clone(),
166 expires_at_unix_ms: compute_expires_at_unix_ms(token.expires_in, &token.access_token),
167 })
168 }
169
170 fn is_valid(&self) -> bool {
171 match self.expires_at_unix_ms {
172 Some(expires_at) => {
173 let buffer_ms = TOKEN_EXPIRY_BUFFER_SECONDS * 1000;
174 now_unix_ms().saturating_add(buffer_ms) < expires_at
175 }
176 None => true,
177 }
178 }
179
180 fn to_access_token(&self) -> AccessToken {
181 AccessToken {
182 access_token: self.access_token.clone(),
183 token_type: self.token_type.clone(),
184 expires_in: self
185 .expires_at_unix_ms
186 .and_then(unix_ms_to_relative_seconds),
187 refresh_token: self.refresh_token.clone(),
188 scope: self.scope.clone(),
189 }
190 }
191
192 fn to_token_exchange_token(&self) -> TokenExchangeToken {
193 TokenExchangeToken {
194 access_token: self.access_token.clone(),
195 issued_token_type: TOKEN_TYPE_ACCESS_TOKEN.to_string(),
196 token_type: self.token_type.clone(),
197 expires_in: self
198 .expires_at_unix_ms
199 .and_then(unix_ms_to_relative_seconds),
200 scope: self.scope.clone(),
201 refresh_token: self.refresh_token.clone(),
202 }
203 }
204}
205
206fn now_unix_ms() -> u64 {
207 SystemTime::now()
208 .duration_since(UNIX_EPOCH)
209 .unwrap_or_else(|_| Duration::from_secs(0))
210 .as_millis() as u64
211}
212
213fn unix_ms_to_relative_seconds(unix_ms: u64) -> Option<i64> {
214 if unix_ms <= now_unix_ms() {
215 return Some(0);
216 }
217
218 let delta_ms = unix_ms - now_unix_ms();
219 let secs = delta_ms / 1000;
220 i64::try_from(secs).ok()
221}
222
223fn compute_expires_at_unix_ms(expires_in: Option<i64>, access_token: &str) -> Option<u64> {
224 if let Some(expires_in) = expires_in
225 && expires_in > 0
226 {
227 let expires_in_u64 = u64::try_from(expires_in).ok()?;
228 return Some(now_unix_ms().saturating_add(expires_in_u64.saturating_mul(1000)));
229 }
230
231 decode_jwt_exp(access_token).map(|exp| exp.saturating_mul(1000))
232}
233
234fn decode_jwt_exp(token: &str) -> Option<u64> {
235 let mut parts = token.split('.');
236 let _header = parts.next()?;
237 let payload = parts.next()?;
238
239 let decoded = base64::engine::general_purpose::URL_SAFE_NO_PAD
240 .decode(payload)
241 .ok()?;
242 let value: serde_json::Value = serde_json::from_slice(&decoded).ok()?;
243 value.get("exp").and_then(|exp| {
244 exp.as_u64()
245 .or_else(|| exp.as_i64().and_then(|v| u64::try_from(v).ok()))
246 })
247}
248
249#[derive(Clone, Debug, Deserialize)]
250struct OAuthErrorBody {
251 error: Option<String>,
252 error_description: Option<String>,
253 message: Option<String>,
254}
255
256#[derive(Clone, Debug, Deserialize)]
257struct OAuthAuthorizationServerMetadata {
258 authorization_endpoint: Option<String>,
259}
260
261#[derive(Clone, Debug, Deserialize)]
262struct OAuthCallbackPayload {
263 code: Option<String>,
264 state: Option<String>,
265 error: Option<String>,
266 error_description: Option<String>,
267}
268
269#[derive(Clone, Debug)]
270struct PkcePair {
271 verifier: String,
272 challenge: String,
273}
274
275fn generate_pkce_pair() -> PkcePair {
276 let mut raw = [0u8; 64];
277 rand::RngCore::fill_bytes(&mut rand::rng(), &mut raw);
278
279 let verifier = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(raw);
280 let digest = sha2::Sha256::digest(verifier.as_bytes());
281 let challenge = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(digest);
282
283 PkcePair {
284 verifier,
285 challenge,
286 }
287}
288
289fn generate_state() -> String {
290 let mut bytes = [0u8; 16];
291 rand::RngCore::fill_bytes(&mut rand::rng(), &mut bytes);
292 bytes.iter().map(|b| format!("{b:02x}")).collect()
293}
294
295fn normalized_scope(scope: &str) -> Option<&str> {
296 let trimmed = scope.trim();
297 if trimmed.is_empty() {
298 None
299 } else {
300 Some(trimmed)
301 }
302}
303
304struct CallbackServerGuard {
305 server: Arc<Server>,
306}
307
308impl Drop for CallbackServerGuard {
309 fn drop(&mut self) {
310 self.server.unblock();
311 }
312}
313
314fn parse_callback(url_path: &str, callback_path: &str) -> Option<OAuthCallbackPayload> {
315 let full = format!("http://localhost{url_path}");
316 let parsed = Url::parse(&full).ok()?;
317
318 if parsed.path() != callback_path {
319 return None;
320 }
321
322 let mut payload = OAuthCallbackPayload {
323 code: None,
324 state: None,
325 error: None,
326 error_description: None,
327 };
328
329 for (key, value) in parsed.query_pairs() {
330 match key.as_ref() {
331 "code" => payload.code = Some(value.to_string()),
332 "state" => payload.state = Some(value.to_string()),
333 "error" => payload.error = Some(value.to_string()),
334 "error_description" => payload.error_description = Some(value.to_string()),
335 _ => {}
336 }
337 }
338
339 Some(payload)
340}
341
342fn spawn_callback_server(
343 server: Arc<Server>,
344 callback_path: String,
345 tx: oneshot::Sender<OAuthCallbackPayload>,
346) -> tokio::task::JoinHandle<()> {
347 tokio::task::spawn_blocking(move || {
348 while let Ok(request) = server.recv() {
349 let path = request.url().to_string();
350 if let Some(payload) = parse_callback(&path, &callback_path) {
351 let response = Response::from_string(
352 "Authentication complete. You can return to your terminal.",
353 );
354 let _ = request.respond(response);
355 let _ = tx.send(payload);
356 break;
357 }
358
359 let response = Response::from_string("Invalid callback").with_status_code(400);
360 let _ = request.respond(response);
361 }
362 })
363}
364
365pub struct KontextDevClient {
366 config: KontextDevConfig,
367 http: Client,
368}
369
370impl KontextDevClient {
371 pub fn new(config: KontextDevConfig) -> Self {
372 Self {
373 config,
374 http: Client::new(),
375 }
376 }
377
378 pub fn config(&self) -> &KontextDevConfig {
379 &self.config
380 }
381
382 pub fn mcp_url(&self) -> Result<String, KontextDevError> {
383 resolve_mcp_url(&self.config).map_err(KontextDevError::from)
384 }
385
386 pub fn token_url(&self) -> Result<String, KontextDevError> {
387 resolve_token_url(&self.config).map_err(KontextDevError::from)
388 }
389
390 pub fn authorize_url(&self) -> Result<String, KontextDevError> {
391 resolve_authorize_url(&self.config).map_err(KontextDevError::from)
392 }
393
394 pub fn connect_session_url(&self) -> Result<String, KontextDevError> {
395 resolve_connect_session_url(&self.config).map_err(KontextDevError::from)
396 }
397
398 pub async fn authenticate_mcp(&self) -> Result<KontextAuthSession, KontextDevError> {
401 let resource = self.config.resource.clone();
402
403 if let Some(cache) = self.read_cache()?
404 && cache.client_id == self.config.client_id
405 && cache.resource == resource
406 && cache.gateway.is_valid()
407 && cache.identity.is_valid()
408 {
409 return Ok(KontextAuthSession {
410 identity_token: cache.identity.to_access_token(),
411 gateway_token: cache.gateway.to_token_exchange_token(),
412 browser_auth_performed: false,
413 });
414 }
415
416 if let Some(cache) = self.read_cache()?
417 && cache.client_id == self.config.client_id
418 && cache.resource == resource
419 {
420 let maybe_refreshed = if cache.identity.is_valid() {
421 Some(cache.identity.to_access_token())
422 } else if let Some(refresh_token) = &cache.identity.refresh_token {
423 Some(self.refresh_identity_token(refresh_token).await?)
424 } else {
425 None
426 };
427
428 if let Some(identity_token) = maybe_refreshed {
429 let gateway_token = self
430 .exchange_for_resource(&identity_token.access_token, &resource, None)
431 .await?;
432 self.write_cache(&identity_token, &gateway_token)?;
433 return Ok(KontextAuthSession {
434 identity_token,
435 gateway_token,
436 browser_auth_performed: false,
437 });
438 }
439 }
440
441 let identity_token = self.authorize_with_browser_pkce().await?;
442 let gateway_token = self
443 .exchange_for_resource(&identity_token.access_token, &resource, None)
444 .await?;
445
446 self.write_cache(&identity_token, &gateway_token)?;
447
448 Ok(KontextAuthSession {
449 identity_token,
450 gateway_token,
451 browser_auth_performed: true,
452 })
453 }
454
455 pub async fn create_integration_connect_url(
458 &self,
459 gateway_access_token: &str,
460 ) -> Result<String, KontextDevError> {
461 let session = self.create_connect_session(gateway_access_token).await?;
462 self.integration_connect_url(&session.session_id)
463 }
464
465 pub fn integration_connect_url(&self, session_id: &str) -> Result<String, KontextDevError> {
466 if session_id.trim().is_empty() {
467 return Err(KontextDevError::ConnectSession {
468 message: "connect session id is empty".to_string(),
469 });
470 }
471
472 let base = if let Some(explicit) = &self.config.integration_ui_url {
473 explicit.trim_end_matches('/').to_string()
474 } else {
475 let server = resolve_server_base_url(&self.config)?;
476 if server.contains("api.kontext.dev") {
477 "https://app.kontext.dev".to_string()
478 } else {
479 server
480 }
481 };
482
483 let mut url = Url::parse(&base).map_err(|source| KontextDevError::InvalidUrl {
484 url: base.clone(),
485 source,
486 })?;
487 url.set_path("/oauth/connect");
488 url.query_pairs_mut().append_pair("session", session_id);
489 Ok(url.to_string())
490 }
491
492 pub async fn open_integration_connect_page(
494 &self,
495 gateway_access_token: &str,
496 ) -> Result<String, KontextDevError> {
497 let url = self
498 .create_integration_connect_url(gateway_access_token)
499 .await?;
500
501 if webbrowser::open(&url).is_err() {
502 return Err(KontextDevError::BrowserOpenFailed);
503 }
504
505 Ok(url)
506 }
507
508 pub async fn create_connect_session(
509 &self,
510 gateway_access_token: &str,
511 ) -> Result<ConnectSession, KontextDevError> {
512 let url = self.connect_session_url()?;
513 let response = self
514 .http
515 .post(&url)
516 .header("Authorization", format!("Bearer {gateway_access_token}"))
517 .json(&serde_json::json!({}))
518 .send()
519 .await
520 .map_err(|err| KontextDevError::ConnectSession {
521 message: err.to_string(),
522 })?;
523
524 if !response.status().is_success() {
525 let message = build_error_message(response).await;
526 return Err(KontextDevError::ConnectSession { message });
527 }
528
529 response
530 .json::<ConnectSession>()
531 .await
532 .map_err(|err| KontextDevError::ConnectSession {
533 message: err.to_string(),
534 })
535 }
536
537 pub async fn initiate_integration_oauth(
538 &self,
539 gateway_access_token: &str,
540 integration_id: &str,
541 return_to: Option<&str>,
542 ) -> Result<IntegrationOAuthInitResponse, KontextDevError> {
543 let url = resolve_integration_oauth_init_url(&self.config, integration_id)?;
544
545 let payload = return_to
546 .map(|value| serde_json::json!({ "returnTo": value }))
547 .unwrap_or_else(|| serde_json::json!({}));
548
549 let request = self
550 .http
551 .post(&url)
552 .header("Authorization", format!("Bearer {gateway_access_token}"))
553 .json(&payload);
554
555 let response =
556 request
557 .send()
558 .await
559 .map_err(|err| KontextDevError::IntegrationOAuthInit {
560 message: err.to_string(),
561 })?;
562
563 if !response.status().is_success() {
564 let message = build_error_message(response).await;
565 return Err(KontextDevError::IntegrationOAuthInit { message });
566 }
567
568 response
569 .json::<IntegrationOAuthInitResponse>()
570 .await
571 .map_err(|err| KontextDevError::IntegrationOAuthInit {
572 message: err.to_string(),
573 })
574 }
575
576 pub async fn integration_connection_status(
577 &self,
578 gateway_access_token: &str,
579 integration_id: &str,
580 ) -> Result<IntegrationConnectionStatus, KontextDevError> {
581 let url = resolve_integration_connection_url(&self.config, integration_id)?;
582 let response = self
583 .http
584 .get(url)
585 .header("Authorization", format!("Bearer {gateway_access_token}"))
586 .send()
587 .await
588 .map_err(|err| KontextDevError::IntegrationOAuthInit {
589 message: err.to_string(),
590 })?;
591
592 if !response.status().is_success() {
593 let message = build_error_message(response).await;
594 return Err(KontextDevError::IntegrationOAuthInit { message });
595 }
596
597 response
598 .json::<IntegrationConnectionStatus>()
599 .await
600 .map_err(|err| KontextDevError::IntegrationOAuthInit {
601 message: err.to_string(),
602 })
603 }
604
605 pub async fn wait_for_integration_connection(
606 &self,
607 gateway_access_token: &str,
608 integration_id: &str,
609 timeout_ms: u64,
610 interval_ms: u64,
611 ) -> Result<bool, KontextDevError> {
612 let started = now_unix_ms();
613
614 loop {
615 let status = self
616 .integration_connection_status(gateway_access_token, integration_id)
617 .await?;
618 if status.connected {
619 return Ok(true);
620 }
621
622 if now_unix_ms().saturating_sub(started) >= timeout_ms {
623 return Ok(false);
624 }
625
626 tokio::time::sleep(Duration::from_millis(interval_ms)).await;
627 }
628 }
629
630 async fn authorize_with_browser_pkce(&self) -> Result<AccessToken, KontextDevError> {
631 let auth_url = self.resolve_authorization_url().await?;
632 let token_url = self.token_url()?;
633 let pkce = generate_pkce_pair();
634 let state = generate_state();
635
636 let (callback_url, callback_payload) = self.listen_for_callback().await?;
637
638 let mut url = Url::parse(&auth_url).map_err(|source| KontextDevError::InvalidUrl {
639 url: auth_url.clone(),
640 source,
641 })?;
642
643 {
644 let mut query = url.query_pairs_mut();
645 query
646 .append_pair("client_id", &self.config.client_id)
647 .append_pair("response_type", "code")
648 .append_pair("redirect_uri", &callback_url)
649 .append_pair("state", &state);
650
651 if let Some(scope) = normalized_scope(&self.config.scope) {
652 query.append_pair("scope", scope);
653 }
654
655 query
656 .append_pair("code_challenge_method", "S256")
657 .append_pair("code_challenge", &pkce.challenge);
658 }
659
660 if webbrowser::open(url.as_str()).is_err() {
661 return Err(KontextDevError::BrowserOpenFailed);
662 }
663
664 let payload = callback_payload.await?;
665
666 if let Some(error) = payload.error {
667 let with_details = payload
668 .error_description
669 .map(|description| format!("{error}: {description}"))
670 .unwrap_or(error);
671 return Err(KontextDevError::OAuthCallbackError {
672 error: with_details,
673 });
674 }
675
676 if payload.state.as_deref() != Some(state.as_str()) {
677 return Err(KontextDevError::InvalidOAuthState);
678 }
679
680 let code = payload
681 .code
682 .ok_or(KontextDevError::MissingAuthorizationCode)?;
683
684 let mut body = vec![
685 ("grant_type", "authorization_code".to_string()),
686 ("code", code),
687 ("redirect_uri", callback_url),
688 ("client_id", self.config.client_id.clone()),
689 ("code_verifier", pkce.verifier),
690 ];
691
692 if let Some(client_secret) = &self.config.client_secret {
693 body.push(("client_secret", client_secret.clone()));
694 }
695
696 post_token(&self.http, &token_url, None, &body).await
697 }
698
699 async fn resolve_authorization_url(&self) -> Result<String, KontextDevError> {
700 if let Some(discovered) = self.discover_authorization_endpoint().await {
701 return Ok(discovered);
702 }
703
704 let authorize_url = self.authorize_url()?;
705 if !self.endpoint_is_missing(&authorize_url).await {
706 return Ok(authorize_url);
707 }
708
709 let server_base = resolve_server_base_url(&self.config)?;
710 Ok(format!("{}/oauth2/auth", server_base.trim_end_matches('/')))
711 }
712
713 async fn discover_authorization_endpoint(&self) -> Option<String> {
714 let base = resolve_server_base_url(&self.config).ok()?;
715 let base = base.trim_end_matches('/');
716
717 let candidates = [
718 format!("{base}/.well-known/oauth-authorization-server/mcp"),
719 format!("{base}/.well-known/oauth-authorization-server"),
720 ];
721
722 for url in candidates {
723 let response = match self.http.get(&url).send().await {
724 Ok(response) => response,
725 Err(_) => continue,
726 };
727
728 if !response.status().is_success() {
729 continue;
730 }
731
732 let metadata = match response.json::<OAuthAuthorizationServerMetadata>().await {
733 Ok(metadata) => metadata,
734 Err(_) => continue,
735 };
736
737 let Some(endpoint) = metadata.authorization_endpoint else {
738 continue;
739 };
740
741 if Url::parse(&endpoint).is_ok() {
742 return Some(endpoint);
743 }
744 }
745
746 None
747 }
748
749 async fn endpoint_is_missing(&self, url: &str) -> bool {
750 let probe_client = match reqwest::Client::builder()
751 .redirect(reqwest::redirect::Policy::none())
752 .build()
753 {
754 Ok(client) => client,
755 Err(_) => return false,
756 };
757
758 match probe_client.get(url).send().await {
759 Ok(response) => response.status() == StatusCode::NOT_FOUND,
760 Err(_) => false,
761 }
762 }
763
764 async fn refresh_identity_token(
765 &self,
766 refresh_token: &str,
767 ) -> Result<AccessToken, KontextDevError> {
768 let token_url = self.token_url()?;
769 let mut body = vec![
770 ("grant_type", "refresh_token".to_string()),
771 ("refresh_token", refresh_token.to_string()),
772 ("client_id", self.config.client_id.clone()),
773 ];
774
775 if let Some(client_secret) = &self.config.client_secret {
776 body.push(("client_secret", client_secret.clone()));
777 }
778
779 post_token(&self.http, &token_url, None, &body).await
780 }
781
782 pub async fn exchange_for_resource(
783 &self,
784 subject_token: &str,
785 resource: &str,
786 scope: Option<&str>,
787 ) -> Result<TokenExchangeToken, KontextDevError> {
788 let token_url = self.token_url()?;
789
790 let mut body = vec![
791 ("grant_type", TOKEN_EXCHANGE_GRANT_TYPE.to_string()),
792 ("subject_token", subject_token.to_string()),
793 ("subject_token_type", TOKEN_TYPE_ACCESS_TOKEN.to_string()),
794 ("resource", resource.to_string()),
795 ];
796
797 if let Some(scope) = scope {
798 body.push(("scope", scope.to_string()));
799 }
800
801 let auth_header = self.config.client_secret.as_ref().map(|secret| {
802 let raw = format!("{}:{}", self.config.client_id, secret);
803 format!(
804 "Basic {}",
805 base64::engine::general_purpose::STANDARD.encode(raw.as_bytes())
806 )
807 });
808
809 if self.config.client_secret.is_none() {
810 body.push(("client_id", self.config.client_id.clone()));
811 }
812
813 let response = post_form_with_optional_auth::<TokenExchangeToken>(
814 &self.http,
815 &token_url,
816 auth_header,
817 &body,
818 )
819 .await
820 .map_err(|message| KontextDevError::TokenExchange {
821 resource: resource.to_string(),
822 message,
823 })?;
824
825 if response.access_token.is_empty() {
826 return Err(KontextDevError::EmptyAccessToken);
827 }
828
829 Ok(response)
830 }
831
832 async fn listen_for_callback(
833 &self,
834 ) -> Result<
835 (
836 String,
837 impl std::future::Future<Output = Result<OAuthCallbackPayload, KontextDevError>>,
838 ),
839 KontextDevError,
840 > {
841 let redirect_uri = self.config.redirect_uri.trim().to_string();
842 let parsed = Url::parse(&redirect_uri).map_err(|source| KontextDevError::InvalidUrl {
843 url: redirect_uri.clone(),
844 source,
845 })?;
846
847 if parsed.scheme() != "http" {
848 return Err(KontextDevError::OAuthCallbackError {
849 error: "redirect_uri must use http".to_string(),
850 });
851 }
852
853 if parsed.query().is_some() || parsed.fragment().is_some() {
854 return Err(KontextDevError::OAuthCallbackError {
855 error: "redirect_uri must not include query parameters or fragments".to_string(),
856 });
857 }
858
859 let host = parsed
860 .host_str()
861 .ok_or_else(|| KontextDevError::OAuthCallbackError {
862 error: "redirect_uri host is missing".to_string(),
863 })?;
864
865 let port = parsed
866 .port()
867 .ok_or_else(|| KontextDevError::OAuthCallbackError {
868 error: "redirect_uri must include an explicit port".to_string(),
869 })?;
870
871 let callback_path = parsed.path().to_string();
872 let bind_addr = if host.contains(':') {
873 format!("[{host}]:{port}")
874 } else {
875 format!("{host}:{port}")
876 };
877
878 let server = Arc::new(Server::http(&bind_addr).map_err(|err| {
879 KontextDevError::OAuthCallbackError {
880 error: format!("failed to start callback server at {bind_addr}: {err}"),
881 }
882 })?);
883
884 let (tx, rx) = oneshot::channel();
885 let _join = spawn_callback_server(server.clone(), callback_path, tx);
886 let _guard = CallbackServerGuard { server };
887 let timeout_seconds = self.config.auth_timeout_seconds.max(1);
888
889 let fut = async move {
890 let payload = timeout(Duration::from_secs(timeout_seconds as u64), rx)
891 .await
892 .map_err(|_| KontextDevError::OAuthCallbackTimeout { timeout_seconds })?
893 .map_err(|_| KontextDevError::OAuthCallbackCancelled)?;
894 drop(_guard);
895 Ok(payload)
896 };
897
898 Ok((redirect_uri, fut))
899 }
900
901 fn token_cache_path(&self) -> Option<PathBuf> {
902 if let Some(explicit) = &self.config.token_cache_path {
903 return Some(PathBuf::from(explicit));
904 }
905
906 let home = dirs::home_dir()?;
907 let mut path = home;
908 path.push(".kontext-dev");
909 path.push("tokens");
910
911 let sanitized_client_id: String = self
912 .config
913 .client_id
914 .chars()
915 .map(|ch| {
916 if ch.is_ascii_alphanumeric() || ch == '-' || ch == '_' {
917 ch
918 } else {
919 '_'
920 }
921 })
922 .collect();
923
924 path.push(format!("{sanitized_client_id}.json"));
925 Some(path)
926 }
927
928 fn read_cache(&self) -> Result<Option<TokenCacheFile>, KontextDevError> {
929 let Some(path) = self.token_cache_path() else {
930 return Ok(None);
931 };
932
933 let raw = match fs::read_to_string(&path) {
934 Ok(raw) => raw,
935 Err(err) if err.kind() == std::io::ErrorKind::NotFound => return Ok(None),
936 Err(source) => {
937 return Err(KontextDevError::TokenCacheRead {
938 path: path.display().to_string(),
939 source,
940 });
941 }
942 };
943
944 serde_json::from_str(&raw).map(Some).map_err(|source| {
945 KontextDevError::TokenCacheDeserialize {
946 path: path.display().to_string(),
947 source,
948 }
949 })
950 }
951
952 fn write_cache(
953 &self,
954 identity: &AccessToken,
955 gateway: &TokenExchangeToken,
956 ) -> Result<(), KontextDevError> {
957 let Some(path) = self.token_cache_path() else {
958 return Ok(());
959 };
960
961 if let Some(parent) = path.parent() {
962 fs::create_dir_all(parent).map_err(|source| KontextDevError::TokenCacheWrite {
963 path: parent.display().to_string(),
964 source,
965 })?;
966 }
967
968 let payload = TokenCacheFile {
969 client_id: self.config.client_id.clone(),
970 resource: self.config.resource.clone(),
971 identity: CachedAccessToken::from_access_token(identity)?,
972 gateway: CachedAccessToken::from_token_exchange(gateway)?,
973 };
974
975 let serialized = serde_json::to_string_pretty(&payload)
976 .map_err(|source| KontextDevError::TokenCacheSerialize { source })?;
977
978 fs::write(&path, serialized).map_err(|source| KontextDevError::TokenCacheWrite {
979 path: path.display().to_string(),
980 source,
981 })
982 }
983}
984
985pub async fn request_access_token(
989 config: &KontextDevConfig,
990) -> Result<AccessToken, KontextDevError> {
991 let token_url = resolve_token_url(config)?;
992
993 let mut body = vec![("grant_type", "client_credentials".to_string())];
994
995 if let Some(scope) = normalized_scope(&config.scope) {
996 body.push(("scope", scope.to_string()));
997 }
998
999 let auth_header = if let Some(secret) = &config.client_secret {
1000 let raw = format!("{}:{}", config.client_id, secret);
1001 Some(format!(
1002 "Basic {}",
1003 base64::engine::general_purpose::STANDARD.encode(raw.as_bytes())
1004 ))
1005 } else {
1006 body.push(("client_id", config.client_id.clone()));
1007 None
1008 };
1009
1010 post_token(&Client::new(), &token_url, auth_header.as_deref(), &body).await
1011}
1012
1013async fn post_token(
1014 http: &Client,
1015 token_url: &str,
1016 authorization: Option<&str>,
1017 body: &[(impl AsRef<str>, String)],
1018) -> Result<AccessToken, KontextDevError> {
1019 let body_vec: Vec<(&str, String)> = body.iter().map(|(k, v)| (k.as_ref(), v.clone())).collect();
1020
1021 let response = post_form_with_optional_auth::<AccessToken>(
1022 http,
1023 token_url,
1024 authorization.map(ToString::to_string),
1025 &body_vec,
1026 )
1027 .await
1028 .map_err(|message| KontextDevError::TokenRequest {
1029 token_url: token_url.to_string(),
1030 message,
1031 })?;
1032
1033 Ok(response)
1034}
1035
1036async fn post_form_with_optional_auth<T>(
1037 http: &Client,
1038 url: &str,
1039 authorization: Option<String>,
1040 body: &[(&str, String)],
1041) -> Result<T, String>
1042where
1043 T: DeserializeOwned,
1044{
1045 let mut request = http
1046 .post(url)
1047 .header("Content-Type", "application/x-www-form-urlencoded");
1048
1049 if let Some(header) = authorization {
1050 request = request.header("Authorization", header);
1051 }
1052
1053 let form = body
1054 .iter()
1055 .map(|(k, v)| (k.to_string(), v.to_string()))
1056 .collect::<Vec<(String, String)>>();
1057
1058 let response = request
1059 .form(&form)
1060 .send()
1061 .await
1062 .map_err(|err| err.to_string())?;
1063
1064 if !response.status().is_success() {
1065 return Err(build_error_message(response).await);
1066 }
1067
1068 response.json::<T>().await.map_err(|err| err.to_string())
1069}
1070
1071async fn build_error_message(response: reqwest::Response) -> String {
1072 let status = response.status();
1073 let fallback = format!(
1074 "{} {}",
1075 status.as_u16(),
1076 status.canonical_reason().unwrap_or("")
1077 );
1078
1079 let body = response.text().await.unwrap_or_default();
1080 if body.is_empty() {
1081 return fallback.trim().to_string();
1082 }
1083
1084 if let Ok(parsed) = serde_json::from_str::<OAuthErrorBody>(&body) {
1085 if let Some(description) = parsed.error_description {
1086 return description;
1087 }
1088 if let Some(message) = parsed.message {
1089 return message;
1090 }
1091 if let Some(error) = parsed.error {
1092 return error;
1093 }
1094 }
1095
1096 format!("{fallback}: {body}")
1097}
1098
1099#[cfg(test)]
1100mod tests {
1101 use super::*;
1102
1103 fn config() -> KontextDevConfig {
1104 KontextDevConfig {
1105 server: Some("http://localhost:4000".to_string()),
1106 mcp_url: None,
1107 token_url: None,
1108 client_id: "client_123".to_string(),
1109 client_secret: None,
1110 scope: DEFAULT_SCOPE.to_string(),
1111 server_name: DEFAULT_SERVER_NAME.to_string(),
1112 resource: DEFAULT_RESOURCE.to_string(),
1113 integration_ui_url: Some("https://app.kontext.dev".to_string()),
1114 integration_return_to: None,
1115 open_connect_page_on_login: true,
1116 auth_timeout_seconds: DEFAULT_AUTH_TIMEOUT_SECONDS,
1117 token_cache_path: None,
1118 redirect_uri: "http://localhost:3333/callback".to_string(),
1119 }
1120 }
1121
1122 #[test]
1123 fn create_connect_url_uses_hosted_ui() {
1124 let client = KontextDevClient::new(config());
1125 let url = client
1126 .integration_connect_url("session-123")
1127 .expect("url should be built");
1128 assert_eq!(
1129 url,
1130 "https://app.kontext.dev/oauth/connect?session=session-123"
1131 );
1132 }
1133
1134 #[test]
1135 fn jwt_exp_decode_reads_exp() {
1136 let header = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(r#"{"alg":"none"}"#);
1137 let payload =
1138 base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(r#"{"exp":4070908800}"#);
1139 let token = format!("{header}.{payload}.sig");
1140 assert_eq!(decode_jwt_exp(&token), Some(4_070_908_800));
1141 }
1142
1143 #[test]
1144 fn oauth_metadata_parses_authorization_endpoint() {
1145 let metadata = serde_json::from_str::<OAuthAuthorizationServerMetadata>(
1146 r#"{
1147 "issuer": "https://issuer.example.com",
1148 "authorization_endpoint": "https://issuer.example.com/oauth2/auth"
1149 }"#,
1150 )
1151 .expect("metadata should parse");
1152
1153 assert_eq!(
1154 metadata.authorization_endpoint.as_deref(),
1155 Some("https://issuer.example.com/oauth2/auth")
1156 );
1157 }
1158
1159 #[test]
1160 fn normalized_scope_omits_blank_values() {
1161 assert_eq!(normalized_scope(""), None);
1162 assert_eq!(normalized_scope(" "), None);
1163 }
1164
1165 #[test]
1166 fn normalized_scope_preserves_non_empty_values() {
1167 assert_eq!(normalized_scope("mcp:invoke"), Some("mcp:invoke"));
1168 assert_eq!(
1169 normalized_scope(" mcp:invoke openid "),
1170 Some("mcp:invoke openid")
1171 );
1172 }
1173}