gsm_core/platforms/webchat/
oauth.rs

1use anyhow::{Context, anyhow};
2use async_trait::async_trait;
3use axum::{
4    extract::{Query, State},
5    response::{Html, IntoResponse, Redirect, Response},
6};
7use greentic_types::TenantCtx;
8use gsm_telemetry::{MessageContext, TelemetryLabels, record_auth_card_clicked};
9use metrics::counter;
10use reqwest::{Client, Url};
11use serde::Deserialize;
12use serde_json::{Value, json};
13use tracing::warn;
14
15#[cfg(feature = "directline_standalone")]
16use super::conversation::{Activity, ChannelAccount, StoreError};
17use super::{config::OAuthProviderConfig, error::WebChatError, http::AppState, telemetry};
18
19pub fn contains_oauth_card(activity: &Value) -> bool {
20    activity
21        .get("attachments")
22        .and_then(Value::as_array)
23        .map(|attachments| {
24            attachments.iter().any(|attachment| {
25                attachment
26                    .get("contentType")
27                    .or_else(|| attachment.get("content_type"))
28                    .and_then(Value::as_str)
29                    .map(|ct| {
30                        ct.to_ascii_lowercase()
31                            .starts_with("application/vnd.microsoft.card.oauth")
32                    })
33                    .unwrap_or(false)
34            })
35        })
36        .unwrap_or(false)
37}
38
39#[derive(Debug, Deserialize)]
40pub struct StartQuery {
41    #[serde(rename = "conversationId")]
42    pub conversation_id: String,
43    #[serde(default)]
44    pub state: Option<String>,
45}
46
47#[derive(Debug, Deserialize)]
48pub struct CallbackQuery {
49    #[serde(rename = "conversationId")]
50    pub conversation_id: String,
51    #[serde(default)]
52    pub code: Option<String>,
53    #[serde(default)]
54    pub state: Option<String>,
55    #[serde(default)]
56    pub error: Option<String>,
57}
58
59pub async fn start(
60    State(state): State<AppState>,
61    Query(query): Query<StartQuery>,
62) -> Result<impl IntoResponse, OAuthRouteError> {
63    let session = state
64        .sessions
65        .get(&query.conversation_id)
66        .await
67        .map_err(OAuthRouteError::Storage)?
68        .ok_or(OAuthRouteError::ConversationNotFound)?;
69
70    let oauth_config = state
71        .provider
72        .oauth_config(&session.tenant_ctx)
73        .await
74        .map_err(|err| OAuthRouteError::Resolve(WebChatError::Internal(err)))?
75        .ok_or(OAuthRouteError::NotConfigured)?;
76
77    let provider_label = Url::parse(oauth_config.issuer.as_str())
78        .ok()
79        .and_then(|url| url.host_str().map(|host| host.to_string()))
80        .unwrap_or_else(|| oauth_config.issuer.clone());
81    let labels = TelemetryLabels {
82        tenant: session.tenant_ctx.tenant.as_ref().to_string(),
83        platform: Some("bf_webchat".into()),
84        chat_id: Some(query.conversation_id.clone()),
85        msg_id: None,
86        extra: Vec::new(),
87    };
88    let ctx = MessageContext::new(labels);
89    let team = session.tenant_ctx.team.as_ref().map(|team| team.as_ref());
90    record_auth_card_clicked(&ctx, provider_label.as_str(), "bf_webchat", None, team);
91
92    let redirect_uri = build_redirect_uri(&oauth_config, &query.conversation_id)?;
93    let authorize_url = build_authorize_url(&oauth_config, &redirect_uri, query.state.as_deref())?;
94
95    let (env_label, tenant_label, team_label) = telemetry::tenant_labels(&session.tenant_ctx);
96    let env_metric = env_label.to_string();
97    let tenant_metric = tenant_label.to_string();
98    let team_metric = team_label.to_string();
99    counter!(
100        "webchat_oauth_started_total",
101        "env" => env_metric.clone(),
102        "tenant" => tenant_metric.clone(),
103        "team" => team_metric.clone()
104    )
105    .increment(1);
106
107    Ok(Redirect::temporary(authorize_url.as_str()))
108}
109
110pub async fn callback(
111    State(state): State<AppState>,
112    Query(query): Query<CallbackQuery>,
113) -> Result<impl IntoResponse, OAuthRouteError> {
114    if let Some(error) = &query.error {
115        warn!(reason = error.as_str(), "oauth callback returned error");
116        return Ok(Html(CLOSE_WINDOW_HTML));
117    }
118
119    let code = query
120        .code
121        .as_deref()
122        .ok_or(OAuthRouteError::BadRequest("missing code"))?;
123
124    let session = state
125        .sessions
126        .get(&query.conversation_id)
127        .await
128        .map_err(OAuthRouteError::Storage)?
129        .ok_or(OAuthRouteError::ConversationNotFound)?;
130
131    let oauth_config = state
132        .provider
133        .oauth_config(&session.tenant_ctx)
134        .await
135        .map_err(|err| OAuthRouteError::Resolve(WebChatError::Internal(err)))?
136        .ok_or(OAuthRouteError::NotConfigured)?;
137    let redirect_uri = build_redirect_uri(&oauth_config, &query.conversation_id)?;
138    let token_handle = state
139        .oauth_client
140        .exchange_code(&session.tenant_ctx, &oauth_config, code, &redirect_uri)
141        .await
142        .map_err(OAuthRouteError::Exchange)?;
143
144    #[cfg(feature = "directline_standalone")]
145    {
146        let mut activity = Activity::new("message");
147        activity.text = Some("You're signed in.".to_string());
148        activity.from = Some(ChannelAccount {
149            id: "bot".into(),
150            name: None,
151            role: Some("bot".into()),
152        });
153        activity.channel_data = Some(json!({
154            "oauth_token_handle": token_handle,
155        }));
156        let append_result = state
157            .conversations
158            .append(&session.conversation_id, activity.clone())
159            .await;
160        let stored = match append_result {
161            Ok(stored) => stored,
162            Err(StoreError::NotFound(_)) => {
163                state
164                    .conversations
165                    .create(&session.conversation_id, session.tenant_ctx.clone())
166                    .await
167                    .map_err(|err| OAuthRouteError::Resume(WebChatError::Internal(err.into())))?;
168                state
169                    .conversations
170                    .append(&session.conversation_id, activity)
171                    .await
172                    .map_err(|err| OAuthRouteError::Resume(WebChatError::Internal(err.into())))?
173            }
174            Err(StoreError::QuotaExceeded(_)) => {
175                return Err(OAuthRouteError::Resume(WebChatError::BadRequest(
176                    "conversation backlog quota exceeded",
177                )));
178            }
179            Err(err) => return Err(OAuthRouteError::Resume(WebChatError::Internal(err.into()))),
180        };
181
182        if let Err(err) = state
183            .sessions
184            .update_watermark(
185                &session.conversation_id,
186                Some((stored.watermark + 1).to_string()),
187            )
188            .await
189        {
190            warn!(error = %err, "failed to update watermark after oauth");
191        }
192    }
193    #[cfg(not(feature = "directline_standalone"))]
194    {
195        let activity = json!({
196            "type": "event",
197            "name": "oauth.token",
198            "channelData": {
199                "oauth_token_handle": token_handle,
200            }
201        });
202
203        state
204            .post_activity(
205                &session.conversation_id,
206                session.bearer_token.as_str(),
207                activity,
208            )
209            .await
210            .map_err(OAuthRouteError::Resume)?;
211    }
212
213    let (env_label, tenant_label, team_label) = telemetry::tenant_labels(&session.tenant_ctx);
214    let env_metric = env_label.to_string();
215    let tenant_metric = tenant_label.to_string();
216    let team_metric = team_label.to_string();
217    counter!(
218        "webchat_oauth_completed_total",
219        "env" => env_metric.clone(),
220        "tenant" => tenant_metric.clone(),
221        "team" => team_metric.clone()
222    )
223    .increment(1);
224
225    Ok(Html(CLOSE_WINDOW_HTML))
226}
227
228fn build_redirect_uri(
229    config: &OAuthProviderConfig,
230    conversation_id: &str,
231) -> Result<String, OAuthRouteError> {
232    let mut redirect = reqwest::Url::parse(&format!(
233        "{}/webchat/oauth/callback",
234        config.redirect_base.trim_end_matches('/')
235    ))
236    .map_err(|err| OAuthRouteError::Url(err.into()))?;
237    redirect
238        .query_pairs_mut()
239        .append_pair("conversationId", conversation_id);
240    Ok(redirect.into())
241}
242
243fn build_authorize_url(
244    config: &OAuthProviderConfig,
245    redirect_uri: &str,
246    state: Option<&str>,
247) -> Result<reqwest::Url, OAuthRouteError> {
248    let mut url = reqwest::Url::parse(&format!(
249        "{}/authorize",
250        config.issuer.trim_end_matches('/')
251    ))
252    .map_err(|err| OAuthRouteError::Url(err.into()))?;
253    {
254        let mut pairs = url.query_pairs_mut();
255        pairs.append_pair("client_id", config.client_id.as_str());
256        pairs.append_pair("response_type", "code");
257        pairs.append_pair("redirect_uri", redirect_uri);
258        if let Some(state) = state {
259            pairs.append_pair("state", state);
260        }
261    }
262    Ok(url)
263}
264
265#[derive(Debug)]
266pub enum OAuthRouteError {
267    BadRequest(&'static str),
268    ConversationNotFound,
269    NotConfigured,
270    Url(anyhow::Error),
271    Storage(anyhow::Error),
272    Exchange(anyhow::Error),
273    Resolve(WebChatError),
274    Resume(WebChatError),
275}
276
277impl OAuthRouteError {
278    fn as_status(&self) -> axum::http::StatusCode {
279        match self {
280            OAuthRouteError::BadRequest(_) => axum::http::StatusCode::BAD_REQUEST,
281            OAuthRouteError::ConversationNotFound => axum::http::StatusCode::NOT_FOUND,
282            OAuthRouteError::NotConfigured => axum::http::StatusCode::NOT_FOUND,
283            OAuthRouteError::Url(_)
284            | OAuthRouteError::Storage(_)
285            | OAuthRouteError::Exchange(_)
286            | OAuthRouteError::Resolve(_) => axum::http::StatusCode::INTERNAL_SERVER_ERROR,
287            OAuthRouteError::Resume(error) => error.status(),
288        }
289    }
290}
291
292impl IntoResponse for OAuthRouteError {
293    fn into_response(self) -> Response {
294        match self {
295            OAuthRouteError::Resume(err) | OAuthRouteError::Resolve(err) => err.into_response(),
296            OAuthRouteError::BadRequest(message) => {
297                (self.as_status(), Html(message)).into_response()
298            }
299            OAuthRouteError::ConversationNotFound | OAuthRouteError::NotConfigured => {
300                (self.as_status(), Html("not found")).into_response()
301            }
302            OAuthRouteError::Url(_)
303            | OAuthRouteError::Storage(_)
304            | OAuthRouteError::Exchange(_) => {
305                (self.as_status(), Html("internal error")).into_response()
306            }
307        }
308    }
309}
310
311impl From<WebChatError> for OAuthRouteError {
312    fn from(value: WebChatError) -> Self {
313        OAuthRouteError::Resume(value)
314    }
315}
316
317pub const CLOSE_WINDOW_HTML: &str =
318    "<!DOCTYPE html><html><body>You can close this window.</body></html>";
319
320#[async_trait]
321pub trait GreenticOauthClient: Send + Sync {
322    async fn exchange_code(
323        &self,
324        tenant_ctx: &TenantCtx,
325        config: &OAuthProviderConfig,
326        code: &str,
327        redirect_uri: &str,
328    ) -> Result<String, anyhow::Error>;
329}
330
331pub struct ReqwestGreenticOauthClient {
332    client: Client,
333}
334
335impl ReqwestGreenticOauthClient {
336    pub fn new(client: Client) -> Self {
337        Self { client }
338    }
339}
340
341#[derive(Debug, Deserialize)]
342struct TokenExchangeResponse {
343    token_handle: String,
344}
345
346#[async_trait]
347impl GreenticOauthClient for ReqwestGreenticOauthClient {
348    async fn exchange_code(
349        &self,
350        _tenant_ctx: &TenantCtx,
351        config: &OAuthProviderConfig,
352        code: &str,
353        redirect_uri: &str,
354    ) -> Result<String, anyhow::Error> {
355        let token_url = format!("{}/token", config.issuer.trim_end_matches('/'));
356        let response = self
357            .client
358            .post(token_url)
359            .form(&[
360                ("grant_type", "authorization_code"),
361                ("code", code),
362                ("client_id", config.client_id.as_str()),
363                ("redirect_uri", redirect_uri),
364            ])
365            .send()
366            .await
367            .context("oauth token request failed")?;
368
369        if !response.status().is_success() {
370            let status = response.status();
371            let body = response
372                .text()
373                .await
374                .unwrap_or_else(|_| "<unreadable>".to_string());
375            return Err(anyhow!("oauth exchange failed ({status}): {body}"));
376        }
377
378        let body = response
379            .json::<TokenExchangeResponse>()
380            .await
381            .context("oauth token decode failed")?;
382
383        if body.token_handle.trim().is_empty() {
384            return Err(anyhow!("oauth token handle missing in response"));
385        }
386
387        Ok(body.token_handle)
388    }
389}
390
391#[cfg(test)]
392mod tests {
393    use super::*;
394
395    #[test]
396    fn detects_oauth_card() {
397        let activity = json!({
398            "type": "message",
399            "attachments": [
400                {"contentType": "application/vnd.microsoft.card.oauth"}
401            ]
402        });
403        assert!(contains_oauth_card(&activity));
404    }
405
406    #[test]
407    fn ignores_non_oauth_card() {
408        let activity = json!({
409            "type": "message",
410            "attachments": [
411                {"contentType": "application/vnd.microsoft.card.adaptive"}
412            ]
413        });
414        assert!(!contains_oauth_card(&activity));
415    }
416}