Skip to main content

kontext_dev/
lib.rs

1use base64::Engine;
2use reqwest::Client;
3use serde::Deserialize;
4use serde::Serialize;
5use serde::de::DeserializeOwned;
6use sha2::Digest;
7use std::fs;
8use std::path::PathBuf;
9use std::sync::Arc;
10use std::time::Duration;
11use std::time::SystemTime;
12use std::time::UNIX_EPOCH;
13use tiny_http::Response;
14use tiny_http::Server;
15use tokio::sync::oneshot;
16use tokio::time::timeout;
17use url::Url;
18
19pub use kontext_dev_core::AccessToken;
20pub use kontext_dev_core::DEFAULT_AUTH_TIMEOUT_SECONDS;
21pub use kontext_dev_core::DEFAULT_RESOURCE;
22pub use kontext_dev_core::DEFAULT_SCOPE;
23pub use kontext_dev_core::DEFAULT_SERVER_NAME;
24pub use kontext_dev_core::KontextDevConfig;
25pub use kontext_dev_core::KontextDevCoreError;
26pub use kontext_dev_core::TokenExchangeToken;
27pub use kontext_dev_core::build_mcp_url;
28pub use kontext_dev_core::normalize_server_url;
29pub use kontext_dev_core::resolve_authorize_url;
30pub use kontext_dev_core::resolve_connect_session_url;
31pub use kontext_dev_core::resolve_integration_connection_url;
32pub use kontext_dev_core::resolve_integration_oauth_init_url;
33pub use kontext_dev_core::resolve_mcp_url;
34pub use kontext_dev_core::resolve_server_base_url;
35pub use kontext_dev_core::resolve_token_url;
36
37const TOKEN_EXCHANGE_GRANT_TYPE: &str = "urn:ietf:params:oauth:grant-type:token-exchange";
38const TOKEN_TYPE_ACCESS_TOKEN: &str = "urn:ietf:params:oauth:token-type:access_token";
39const CONNECT_CALLBACK_PATH: &str = "/callback";
40const TOKEN_EXPIRY_BUFFER_SECONDS: u64 = 60;
41
42#[derive(Debug, thiserror::Error)]
43pub enum KontextDevError {
44    #[error(transparent)]
45    Core(#[from] kontext_dev_core::KontextDevCoreError),
46    #[error("failed to parse URL `{url}`")]
47    InvalidUrl {
48        url: String,
49        source: url::ParseError,
50    },
51    #[error("failed to open browser for OAuth authorization")]
52    BrowserOpenFailed,
53    #[error("OAuth callback timed out after {timeout_seconds}s")]
54    OAuthCallbackTimeout { timeout_seconds: i64 },
55    #[error("OAuth callback channel was unexpectedly cancelled")]
56    OAuthCallbackCancelled,
57    #[error("OAuth callback is missing the authorization code")]
58    MissingAuthorizationCode,
59    #[error("OAuth callback returned an error: {error}")]
60    OAuthCallbackError { error: String },
61    #[error("OAuth state mismatch")]
62    InvalidOAuthState,
63    #[error("Kontext-Dev token request failed for {token_url}: {message}")]
64    TokenRequest { token_url: String, message: String },
65    #[error("Kontext-Dev token exchange failed for resource `{resource}`: {message}")]
66    TokenExchange { resource: String, message: String },
67    #[error("Kontext-Dev connect session request failed: {message}")]
68    ConnectSession { message: String },
69    #[error("Kontext-Dev integration OAuth init failed: {message}")]
70    IntegrationOAuthInit { message: String },
71    #[error("failed to read token cache at `{path}`: {source}")]
72    TokenCacheRead {
73        path: String,
74        source: std::io::Error,
75    },
76    #[error("failed to write token cache at `{path}`: {source}")]
77    TokenCacheWrite {
78        path: String,
79        source: std::io::Error,
80    },
81    #[error("failed to deserialize token cache at `{path}`: {source}")]
82    TokenCacheDeserialize {
83        path: String,
84        source: serde_json::Error,
85    },
86    #[error("failed to serialize token cache: {source}")]
87    TokenCacheSerialize { source: serde_json::Error },
88    #[error("Kontext-Dev access token is empty")]
89    EmptyAccessToken,
90    #[error("missing integration UI URL; set `integration_ui_url` in config")]
91    MissingIntegrationUiUrl,
92}
93
94#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
95pub struct ConnectSession {
96    #[serde(rename = "sessionId")]
97    pub session_id: String,
98    #[serde(rename = "expiresAt")]
99    pub expires_at: String,
100}
101
102#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
103pub struct IntegrationOAuthInitResponse {
104    #[serde(rename = "authorizationUrl")]
105    pub authorization_url: String,
106    #[serde(default)]
107    pub state: Option<String>,
108}
109
110#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
111pub struct IntegrationConnectionStatus {
112    pub connected: bool,
113    #[serde(default)]
114    pub expires_at: Option<String>,
115}
116
117#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
118pub struct KontextAuthSession {
119    pub identity_token: AccessToken,
120    pub gateway_token: TokenExchangeToken,
121    pub browser_auth_performed: bool,
122}
123
124#[derive(Clone, Debug, Deserialize, Serialize)]
125struct CachedAccessToken {
126    access_token: String,
127    token_type: String,
128    refresh_token: Option<String>,
129    scope: Option<String>,
130    expires_at_unix_ms: Option<u64>,
131}
132
133#[derive(Clone, Debug, Deserialize, Serialize)]
134struct TokenCacheFile {
135    client_id: String,
136    resource: String,
137    identity: CachedAccessToken,
138    gateway: CachedAccessToken,
139}
140
141impl CachedAccessToken {
142    fn from_access_token(token: &AccessToken) -> Result<Self, KontextDevError> {
143        if token.access_token.is_empty() {
144            return Err(KontextDevError::EmptyAccessToken);
145        }
146
147        Ok(Self {
148            access_token: token.access_token.clone(),
149            token_type: token.token_type.clone(),
150            refresh_token: token.refresh_token.clone(),
151            scope: token.scope.clone(),
152            expires_at_unix_ms: compute_expires_at_unix_ms(token.expires_in, &token.access_token),
153        })
154    }
155
156    fn from_token_exchange(token: &TokenExchangeToken) -> Result<Self, KontextDevError> {
157        if token.access_token.is_empty() {
158            return Err(KontextDevError::EmptyAccessToken);
159        }
160
161        Ok(Self {
162            access_token: token.access_token.clone(),
163            token_type: token.token_type.clone(),
164            refresh_token: token.refresh_token.clone(),
165            scope: token.scope.clone(),
166            expires_at_unix_ms: compute_expires_at_unix_ms(token.expires_in, &token.access_token),
167        })
168    }
169
170    fn is_valid(&self) -> bool {
171        match self.expires_at_unix_ms {
172            Some(expires_at) => {
173                let buffer_ms = TOKEN_EXPIRY_BUFFER_SECONDS * 1000;
174                now_unix_ms().saturating_add(buffer_ms) < expires_at
175            }
176            None => true,
177        }
178    }
179
180    fn to_access_token(&self) -> AccessToken {
181        AccessToken {
182            access_token: self.access_token.clone(),
183            token_type: self.token_type.clone(),
184            expires_in: self
185                .expires_at_unix_ms
186                .and_then(unix_ms_to_relative_seconds),
187            refresh_token: self.refresh_token.clone(),
188            scope: self.scope.clone(),
189        }
190    }
191
192    fn to_token_exchange_token(&self) -> TokenExchangeToken {
193        TokenExchangeToken {
194            access_token: self.access_token.clone(),
195            issued_token_type: TOKEN_TYPE_ACCESS_TOKEN.to_string(),
196            token_type: self.token_type.clone(),
197            expires_in: self
198                .expires_at_unix_ms
199                .and_then(unix_ms_to_relative_seconds),
200            scope: self.scope.clone(),
201            refresh_token: self.refresh_token.clone(),
202        }
203    }
204}
205
206fn now_unix_ms() -> u64 {
207    SystemTime::now()
208        .duration_since(UNIX_EPOCH)
209        .unwrap_or_else(|_| Duration::from_secs(0))
210        .as_millis() as u64
211}
212
213fn unix_ms_to_relative_seconds(unix_ms: u64) -> Option<i64> {
214    if unix_ms <= now_unix_ms() {
215        return Some(0);
216    }
217
218    let delta_ms = unix_ms - now_unix_ms();
219    let secs = delta_ms / 1000;
220    i64::try_from(secs).ok()
221}
222
223fn compute_expires_at_unix_ms(expires_in: Option<i64>, access_token: &str) -> Option<u64> {
224    if let Some(expires_in) = expires_in
225        && expires_in > 0
226    {
227        let expires_in_u64 = u64::try_from(expires_in).ok()?;
228        return Some(now_unix_ms().saturating_add(expires_in_u64.saturating_mul(1000)));
229    }
230
231    decode_jwt_exp(access_token).map(|exp| exp.saturating_mul(1000))
232}
233
234fn decode_jwt_exp(token: &str) -> Option<u64> {
235    let mut parts = token.split('.');
236    let _header = parts.next()?;
237    let payload = parts.next()?;
238
239    let decoded = base64::engine::general_purpose::URL_SAFE_NO_PAD
240        .decode(payload)
241        .ok()?;
242    let value: serde_json::Value = serde_json::from_slice(&decoded).ok()?;
243    value.get("exp").and_then(|exp| {
244        exp.as_u64()
245            .or_else(|| exp.as_i64().and_then(|v| u64::try_from(v).ok()))
246    })
247}
248
249#[derive(Clone, Debug, Deserialize)]
250struct OAuthErrorBody {
251    error: Option<String>,
252    error_description: Option<String>,
253    message: Option<String>,
254}
255
256#[derive(Clone, Debug, Deserialize)]
257struct OAuthCallbackPayload {
258    code: Option<String>,
259    state: Option<String>,
260    error: Option<String>,
261    error_description: Option<String>,
262}
263
264#[derive(Clone, Debug)]
265struct PkcePair {
266    verifier: String,
267    challenge: String,
268}
269
270fn generate_pkce_pair() -> PkcePair {
271    let mut raw = [0u8; 64];
272    rand::RngCore::fill_bytes(&mut rand::rng(), &mut raw);
273
274    let verifier = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(raw);
275    let digest = sha2::Sha256::digest(verifier.as_bytes());
276    let challenge = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(digest);
277
278    PkcePair {
279        verifier,
280        challenge,
281    }
282}
283
284fn generate_state() -> String {
285    let mut bytes = [0u8; 16];
286    rand::RngCore::fill_bytes(&mut rand::rng(), &mut bytes);
287    bytes.iter().map(|b| format!("{b:02x}")).collect()
288}
289
290struct CallbackServerGuard {
291    server: Arc<Server>,
292}
293
294impl Drop for CallbackServerGuard {
295    fn drop(&mut self) {
296        self.server.unblock();
297    }
298}
299
300fn parse_callback(url_path: &str) -> Option<OAuthCallbackPayload> {
301    let full = format!("http://localhost{url_path}");
302    let parsed = Url::parse(&full).ok()?;
303
304    if parsed.path() != CONNECT_CALLBACK_PATH {
305        return None;
306    }
307
308    let mut payload = OAuthCallbackPayload {
309        code: None,
310        state: None,
311        error: None,
312        error_description: None,
313    };
314
315    for (key, value) in parsed.query_pairs() {
316        match key.as_ref() {
317            "code" => payload.code = Some(value.to_string()),
318            "state" => payload.state = Some(value.to_string()),
319            "error" => payload.error = Some(value.to_string()),
320            "error_description" => payload.error_description = Some(value.to_string()),
321            _ => {}
322        }
323    }
324
325    Some(payload)
326}
327
328fn spawn_callback_server(
329    server: Arc<Server>,
330    tx: oneshot::Sender<OAuthCallbackPayload>,
331) -> tokio::task::JoinHandle<()> {
332    tokio::task::spawn_blocking(move || {
333        while let Ok(request) = server.recv() {
334            let path = request.url().to_string();
335            if let Some(payload) = parse_callback(&path) {
336                let response = Response::from_string(
337                    "Authentication complete. You can return to your terminal.",
338                );
339                let _ = request.respond(response);
340                let _ = tx.send(payload);
341                break;
342            }
343
344            let response = Response::from_string("Invalid callback").with_status_code(400);
345            let _ = request.respond(response);
346        }
347    })
348}
349
350pub struct KontextDevClient {
351    config: KontextDevConfig,
352    http: Client,
353}
354
355impl KontextDevClient {
356    pub fn new(config: KontextDevConfig) -> Self {
357        Self {
358            config,
359            http: Client::new(),
360        }
361    }
362
363    pub fn config(&self) -> &KontextDevConfig {
364        &self.config
365    }
366
367    pub fn mcp_url(&self) -> Result<String, KontextDevError> {
368        resolve_mcp_url(&self.config).map_err(KontextDevError::from)
369    }
370
371    pub fn token_url(&self) -> Result<String, KontextDevError> {
372        resolve_token_url(&self.config).map_err(KontextDevError::from)
373    }
374
375    pub fn authorize_url(&self) -> Result<String, KontextDevError> {
376        resolve_authorize_url(&self.config).map_err(KontextDevError::from)
377    }
378
379    pub fn connect_session_url(&self) -> Result<String, KontextDevError> {
380        resolve_connect_session_url(&self.config).map_err(KontextDevError::from)
381    }
382
383    /// Authenticate a user via browser PKCE, exchange the identity token to a
384    /// resource-scoped MCP gateway token, and persist cache for future runs.
385    pub async fn authenticate_mcp(&self) -> Result<KontextAuthSession, KontextDevError> {
386        let resource = self.config.resource.clone();
387
388        if let Some(cache) = self.read_cache()?
389            && cache.client_id == self.config.client_id
390            && cache.resource == resource
391            && cache.gateway.is_valid()
392            && cache.identity.is_valid()
393        {
394            return Ok(KontextAuthSession {
395                identity_token: cache.identity.to_access_token(),
396                gateway_token: cache.gateway.to_token_exchange_token(),
397                browser_auth_performed: false,
398            });
399        }
400
401        if let Some(cache) = self.read_cache()?
402            && cache.client_id == self.config.client_id
403            && cache.resource == resource
404        {
405            let maybe_refreshed = if cache.identity.is_valid() {
406                Some(cache.identity.to_access_token())
407            } else if let Some(refresh_token) = &cache.identity.refresh_token {
408                Some(self.refresh_identity_token(refresh_token).await?)
409            } else {
410                None
411            };
412
413            if let Some(identity_token) = maybe_refreshed {
414                let gateway_token = self
415                    .exchange_for_resource(&identity_token.access_token, &resource, None)
416                    .await?;
417                self.write_cache(&identity_token, &gateway_token)?;
418                return Ok(KontextAuthSession {
419                    identity_token,
420                    gateway_token,
421                    browser_auth_performed: false,
422                });
423            }
424        }
425
426        let identity_token = self.authorize_with_browser_pkce().await?;
427        let gateway_token = self
428            .exchange_for_resource(&identity_token.access_token, &resource, None)
429            .await?;
430
431        self.write_cache(&identity_token, &gateway_token)?;
432
433        Ok(KontextAuthSession {
434            identity_token,
435            gateway_token,
436            browser_auth_performed: true,
437        })
438    }
439
440    /// Request a short-lived connect session and build the hosted integration
441    /// connect URL (`.../oauth/connect?session=...`).
442    pub async fn create_integration_connect_url(
443        &self,
444        gateway_access_token: &str,
445    ) -> Result<String, KontextDevError> {
446        let session = self.create_connect_session(gateway_access_token).await?;
447        self.integration_connect_url(&session.session_id)
448    }
449
450    pub fn integration_connect_url(&self, session_id: &str) -> Result<String, KontextDevError> {
451        if session_id.trim().is_empty() {
452            return Err(KontextDevError::ConnectSession {
453                message: "connect session id is empty".to_string(),
454            });
455        }
456
457        let base = if let Some(explicit) = &self.config.integration_ui_url {
458            explicit.trim_end_matches('/').to_string()
459        } else {
460            let server = resolve_server_base_url(&self.config)?;
461            if server.contains("api.kontext.dev") {
462                "https://app.kontext.dev".to_string()
463            } else {
464                server
465            }
466        };
467
468        let mut url = Url::parse(&base).map_err(|source| KontextDevError::InvalidUrl {
469            url: base.clone(),
470            source,
471        })?;
472        url.set_path("/oauth/connect");
473        url.query_pairs_mut().append_pair("session", session_id);
474        Ok(url.to_string())
475    }
476
477    /// Opens the integration connect page in the browser.
478    pub async fn open_integration_connect_page(
479        &self,
480        gateway_access_token: &str,
481    ) -> Result<String, KontextDevError> {
482        let url = self
483            .create_integration_connect_url(gateway_access_token)
484            .await?;
485
486        if webbrowser::open(&url).is_err() {
487            return Err(KontextDevError::BrowserOpenFailed);
488        }
489
490        Ok(url)
491    }
492
493    pub async fn create_connect_session(
494        &self,
495        gateway_access_token: &str,
496    ) -> Result<ConnectSession, KontextDevError> {
497        let url = self.connect_session_url()?;
498        let response = self
499            .http
500            .post(&url)
501            .header("Authorization", format!("Bearer {gateway_access_token}"))
502            .send()
503            .await
504            .map_err(|err| KontextDevError::ConnectSession {
505                message: err.to_string(),
506            })?;
507
508        if !response.status().is_success() {
509            let message = build_error_message(response).await;
510            return Err(KontextDevError::ConnectSession { message });
511        }
512
513        response
514            .json::<ConnectSession>()
515            .await
516            .map_err(|err| KontextDevError::ConnectSession {
517                message: err.to_string(),
518            })
519    }
520
521    pub async fn initiate_integration_oauth(
522        &self,
523        gateway_access_token: &str,
524        integration_id: &str,
525        return_to: Option<&str>,
526    ) -> Result<IntegrationOAuthInitResponse, KontextDevError> {
527        let url = resolve_integration_oauth_init_url(&self.config, integration_id)?;
528
529        let mut request = self
530            .http
531            .post(&url)
532            .header("Authorization", format!("Bearer {gateway_access_token}"));
533
534        if let Some(return_to) = return_to {
535            request = request.json(&serde_json::json!({ "returnTo": return_to }));
536        }
537
538        let response =
539            request
540                .send()
541                .await
542                .map_err(|err| KontextDevError::IntegrationOAuthInit {
543                    message: err.to_string(),
544                })?;
545
546        if !response.status().is_success() {
547            let message = build_error_message(response).await;
548            return Err(KontextDevError::IntegrationOAuthInit { message });
549        }
550
551        response
552            .json::<IntegrationOAuthInitResponse>()
553            .await
554            .map_err(|err| KontextDevError::IntegrationOAuthInit {
555                message: err.to_string(),
556            })
557    }
558
559    pub async fn integration_connection_status(
560        &self,
561        gateway_access_token: &str,
562        integration_id: &str,
563    ) -> Result<IntegrationConnectionStatus, KontextDevError> {
564        let url = resolve_integration_connection_url(&self.config, integration_id)?;
565        let response = self
566            .http
567            .get(url)
568            .header("Authorization", format!("Bearer {gateway_access_token}"))
569            .send()
570            .await
571            .map_err(|err| KontextDevError::IntegrationOAuthInit {
572                message: err.to_string(),
573            })?;
574
575        if !response.status().is_success() {
576            let message = build_error_message(response).await;
577            return Err(KontextDevError::IntegrationOAuthInit { message });
578        }
579
580        response
581            .json::<IntegrationConnectionStatus>()
582            .await
583            .map_err(|err| KontextDevError::IntegrationOAuthInit {
584                message: err.to_string(),
585            })
586    }
587
588    pub async fn wait_for_integration_connection(
589        &self,
590        gateway_access_token: &str,
591        integration_id: &str,
592        timeout_ms: u64,
593        interval_ms: u64,
594    ) -> Result<bool, KontextDevError> {
595        let started = now_unix_ms();
596
597        loop {
598            let status = self
599                .integration_connection_status(gateway_access_token, integration_id)
600                .await?;
601            if status.connected {
602                return Ok(true);
603            }
604
605            if now_unix_ms().saturating_sub(started) >= timeout_ms {
606                return Ok(false);
607            }
608
609            tokio::time::sleep(Duration::from_millis(interval_ms)).await;
610        }
611    }
612
613    async fn authorize_with_browser_pkce(&self) -> Result<AccessToken, KontextDevError> {
614        let auth_url = self.authorize_url()?;
615        let token_url = self.token_url()?;
616        let pkce = generate_pkce_pair();
617        let state = generate_state();
618
619        let (callback_url, callback_payload) = self.listen_for_callback().await?;
620
621        let mut url = Url::parse(&auth_url).map_err(|source| KontextDevError::InvalidUrl {
622            url: auth_url.clone(),
623            source,
624        })?;
625
626        url.query_pairs_mut()
627            .append_pair("client_id", &self.config.client_id)
628            .append_pair("response_type", "code")
629            .append_pair("redirect_uri", &callback_url)
630            .append_pair("state", &state)
631            .append_pair("scope", &self.config.scope)
632            .append_pair("code_challenge_method", "S256")
633            .append_pair("code_challenge", &pkce.challenge);
634
635        println!(
636            "Authorize `{}` by opening this URL in your browser:\n{}\n",
637            self.config.server_name, url
638        );
639
640        if webbrowser::open(url.as_str()).is_err() {
641            return Err(KontextDevError::BrowserOpenFailed);
642        }
643
644        let payload = callback_payload.await?;
645
646        if let Some(error) = payload.error {
647            let with_details = payload
648                .error_description
649                .map(|description| format!("{error}: {description}"))
650                .unwrap_or(error);
651            return Err(KontextDevError::OAuthCallbackError {
652                error: with_details,
653            });
654        }
655
656        if payload.state.as_deref() != Some(state.as_str()) {
657            return Err(KontextDevError::InvalidOAuthState);
658        }
659
660        let code = payload
661            .code
662            .ok_or(KontextDevError::MissingAuthorizationCode)?;
663
664        let mut body = vec![
665            ("grant_type", "authorization_code".to_string()),
666            ("code", code),
667            ("redirect_uri", callback_url),
668            ("client_id", self.config.client_id.clone()),
669            ("code_verifier", pkce.verifier),
670        ];
671
672        if let Some(client_secret) = &self.config.client_secret {
673            body.push(("client_secret", client_secret.clone()));
674        }
675
676        post_token(&self.http, &token_url, None, &body).await
677    }
678
679    async fn refresh_identity_token(
680        &self,
681        refresh_token: &str,
682    ) -> Result<AccessToken, KontextDevError> {
683        let token_url = self.token_url()?;
684        let mut body = vec![
685            ("grant_type", "refresh_token".to_string()),
686            ("refresh_token", refresh_token.to_string()),
687            ("client_id", self.config.client_id.clone()),
688        ];
689
690        if let Some(client_secret) = &self.config.client_secret {
691            body.push(("client_secret", client_secret.clone()));
692        }
693
694        post_token(&self.http, &token_url, None, &body).await
695    }
696
697    pub async fn exchange_for_resource(
698        &self,
699        subject_token: &str,
700        resource: &str,
701        scope: Option<&str>,
702    ) -> Result<TokenExchangeToken, KontextDevError> {
703        let token_url = self.token_url()?;
704
705        let mut body = vec![
706            ("grant_type", TOKEN_EXCHANGE_GRANT_TYPE.to_string()),
707            ("subject_token", subject_token.to_string()),
708            ("subject_token_type", TOKEN_TYPE_ACCESS_TOKEN.to_string()),
709            ("resource", resource.to_string()),
710        ];
711
712        if let Some(scope) = scope {
713            body.push(("scope", scope.to_string()));
714        }
715
716        let auth_header = self.config.client_secret.as_ref().map(|secret| {
717            let raw = format!("{}:{}", self.config.client_id, secret);
718            format!(
719                "Basic {}",
720                base64::engine::general_purpose::STANDARD.encode(raw.as_bytes())
721            )
722        });
723
724        if self.config.client_secret.is_none() {
725            body.push(("client_id", self.config.client_id.clone()));
726        }
727
728        let response = post_form_with_optional_auth::<TokenExchangeToken>(
729            &self.http,
730            &token_url,
731            auth_header,
732            &body,
733        )
734        .await
735        .map_err(|message| KontextDevError::TokenExchange {
736            resource: resource.to_string(),
737            message,
738        })?;
739
740        if response.access_token.is_empty() {
741            return Err(KontextDevError::EmptyAccessToken);
742        }
743
744        Ok(response)
745    }
746
747    async fn listen_for_callback(
748        &self,
749    ) -> Result<
750        (
751            String,
752            impl std::future::Future<Output = Result<OAuthCallbackPayload, KontextDevError>>,
753        ),
754        KontextDevError,
755    > {
756        let server = Arc::new(Server::http("127.0.0.1:0").map_err(|err| {
757            KontextDevError::OAuthCallbackError {
758                error: format!("failed to start callback server: {err}"),
759            }
760        })?);
761
762        let callback_url = match server.server_addr() {
763            tiny_http::ListenAddr::IP(std::net::SocketAddr::V4(addr)) => {
764                format!(
765                    "http://{}:{}{}",
766                    addr.ip(),
767                    addr.port(),
768                    CONNECT_CALLBACK_PATH
769                )
770            }
771            tiny_http::ListenAddr::IP(std::net::SocketAddr::V6(addr)) => {
772                format!(
773                    "http://[{}]:{}{}",
774                    addr.ip(),
775                    addr.port(),
776                    CONNECT_CALLBACK_PATH
777                )
778            }
779            #[cfg(not(target_os = "windows"))]
780            _ => {
781                return Err(KontextDevError::OAuthCallbackError {
782                    error: "unable to determine callback address".to_string(),
783                });
784            }
785        };
786
787        let (tx, rx) = oneshot::channel();
788        let _join = spawn_callback_server(server.clone(), tx);
789        let _guard = CallbackServerGuard { server };
790        let timeout_seconds = self.config.auth_timeout_seconds.max(1);
791
792        let fut = async move {
793            let payload = timeout(Duration::from_secs(timeout_seconds as u64), rx)
794                .await
795                .map_err(|_| KontextDevError::OAuthCallbackTimeout { timeout_seconds })?
796                .map_err(|_| KontextDevError::OAuthCallbackCancelled)?;
797            drop(_guard);
798            Ok(payload)
799        };
800
801        Ok((callback_url, fut))
802    }
803
804    fn token_cache_path(&self) -> Option<PathBuf> {
805        if let Some(explicit) = &self.config.token_cache_path {
806            return Some(PathBuf::from(explicit));
807        }
808
809        let home = dirs::home_dir()?;
810        let mut path = home;
811        path.push(".kontext-dev");
812        path.push("tokens");
813
814        let sanitized_client_id: String = self
815            .config
816            .client_id
817            .chars()
818            .map(|ch| {
819                if ch.is_ascii_alphanumeric() || ch == '-' || ch == '_' {
820                    ch
821                } else {
822                    '_'
823                }
824            })
825            .collect();
826
827        path.push(format!("{sanitized_client_id}.json"));
828        Some(path)
829    }
830
831    fn read_cache(&self) -> Result<Option<TokenCacheFile>, KontextDevError> {
832        let Some(path) = self.token_cache_path() else {
833            return Ok(None);
834        };
835
836        let raw = match fs::read_to_string(&path) {
837            Ok(raw) => raw,
838            Err(err) if err.kind() == std::io::ErrorKind::NotFound => return Ok(None),
839            Err(source) => {
840                return Err(KontextDevError::TokenCacheRead {
841                    path: path.display().to_string(),
842                    source,
843                });
844            }
845        };
846
847        serde_json::from_str(&raw).map(Some).map_err(|source| {
848            KontextDevError::TokenCacheDeserialize {
849                path: path.display().to_string(),
850                source,
851            }
852        })
853    }
854
855    fn write_cache(
856        &self,
857        identity: &AccessToken,
858        gateway: &TokenExchangeToken,
859    ) -> Result<(), KontextDevError> {
860        let Some(path) = self.token_cache_path() else {
861            return Ok(());
862        };
863
864        if let Some(parent) = path.parent() {
865            fs::create_dir_all(parent).map_err(|source| KontextDevError::TokenCacheWrite {
866                path: parent.display().to_string(),
867                source,
868            })?;
869        }
870
871        let payload = TokenCacheFile {
872            client_id: self.config.client_id.clone(),
873            resource: self.config.resource.clone(),
874            identity: CachedAccessToken::from_access_token(identity)?,
875            gateway: CachedAccessToken::from_token_exchange(gateway)?,
876        };
877
878        let serialized = serde_json::to_string_pretty(&payload)
879            .map_err(|source| KontextDevError::TokenCacheSerialize { source })?;
880
881        fs::write(&path, serialized).map_err(|source| KontextDevError::TokenCacheWrite {
882            path: path.display().to_string(),
883            source,
884        })
885    }
886}
887
888/// Legacy helper that uses the client-credentials grant.
889///
890/// New PKCE-based apps should use `KontextDevClient::authenticate_mcp`.
891pub async fn request_access_token(
892    config: &KontextDevConfig,
893) -> Result<AccessToken, KontextDevError> {
894    let token_url = resolve_token_url(config)?;
895
896    let mut body = vec![
897        ("grant_type", "client_credentials".to_string()),
898        ("scope", config.scope.clone()),
899    ];
900
901    let auth_header = if let Some(secret) = &config.client_secret {
902        let raw = format!("{}:{}", config.client_id, secret);
903        Some(format!(
904            "Basic {}",
905            base64::engine::general_purpose::STANDARD.encode(raw.as_bytes())
906        ))
907    } else {
908        body.push(("client_id", config.client_id.clone()));
909        None
910    };
911
912    post_token(&Client::new(), &token_url, auth_header.as_deref(), &body).await
913}
914
915async fn post_token(
916    http: &Client,
917    token_url: &str,
918    authorization: Option<&str>,
919    body: &[(impl AsRef<str>, String)],
920) -> Result<AccessToken, KontextDevError> {
921    let body_vec: Vec<(&str, String)> = body.iter().map(|(k, v)| (k.as_ref(), v.clone())).collect();
922
923    let response = post_form_with_optional_auth::<AccessToken>(
924        http,
925        token_url,
926        authorization.map(ToString::to_string),
927        &body_vec,
928    )
929    .await
930    .map_err(|message| KontextDevError::TokenRequest {
931        token_url: token_url.to_string(),
932        message,
933    })?;
934
935    Ok(response)
936}
937
938async fn post_form_with_optional_auth<T>(
939    http: &Client,
940    url: &str,
941    authorization: Option<String>,
942    body: &[(&str, String)],
943) -> Result<T, String>
944where
945    T: DeserializeOwned,
946{
947    let mut request = http
948        .post(url)
949        .header("Content-Type", "application/x-www-form-urlencoded");
950
951    if let Some(header) = authorization {
952        request = request.header("Authorization", header);
953    }
954
955    let form = body
956        .iter()
957        .map(|(k, v)| (k.to_string(), v.to_string()))
958        .collect::<Vec<(String, String)>>();
959
960    let response = request
961        .form(&form)
962        .send()
963        .await
964        .map_err(|err| err.to_string())?;
965
966    if !response.status().is_success() {
967        return Err(build_error_message(response).await);
968    }
969
970    response.json::<T>().await.map_err(|err| err.to_string())
971}
972
973async fn build_error_message(response: reqwest::Response) -> String {
974    let status = response.status();
975    let fallback = format!(
976        "{} {}",
977        status.as_u16(),
978        status.canonical_reason().unwrap_or("")
979    );
980
981    let body = response.text().await.unwrap_or_default();
982    if body.is_empty() {
983        return fallback.trim().to_string();
984    }
985
986    if let Ok(parsed) = serde_json::from_str::<OAuthErrorBody>(&body) {
987        if let Some(description) = parsed.error_description {
988            return description;
989        }
990        if let Some(message) = parsed.message {
991            return message;
992        }
993        if let Some(error) = parsed.error {
994            return error;
995        }
996    }
997
998    format!("{fallback}: {body}")
999}
1000
1001#[cfg(test)]
1002mod tests {
1003    use super::*;
1004
1005    fn config() -> KontextDevConfig {
1006        KontextDevConfig {
1007            server: Some("http://localhost:4000".to_string()),
1008            mcp_url: None,
1009            token_url: None,
1010            client_id: "client_123".to_string(),
1011            client_secret: None,
1012            scope: DEFAULT_SCOPE.to_string(),
1013            server_name: DEFAULT_SERVER_NAME.to_string(),
1014            resource: DEFAULT_RESOURCE.to_string(),
1015            integration_ui_url: Some("https://app.kontext.dev".to_string()),
1016            integration_return_to: None,
1017            open_connect_page_on_login: true,
1018            auth_timeout_seconds: DEFAULT_AUTH_TIMEOUT_SECONDS,
1019            token_cache_path: None,
1020        }
1021    }
1022
1023    #[test]
1024    fn create_connect_url_uses_hosted_ui() {
1025        let client = KontextDevClient::new(config());
1026        let url = client
1027            .integration_connect_url("session-123")
1028            .expect("url should be built");
1029        assert_eq!(
1030            url,
1031            "https://app.kontext.dev/oauth/connect?session=session-123"
1032        );
1033    }
1034
1035    #[test]
1036    fn jwt_exp_decode_reads_exp() {
1037        let header = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(r#"{"alg":"none"}"#);
1038        let payload =
1039            base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(r#"{"exp":4070908800}"#);
1040        let token = format!("{header}.{payload}.sig");
1041        assert_eq!(decode_jwt_exp(&token), Some(4_070_908_800));
1042    }
1043}