1use base64::Engine;
2use reqwest::Client;
3use reqwest::StatusCode;
4use serde::Deserialize;
5use serde::Serialize;
6use serde::de::DeserializeOwned;
7use sha2::Digest;
8use std::fs;
9use std::path::PathBuf;
10use std::sync::Arc;
11use std::time::Duration;
12use std::time::SystemTime;
13use std::time::UNIX_EPOCH;
14use tiny_http::Response;
15use tiny_http::Server;
16use tokio::sync::oneshot;
17use tokio::time::timeout;
18use url::Url;
19
20pub use kontext_dev_core::AccessToken;
21pub use kontext_dev_core::DEFAULT_AUTH_TIMEOUT_SECONDS;
22pub use kontext_dev_core::DEFAULT_RESOURCE;
23pub use kontext_dev_core::DEFAULT_SCOPE;
24pub use kontext_dev_core::DEFAULT_SERVER_NAME;
25pub use kontext_dev_core::KontextDevConfig;
26pub use kontext_dev_core::KontextDevCoreError;
27pub use kontext_dev_core::TokenExchangeToken;
28pub use kontext_dev_core::build_mcp_url;
29pub use kontext_dev_core::normalize_server_url;
30pub use kontext_dev_core::resolve_authorize_url;
31pub use kontext_dev_core::resolve_connect_session_url;
32pub use kontext_dev_core::resolve_integration_connection_url;
33pub use kontext_dev_core::resolve_integration_oauth_init_url;
34pub use kontext_dev_core::resolve_mcp_url;
35pub use kontext_dev_core::resolve_server_base_url;
36pub use kontext_dev_core::resolve_token_url;
37
38const TOKEN_EXCHANGE_GRANT_TYPE: &str = "urn:ietf:params:oauth:grant-type:token-exchange";
39const TOKEN_TYPE_ACCESS_TOKEN: &str = "urn:ietf:params:oauth:token-type:access_token";
40const CONNECT_CALLBACK_PATH: &str = "/callback";
41const TOKEN_EXPIRY_BUFFER_SECONDS: u64 = 60;
42
43#[derive(Debug, thiserror::Error)]
44pub enum KontextDevError {
45 #[error(transparent)]
46 Core(#[from] kontext_dev_core::KontextDevCoreError),
47 #[error("failed to parse URL `{url}`")]
48 InvalidUrl {
49 url: String,
50 source: url::ParseError,
51 },
52 #[error("failed to open browser for OAuth authorization")]
53 BrowserOpenFailed,
54 #[error("OAuth callback timed out after {timeout_seconds}s")]
55 OAuthCallbackTimeout { timeout_seconds: i64 },
56 #[error("OAuth callback channel was unexpectedly cancelled")]
57 OAuthCallbackCancelled,
58 #[error("OAuth callback is missing the authorization code")]
59 MissingAuthorizationCode,
60 #[error("OAuth callback returned an error: {error}")]
61 OAuthCallbackError { error: String },
62 #[error("OAuth state mismatch")]
63 InvalidOAuthState,
64 #[error("Kontext-Dev token request failed for {token_url}: {message}")]
65 TokenRequest { token_url: String, message: String },
66 #[error("Kontext-Dev token exchange failed for resource `{resource}`: {message}")]
67 TokenExchange { resource: String, message: String },
68 #[error("Kontext-Dev connect session request failed: {message}")]
69 ConnectSession { message: String },
70 #[error("Kontext-Dev integration OAuth init failed: {message}")]
71 IntegrationOAuthInit { message: String },
72 #[error("failed to read token cache at `{path}`: {source}")]
73 TokenCacheRead {
74 path: String,
75 source: std::io::Error,
76 },
77 #[error("failed to write token cache at `{path}`: {source}")]
78 TokenCacheWrite {
79 path: String,
80 source: std::io::Error,
81 },
82 #[error("failed to deserialize token cache at `{path}`: {source}")]
83 TokenCacheDeserialize {
84 path: String,
85 source: serde_json::Error,
86 },
87 #[error("failed to serialize token cache: {source}")]
88 TokenCacheSerialize { source: serde_json::Error },
89 #[error("Kontext-Dev access token is empty")]
90 EmptyAccessToken,
91 #[error("missing integration UI URL; set `integration_ui_url` in config")]
92 MissingIntegrationUiUrl,
93}
94
95#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
96pub struct ConnectSession {
97 #[serde(rename = "sessionId")]
98 pub session_id: String,
99 #[serde(rename = "expiresAt")]
100 pub expires_at: String,
101}
102
103#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
104pub struct IntegrationOAuthInitResponse {
105 #[serde(rename = "authorizationUrl")]
106 pub authorization_url: String,
107 #[serde(default)]
108 pub state: Option<String>,
109}
110
111#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
112pub struct IntegrationConnectionStatus {
113 pub connected: bool,
114 #[serde(default)]
115 pub expires_at: Option<String>,
116}
117
118#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
119pub struct KontextAuthSession {
120 pub identity_token: AccessToken,
121 pub gateway_token: TokenExchangeToken,
122 pub browser_auth_performed: bool,
123}
124
125#[derive(Clone, Debug, Deserialize, Serialize)]
126struct CachedAccessToken {
127 access_token: String,
128 token_type: String,
129 refresh_token: Option<String>,
130 scope: Option<String>,
131 expires_at_unix_ms: Option<u64>,
132}
133
134#[derive(Clone, Debug, Deserialize, Serialize)]
135struct TokenCacheFile {
136 client_id: String,
137 resource: String,
138 identity: CachedAccessToken,
139 gateway: CachedAccessToken,
140}
141
142impl CachedAccessToken {
143 fn from_access_token(token: &AccessToken) -> Result<Self, KontextDevError> {
144 if token.access_token.is_empty() {
145 return Err(KontextDevError::EmptyAccessToken);
146 }
147
148 Ok(Self {
149 access_token: token.access_token.clone(),
150 token_type: token.token_type.clone(),
151 refresh_token: token.refresh_token.clone(),
152 scope: token.scope.clone(),
153 expires_at_unix_ms: compute_expires_at_unix_ms(token.expires_in, &token.access_token),
154 })
155 }
156
157 fn from_token_exchange(token: &TokenExchangeToken) -> Result<Self, KontextDevError> {
158 if token.access_token.is_empty() {
159 return Err(KontextDevError::EmptyAccessToken);
160 }
161
162 Ok(Self {
163 access_token: token.access_token.clone(),
164 token_type: token.token_type.clone(),
165 refresh_token: token.refresh_token.clone(),
166 scope: token.scope.clone(),
167 expires_at_unix_ms: compute_expires_at_unix_ms(token.expires_in, &token.access_token),
168 })
169 }
170
171 fn is_valid(&self) -> bool {
172 match self.expires_at_unix_ms {
173 Some(expires_at) => {
174 let buffer_ms = TOKEN_EXPIRY_BUFFER_SECONDS * 1000;
175 now_unix_ms().saturating_add(buffer_ms) < expires_at
176 }
177 None => true,
178 }
179 }
180
181 fn to_access_token(&self) -> AccessToken {
182 AccessToken {
183 access_token: self.access_token.clone(),
184 token_type: self.token_type.clone(),
185 expires_in: self
186 .expires_at_unix_ms
187 .and_then(unix_ms_to_relative_seconds),
188 refresh_token: self.refresh_token.clone(),
189 scope: self.scope.clone(),
190 }
191 }
192
193 fn to_token_exchange_token(&self) -> TokenExchangeToken {
194 TokenExchangeToken {
195 access_token: self.access_token.clone(),
196 issued_token_type: TOKEN_TYPE_ACCESS_TOKEN.to_string(),
197 token_type: self.token_type.clone(),
198 expires_in: self
199 .expires_at_unix_ms
200 .and_then(unix_ms_to_relative_seconds),
201 scope: self.scope.clone(),
202 refresh_token: self.refresh_token.clone(),
203 }
204 }
205}
206
207fn now_unix_ms() -> u64 {
208 SystemTime::now()
209 .duration_since(UNIX_EPOCH)
210 .unwrap_or_else(|_| Duration::from_secs(0))
211 .as_millis() as u64
212}
213
214fn unix_ms_to_relative_seconds(unix_ms: u64) -> Option<i64> {
215 if unix_ms <= now_unix_ms() {
216 return Some(0);
217 }
218
219 let delta_ms = unix_ms - now_unix_ms();
220 let secs = delta_ms / 1000;
221 i64::try_from(secs).ok()
222}
223
224fn compute_expires_at_unix_ms(expires_in: Option<i64>, access_token: &str) -> Option<u64> {
225 if let Some(expires_in) = expires_in
226 && expires_in > 0
227 {
228 let expires_in_u64 = u64::try_from(expires_in).ok()?;
229 return Some(now_unix_ms().saturating_add(expires_in_u64.saturating_mul(1000)));
230 }
231
232 decode_jwt_exp(access_token).map(|exp| exp.saturating_mul(1000))
233}
234
235fn decode_jwt_exp(token: &str) -> Option<u64> {
236 let mut parts = token.split('.');
237 let _header = parts.next()?;
238 let payload = parts.next()?;
239
240 let decoded = base64::engine::general_purpose::URL_SAFE_NO_PAD
241 .decode(payload)
242 .ok()?;
243 let value: serde_json::Value = serde_json::from_slice(&decoded).ok()?;
244 value.get("exp").and_then(|exp| {
245 exp.as_u64()
246 .or_else(|| exp.as_i64().and_then(|v| u64::try_from(v).ok()))
247 })
248}
249
250#[derive(Clone, Debug, Deserialize)]
251struct OAuthErrorBody {
252 error: Option<String>,
253 error_description: Option<String>,
254 message: Option<String>,
255}
256
257#[derive(Clone, Debug, Deserialize)]
258struct OAuthAuthorizationServerMetadata {
259 authorization_endpoint: Option<String>,
260}
261
262#[derive(Clone, Debug, Deserialize)]
263struct OAuthCallbackPayload {
264 code: Option<String>,
265 state: Option<String>,
266 error: Option<String>,
267 error_description: Option<String>,
268}
269
270#[derive(Clone, Debug)]
271struct PkcePair {
272 verifier: String,
273 challenge: String,
274}
275
276fn generate_pkce_pair() -> PkcePair {
277 let mut raw = [0u8; 64];
278 rand::RngCore::fill_bytes(&mut rand::rng(), &mut raw);
279
280 let verifier = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(raw);
281 let digest = sha2::Sha256::digest(verifier.as_bytes());
282 let challenge = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(digest);
283
284 PkcePair {
285 verifier,
286 challenge,
287 }
288}
289
290fn generate_state() -> String {
291 let mut bytes = [0u8; 16];
292 rand::RngCore::fill_bytes(&mut rand::rng(), &mut bytes);
293 bytes.iter().map(|b| format!("{b:02x}")).collect()
294}
295
296struct CallbackServerGuard {
297 server: Arc<Server>,
298}
299
300impl Drop for CallbackServerGuard {
301 fn drop(&mut self) {
302 self.server.unblock();
303 }
304}
305
306fn parse_callback(url_path: &str) -> Option<OAuthCallbackPayload> {
307 let full = format!("http://localhost{url_path}");
308 let parsed = Url::parse(&full).ok()?;
309
310 if parsed.path() != CONNECT_CALLBACK_PATH {
311 return None;
312 }
313
314 let mut payload = OAuthCallbackPayload {
315 code: None,
316 state: None,
317 error: None,
318 error_description: None,
319 };
320
321 for (key, value) in parsed.query_pairs() {
322 match key.as_ref() {
323 "code" => payload.code = Some(value.to_string()),
324 "state" => payload.state = Some(value.to_string()),
325 "error" => payload.error = Some(value.to_string()),
326 "error_description" => payload.error_description = Some(value.to_string()),
327 _ => {}
328 }
329 }
330
331 Some(payload)
332}
333
334fn spawn_callback_server(
335 server: Arc<Server>,
336 tx: oneshot::Sender<OAuthCallbackPayload>,
337) -> tokio::task::JoinHandle<()> {
338 tokio::task::spawn_blocking(move || {
339 while let Ok(request) = server.recv() {
340 let path = request.url().to_string();
341 if let Some(payload) = parse_callback(&path) {
342 let response = Response::from_string(
343 "Authentication complete. You can return to your terminal.",
344 );
345 let _ = request.respond(response);
346 let _ = tx.send(payload);
347 break;
348 }
349
350 let response = Response::from_string("Invalid callback").with_status_code(400);
351 let _ = request.respond(response);
352 }
353 })
354}
355
356pub struct KontextDevClient {
357 config: KontextDevConfig,
358 http: Client,
359}
360
361impl KontextDevClient {
362 pub fn new(config: KontextDevConfig) -> Self {
363 Self {
364 config,
365 http: Client::new(),
366 }
367 }
368
369 pub fn config(&self) -> &KontextDevConfig {
370 &self.config
371 }
372
373 pub fn mcp_url(&self) -> Result<String, KontextDevError> {
374 resolve_mcp_url(&self.config).map_err(KontextDevError::from)
375 }
376
377 pub fn token_url(&self) -> Result<String, KontextDevError> {
378 resolve_token_url(&self.config).map_err(KontextDevError::from)
379 }
380
381 pub fn authorize_url(&self) -> Result<String, KontextDevError> {
382 resolve_authorize_url(&self.config).map_err(KontextDevError::from)
383 }
384
385 pub fn connect_session_url(&self) -> Result<String, KontextDevError> {
386 resolve_connect_session_url(&self.config).map_err(KontextDevError::from)
387 }
388
389 pub async fn authenticate_mcp(&self) -> Result<KontextAuthSession, KontextDevError> {
392 let resource = self.config.resource.clone();
393
394 if let Some(cache) = self.read_cache()?
395 && cache.client_id == self.config.client_id
396 && cache.resource == resource
397 && cache.gateway.is_valid()
398 && cache.identity.is_valid()
399 {
400 return Ok(KontextAuthSession {
401 identity_token: cache.identity.to_access_token(),
402 gateway_token: cache.gateway.to_token_exchange_token(),
403 browser_auth_performed: false,
404 });
405 }
406
407 if let Some(cache) = self.read_cache()?
408 && cache.client_id == self.config.client_id
409 && cache.resource == resource
410 {
411 let maybe_refreshed = if cache.identity.is_valid() {
412 Some(cache.identity.to_access_token())
413 } else if let Some(refresh_token) = &cache.identity.refresh_token {
414 Some(self.refresh_identity_token(refresh_token).await?)
415 } else {
416 None
417 };
418
419 if let Some(identity_token) = maybe_refreshed {
420 let gateway_token = self
421 .exchange_for_resource(&identity_token.access_token, &resource, None)
422 .await?;
423 self.write_cache(&identity_token, &gateway_token)?;
424 return Ok(KontextAuthSession {
425 identity_token,
426 gateway_token,
427 browser_auth_performed: false,
428 });
429 }
430 }
431
432 let identity_token = self.authorize_with_browser_pkce().await?;
433 let gateway_token = self
434 .exchange_for_resource(&identity_token.access_token, &resource, None)
435 .await?;
436
437 self.write_cache(&identity_token, &gateway_token)?;
438
439 Ok(KontextAuthSession {
440 identity_token,
441 gateway_token,
442 browser_auth_performed: true,
443 })
444 }
445
446 pub async fn create_integration_connect_url(
449 &self,
450 gateway_access_token: &str,
451 ) -> Result<String, KontextDevError> {
452 let session = self.create_connect_session(gateway_access_token).await?;
453 self.integration_connect_url(&session.session_id)
454 }
455
456 pub fn integration_connect_url(&self, session_id: &str) -> Result<String, KontextDevError> {
457 if session_id.trim().is_empty() {
458 return Err(KontextDevError::ConnectSession {
459 message: "connect session id is empty".to_string(),
460 });
461 }
462
463 let base = if let Some(explicit) = &self.config.integration_ui_url {
464 explicit.trim_end_matches('/').to_string()
465 } else {
466 let server = resolve_server_base_url(&self.config)?;
467 if server.contains("api.kontext.dev") {
468 "https://app.kontext.dev".to_string()
469 } else {
470 server
471 }
472 };
473
474 let mut url = Url::parse(&base).map_err(|source| KontextDevError::InvalidUrl {
475 url: base.clone(),
476 source,
477 })?;
478 url.set_path("/oauth/connect");
479 url.query_pairs_mut().append_pair("session", session_id);
480 Ok(url.to_string())
481 }
482
483 pub async fn open_integration_connect_page(
485 &self,
486 gateway_access_token: &str,
487 ) -> Result<String, KontextDevError> {
488 let url = self
489 .create_integration_connect_url(gateway_access_token)
490 .await?;
491
492 if webbrowser::open(&url).is_err() {
493 return Err(KontextDevError::BrowserOpenFailed);
494 }
495
496 Ok(url)
497 }
498
499 pub async fn create_connect_session(
500 &self,
501 gateway_access_token: &str,
502 ) -> Result<ConnectSession, KontextDevError> {
503 let url = self.connect_session_url()?;
504 let response = self
505 .http
506 .post(&url)
507 .header("Authorization", format!("Bearer {gateway_access_token}"))
508 .send()
509 .await
510 .map_err(|err| KontextDevError::ConnectSession {
511 message: err.to_string(),
512 })?;
513
514 if !response.status().is_success() {
515 let message = build_error_message(response).await;
516 return Err(KontextDevError::ConnectSession { message });
517 }
518
519 response
520 .json::<ConnectSession>()
521 .await
522 .map_err(|err| KontextDevError::ConnectSession {
523 message: err.to_string(),
524 })
525 }
526
527 pub async fn initiate_integration_oauth(
528 &self,
529 gateway_access_token: &str,
530 integration_id: &str,
531 return_to: Option<&str>,
532 ) -> Result<IntegrationOAuthInitResponse, KontextDevError> {
533 let url = resolve_integration_oauth_init_url(&self.config, integration_id)?;
534
535 let mut request = self
536 .http
537 .post(&url)
538 .header("Authorization", format!("Bearer {gateway_access_token}"));
539
540 if let Some(return_to) = return_to {
541 request = request.json(&serde_json::json!({ "returnTo": return_to }));
542 }
543
544 let response =
545 request
546 .send()
547 .await
548 .map_err(|err| KontextDevError::IntegrationOAuthInit {
549 message: err.to_string(),
550 })?;
551
552 if !response.status().is_success() {
553 let message = build_error_message(response).await;
554 return Err(KontextDevError::IntegrationOAuthInit { message });
555 }
556
557 response
558 .json::<IntegrationOAuthInitResponse>()
559 .await
560 .map_err(|err| KontextDevError::IntegrationOAuthInit {
561 message: err.to_string(),
562 })
563 }
564
565 pub async fn integration_connection_status(
566 &self,
567 gateway_access_token: &str,
568 integration_id: &str,
569 ) -> Result<IntegrationConnectionStatus, KontextDevError> {
570 let url = resolve_integration_connection_url(&self.config, integration_id)?;
571 let response = self
572 .http
573 .get(url)
574 .header("Authorization", format!("Bearer {gateway_access_token}"))
575 .send()
576 .await
577 .map_err(|err| KontextDevError::IntegrationOAuthInit {
578 message: err.to_string(),
579 })?;
580
581 if !response.status().is_success() {
582 let message = build_error_message(response).await;
583 return Err(KontextDevError::IntegrationOAuthInit { message });
584 }
585
586 response
587 .json::<IntegrationConnectionStatus>()
588 .await
589 .map_err(|err| KontextDevError::IntegrationOAuthInit {
590 message: err.to_string(),
591 })
592 }
593
594 pub async fn wait_for_integration_connection(
595 &self,
596 gateway_access_token: &str,
597 integration_id: &str,
598 timeout_ms: u64,
599 interval_ms: u64,
600 ) -> Result<bool, KontextDevError> {
601 let started = now_unix_ms();
602
603 loop {
604 let status = self
605 .integration_connection_status(gateway_access_token, integration_id)
606 .await?;
607 if status.connected {
608 return Ok(true);
609 }
610
611 if now_unix_ms().saturating_sub(started) >= timeout_ms {
612 return Ok(false);
613 }
614
615 tokio::time::sleep(Duration::from_millis(interval_ms)).await;
616 }
617 }
618
619 async fn authorize_with_browser_pkce(&self) -> Result<AccessToken, KontextDevError> {
620 let auth_url = self.resolve_authorization_url().await?;
621 let token_url = self.token_url()?;
622 let pkce = generate_pkce_pair();
623 let state = generate_state();
624
625 let (callback_url, callback_payload) = self.listen_for_callback().await?;
626
627 let mut url = Url::parse(&auth_url).map_err(|source| KontextDevError::InvalidUrl {
628 url: auth_url.clone(),
629 source,
630 })?;
631
632 url.query_pairs_mut()
633 .append_pair("client_id", &self.config.client_id)
634 .append_pair("response_type", "code")
635 .append_pair("redirect_uri", &callback_url)
636 .append_pair("state", &state)
637 .append_pair("scope", &self.config.scope)
638 .append_pair("code_challenge_method", "S256")
639 .append_pair("code_challenge", &pkce.challenge);
640
641 if webbrowser::open(url.as_str()).is_err() {
642 return Err(KontextDevError::BrowserOpenFailed);
643 }
644
645 let payload = callback_payload.await?;
646
647 if let Some(error) = payload.error {
648 let with_details = payload
649 .error_description
650 .map(|description| format!("{error}: {description}"))
651 .unwrap_or(error);
652 return Err(KontextDevError::OAuthCallbackError {
653 error: with_details,
654 });
655 }
656
657 if payload.state.as_deref() != Some(state.as_str()) {
658 return Err(KontextDevError::InvalidOAuthState);
659 }
660
661 let code = payload
662 .code
663 .ok_or(KontextDevError::MissingAuthorizationCode)?;
664
665 let mut body = vec![
666 ("grant_type", "authorization_code".to_string()),
667 ("code", code),
668 ("redirect_uri", callback_url),
669 ("client_id", self.config.client_id.clone()),
670 ("code_verifier", pkce.verifier),
671 ];
672
673 if let Some(client_secret) = &self.config.client_secret {
674 body.push(("client_secret", client_secret.clone()));
675 }
676
677 post_token(&self.http, &token_url, None, &body).await
678 }
679
680 async fn resolve_authorization_url(&self) -> Result<String, KontextDevError> {
681 if let Some(discovered) = self.discover_authorization_endpoint().await {
682 return Ok(discovered);
683 }
684
685 let authorize_url = self.authorize_url()?;
686 if !self.endpoint_is_missing(&authorize_url).await {
687 return Ok(authorize_url);
688 }
689
690 let server_base = resolve_server_base_url(&self.config)?;
691 Ok(format!("{}/oauth2/auth", server_base.trim_end_matches('/')))
692 }
693
694 async fn discover_authorization_endpoint(&self) -> Option<String> {
695 let base = resolve_server_base_url(&self.config).ok()?;
696 let base = base.trim_end_matches('/');
697
698 let candidates = [
699 format!("{base}/.well-known/oauth-authorization-server/mcp"),
700 format!("{base}/.well-known/oauth-authorization-server"),
701 ];
702
703 for url in candidates {
704 let response = match self.http.get(&url).send().await {
705 Ok(response) => response,
706 Err(_) => continue,
707 };
708
709 if !response.status().is_success() {
710 continue;
711 }
712
713 let metadata = match response.json::<OAuthAuthorizationServerMetadata>().await {
714 Ok(metadata) => metadata,
715 Err(_) => continue,
716 };
717
718 let Some(endpoint) = metadata.authorization_endpoint else {
719 continue;
720 };
721
722 if Url::parse(&endpoint).is_ok() {
723 return Some(endpoint);
724 }
725 }
726
727 None
728 }
729
730 async fn endpoint_is_missing(&self, url: &str) -> bool {
731 let probe_client = match reqwest::Client::builder()
732 .redirect(reqwest::redirect::Policy::none())
733 .build()
734 {
735 Ok(client) => client,
736 Err(_) => return false,
737 };
738
739 match probe_client.get(url).send().await {
740 Ok(response) => response.status() == StatusCode::NOT_FOUND,
741 Err(_) => false,
742 }
743 }
744
745 async fn refresh_identity_token(
746 &self,
747 refresh_token: &str,
748 ) -> Result<AccessToken, KontextDevError> {
749 let token_url = self.token_url()?;
750 let mut body = vec![
751 ("grant_type", "refresh_token".to_string()),
752 ("refresh_token", refresh_token.to_string()),
753 ("client_id", self.config.client_id.clone()),
754 ];
755
756 if let Some(client_secret) = &self.config.client_secret {
757 body.push(("client_secret", client_secret.clone()));
758 }
759
760 post_token(&self.http, &token_url, None, &body).await
761 }
762
763 pub async fn exchange_for_resource(
764 &self,
765 subject_token: &str,
766 resource: &str,
767 scope: Option<&str>,
768 ) -> Result<TokenExchangeToken, KontextDevError> {
769 let token_url = self.token_url()?;
770
771 let mut body = vec![
772 ("grant_type", TOKEN_EXCHANGE_GRANT_TYPE.to_string()),
773 ("subject_token", subject_token.to_string()),
774 ("subject_token_type", TOKEN_TYPE_ACCESS_TOKEN.to_string()),
775 ("resource", resource.to_string()),
776 ];
777
778 if let Some(scope) = scope {
779 body.push(("scope", scope.to_string()));
780 }
781
782 let auth_header = self.config.client_secret.as_ref().map(|secret| {
783 let raw = format!("{}:{}", self.config.client_id, secret);
784 format!(
785 "Basic {}",
786 base64::engine::general_purpose::STANDARD.encode(raw.as_bytes())
787 )
788 });
789
790 if self.config.client_secret.is_none() {
791 body.push(("client_id", self.config.client_id.clone()));
792 }
793
794 let response = post_form_with_optional_auth::<TokenExchangeToken>(
795 &self.http,
796 &token_url,
797 auth_header,
798 &body,
799 )
800 .await
801 .map_err(|message| KontextDevError::TokenExchange {
802 resource: resource.to_string(),
803 message,
804 })?;
805
806 if response.access_token.is_empty() {
807 return Err(KontextDevError::EmptyAccessToken);
808 }
809
810 Ok(response)
811 }
812
813 async fn listen_for_callback(
814 &self,
815 ) -> Result<
816 (
817 String,
818 impl std::future::Future<Output = Result<OAuthCallbackPayload, KontextDevError>>,
819 ),
820 KontextDevError,
821 > {
822 let server = Arc::new(Server::http("127.0.0.1:0").map_err(|err| {
823 KontextDevError::OAuthCallbackError {
824 error: format!("failed to start callback server: {err}"),
825 }
826 })?);
827
828 let callback_url = match server.server_addr() {
829 tiny_http::ListenAddr::IP(std::net::SocketAddr::V4(addr)) => {
830 format!(
831 "http://{}:{}{}",
832 addr.ip(),
833 addr.port(),
834 CONNECT_CALLBACK_PATH
835 )
836 }
837 tiny_http::ListenAddr::IP(std::net::SocketAddr::V6(addr)) => {
838 format!(
839 "http://[{}]:{}{}",
840 addr.ip(),
841 addr.port(),
842 CONNECT_CALLBACK_PATH
843 )
844 }
845 #[cfg(not(target_os = "windows"))]
846 _ => {
847 return Err(KontextDevError::OAuthCallbackError {
848 error: "unable to determine callback address".to_string(),
849 });
850 }
851 };
852
853 let (tx, rx) = oneshot::channel();
854 let _join = spawn_callback_server(server.clone(), tx);
855 let _guard = CallbackServerGuard { server };
856 let timeout_seconds = self.config.auth_timeout_seconds.max(1);
857
858 let fut = async move {
859 let payload = timeout(Duration::from_secs(timeout_seconds as u64), rx)
860 .await
861 .map_err(|_| KontextDevError::OAuthCallbackTimeout { timeout_seconds })?
862 .map_err(|_| KontextDevError::OAuthCallbackCancelled)?;
863 drop(_guard);
864 Ok(payload)
865 };
866
867 Ok((callback_url, fut))
868 }
869
870 fn token_cache_path(&self) -> Option<PathBuf> {
871 if let Some(explicit) = &self.config.token_cache_path {
872 return Some(PathBuf::from(explicit));
873 }
874
875 let home = dirs::home_dir()?;
876 let mut path = home;
877 path.push(".kontext-dev");
878 path.push("tokens");
879
880 let sanitized_client_id: String = self
881 .config
882 .client_id
883 .chars()
884 .map(|ch| {
885 if ch.is_ascii_alphanumeric() || ch == '-' || ch == '_' {
886 ch
887 } else {
888 '_'
889 }
890 })
891 .collect();
892
893 path.push(format!("{sanitized_client_id}.json"));
894 Some(path)
895 }
896
897 fn read_cache(&self) -> Result<Option<TokenCacheFile>, KontextDevError> {
898 let Some(path) = self.token_cache_path() else {
899 return Ok(None);
900 };
901
902 let raw = match fs::read_to_string(&path) {
903 Ok(raw) => raw,
904 Err(err) if err.kind() == std::io::ErrorKind::NotFound => return Ok(None),
905 Err(source) => {
906 return Err(KontextDevError::TokenCacheRead {
907 path: path.display().to_string(),
908 source,
909 });
910 }
911 };
912
913 serde_json::from_str(&raw).map(Some).map_err(|source| {
914 KontextDevError::TokenCacheDeserialize {
915 path: path.display().to_string(),
916 source,
917 }
918 })
919 }
920
921 fn write_cache(
922 &self,
923 identity: &AccessToken,
924 gateway: &TokenExchangeToken,
925 ) -> Result<(), KontextDevError> {
926 let Some(path) = self.token_cache_path() else {
927 return Ok(());
928 };
929
930 if let Some(parent) = path.parent() {
931 fs::create_dir_all(parent).map_err(|source| KontextDevError::TokenCacheWrite {
932 path: parent.display().to_string(),
933 source,
934 })?;
935 }
936
937 let payload = TokenCacheFile {
938 client_id: self.config.client_id.clone(),
939 resource: self.config.resource.clone(),
940 identity: CachedAccessToken::from_access_token(identity)?,
941 gateway: CachedAccessToken::from_token_exchange(gateway)?,
942 };
943
944 let serialized = serde_json::to_string_pretty(&payload)
945 .map_err(|source| KontextDevError::TokenCacheSerialize { source })?;
946
947 fs::write(&path, serialized).map_err(|source| KontextDevError::TokenCacheWrite {
948 path: path.display().to_string(),
949 source,
950 })
951 }
952}
953
954pub async fn request_access_token(
958 config: &KontextDevConfig,
959) -> Result<AccessToken, KontextDevError> {
960 let token_url = resolve_token_url(config)?;
961
962 let mut body = vec![
963 ("grant_type", "client_credentials".to_string()),
964 ("scope", config.scope.clone()),
965 ];
966
967 let auth_header = if let Some(secret) = &config.client_secret {
968 let raw = format!("{}:{}", config.client_id, secret);
969 Some(format!(
970 "Basic {}",
971 base64::engine::general_purpose::STANDARD.encode(raw.as_bytes())
972 ))
973 } else {
974 body.push(("client_id", config.client_id.clone()));
975 None
976 };
977
978 post_token(&Client::new(), &token_url, auth_header.as_deref(), &body).await
979}
980
981async fn post_token(
982 http: &Client,
983 token_url: &str,
984 authorization: Option<&str>,
985 body: &[(impl AsRef<str>, String)],
986) -> Result<AccessToken, KontextDevError> {
987 let body_vec: Vec<(&str, String)> = body.iter().map(|(k, v)| (k.as_ref(), v.clone())).collect();
988
989 let response = post_form_with_optional_auth::<AccessToken>(
990 http,
991 token_url,
992 authorization.map(ToString::to_string),
993 &body_vec,
994 )
995 .await
996 .map_err(|message| KontextDevError::TokenRequest {
997 token_url: token_url.to_string(),
998 message,
999 })?;
1000
1001 Ok(response)
1002}
1003
1004async fn post_form_with_optional_auth<T>(
1005 http: &Client,
1006 url: &str,
1007 authorization: Option<String>,
1008 body: &[(&str, String)],
1009) -> Result<T, String>
1010where
1011 T: DeserializeOwned,
1012{
1013 let mut request = http
1014 .post(url)
1015 .header("Content-Type", "application/x-www-form-urlencoded");
1016
1017 if let Some(header) = authorization {
1018 request = request.header("Authorization", header);
1019 }
1020
1021 let form = body
1022 .iter()
1023 .map(|(k, v)| (k.to_string(), v.to_string()))
1024 .collect::<Vec<(String, String)>>();
1025
1026 let response = request
1027 .form(&form)
1028 .send()
1029 .await
1030 .map_err(|err| err.to_string())?;
1031
1032 if !response.status().is_success() {
1033 return Err(build_error_message(response).await);
1034 }
1035
1036 response.json::<T>().await.map_err(|err| err.to_string())
1037}
1038
1039async fn build_error_message(response: reqwest::Response) -> String {
1040 let status = response.status();
1041 let fallback = format!(
1042 "{} {}",
1043 status.as_u16(),
1044 status.canonical_reason().unwrap_or("")
1045 );
1046
1047 let body = response.text().await.unwrap_or_default();
1048 if body.is_empty() {
1049 return fallback.trim().to_string();
1050 }
1051
1052 if let Ok(parsed) = serde_json::from_str::<OAuthErrorBody>(&body) {
1053 if let Some(description) = parsed.error_description {
1054 return description;
1055 }
1056 if let Some(message) = parsed.message {
1057 return message;
1058 }
1059 if let Some(error) = parsed.error {
1060 return error;
1061 }
1062 }
1063
1064 format!("{fallback}: {body}")
1065}
1066
1067#[cfg(test)]
1068mod tests {
1069 use super::*;
1070
1071 fn config() -> KontextDevConfig {
1072 KontextDevConfig {
1073 server: Some("http://localhost:4000".to_string()),
1074 mcp_url: None,
1075 token_url: None,
1076 client_id: "client_123".to_string(),
1077 client_secret: None,
1078 scope: DEFAULT_SCOPE.to_string(),
1079 server_name: DEFAULT_SERVER_NAME.to_string(),
1080 resource: DEFAULT_RESOURCE.to_string(),
1081 integration_ui_url: Some("https://app.kontext.dev".to_string()),
1082 integration_return_to: None,
1083 open_connect_page_on_login: true,
1084 auth_timeout_seconds: DEFAULT_AUTH_TIMEOUT_SECONDS,
1085 token_cache_path: None,
1086 }
1087 }
1088
1089 #[test]
1090 fn create_connect_url_uses_hosted_ui() {
1091 let client = KontextDevClient::new(config());
1092 let url = client
1093 .integration_connect_url("session-123")
1094 .expect("url should be built");
1095 assert_eq!(
1096 url,
1097 "https://app.kontext.dev/oauth/connect?session=session-123"
1098 );
1099 }
1100
1101 #[test]
1102 fn jwt_exp_decode_reads_exp() {
1103 let header = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(r#"{"alg":"none"}"#);
1104 let payload =
1105 base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(r#"{"exp":4070908800}"#);
1106 let token = format!("{header}.{payload}.sig");
1107 assert_eq!(decode_jwt_exp(&token), Some(4_070_908_800));
1108 }
1109
1110 #[test]
1111 fn oauth_metadata_parses_authorization_endpoint() {
1112 let metadata = serde_json::from_str::<OAuthAuthorizationServerMetadata>(
1113 r#"{
1114 "issuer": "https://issuer.example.com",
1115 "authorization_endpoint": "https://issuer.example.com/oauth2/auth"
1116 }"#,
1117 )
1118 .expect("metadata should parse");
1119
1120 assert_eq!(
1121 metadata.authorization_endpoint.as_deref(),
1122 Some("https://issuer.example.com/oauth2/auth")
1123 );
1124 }
1125}