Skip to main content

kontext_dev/
lib.rs

1use base64::Engine;
2use reqwest::Client;
3use reqwest::StatusCode;
4use serde::Deserialize;
5use serde::Serialize;
6use serde::de::DeserializeOwned;
7use sha2::Digest;
8use std::fs;
9use std::path::PathBuf;
10use std::sync::Arc;
11use std::time::Duration;
12use std::time::SystemTime;
13use std::time::UNIX_EPOCH;
14use tiny_http::Response;
15use tiny_http::Server;
16use tokio::sync::oneshot;
17use tokio::time::timeout;
18use url::Url;
19
20pub use kontext_dev_core::AccessToken;
21pub use kontext_dev_core::DEFAULT_AUTH_TIMEOUT_SECONDS;
22pub use kontext_dev_core::DEFAULT_RESOURCE;
23pub use kontext_dev_core::DEFAULT_SCOPE;
24pub use kontext_dev_core::DEFAULT_SERVER_NAME;
25pub use kontext_dev_core::KontextDevConfig;
26pub use kontext_dev_core::KontextDevCoreError;
27pub use kontext_dev_core::TokenExchangeToken;
28pub use kontext_dev_core::build_mcp_url;
29pub use kontext_dev_core::normalize_server_url;
30pub use kontext_dev_core::resolve_authorize_url;
31pub use kontext_dev_core::resolve_connect_session_url;
32pub use kontext_dev_core::resolve_integration_connection_url;
33pub use kontext_dev_core::resolve_integration_oauth_init_url;
34pub use kontext_dev_core::resolve_mcp_url;
35pub use kontext_dev_core::resolve_server_base_url;
36pub use kontext_dev_core::resolve_token_url;
37
38const TOKEN_EXCHANGE_GRANT_TYPE: &str = "urn:ietf:params:oauth:grant-type:token-exchange";
39const TOKEN_TYPE_ACCESS_TOKEN: &str = "urn:ietf:params:oauth:token-type:access_token";
40const TOKEN_EXPIRY_BUFFER_SECONDS: u64 = 60;
41
42#[derive(Debug, thiserror::Error)]
43pub enum KontextDevError {
44    #[error(transparent)]
45    Core(#[from] kontext_dev_core::KontextDevCoreError),
46    #[error("failed to parse URL `{url}`")]
47    InvalidUrl {
48        url: String,
49        source: url::ParseError,
50    },
51    #[error("failed to open browser for OAuth authorization")]
52    BrowserOpenFailed,
53    #[error("OAuth callback timed out after {timeout_seconds}s")]
54    OAuthCallbackTimeout { timeout_seconds: i64 },
55    #[error("OAuth callback channel was unexpectedly cancelled")]
56    OAuthCallbackCancelled,
57    #[error("OAuth callback is missing the authorization code")]
58    MissingAuthorizationCode,
59    #[error("OAuth callback returned an error: {error}")]
60    OAuthCallbackError { error: String },
61    #[error("OAuth state mismatch")]
62    InvalidOAuthState,
63    #[error("Kontext-Dev token request failed for {token_url}: {message}")]
64    TokenRequest { token_url: String, message: String },
65    #[error("Kontext-Dev token exchange failed for resource `{resource}`: {message}")]
66    TokenExchange { resource: String, message: String },
67    #[error("Kontext-Dev connect session request failed: {message}")]
68    ConnectSession { message: String },
69    #[error("Kontext-Dev integration OAuth init failed: {message}")]
70    IntegrationOAuthInit { message: String },
71    #[error("failed to read token cache at `{path}`: {source}")]
72    TokenCacheRead {
73        path: String,
74        source: std::io::Error,
75    },
76    #[error("failed to write token cache at `{path}`: {source}")]
77    TokenCacheWrite {
78        path: String,
79        source: std::io::Error,
80    },
81    #[error("failed to deserialize token cache at `{path}`: {source}")]
82    TokenCacheDeserialize {
83        path: String,
84        source: serde_json::Error,
85    },
86    #[error("failed to serialize token cache: {source}")]
87    TokenCacheSerialize { source: serde_json::Error },
88    #[error("Kontext-Dev access token is empty")]
89    EmptyAccessToken,
90    #[error("missing integration UI URL; set `integration_ui_url` in config")]
91    MissingIntegrationUiUrl,
92}
93
94#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
95pub struct ConnectSession {
96    #[serde(rename = "sessionId")]
97    pub session_id: String,
98    #[serde(rename = "expiresAt")]
99    pub expires_at: String,
100}
101
102#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
103pub struct IntegrationOAuthInitResponse {
104    #[serde(rename = "authorizationUrl")]
105    pub authorization_url: String,
106    #[serde(default)]
107    pub state: Option<String>,
108}
109
110#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
111pub struct IntegrationConnectionStatus {
112    pub connected: bool,
113    #[serde(default)]
114    pub expires_at: Option<String>,
115}
116
117#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, Eq)]
118pub struct KontextAuthSession {
119    pub identity_token: AccessToken,
120    pub gateway_token: TokenExchangeToken,
121    pub browser_auth_performed: bool,
122}
123
124#[derive(Clone, Debug, Deserialize, Serialize)]
125struct CachedAccessToken {
126    access_token: String,
127    token_type: String,
128    refresh_token: Option<String>,
129    scope: Option<String>,
130    expires_at_unix_ms: Option<u64>,
131}
132
133#[derive(Clone, Debug, Deserialize, Serialize)]
134struct TokenCacheFile {
135    client_id: String,
136    resource: String,
137    identity: CachedAccessToken,
138    gateway: CachedAccessToken,
139}
140
141impl CachedAccessToken {
142    fn from_access_token(token: &AccessToken) -> Result<Self, KontextDevError> {
143        if token.access_token.is_empty() {
144            return Err(KontextDevError::EmptyAccessToken);
145        }
146
147        Ok(Self {
148            access_token: token.access_token.clone(),
149            token_type: token.token_type.clone(),
150            refresh_token: token.refresh_token.clone(),
151            scope: token.scope.clone(),
152            expires_at_unix_ms: compute_expires_at_unix_ms(token.expires_in, &token.access_token),
153        })
154    }
155
156    fn from_token_exchange(token: &TokenExchangeToken) -> Result<Self, KontextDevError> {
157        if token.access_token.is_empty() {
158            return Err(KontextDevError::EmptyAccessToken);
159        }
160
161        Ok(Self {
162            access_token: token.access_token.clone(),
163            token_type: token.token_type.clone(),
164            refresh_token: token.refresh_token.clone(),
165            scope: token.scope.clone(),
166            expires_at_unix_ms: compute_expires_at_unix_ms(token.expires_in, &token.access_token),
167        })
168    }
169
170    fn is_valid(&self) -> bool {
171        match self.expires_at_unix_ms {
172            Some(expires_at) => {
173                let buffer_ms = TOKEN_EXPIRY_BUFFER_SECONDS * 1000;
174                now_unix_ms().saturating_add(buffer_ms) < expires_at
175            }
176            None => true,
177        }
178    }
179
180    fn to_access_token(&self) -> AccessToken {
181        AccessToken {
182            access_token: self.access_token.clone(),
183            token_type: self.token_type.clone(),
184            expires_in: self
185                .expires_at_unix_ms
186                .and_then(unix_ms_to_relative_seconds),
187            refresh_token: self.refresh_token.clone(),
188            scope: self.scope.clone(),
189        }
190    }
191
192    fn to_token_exchange_token(&self) -> TokenExchangeToken {
193        TokenExchangeToken {
194            access_token: self.access_token.clone(),
195            issued_token_type: TOKEN_TYPE_ACCESS_TOKEN.to_string(),
196            token_type: self.token_type.clone(),
197            expires_in: self
198                .expires_at_unix_ms
199                .and_then(unix_ms_to_relative_seconds),
200            scope: self.scope.clone(),
201            refresh_token: self.refresh_token.clone(),
202        }
203    }
204}
205
206fn now_unix_ms() -> u64 {
207    SystemTime::now()
208        .duration_since(UNIX_EPOCH)
209        .unwrap_or_else(|_| Duration::from_secs(0))
210        .as_millis() as u64
211}
212
213fn unix_ms_to_relative_seconds(unix_ms: u64) -> Option<i64> {
214    if unix_ms <= now_unix_ms() {
215        return Some(0);
216    }
217
218    let delta_ms = unix_ms - now_unix_ms();
219    let secs = delta_ms / 1000;
220    i64::try_from(secs).ok()
221}
222
223fn compute_expires_at_unix_ms(expires_in: Option<i64>, access_token: &str) -> Option<u64> {
224    if let Some(expires_in) = expires_in
225        && expires_in > 0
226    {
227        let expires_in_u64 = u64::try_from(expires_in).ok()?;
228        return Some(now_unix_ms().saturating_add(expires_in_u64.saturating_mul(1000)));
229    }
230
231    decode_jwt_exp(access_token).map(|exp| exp.saturating_mul(1000))
232}
233
234fn decode_jwt_exp(token: &str) -> Option<u64> {
235    let mut parts = token.split('.');
236    let _header = parts.next()?;
237    let payload = parts.next()?;
238
239    let decoded = base64::engine::general_purpose::URL_SAFE_NO_PAD
240        .decode(payload)
241        .ok()?;
242    let value: serde_json::Value = serde_json::from_slice(&decoded).ok()?;
243    value.get("exp").and_then(|exp| {
244        exp.as_u64()
245            .or_else(|| exp.as_i64().and_then(|v| u64::try_from(v).ok()))
246    })
247}
248
249#[derive(Clone, Debug, Deserialize)]
250struct OAuthErrorBody {
251    error: Option<String>,
252    error_description: Option<String>,
253    message: Option<String>,
254}
255
256#[derive(Clone, Debug, Deserialize)]
257struct OAuthAuthorizationServerMetadata {
258    authorization_endpoint: Option<String>,
259}
260
261#[derive(Clone, Debug, Deserialize)]
262struct OAuthCallbackPayload {
263    code: Option<String>,
264    state: Option<String>,
265    error: Option<String>,
266    error_description: Option<String>,
267}
268
269#[derive(Clone, Debug)]
270struct PkcePair {
271    verifier: String,
272    challenge: String,
273}
274
275fn generate_pkce_pair() -> PkcePair {
276    let mut raw = [0u8; 64];
277    rand::RngCore::fill_bytes(&mut rand::rng(), &mut raw);
278
279    let verifier = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(raw);
280    let digest = sha2::Sha256::digest(verifier.as_bytes());
281    let challenge = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(digest);
282
283    PkcePair {
284        verifier,
285        challenge,
286    }
287}
288
289fn generate_state() -> String {
290    let mut bytes = [0u8; 16];
291    rand::RngCore::fill_bytes(&mut rand::rng(), &mut bytes);
292    bytes.iter().map(|b| format!("{b:02x}")).collect()
293}
294
295fn normalized_scope(scope: &str) -> Option<&str> {
296    let trimmed = scope.trim();
297    if trimmed.is_empty() {
298        None
299    } else {
300        Some(trimmed)
301    }
302}
303
304struct CallbackServerGuard {
305    server: Arc<Server>,
306}
307
308impl Drop for CallbackServerGuard {
309    fn drop(&mut self) {
310        self.server.unblock();
311    }
312}
313
314fn parse_callback(url_path: &str, callback_path: &str) -> Option<OAuthCallbackPayload> {
315    let full = format!("http://localhost{url_path}");
316    let parsed = Url::parse(&full).ok()?;
317
318    if parsed.path() != callback_path {
319        return None;
320    }
321
322    let mut payload = OAuthCallbackPayload {
323        code: None,
324        state: None,
325        error: None,
326        error_description: None,
327    };
328
329    for (key, value) in parsed.query_pairs() {
330        match key.as_ref() {
331            "code" => payload.code = Some(value.to_string()),
332            "state" => payload.state = Some(value.to_string()),
333            "error" => payload.error = Some(value.to_string()),
334            "error_description" => payload.error_description = Some(value.to_string()),
335            _ => {}
336        }
337    }
338
339    Some(payload)
340}
341
342fn spawn_callback_server(
343    server: Arc<Server>,
344    callback_path: String,
345    tx: oneshot::Sender<OAuthCallbackPayload>,
346) -> tokio::task::JoinHandle<()> {
347    tokio::task::spawn_blocking(move || {
348        while let Ok(request) = server.recv() {
349            let path = request.url().to_string();
350            if let Some(payload) = parse_callback(&path, &callback_path) {
351                let response = Response::from_string(
352                    "Authentication complete. You can return to your terminal.",
353                );
354                let _ = request.respond(response);
355                let _ = tx.send(payload);
356                break;
357            }
358
359            let response = Response::from_string("Invalid callback").with_status_code(400);
360            let _ = request.respond(response);
361        }
362    })
363}
364
365pub struct KontextDevClient {
366    config: KontextDevConfig,
367    http: Client,
368}
369
370impl KontextDevClient {
371    pub fn new(config: KontextDevConfig) -> Self {
372        Self {
373            config,
374            http: Client::new(),
375        }
376    }
377
378    pub fn config(&self) -> &KontextDevConfig {
379        &self.config
380    }
381
382    pub fn mcp_url(&self) -> Result<String, KontextDevError> {
383        resolve_mcp_url(&self.config).map_err(KontextDevError::from)
384    }
385
386    pub fn token_url(&self) -> Result<String, KontextDevError> {
387        resolve_token_url(&self.config).map_err(KontextDevError::from)
388    }
389
390    pub fn authorize_url(&self) -> Result<String, KontextDevError> {
391        resolve_authorize_url(&self.config).map_err(KontextDevError::from)
392    }
393
394    pub fn connect_session_url(&self) -> Result<String, KontextDevError> {
395        resolve_connect_session_url(&self.config).map_err(KontextDevError::from)
396    }
397
398    /// Authenticate a user via browser PKCE, exchange the identity token to a
399    /// resource-scoped MCP gateway token, and persist cache for future runs.
400    pub async fn authenticate_mcp(&self) -> Result<KontextAuthSession, KontextDevError> {
401        let resource = self.config.resource.clone();
402
403        if let Some(cache) = self.read_cache()?
404            && cache.client_id == self.config.client_id
405            && cache.resource == resource
406            && cache.gateway.is_valid()
407            && cache.identity.is_valid()
408        {
409            return Ok(KontextAuthSession {
410                identity_token: cache.identity.to_access_token(),
411                gateway_token: cache.gateway.to_token_exchange_token(),
412                browser_auth_performed: false,
413            });
414        }
415
416        if let Some(cache) = self.read_cache()?
417            && cache.client_id == self.config.client_id
418            && cache.resource == resource
419        {
420            let maybe_refreshed = if cache.identity.is_valid() {
421                Some(cache.identity.to_access_token())
422            } else if let Some(refresh_token) = &cache.identity.refresh_token {
423                Some(self.refresh_identity_token(refresh_token).await?)
424            } else {
425                None
426            };
427
428            if let Some(identity_token) = maybe_refreshed {
429                let gateway_token = self
430                    .exchange_for_resource(&identity_token.access_token, &resource, None)
431                    .await?;
432                self.write_cache(&identity_token, &gateway_token)?;
433                return Ok(KontextAuthSession {
434                    identity_token,
435                    gateway_token,
436                    browser_auth_performed: false,
437                });
438            }
439        }
440
441        let identity_token = self.authorize_with_browser_pkce().await?;
442        let gateway_token = self
443            .exchange_for_resource(&identity_token.access_token, &resource, None)
444            .await?;
445
446        self.write_cache(&identity_token, &gateway_token)?;
447
448        Ok(KontextAuthSession {
449            identity_token,
450            gateway_token,
451            browser_auth_performed: true,
452        })
453    }
454
455    /// Request a short-lived connect session and build the hosted integration
456    /// connect URL (`.../oauth/connect?session=...`).
457    pub async fn create_integration_connect_url(
458        &self,
459        gateway_access_token: &str,
460    ) -> Result<String, KontextDevError> {
461        let session = self.create_connect_session(gateway_access_token).await?;
462        self.integration_connect_url(&session.session_id)
463    }
464
465    pub fn integration_connect_url(&self, session_id: &str) -> Result<String, KontextDevError> {
466        if session_id.trim().is_empty() {
467            return Err(KontextDevError::ConnectSession {
468                message: "connect session id is empty".to_string(),
469            });
470        }
471
472        let base = if let Some(explicit) = &self.config.integration_ui_url {
473            explicit.trim_end_matches('/').to_string()
474        } else {
475            let server = resolve_server_base_url(&self.config)?;
476            if server.contains("api.kontext.dev") {
477                "https://app.kontext.dev".to_string()
478            } else {
479                server
480            }
481        };
482
483        let mut url = Url::parse(&base).map_err(|source| KontextDevError::InvalidUrl {
484            url: base.clone(),
485            source,
486        })?;
487        url.set_path("/oauth/connect");
488        url.query_pairs_mut().append_pair("session", session_id);
489        Ok(url.to_string())
490    }
491
492    /// Opens the integration connect page in the browser.
493    pub async fn open_integration_connect_page(
494        &self,
495        gateway_access_token: &str,
496    ) -> Result<String, KontextDevError> {
497        let url = self
498            .create_integration_connect_url(gateway_access_token)
499            .await?;
500
501        if webbrowser::open(&url).is_err() {
502            return Err(KontextDevError::BrowserOpenFailed);
503        }
504
505        Ok(url)
506    }
507
508    pub async fn create_connect_session(
509        &self,
510        gateway_access_token: &str,
511    ) -> Result<ConnectSession, KontextDevError> {
512        let url = self.connect_session_url()?;
513        let response = self
514            .http
515            .post(&url)
516            .header("Authorization", format!("Bearer {gateway_access_token}"))
517            .json(&serde_json::json!({}))
518            .send()
519            .await
520            .map_err(|err| KontextDevError::ConnectSession {
521                message: err.to_string(),
522            })?;
523
524        if !response.status().is_success() {
525            let message = build_error_message(response).await;
526            return Err(KontextDevError::ConnectSession { message });
527        }
528
529        response
530            .json::<ConnectSession>()
531            .await
532            .map_err(|err| KontextDevError::ConnectSession {
533                message: err.to_string(),
534            })
535    }
536
537    pub async fn initiate_integration_oauth(
538        &self,
539        gateway_access_token: &str,
540        integration_id: &str,
541        return_to: Option<&str>,
542    ) -> Result<IntegrationOAuthInitResponse, KontextDevError> {
543        let url = resolve_integration_oauth_init_url(&self.config, integration_id)?;
544
545        let payload = return_to
546            .map(|value| serde_json::json!({ "returnTo": value }))
547            .unwrap_or_else(|| serde_json::json!({}));
548
549        let request = self
550            .http
551            .post(&url)
552            .header("Authorization", format!("Bearer {gateway_access_token}"))
553            .json(&payload);
554
555        let response =
556            request
557                .send()
558                .await
559                .map_err(|err| KontextDevError::IntegrationOAuthInit {
560                    message: err.to_string(),
561                })?;
562
563        if !response.status().is_success() {
564            let message = build_error_message(response).await;
565            return Err(KontextDevError::IntegrationOAuthInit { message });
566        }
567
568        response
569            .json::<IntegrationOAuthInitResponse>()
570            .await
571            .map_err(|err| KontextDevError::IntegrationOAuthInit {
572                message: err.to_string(),
573            })
574    }
575
576    pub async fn integration_connection_status(
577        &self,
578        gateway_access_token: &str,
579        integration_id: &str,
580    ) -> Result<IntegrationConnectionStatus, KontextDevError> {
581        let url = resolve_integration_connection_url(&self.config, integration_id)?;
582        let response = self
583            .http
584            .get(url)
585            .header("Authorization", format!("Bearer {gateway_access_token}"))
586            .send()
587            .await
588            .map_err(|err| KontextDevError::IntegrationOAuthInit {
589                message: err.to_string(),
590            })?;
591
592        if !response.status().is_success() {
593            let message = build_error_message(response).await;
594            return Err(KontextDevError::IntegrationOAuthInit { message });
595        }
596
597        response
598            .json::<IntegrationConnectionStatus>()
599            .await
600            .map_err(|err| KontextDevError::IntegrationOAuthInit {
601                message: err.to_string(),
602            })
603    }
604
605    pub async fn wait_for_integration_connection(
606        &self,
607        gateway_access_token: &str,
608        integration_id: &str,
609        timeout_ms: u64,
610        interval_ms: u64,
611    ) -> Result<bool, KontextDevError> {
612        let started = now_unix_ms();
613
614        loop {
615            let status = self
616                .integration_connection_status(gateway_access_token, integration_id)
617                .await?;
618            if status.connected {
619                return Ok(true);
620            }
621
622            if now_unix_ms().saturating_sub(started) >= timeout_ms {
623                return Ok(false);
624            }
625
626            tokio::time::sleep(Duration::from_millis(interval_ms)).await;
627        }
628    }
629
630    async fn authorize_with_browser_pkce(&self) -> Result<AccessToken, KontextDevError> {
631        let auth_url = self.resolve_authorization_url().await?;
632        let token_url = self.token_url()?;
633        let pkce = generate_pkce_pair();
634        let state = generate_state();
635
636        let (callback_url, callback_payload) = self.listen_for_callback().await?;
637
638        let mut url = Url::parse(&auth_url).map_err(|source| KontextDevError::InvalidUrl {
639            url: auth_url.clone(),
640            source,
641        })?;
642
643        {
644            let mut query = url.query_pairs_mut();
645            query
646                .append_pair("client_id", &self.config.client_id)
647                .append_pair("response_type", "code")
648                .append_pair("redirect_uri", &callback_url)
649                .append_pair("state", &state);
650
651            if let Some(scope) = normalized_scope(&self.config.scope) {
652                query.append_pair("scope", scope);
653            }
654
655            query
656                .append_pair("code_challenge_method", "S256")
657                .append_pair("code_challenge", &pkce.challenge);
658        }
659
660        if webbrowser::open(url.as_str()).is_err() {
661            return Err(KontextDevError::BrowserOpenFailed);
662        }
663
664        let payload = callback_payload.await?;
665
666        if let Some(error) = payload.error {
667            let with_details = payload
668                .error_description
669                .map(|description| format!("{error}: {description}"))
670                .unwrap_or(error);
671            return Err(KontextDevError::OAuthCallbackError {
672                error: with_details,
673            });
674        }
675
676        if payload.state.as_deref() != Some(state.as_str()) {
677            return Err(KontextDevError::InvalidOAuthState);
678        }
679
680        let code = payload
681            .code
682            .ok_or(KontextDevError::MissingAuthorizationCode)?;
683
684        let mut body = vec![
685            ("grant_type", "authorization_code".to_string()),
686            ("code", code),
687            ("redirect_uri", callback_url),
688            ("client_id", self.config.client_id.clone()),
689            ("code_verifier", pkce.verifier),
690        ];
691
692        if let Some(client_secret) = &self.config.client_secret {
693            body.push(("client_secret", client_secret.clone()));
694        }
695
696        post_token(&self.http, &token_url, None, &body).await
697    }
698
699    async fn resolve_authorization_url(&self) -> Result<String, KontextDevError> {
700        if let Some(discovered) = self.discover_authorization_endpoint().await {
701            return Ok(discovered);
702        }
703
704        let authorize_url = self.authorize_url()?;
705        if !self.endpoint_is_missing(&authorize_url).await {
706            return Ok(authorize_url);
707        }
708
709        let server_base = resolve_server_base_url(&self.config)?;
710        Ok(format!("{}/oauth2/auth", server_base.trim_end_matches('/')))
711    }
712
713    async fn discover_authorization_endpoint(&self) -> Option<String> {
714        let base = resolve_server_base_url(&self.config).ok()?;
715        let base = base.trim_end_matches('/');
716
717        let candidates = [
718            format!("{base}/.well-known/oauth-authorization-server/mcp"),
719            format!("{base}/.well-known/oauth-authorization-server"),
720        ];
721
722        for url in candidates {
723            let response = match self.http.get(&url).send().await {
724                Ok(response) => response,
725                Err(_) => continue,
726            };
727
728            if !response.status().is_success() {
729                continue;
730            }
731
732            let metadata = match response.json::<OAuthAuthorizationServerMetadata>().await {
733                Ok(metadata) => metadata,
734                Err(_) => continue,
735            };
736
737            let Some(endpoint) = metadata.authorization_endpoint else {
738                continue;
739            };
740
741            if Url::parse(&endpoint).is_ok() {
742                return Some(endpoint);
743            }
744        }
745
746        None
747    }
748
749    async fn endpoint_is_missing(&self, url: &str) -> bool {
750        let probe_client = match reqwest::Client::builder()
751            .redirect(reqwest::redirect::Policy::none())
752            .build()
753        {
754            Ok(client) => client,
755            Err(_) => return false,
756        };
757
758        match probe_client.get(url).send().await {
759            Ok(response) => response.status() == StatusCode::NOT_FOUND,
760            Err(_) => false,
761        }
762    }
763
764    async fn refresh_identity_token(
765        &self,
766        refresh_token: &str,
767    ) -> Result<AccessToken, KontextDevError> {
768        let token_url = self.token_url()?;
769        let mut body = vec![
770            ("grant_type", "refresh_token".to_string()),
771            ("refresh_token", refresh_token.to_string()),
772            ("client_id", self.config.client_id.clone()),
773        ];
774
775        if let Some(client_secret) = &self.config.client_secret {
776            body.push(("client_secret", client_secret.clone()));
777        }
778
779        post_token(&self.http, &token_url, None, &body).await
780    }
781
782    pub async fn exchange_for_resource(
783        &self,
784        subject_token: &str,
785        resource: &str,
786        scope: Option<&str>,
787    ) -> Result<TokenExchangeToken, KontextDevError> {
788        let token_url = self.token_url()?;
789
790        let mut body = vec![
791            ("grant_type", TOKEN_EXCHANGE_GRANT_TYPE.to_string()),
792            ("subject_token", subject_token.to_string()),
793            ("subject_token_type", TOKEN_TYPE_ACCESS_TOKEN.to_string()),
794            ("resource", resource.to_string()),
795        ];
796
797        if let Some(scope) = scope {
798            body.push(("scope", scope.to_string()));
799        }
800
801        let auth_header = self.config.client_secret.as_ref().map(|secret| {
802            let raw = format!("{}:{}", self.config.client_id, secret);
803            format!(
804                "Basic {}",
805                base64::engine::general_purpose::STANDARD.encode(raw.as_bytes())
806            )
807        });
808
809        if self.config.client_secret.is_none() {
810            body.push(("client_id", self.config.client_id.clone()));
811        }
812
813        let response = post_form_with_optional_auth::<TokenExchangeToken>(
814            &self.http,
815            &token_url,
816            auth_header,
817            &body,
818        )
819        .await
820        .map_err(|message| KontextDevError::TokenExchange {
821            resource: resource.to_string(),
822            message,
823        })?;
824
825        if response.access_token.is_empty() {
826            return Err(KontextDevError::EmptyAccessToken);
827        }
828
829        Ok(response)
830    }
831
832    async fn listen_for_callback(
833        &self,
834    ) -> Result<
835        (
836            String,
837            impl std::future::Future<Output = Result<OAuthCallbackPayload, KontextDevError>>,
838        ),
839        KontextDevError,
840    > {
841        let redirect_uri = self.config.redirect_uri.trim().to_string();
842        let parsed = Url::parse(&redirect_uri).map_err(|source| KontextDevError::InvalidUrl {
843            url: redirect_uri.clone(),
844            source,
845        })?;
846
847        if parsed.scheme() != "http" {
848            return Err(KontextDevError::OAuthCallbackError {
849                error: "redirect_uri must use http".to_string(),
850            });
851        }
852
853        if parsed.query().is_some() || parsed.fragment().is_some() {
854            return Err(KontextDevError::OAuthCallbackError {
855                error: "redirect_uri must not include query parameters or fragments".to_string(),
856            });
857        }
858
859        let host = parsed
860            .host_str()
861            .ok_or_else(|| KontextDevError::OAuthCallbackError {
862                error: "redirect_uri host is missing".to_string(),
863            })?;
864
865        let port = parsed
866            .port()
867            .ok_or_else(|| KontextDevError::OAuthCallbackError {
868                error: "redirect_uri must include an explicit port".to_string(),
869            })?;
870
871        let callback_path = parsed.path().to_string();
872        let bind_addr = if host.contains(':') {
873            format!("[{host}]:{port}")
874        } else {
875            format!("{host}:{port}")
876        };
877
878        let server = Arc::new(Server::http(&bind_addr).map_err(|err| {
879            KontextDevError::OAuthCallbackError {
880                error: format!("failed to start callback server at {bind_addr}: {err}"),
881            }
882        })?);
883
884        let (tx, rx) = oneshot::channel();
885        let _join = spawn_callback_server(server.clone(), callback_path, tx);
886        let _guard = CallbackServerGuard { server };
887        let timeout_seconds = self.config.auth_timeout_seconds.max(1);
888
889        let fut = async move {
890            let payload = timeout(Duration::from_secs(timeout_seconds as u64), rx)
891                .await
892                .map_err(|_| KontextDevError::OAuthCallbackTimeout { timeout_seconds })?
893                .map_err(|_| KontextDevError::OAuthCallbackCancelled)?;
894            drop(_guard);
895            Ok(payload)
896        };
897
898        Ok((redirect_uri, fut))
899    }
900
901    fn token_cache_path(&self) -> Option<PathBuf> {
902        if let Some(explicit) = &self.config.token_cache_path {
903            return Some(PathBuf::from(explicit));
904        }
905
906        let home = dirs::home_dir()?;
907        let mut path = home;
908        path.push(".kontext-dev");
909        path.push("tokens");
910
911        let sanitized_client_id: String = self
912            .config
913            .client_id
914            .chars()
915            .map(|ch| {
916                if ch.is_ascii_alphanumeric() || ch == '-' || ch == '_' {
917                    ch
918                } else {
919                    '_'
920                }
921            })
922            .collect();
923
924        path.push(format!("{sanitized_client_id}.json"));
925        Some(path)
926    }
927
928    fn read_cache(&self) -> Result<Option<TokenCacheFile>, KontextDevError> {
929        let Some(path) = self.token_cache_path() else {
930            return Ok(None);
931        };
932
933        let raw = match fs::read_to_string(&path) {
934            Ok(raw) => raw,
935            Err(err) if err.kind() == std::io::ErrorKind::NotFound => return Ok(None),
936            Err(source) => {
937                return Err(KontextDevError::TokenCacheRead {
938                    path: path.display().to_string(),
939                    source,
940                });
941            }
942        };
943
944        serde_json::from_str(&raw).map(Some).map_err(|source| {
945            KontextDevError::TokenCacheDeserialize {
946                path: path.display().to_string(),
947                source,
948            }
949        })
950    }
951
952    fn write_cache(
953        &self,
954        identity: &AccessToken,
955        gateway: &TokenExchangeToken,
956    ) -> Result<(), KontextDevError> {
957        let Some(path) = self.token_cache_path() else {
958            return Ok(());
959        };
960
961        if let Some(parent) = path.parent() {
962            fs::create_dir_all(parent).map_err(|source| KontextDevError::TokenCacheWrite {
963                path: parent.display().to_string(),
964                source,
965            })?;
966        }
967
968        let payload = TokenCacheFile {
969            client_id: self.config.client_id.clone(),
970            resource: self.config.resource.clone(),
971            identity: CachedAccessToken::from_access_token(identity)?,
972            gateway: CachedAccessToken::from_token_exchange(gateway)?,
973        };
974
975        let serialized = serde_json::to_string_pretty(&payload)
976            .map_err(|source| KontextDevError::TokenCacheSerialize { source })?;
977
978        fs::write(&path, serialized).map_err(|source| KontextDevError::TokenCacheWrite {
979            path: path.display().to_string(),
980            source,
981        })
982    }
983}
984
985/// Legacy helper that uses the client-credentials grant.
986///
987/// New PKCE-based apps should use `KontextDevClient::authenticate_mcp`.
988pub async fn request_access_token(
989    config: &KontextDevConfig,
990) -> Result<AccessToken, KontextDevError> {
991    let token_url = resolve_token_url(config)?;
992
993    let mut body = vec![("grant_type", "client_credentials".to_string())];
994
995    if let Some(scope) = normalized_scope(&config.scope) {
996        body.push(("scope", scope.to_string()));
997    }
998
999    let auth_header = if let Some(secret) = &config.client_secret {
1000        let raw = format!("{}:{}", config.client_id, secret);
1001        Some(format!(
1002            "Basic {}",
1003            base64::engine::general_purpose::STANDARD.encode(raw.as_bytes())
1004        ))
1005    } else {
1006        body.push(("client_id", config.client_id.clone()));
1007        None
1008    };
1009
1010    post_token(&Client::new(), &token_url, auth_header.as_deref(), &body).await
1011}
1012
1013async fn post_token(
1014    http: &Client,
1015    token_url: &str,
1016    authorization: Option<&str>,
1017    body: &[(impl AsRef<str>, String)],
1018) -> Result<AccessToken, KontextDevError> {
1019    let body_vec: Vec<(&str, String)> = body.iter().map(|(k, v)| (k.as_ref(), v.clone())).collect();
1020
1021    let response = post_form_with_optional_auth::<AccessToken>(
1022        http,
1023        token_url,
1024        authorization.map(ToString::to_string),
1025        &body_vec,
1026    )
1027    .await
1028    .map_err(|message| KontextDevError::TokenRequest {
1029        token_url: token_url.to_string(),
1030        message,
1031    })?;
1032
1033    Ok(response)
1034}
1035
1036async fn post_form_with_optional_auth<T>(
1037    http: &Client,
1038    url: &str,
1039    authorization: Option<String>,
1040    body: &[(&str, String)],
1041) -> Result<T, String>
1042where
1043    T: DeserializeOwned,
1044{
1045    let mut request = http
1046        .post(url)
1047        .header("Content-Type", "application/x-www-form-urlencoded");
1048
1049    if let Some(header) = authorization {
1050        request = request.header("Authorization", header);
1051    }
1052
1053    let form = body
1054        .iter()
1055        .map(|(k, v)| (k.to_string(), v.to_string()))
1056        .collect::<Vec<(String, String)>>();
1057
1058    let response = request
1059        .form(&form)
1060        .send()
1061        .await
1062        .map_err(|err| err.to_string())?;
1063
1064    if !response.status().is_success() {
1065        return Err(build_error_message(response).await);
1066    }
1067
1068    response.json::<T>().await.map_err(|err| err.to_string())
1069}
1070
1071async fn build_error_message(response: reqwest::Response) -> String {
1072    let status = response.status();
1073    let fallback = format!(
1074        "{} {}",
1075        status.as_u16(),
1076        status.canonical_reason().unwrap_or("")
1077    );
1078
1079    let body = response.text().await.unwrap_or_default();
1080    if body.is_empty() {
1081        return fallback.trim().to_string();
1082    }
1083
1084    if let Ok(parsed) = serde_json::from_str::<OAuthErrorBody>(&body) {
1085        if let Some(description) = parsed.error_description {
1086            return description;
1087        }
1088        if let Some(message) = parsed.message {
1089            return message;
1090        }
1091        if let Some(error) = parsed.error {
1092            return error;
1093        }
1094    }
1095
1096    format!("{fallback}: {body}")
1097}
1098
1099#[cfg(test)]
1100mod tests {
1101    use super::*;
1102
1103    fn config() -> KontextDevConfig {
1104        KontextDevConfig {
1105            server: Some("http://localhost:4000".to_string()),
1106            mcp_url: None,
1107            token_url: None,
1108            client_id: "client_123".to_string(),
1109            client_secret: None,
1110            scope: DEFAULT_SCOPE.to_string(),
1111            server_name: DEFAULT_SERVER_NAME.to_string(),
1112            resource: DEFAULT_RESOURCE.to_string(),
1113            integration_ui_url: Some("https://app.kontext.dev".to_string()),
1114            integration_return_to: None,
1115            open_connect_page_on_login: true,
1116            auth_timeout_seconds: DEFAULT_AUTH_TIMEOUT_SECONDS,
1117            token_cache_path: None,
1118            redirect_uri: "http://localhost:3333/callback".to_string(),
1119        }
1120    }
1121
1122    #[test]
1123    fn create_connect_url_uses_hosted_ui() {
1124        let client = KontextDevClient::new(config());
1125        let url = client
1126            .integration_connect_url("session-123")
1127            .expect("url should be built");
1128        assert_eq!(
1129            url,
1130            "https://app.kontext.dev/oauth/connect?session=session-123"
1131        );
1132    }
1133
1134    #[test]
1135    fn jwt_exp_decode_reads_exp() {
1136        let header = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(r#"{"alg":"none"}"#);
1137        let payload =
1138            base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(r#"{"exp":4070908800}"#);
1139        let token = format!("{header}.{payload}.sig");
1140        assert_eq!(decode_jwt_exp(&token), Some(4_070_908_800));
1141    }
1142
1143    #[test]
1144    fn oauth_metadata_parses_authorization_endpoint() {
1145        let metadata = serde_json::from_str::<OAuthAuthorizationServerMetadata>(
1146            r#"{
1147              "issuer": "https://issuer.example.com",
1148              "authorization_endpoint": "https://issuer.example.com/oauth2/auth"
1149            }"#,
1150        )
1151        .expect("metadata should parse");
1152
1153        assert_eq!(
1154            metadata.authorization_endpoint.as_deref(),
1155            Some("https://issuer.example.com/oauth2/auth")
1156        );
1157    }
1158
1159    #[test]
1160    fn normalized_scope_omits_blank_values() {
1161        assert_eq!(normalized_scope(""), None);
1162        assert_eq!(normalized_scope("   "), None);
1163    }
1164
1165    #[test]
1166    fn normalized_scope_preserves_non_empty_values() {
1167        assert_eq!(normalized_scope("mcp:invoke"), Some("mcp:invoke"));
1168        assert_eq!(
1169            normalized_scope("  mcp:invoke openid  "),
1170            Some("mcp:invoke openid")
1171        );
1172    }
1173}