1use base64::Engine;
2use reqwest::Client;
3use serde::Deserialize;
4use serde::Serialize;
5use serde::de::DeserializeOwned;
6use sha2::Digest;
7use std::fs;
8use std::path::PathBuf;
9use std::sync::Arc;
10use std::time::Duration;
11use std::time::SystemTime;
12use std::time::UNIX_EPOCH;
13use tiny_http::Response;
14use tiny_http::Server;
15use tokio::sync::oneshot;
16use tokio::time::timeout;
17use url::Url;
18
19pub use kontext_dev_core::AccessToken;
20pub use kontext_dev_core::DEFAULT_AUTH_TIMEOUT_SECONDS;
21pub use kontext_dev_core::DEFAULT_RESOURCE;
22pub use kontext_dev_core::DEFAULT_SCOPE;
23pub use kontext_dev_core::DEFAULT_SERVER_NAME;
24pub use kontext_dev_core::KontextDevConfig;
25pub use kontext_dev_core::KontextDevCoreError;
26pub use kontext_dev_core::TokenExchangeToken;
27pub use kontext_dev_core::build_mcp_url;
28pub use kontext_dev_core::normalize_server_url;
29pub use kontext_dev_core::resolve_authorize_url;
30pub use kontext_dev_core::resolve_connect_session_url;
31pub use kontext_dev_core::resolve_integration_connection_url;
32pub use kontext_dev_core::resolve_integration_oauth_init_url;
33pub use kontext_dev_core::resolve_mcp_url;
34pub use kontext_dev_core::resolve_server_base_url;
35pub use kontext_dev_core::resolve_token_url;
36
37const TOKEN_EXCHANGE_GRANT_TYPE: &str = "urn:ietf:params:oauth:grant-type:token-exchange";
38const TOKEN_TYPE_ACCESS_TOKEN: &str = "urn:ietf:params:oauth:token-type:access_token";
39const CONNECT_CALLBACK_PATH: &str = "/callback";
40const TOKEN_EXPIRY_BUFFER_SECONDS: u64 = 60;
41
42#[derive(Debug, thiserror::Error)]
43pub enum KontextDevError {
44 #[error(transparent)]
45 Core(#[from] kontext_dev_core::KontextDevCoreError),
46 #[error("failed to parse URL `{url}`")]
47 InvalidUrl {
48 url: String,
49 source: url::ParseError,
50 },
51 #[error("failed to open browser for OAuth authorization")]
52 BrowserOpenFailed,
53 #[error("OAuth callback timed out after {timeout_seconds}s")]
54 OAuthCallbackTimeout { timeout_seconds: i64 },
55 #[error("OAuth callback channel was unexpectedly cancelled")]
56 OAuthCallbackCancelled,
57 #[error("OAuth callback is missing the authorization code")]
58 MissingAuthorizationCode,
59 #[error("OAuth callback returned an error: {error}")]
60 OAuthCallbackError { error: String },
61 #[error("OAuth state mismatch")]
62 InvalidOAuthState,
63 #[error("Kontext-Dev token request failed for {token_url}: {message}")]
64 TokenRequest { token_url: String, message: String },
65 #[error("Kontext-Dev token exchange failed for resource `{resource}`: {message}")]
66 TokenExchange { resource: String, message: String },
67 #[error("Kontext-Dev connect session request failed: {message}")]
68 ConnectSession { message: String },
69 #[error("Kontext-Dev integration OAuth init failed: {message}")]
70 IntegrationOAuthInit { message: String },
71 #[error("failed to read token cache at `{path}`: {source}")]
72 TokenCacheRead {
73 path: String,
74 source: std::io::Error,
75 },
76 #[error("failed to write token cache at `{path}`: {source}")]
77 TokenCacheWrite {
78 path: String,
79 source: std::io::Error,
80 },
81 #[error("failed to deserialize token cache at `{path}`: {source}")]
82 TokenCacheDeserialize {
83 path: String,
84 source: serde_json::Error,
85 },
86 #[error("failed to serialize token cache: {source}")]
87 TokenCacheSerialize { source: serde_json::Error },
88 #[error("Kontext-Dev access token is empty")]
89 EmptyAccessToken,
90 #[error("missing integration UI URL; set `integration_ui_url` in config")]
91 MissingIntegrationUiUrl,
92}
93
94#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
95pub struct ConnectSession {
96 #[serde(rename = "sessionId")]
97 pub session_id: String,
98 #[serde(rename = "expiresAt")]
99 pub expires_at: String,
100}
101
102#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
103pub struct IntegrationOAuthInitResponse {
104 #[serde(rename = "authorizationUrl")]
105 pub authorization_url: String,
106 #[serde(default)]
107 pub state: Option<String>,
108}
109
110#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
111pub struct IntegrationConnectionStatus {
112 pub connected: bool,
113 #[serde(default)]
114 pub expires_at: Option<String>,
115}
116
117#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
118pub struct KontextAuthSession {
119 pub identity_token: AccessToken,
120 pub gateway_token: TokenExchangeToken,
121 pub browser_auth_performed: bool,
122}
123
124#[derive(Clone, Debug, Deserialize, Serialize)]
125struct CachedAccessToken {
126 access_token: String,
127 token_type: String,
128 refresh_token: Option<String>,
129 scope: Option<String>,
130 expires_at_unix_ms: Option<u64>,
131}
132
133#[derive(Clone, Debug, Deserialize, Serialize)]
134struct TokenCacheFile {
135 client_id: String,
136 resource: String,
137 identity: CachedAccessToken,
138 gateway: CachedAccessToken,
139}
140
141impl CachedAccessToken {
142 fn from_access_token(token: &AccessToken) -> Result<Self, KontextDevError> {
143 if token.access_token.is_empty() {
144 return Err(KontextDevError::EmptyAccessToken);
145 }
146
147 Ok(Self {
148 access_token: token.access_token.clone(),
149 token_type: token.token_type.clone(),
150 refresh_token: token.refresh_token.clone(),
151 scope: token.scope.clone(),
152 expires_at_unix_ms: compute_expires_at_unix_ms(token.expires_in, &token.access_token),
153 })
154 }
155
156 fn from_token_exchange(token: &TokenExchangeToken) -> Result<Self, KontextDevError> {
157 if token.access_token.is_empty() {
158 return Err(KontextDevError::EmptyAccessToken);
159 }
160
161 Ok(Self {
162 access_token: token.access_token.clone(),
163 token_type: token.token_type.clone(),
164 refresh_token: token.refresh_token.clone(),
165 scope: token.scope.clone(),
166 expires_at_unix_ms: compute_expires_at_unix_ms(token.expires_in, &token.access_token),
167 })
168 }
169
170 fn is_valid(&self) -> bool {
171 match self.expires_at_unix_ms {
172 Some(expires_at) => {
173 let buffer_ms = TOKEN_EXPIRY_BUFFER_SECONDS * 1000;
174 now_unix_ms().saturating_add(buffer_ms) < expires_at
175 }
176 None => true,
177 }
178 }
179
180 fn to_access_token(&self) -> AccessToken {
181 AccessToken {
182 access_token: self.access_token.clone(),
183 token_type: self.token_type.clone(),
184 expires_in: self
185 .expires_at_unix_ms
186 .and_then(unix_ms_to_relative_seconds),
187 refresh_token: self.refresh_token.clone(),
188 scope: self.scope.clone(),
189 }
190 }
191
192 fn to_token_exchange_token(&self) -> TokenExchangeToken {
193 TokenExchangeToken {
194 access_token: self.access_token.clone(),
195 issued_token_type: TOKEN_TYPE_ACCESS_TOKEN.to_string(),
196 token_type: self.token_type.clone(),
197 expires_in: self
198 .expires_at_unix_ms
199 .and_then(unix_ms_to_relative_seconds),
200 scope: self.scope.clone(),
201 refresh_token: self.refresh_token.clone(),
202 }
203 }
204}
205
206fn now_unix_ms() -> u64 {
207 SystemTime::now()
208 .duration_since(UNIX_EPOCH)
209 .unwrap_or_else(|_| Duration::from_secs(0))
210 .as_millis() as u64
211}
212
213fn unix_ms_to_relative_seconds(unix_ms: u64) -> Option<i64> {
214 if unix_ms <= now_unix_ms() {
215 return Some(0);
216 }
217
218 let delta_ms = unix_ms - now_unix_ms();
219 let secs = delta_ms / 1000;
220 i64::try_from(secs).ok()
221}
222
223fn compute_expires_at_unix_ms(expires_in: Option<i64>, access_token: &str) -> Option<u64> {
224 if let Some(expires_in) = expires_in
225 && expires_in > 0
226 {
227 let expires_in_u64 = u64::try_from(expires_in).ok()?;
228 return Some(now_unix_ms().saturating_add(expires_in_u64.saturating_mul(1000)));
229 }
230
231 decode_jwt_exp(access_token).map(|exp| exp.saturating_mul(1000))
232}
233
234fn decode_jwt_exp(token: &str) -> Option<u64> {
235 let mut parts = token.split('.');
236 let _header = parts.next()?;
237 let payload = parts.next()?;
238
239 let decoded = base64::engine::general_purpose::URL_SAFE_NO_PAD
240 .decode(payload)
241 .ok()?;
242 let value: serde_json::Value = serde_json::from_slice(&decoded).ok()?;
243 value.get("exp").and_then(|exp| {
244 exp.as_u64()
245 .or_else(|| exp.as_i64().and_then(|v| u64::try_from(v).ok()))
246 })
247}
248
249#[derive(Clone, Debug, Deserialize)]
250struct OAuthErrorBody {
251 error: Option<String>,
252 error_description: Option<String>,
253 message: Option<String>,
254}
255
256#[derive(Clone, Debug, Deserialize)]
257struct OAuthCallbackPayload {
258 code: Option<String>,
259 state: Option<String>,
260 error: Option<String>,
261 error_description: Option<String>,
262}
263
264#[derive(Clone, Debug)]
265struct PkcePair {
266 verifier: String,
267 challenge: String,
268}
269
270fn generate_pkce_pair() -> PkcePair {
271 let mut raw = [0u8; 64];
272 rand::RngCore::fill_bytes(&mut rand::rng(), &mut raw);
273
274 let verifier = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(raw);
275 let digest = sha2::Sha256::digest(verifier.as_bytes());
276 let challenge = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(digest);
277
278 PkcePair {
279 verifier,
280 challenge,
281 }
282}
283
284fn generate_state() -> String {
285 let mut bytes = [0u8; 16];
286 rand::RngCore::fill_bytes(&mut rand::rng(), &mut bytes);
287 bytes.iter().map(|b| format!("{b:02x}")).collect()
288}
289
290struct CallbackServerGuard {
291 server: Arc<Server>,
292}
293
294impl Drop for CallbackServerGuard {
295 fn drop(&mut self) {
296 self.server.unblock();
297 }
298}
299
300fn parse_callback(url_path: &str) -> Option<OAuthCallbackPayload> {
301 let full = format!("http://localhost{url_path}");
302 let parsed = Url::parse(&full).ok()?;
303
304 if parsed.path() != CONNECT_CALLBACK_PATH {
305 return None;
306 }
307
308 let mut payload = OAuthCallbackPayload {
309 code: None,
310 state: None,
311 error: None,
312 error_description: None,
313 };
314
315 for (key, value) in parsed.query_pairs() {
316 match key.as_ref() {
317 "code" => payload.code = Some(value.to_string()),
318 "state" => payload.state = Some(value.to_string()),
319 "error" => payload.error = Some(value.to_string()),
320 "error_description" => payload.error_description = Some(value.to_string()),
321 _ => {}
322 }
323 }
324
325 Some(payload)
326}
327
328fn spawn_callback_server(
329 server: Arc<Server>,
330 tx: oneshot::Sender<OAuthCallbackPayload>,
331) -> tokio::task::JoinHandle<()> {
332 tokio::task::spawn_blocking(move || {
333 while let Ok(request) = server.recv() {
334 let path = request.url().to_string();
335 if let Some(payload) = parse_callback(&path) {
336 let response = Response::from_string(
337 "Authentication complete. You can return to your terminal.",
338 );
339 let _ = request.respond(response);
340 let _ = tx.send(payload);
341 break;
342 }
343
344 let response = Response::from_string("Invalid callback").with_status_code(400);
345 let _ = request.respond(response);
346 }
347 })
348}
349
350pub struct KontextDevClient {
351 config: KontextDevConfig,
352 http: Client,
353}
354
355impl KontextDevClient {
356 pub fn new(config: KontextDevConfig) -> Self {
357 Self {
358 config,
359 http: Client::new(),
360 }
361 }
362
363 pub fn config(&self) -> &KontextDevConfig {
364 &self.config
365 }
366
367 pub fn mcp_url(&self) -> Result<String, KontextDevError> {
368 resolve_mcp_url(&self.config).map_err(KontextDevError::from)
369 }
370
371 pub fn token_url(&self) -> Result<String, KontextDevError> {
372 resolve_token_url(&self.config).map_err(KontextDevError::from)
373 }
374
375 pub fn authorize_url(&self) -> Result<String, KontextDevError> {
376 resolve_authorize_url(&self.config).map_err(KontextDevError::from)
377 }
378
379 pub fn connect_session_url(&self) -> Result<String, KontextDevError> {
380 resolve_connect_session_url(&self.config).map_err(KontextDevError::from)
381 }
382
383 pub async fn authenticate_mcp(&self) -> Result<KontextAuthSession, KontextDevError> {
386 let resource = self.config.resource.clone();
387
388 if let Some(cache) = self.read_cache()?
389 && cache.client_id == self.config.client_id
390 && cache.resource == resource
391 && cache.gateway.is_valid()
392 && cache.identity.is_valid()
393 {
394 return Ok(KontextAuthSession {
395 identity_token: cache.identity.to_access_token(),
396 gateway_token: cache.gateway.to_token_exchange_token(),
397 browser_auth_performed: false,
398 });
399 }
400
401 if let Some(cache) = self.read_cache()?
402 && cache.client_id == self.config.client_id
403 && cache.resource == resource
404 {
405 let maybe_refreshed = if cache.identity.is_valid() {
406 Some(cache.identity.to_access_token())
407 } else if let Some(refresh_token) = &cache.identity.refresh_token {
408 Some(self.refresh_identity_token(refresh_token).await?)
409 } else {
410 None
411 };
412
413 if let Some(identity_token) = maybe_refreshed {
414 let gateway_token = self
415 .exchange_for_resource(&identity_token.access_token, &resource, None)
416 .await?;
417 self.write_cache(&identity_token, &gateway_token)?;
418 return Ok(KontextAuthSession {
419 identity_token,
420 gateway_token,
421 browser_auth_performed: false,
422 });
423 }
424 }
425
426 let identity_token = self.authorize_with_browser_pkce().await?;
427 let gateway_token = self
428 .exchange_for_resource(&identity_token.access_token, &resource, None)
429 .await?;
430
431 self.write_cache(&identity_token, &gateway_token)?;
432
433 Ok(KontextAuthSession {
434 identity_token,
435 gateway_token,
436 browser_auth_performed: true,
437 })
438 }
439
440 pub async fn create_integration_connect_url(
443 &self,
444 gateway_access_token: &str,
445 ) -> Result<String, KontextDevError> {
446 let session = self.create_connect_session(gateway_access_token).await?;
447 self.integration_connect_url(&session.session_id)
448 }
449
450 pub fn integration_connect_url(&self, session_id: &str) -> Result<String, KontextDevError> {
451 if session_id.trim().is_empty() {
452 return Err(KontextDevError::ConnectSession {
453 message: "connect session id is empty".to_string(),
454 });
455 }
456
457 let base = if let Some(explicit) = &self.config.integration_ui_url {
458 explicit.trim_end_matches('/').to_string()
459 } else {
460 let server = resolve_server_base_url(&self.config)?;
461 if server.contains("api.kontext.dev") {
462 "https://app.kontext.dev".to_string()
463 } else {
464 server
465 }
466 };
467
468 let mut url = Url::parse(&base).map_err(|source| KontextDevError::InvalidUrl {
469 url: base.clone(),
470 source,
471 })?;
472 url.set_path("/oauth/connect");
473 url.query_pairs_mut().append_pair("session", session_id);
474 Ok(url.to_string())
475 }
476
477 pub async fn open_integration_connect_page(
479 &self,
480 gateway_access_token: &str,
481 ) -> Result<String, KontextDevError> {
482 let url = self
483 .create_integration_connect_url(gateway_access_token)
484 .await?;
485
486 if webbrowser::open(&url).is_err() {
487 return Err(KontextDevError::BrowserOpenFailed);
488 }
489
490 Ok(url)
491 }
492
493 pub async fn create_connect_session(
494 &self,
495 gateway_access_token: &str,
496 ) -> Result<ConnectSession, KontextDevError> {
497 let url = self.connect_session_url()?;
498 let response = self
499 .http
500 .post(&url)
501 .header("Authorization", format!("Bearer {gateway_access_token}"))
502 .send()
503 .await
504 .map_err(|err| KontextDevError::ConnectSession {
505 message: err.to_string(),
506 })?;
507
508 if !response.status().is_success() {
509 let message = build_error_message(response).await;
510 return Err(KontextDevError::ConnectSession { message });
511 }
512
513 response
514 .json::<ConnectSession>()
515 .await
516 .map_err(|err| KontextDevError::ConnectSession {
517 message: err.to_string(),
518 })
519 }
520
521 pub async fn initiate_integration_oauth(
522 &self,
523 gateway_access_token: &str,
524 integration_id: &str,
525 return_to: Option<&str>,
526 ) -> Result<IntegrationOAuthInitResponse, KontextDevError> {
527 let url = resolve_integration_oauth_init_url(&self.config, integration_id)?;
528
529 let mut request = self
530 .http
531 .post(&url)
532 .header("Authorization", format!("Bearer {gateway_access_token}"));
533
534 if let Some(return_to) = return_to {
535 request = request.json(&serde_json::json!({ "returnTo": return_to }));
536 }
537
538 let response =
539 request
540 .send()
541 .await
542 .map_err(|err| KontextDevError::IntegrationOAuthInit {
543 message: err.to_string(),
544 })?;
545
546 if !response.status().is_success() {
547 let message = build_error_message(response).await;
548 return Err(KontextDevError::IntegrationOAuthInit { message });
549 }
550
551 response
552 .json::<IntegrationOAuthInitResponse>()
553 .await
554 .map_err(|err| KontextDevError::IntegrationOAuthInit {
555 message: err.to_string(),
556 })
557 }
558
559 pub async fn integration_connection_status(
560 &self,
561 gateway_access_token: &str,
562 integration_id: &str,
563 ) -> Result<IntegrationConnectionStatus, KontextDevError> {
564 let url = resolve_integration_connection_url(&self.config, integration_id)?;
565 let response = self
566 .http
567 .get(url)
568 .header("Authorization", format!("Bearer {gateway_access_token}"))
569 .send()
570 .await
571 .map_err(|err| KontextDevError::IntegrationOAuthInit {
572 message: err.to_string(),
573 })?;
574
575 if !response.status().is_success() {
576 let message = build_error_message(response).await;
577 return Err(KontextDevError::IntegrationOAuthInit { message });
578 }
579
580 response
581 .json::<IntegrationConnectionStatus>()
582 .await
583 .map_err(|err| KontextDevError::IntegrationOAuthInit {
584 message: err.to_string(),
585 })
586 }
587
588 pub async fn wait_for_integration_connection(
589 &self,
590 gateway_access_token: &str,
591 integration_id: &str,
592 timeout_ms: u64,
593 interval_ms: u64,
594 ) -> Result<bool, KontextDevError> {
595 let started = now_unix_ms();
596
597 loop {
598 let status = self
599 .integration_connection_status(gateway_access_token, integration_id)
600 .await?;
601 if status.connected {
602 return Ok(true);
603 }
604
605 if now_unix_ms().saturating_sub(started) >= timeout_ms {
606 return Ok(false);
607 }
608
609 tokio::time::sleep(Duration::from_millis(interval_ms)).await;
610 }
611 }
612
613 async fn authorize_with_browser_pkce(&self) -> Result<AccessToken, KontextDevError> {
614 let auth_url = self.authorize_url()?;
615 let token_url = self.token_url()?;
616 let pkce = generate_pkce_pair();
617 let state = generate_state();
618
619 let (callback_url, callback_payload) = self.listen_for_callback().await?;
620
621 let mut url = Url::parse(&auth_url).map_err(|source| KontextDevError::InvalidUrl {
622 url: auth_url.clone(),
623 source,
624 })?;
625
626 url.query_pairs_mut()
627 .append_pair("client_id", &self.config.client_id)
628 .append_pair("response_type", "code")
629 .append_pair("redirect_uri", &callback_url)
630 .append_pair("state", &state)
631 .append_pair("scope", &self.config.scope)
632 .append_pair("code_challenge_method", "S256")
633 .append_pair("code_challenge", &pkce.challenge);
634
635 println!(
636 "Authorize `{}` by opening this URL in your browser:\n{}\n",
637 self.config.server_name, url
638 );
639
640 if webbrowser::open(url.as_str()).is_err() {
641 return Err(KontextDevError::BrowserOpenFailed);
642 }
643
644 let payload = callback_payload.await?;
645
646 if let Some(error) = payload.error {
647 let with_details = payload
648 .error_description
649 .map(|description| format!("{error}: {description}"))
650 .unwrap_or(error);
651 return Err(KontextDevError::OAuthCallbackError {
652 error: with_details,
653 });
654 }
655
656 if payload.state.as_deref() != Some(state.as_str()) {
657 return Err(KontextDevError::InvalidOAuthState);
658 }
659
660 let code = payload
661 .code
662 .ok_or(KontextDevError::MissingAuthorizationCode)?;
663
664 let mut body = vec![
665 ("grant_type", "authorization_code".to_string()),
666 ("code", code),
667 ("redirect_uri", callback_url),
668 ("client_id", self.config.client_id.clone()),
669 ("code_verifier", pkce.verifier),
670 ];
671
672 if let Some(client_secret) = &self.config.client_secret {
673 body.push(("client_secret", client_secret.clone()));
674 }
675
676 post_token(&self.http, &token_url, None, &body).await
677 }
678
679 async fn refresh_identity_token(
680 &self,
681 refresh_token: &str,
682 ) -> Result<AccessToken, KontextDevError> {
683 let token_url = self.token_url()?;
684 let mut body = vec![
685 ("grant_type", "refresh_token".to_string()),
686 ("refresh_token", refresh_token.to_string()),
687 ("client_id", self.config.client_id.clone()),
688 ];
689
690 if let Some(client_secret) = &self.config.client_secret {
691 body.push(("client_secret", client_secret.clone()));
692 }
693
694 post_token(&self.http, &token_url, None, &body).await
695 }
696
697 pub async fn exchange_for_resource(
698 &self,
699 subject_token: &str,
700 resource: &str,
701 scope: Option<&str>,
702 ) -> Result<TokenExchangeToken, KontextDevError> {
703 let token_url = self.token_url()?;
704
705 let mut body = vec![
706 ("grant_type", TOKEN_EXCHANGE_GRANT_TYPE.to_string()),
707 ("subject_token", subject_token.to_string()),
708 ("subject_token_type", TOKEN_TYPE_ACCESS_TOKEN.to_string()),
709 ("resource", resource.to_string()),
710 ];
711
712 if let Some(scope) = scope {
713 body.push(("scope", scope.to_string()));
714 }
715
716 let auth_header = self.config.client_secret.as_ref().map(|secret| {
717 let raw = format!("{}:{}", self.config.client_id, secret);
718 format!(
719 "Basic {}",
720 base64::engine::general_purpose::STANDARD.encode(raw.as_bytes())
721 )
722 });
723
724 if self.config.client_secret.is_none() {
725 body.push(("client_id", self.config.client_id.clone()));
726 }
727
728 let response = post_form_with_optional_auth::<TokenExchangeToken>(
729 &self.http,
730 &token_url,
731 auth_header,
732 &body,
733 )
734 .await
735 .map_err(|message| KontextDevError::TokenExchange {
736 resource: resource.to_string(),
737 message,
738 })?;
739
740 if response.access_token.is_empty() {
741 return Err(KontextDevError::EmptyAccessToken);
742 }
743
744 Ok(response)
745 }
746
747 async fn listen_for_callback(
748 &self,
749 ) -> Result<
750 (
751 String,
752 impl std::future::Future<Output = Result<OAuthCallbackPayload, KontextDevError>>,
753 ),
754 KontextDevError,
755 > {
756 let server = Arc::new(Server::http("127.0.0.1:0").map_err(|err| {
757 KontextDevError::OAuthCallbackError {
758 error: format!("failed to start callback server: {err}"),
759 }
760 })?);
761
762 let callback_url = match server.server_addr() {
763 tiny_http::ListenAddr::IP(std::net::SocketAddr::V4(addr)) => {
764 format!(
765 "http://{}:{}{}",
766 addr.ip(),
767 addr.port(),
768 CONNECT_CALLBACK_PATH
769 )
770 }
771 tiny_http::ListenAddr::IP(std::net::SocketAddr::V6(addr)) => {
772 format!(
773 "http://[{}]:{}{}",
774 addr.ip(),
775 addr.port(),
776 CONNECT_CALLBACK_PATH
777 )
778 }
779 #[cfg(not(target_os = "windows"))]
780 _ => {
781 return Err(KontextDevError::OAuthCallbackError {
782 error: "unable to determine callback address".to_string(),
783 });
784 }
785 };
786
787 let (tx, rx) = oneshot::channel();
788 let _join = spawn_callback_server(server.clone(), tx);
789 let _guard = CallbackServerGuard { server };
790 let timeout_seconds = self.config.auth_timeout_seconds.max(1);
791
792 let fut = async move {
793 let payload = timeout(Duration::from_secs(timeout_seconds as u64), rx)
794 .await
795 .map_err(|_| KontextDevError::OAuthCallbackTimeout { timeout_seconds })?
796 .map_err(|_| KontextDevError::OAuthCallbackCancelled)?;
797 drop(_guard);
798 Ok(payload)
799 };
800
801 Ok((callback_url, fut))
802 }
803
804 fn token_cache_path(&self) -> Option<PathBuf> {
805 if let Some(explicit) = &self.config.token_cache_path {
806 return Some(PathBuf::from(explicit));
807 }
808
809 let home = dirs::home_dir()?;
810 let mut path = home;
811 path.push(".kontext-dev");
812 path.push("tokens");
813
814 let sanitized_client_id: String = self
815 .config
816 .client_id
817 .chars()
818 .map(|ch| {
819 if ch.is_ascii_alphanumeric() || ch == '-' || ch == '_' {
820 ch
821 } else {
822 '_'
823 }
824 })
825 .collect();
826
827 path.push(format!("{sanitized_client_id}.json"));
828 Some(path)
829 }
830
831 fn read_cache(&self) -> Result<Option<TokenCacheFile>, KontextDevError> {
832 let Some(path) = self.token_cache_path() else {
833 return Ok(None);
834 };
835
836 let raw = match fs::read_to_string(&path) {
837 Ok(raw) => raw,
838 Err(err) if err.kind() == std::io::ErrorKind::NotFound => return Ok(None),
839 Err(source) => {
840 return Err(KontextDevError::TokenCacheRead {
841 path: path.display().to_string(),
842 source,
843 });
844 }
845 };
846
847 serde_json::from_str(&raw).map(Some).map_err(|source| {
848 KontextDevError::TokenCacheDeserialize {
849 path: path.display().to_string(),
850 source,
851 }
852 })
853 }
854
855 fn write_cache(
856 &self,
857 identity: &AccessToken,
858 gateway: &TokenExchangeToken,
859 ) -> Result<(), KontextDevError> {
860 let Some(path) = self.token_cache_path() else {
861 return Ok(());
862 };
863
864 if let Some(parent) = path.parent() {
865 fs::create_dir_all(parent).map_err(|source| KontextDevError::TokenCacheWrite {
866 path: parent.display().to_string(),
867 source,
868 })?;
869 }
870
871 let payload = TokenCacheFile {
872 client_id: self.config.client_id.clone(),
873 resource: self.config.resource.clone(),
874 identity: CachedAccessToken::from_access_token(identity)?,
875 gateway: CachedAccessToken::from_token_exchange(gateway)?,
876 };
877
878 let serialized = serde_json::to_string_pretty(&payload)
879 .map_err(|source| KontextDevError::TokenCacheSerialize { source })?;
880
881 fs::write(&path, serialized).map_err(|source| KontextDevError::TokenCacheWrite {
882 path: path.display().to_string(),
883 source,
884 })
885 }
886}
887
888pub async fn request_access_token(
892 config: &KontextDevConfig,
893) -> Result<AccessToken, KontextDevError> {
894 let token_url = resolve_token_url(config)?;
895
896 let mut body = vec![
897 ("grant_type", "client_credentials".to_string()),
898 ("scope", config.scope.clone()),
899 ];
900
901 let auth_header = if let Some(secret) = &config.client_secret {
902 let raw = format!("{}:{}", config.client_id, secret);
903 Some(format!(
904 "Basic {}",
905 base64::engine::general_purpose::STANDARD.encode(raw.as_bytes())
906 ))
907 } else {
908 body.push(("client_id", config.client_id.clone()));
909 None
910 };
911
912 post_token(&Client::new(), &token_url, auth_header.as_deref(), &body).await
913}
914
915async fn post_token(
916 http: &Client,
917 token_url: &str,
918 authorization: Option<&str>,
919 body: &[(impl AsRef<str>, String)],
920) -> Result<AccessToken, KontextDevError> {
921 let body_vec: Vec<(&str, String)> = body.iter().map(|(k, v)| (k.as_ref(), v.clone())).collect();
922
923 let response = post_form_with_optional_auth::<AccessToken>(
924 http,
925 token_url,
926 authorization.map(ToString::to_string),
927 &body_vec,
928 )
929 .await
930 .map_err(|message| KontextDevError::TokenRequest {
931 token_url: token_url.to_string(),
932 message,
933 })?;
934
935 Ok(response)
936}
937
938async fn post_form_with_optional_auth<T>(
939 http: &Client,
940 url: &str,
941 authorization: Option<String>,
942 body: &[(&str, String)],
943) -> Result<T, String>
944where
945 T: DeserializeOwned,
946{
947 let mut request = http
948 .post(url)
949 .header("Content-Type", "application/x-www-form-urlencoded");
950
951 if let Some(header) = authorization {
952 request = request.header("Authorization", header);
953 }
954
955 let form = body
956 .iter()
957 .map(|(k, v)| (k.to_string(), v.to_string()))
958 .collect::<Vec<(String, String)>>();
959
960 let response = request
961 .form(&form)
962 .send()
963 .await
964 .map_err(|err| err.to_string())?;
965
966 if !response.status().is_success() {
967 return Err(build_error_message(response).await);
968 }
969
970 response.json::<T>().await.map_err(|err| err.to_string())
971}
972
973async fn build_error_message(response: reqwest::Response) -> String {
974 let status = response.status();
975 let fallback = format!(
976 "{} {}",
977 status.as_u16(),
978 status.canonical_reason().unwrap_or("")
979 );
980
981 let body = response.text().await.unwrap_or_default();
982 if body.is_empty() {
983 return fallback.trim().to_string();
984 }
985
986 if let Ok(parsed) = serde_json::from_str::<OAuthErrorBody>(&body) {
987 if let Some(description) = parsed.error_description {
988 return description;
989 }
990 if let Some(message) = parsed.message {
991 return message;
992 }
993 if let Some(error) = parsed.error {
994 return error;
995 }
996 }
997
998 format!("{fallback}: {body}")
999}
1000
1001#[cfg(test)]
1002mod tests {
1003 use super::*;
1004
1005 fn config() -> KontextDevConfig {
1006 KontextDevConfig {
1007 server: Some("http://localhost:4000".to_string()),
1008 mcp_url: None,
1009 token_url: None,
1010 client_id: "client_123".to_string(),
1011 client_secret: None,
1012 scope: DEFAULT_SCOPE.to_string(),
1013 server_name: DEFAULT_SERVER_NAME.to_string(),
1014 resource: DEFAULT_RESOURCE.to_string(),
1015 integration_ui_url: Some("https://app.kontext.dev".to_string()),
1016 integration_return_to: None,
1017 open_connect_page_on_login: true,
1018 auth_timeout_seconds: DEFAULT_AUTH_TIMEOUT_SECONDS,
1019 token_cache_path: None,
1020 }
1021 }
1022
1023 #[test]
1024 fn create_connect_url_uses_hosted_ui() {
1025 let client = KontextDevClient::new(config());
1026 let url = client
1027 .integration_connect_url("session-123")
1028 .expect("url should be built");
1029 assert_eq!(
1030 url,
1031 "https://app.kontext.dev/oauth/connect?session=session-123"
1032 );
1033 }
1034
1035 #[test]
1036 fn jwt_exp_decode_reads_exp() {
1037 let header = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(r#"{"alg":"none"}"#);
1038 let payload =
1039 base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(r#"{"exp":4070908800}"#);
1040 let token = format!("{header}.{payload}.sig");
1041 assert_eq!(decode_jwt_exp(&token), Some(4_070_908_800));
1042 }
1043}