gsm_core/platforms/webchat/
oauth.rs1use anyhow::{Context, anyhow};
2use async_trait::async_trait;
3use axum::{
4 extract::{Query, State},
5 response::{Html, IntoResponse, Redirect, Response},
6};
7use greentic_types::TenantCtx;
8use gsm_telemetry::{MessageContext, TelemetryLabels, record_auth_card_clicked};
9use metrics::counter;
10use reqwest::{Client, Url};
11use serde::Deserialize;
12use serde_json::{Value, json};
13use tracing::warn;
14
15#[cfg(feature = "directline_standalone")]
16use super::conversation::{Activity, ChannelAccount, StoreError};
17use super::{config::OAuthProviderConfig, error::WebChatError, http::AppState, telemetry};
18
19pub fn contains_oauth_card(activity: &Value) -> bool {
20 activity
21 .get("attachments")
22 .and_then(Value::as_array)
23 .map(|attachments| {
24 attachments.iter().any(|attachment| {
25 attachment
26 .get("contentType")
27 .or_else(|| attachment.get("content_type"))
28 .and_then(Value::as_str)
29 .map(|ct| {
30 ct.to_ascii_lowercase()
31 .starts_with("application/vnd.microsoft.card.oauth")
32 })
33 .unwrap_or(false)
34 })
35 })
36 .unwrap_or(false)
37}
38
39#[derive(Debug, Deserialize)]
40pub struct StartQuery {
41 #[serde(rename = "conversationId")]
42 pub conversation_id: String,
43 #[serde(default)]
44 pub state: Option<String>,
45}
46
47#[derive(Debug, Deserialize)]
48pub struct CallbackQuery {
49 #[serde(rename = "conversationId")]
50 pub conversation_id: String,
51 #[serde(default)]
52 pub code: Option<String>,
53 #[serde(default)]
54 pub state: Option<String>,
55 #[serde(default)]
56 pub error: Option<String>,
57}
58
59pub async fn start(
60 State(state): State<AppState>,
61 Query(query): Query<StartQuery>,
62) -> Result<impl IntoResponse, OAuthRouteError> {
63 let session = state
64 .sessions
65 .get(&query.conversation_id)
66 .await
67 .map_err(OAuthRouteError::Storage)?
68 .ok_or(OAuthRouteError::ConversationNotFound)?;
69
70 let oauth_config = state
71 .provider
72 .oauth_config(&session.tenant_ctx)
73 .await
74 .map_err(|err| OAuthRouteError::Resolve(WebChatError::Internal(err)))?
75 .ok_or(OAuthRouteError::NotConfigured)?;
76
77 let provider_label = Url::parse(oauth_config.issuer.as_str())
78 .ok()
79 .and_then(|url| url.host_str().map(|host| host.to_string()))
80 .unwrap_or_else(|| oauth_config.issuer.clone());
81 let labels = TelemetryLabels {
82 tenant: session.tenant_ctx.tenant.as_ref().to_string(),
83 platform: Some("bf_webchat".into()),
84 chat_id: Some(query.conversation_id.clone()),
85 msg_id: None,
86 extra: Vec::new(),
87 };
88 let ctx = MessageContext::new(labels);
89 let team = session.tenant_ctx.team.as_ref().map(|team| team.as_ref());
90 record_auth_card_clicked(&ctx, provider_label.as_str(), "bf_webchat", None, team);
91
92 let redirect_uri = build_redirect_uri(&oauth_config, &query.conversation_id)?;
93 let authorize_url = build_authorize_url(&oauth_config, &redirect_uri, query.state.as_deref())?;
94
95 let (env_label, tenant_label, team_label) = telemetry::tenant_labels(&session.tenant_ctx);
96 let env_metric = env_label.to_string();
97 let tenant_metric = tenant_label.to_string();
98 let team_metric = team_label.to_string();
99 counter!(
100 "webchat_oauth_started_total",
101 "env" => env_metric.clone(),
102 "tenant" => tenant_metric.clone(),
103 "team" => team_metric.clone()
104 )
105 .increment(1);
106
107 Ok(Redirect::temporary(authorize_url.as_str()))
108}
109
110pub async fn callback(
111 State(state): State<AppState>,
112 Query(query): Query<CallbackQuery>,
113) -> Result<impl IntoResponse, OAuthRouteError> {
114 if let Some(error) = &query.error {
115 warn!(reason = error.as_str(), "oauth callback returned error");
116 return Ok(Html(CLOSE_WINDOW_HTML));
117 }
118
119 let code = query
120 .code
121 .as_deref()
122 .ok_or(OAuthRouteError::BadRequest("missing code"))?;
123
124 let session = state
125 .sessions
126 .get(&query.conversation_id)
127 .await
128 .map_err(OAuthRouteError::Storage)?
129 .ok_or(OAuthRouteError::ConversationNotFound)?;
130
131 let oauth_config = state
132 .provider
133 .oauth_config(&session.tenant_ctx)
134 .await
135 .map_err(|err| OAuthRouteError::Resolve(WebChatError::Internal(err)))?
136 .ok_or(OAuthRouteError::NotConfigured)?;
137 let redirect_uri = build_redirect_uri(&oauth_config, &query.conversation_id)?;
138 let token_handle = state
139 .oauth_client
140 .exchange_code(&session.tenant_ctx, &oauth_config, code, &redirect_uri)
141 .await
142 .map_err(OAuthRouteError::Exchange)?;
143
144 #[cfg(feature = "directline_standalone")]
145 {
146 let mut activity = Activity::new("message");
147 activity.text = Some("You're signed in.".to_string());
148 activity.from = Some(ChannelAccount {
149 id: "bot".into(),
150 name: None,
151 role: Some("bot".into()),
152 });
153 activity.channel_data = Some(json!({
154 "oauth_token_handle": token_handle,
155 }));
156 let append_result = state
157 .conversations
158 .append(&session.conversation_id, activity.clone())
159 .await;
160 let stored = match append_result {
161 Ok(stored) => stored,
162 Err(StoreError::NotFound(_)) => {
163 state
164 .conversations
165 .create(&session.conversation_id, session.tenant_ctx.clone())
166 .await
167 .map_err(|err| OAuthRouteError::Resume(WebChatError::Internal(err.into())))?;
168 state
169 .conversations
170 .append(&session.conversation_id, activity)
171 .await
172 .map_err(|err| OAuthRouteError::Resume(WebChatError::Internal(err.into())))?
173 }
174 Err(StoreError::QuotaExceeded(_)) => {
175 return Err(OAuthRouteError::Resume(WebChatError::BadRequest(
176 "conversation backlog quota exceeded",
177 )));
178 }
179 Err(err) => return Err(OAuthRouteError::Resume(WebChatError::Internal(err.into()))),
180 };
181
182 if let Err(err) = state
183 .sessions
184 .update_watermark(
185 &session.conversation_id,
186 Some((stored.watermark + 1).to_string()),
187 )
188 .await
189 {
190 warn!(error = %err, "failed to update watermark after oauth");
191 }
192 }
193 #[cfg(not(feature = "directline_standalone"))]
194 {
195 let activity = json!({
196 "type": "event",
197 "name": "oauth.token",
198 "channelData": {
199 "oauth_token_handle": token_handle,
200 }
201 });
202
203 state
204 .post_activity(
205 &session.conversation_id,
206 session.bearer_token.as_str(),
207 activity,
208 )
209 .await
210 .map_err(OAuthRouteError::Resume)?;
211 }
212
213 let (env_label, tenant_label, team_label) = telemetry::tenant_labels(&session.tenant_ctx);
214 let env_metric = env_label.to_string();
215 let tenant_metric = tenant_label.to_string();
216 let team_metric = team_label.to_string();
217 counter!(
218 "webchat_oauth_completed_total",
219 "env" => env_metric.clone(),
220 "tenant" => tenant_metric.clone(),
221 "team" => team_metric.clone()
222 )
223 .increment(1);
224
225 Ok(Html(CLOSE_WINDOW_HTML))
226}
227
228fn build_redirect_uri(
229 config: &OAuthProviderConfig,
230 conversation_id: &str,
231) -> Result<String, OAuthRouteError> {
232 let mut redirect = reqwest::Url::parse(&format!(
233 "{}/webchat/oauth/callback",
234 config.redirect_base.trim_end_matches('/')
235 ))
236 .map_err(|err| OAuthRouteError::Url(err.into()))?;
237 redirect
238 .query_pairs_mut()
239 .append_pair("conversationId", conversation_id);
240 Ok(redirect.into())
241}
242
243fn build_authorize_url(
244 config: &OAuthProviderConfig,
245 redirect_uri: &str,
246 state: Option<&str>,
247) -> Result<reqwest::Url, OAuthRouteError> {
248 let mut url = reqwest::Url::parse(&format!(
249 "{}/authorize",
250 config.issuer.trim_end_matches('/')
251 ))
252 .map_err(|err| OAuthRouteError::Url(err.into()))?;
253 {
254 let mut pairs = url.query_pairs_mut();
255 pairs.append_pair("client_id", config.client_id.as_str());
256 pairs.append_pair("response_type", "code");
257 pairs.append_pair("redirect_uri", redirect_uri);
258 if let Some(state) = state {
259 pairs.append_pair("state", state);
260 }
261 }
262 Ok(url)
263}
264
265#[derive(Debug)]
266pub enum OAuthRouteError {
267 BadRequest(&'static str),
268 ConversationNotFound,
269 NotConfigured,
270 Url(anyhow::Error),
271 Storage(anyhow::Error),
272 Exchange(anyhow::Error),
273 Resolve(WebChatError),
274 Resume(WebChatError),
275}
276
277impl OAuthRouteError {
278 fn as_status(&self) -> axum::http::StatusCode {
279 match self {
280 OAuthRouteError::BadRequest(_) => axum::http::StatusCode::BAD_REQUEST,
281 OAuthRouteError::ConversationNotFound => axum::http::StatusCode::NOT_FOUND,
282 OAuthRouteError::NotConfigured => axum::http::StatusCode::NOT_FOUND,
283 OAuthRouteError::Url(_)
284 | OAuthRouteError::Storage(_)
285 | OAuthRouteError::Exchange(_)
286 | OAuthRouteError::Resolve(_) => axum::http::StatusCode::INTERNAL_SERVER_ERROR,
287 OAuthRouteError::Resume(error) => error.status(),
288 }
289 }
290}
291
292impl IntoResponse for OAuthRouteError {
293 fn into_response(self) -> Response {
294 match self {
295 OAuthRouteError::Resume(err) | OAuthRouteError::Resolve(err) => err.into_response(),
296 OAuthRouteError::BadRequest(message) => {
297 (self.as_status(), Html(message)).into_response()
298 }
299 OAuthRouteError::ConversationNotFound | OAuthRouteError::NotConfigured => {
300 (self.as_status(), Html("not found")).into_response()
301 }
302 OAuthRouteError::Url(_)
303 | OAuthRouteError::Storage(_)
304 | OAuthRouteError::Exchange(_) => {
305 (self.as_status(), Html("internal error")).into_response()
306 }
307 }
308 }
309}
310
311impl From<WebChatError> for OAuthRouteError {
312 fn from(value: WebChatError) -> Self {
313 OAuthRouteError::Resume(value)
314 }
315}
316
317pub const CLOSE_WINDOW_HTML: &str =
318 "<!DOCTYPE html><html><body>You can close this window.</body></html>";
319
320#[async_trait]
321pub trait GreenticOauthClient: Send + Sync {
322 async fn exchange_code(
323 &self,
324 tenant_ctx: &TenantCtx,
325 config: &OAuthProviderConfig,
326 code: &str,
327 redirect_uri: &str,
328 ) -> Result<String, anyhow::Error>;
329}
330
331pub struct ReqwestGreenticOauthClient {
332 client: Client,
333}
334
335impl ReqwestGreenticOauthClient {
336 pub fn new(client: Client) -> Self {
337 Self { client }
338 }
339}
340
341#[derive(Debug, Deserialize)]
342struct TokenExchangeResponse {
343 token_handle: String,
344}
345
346#[async_trait]
347impl GreenticOauthClient for ReqwestGreenticOauthClient {
348 async fn exchange_code(
349 &self,
350 _tenant_ctx: &TenantCtx,
351 config: &OAuthProviderConfig,
352 code: &str,
353 redirect_uri: &str,
354 ) -> Result<String, anyhow::Error> {
355 let token_url = format!("{}/token", config.issuer.trim_end_matches('/'));
356 let response = self
357 .client
358 .post(token_url)
359 .form(&[
360 ("grant_type", "authorization_code"),
361 ("code", code),
362 ("client_id", config.client_id.as_str()),
363 ("redirect_uri", redirect_uri),
364 ])
365 .send()
366 .await
367 .context("oauth token request failed")?;
368
369 if !response.status().is_success() {
370 let status = response.status();
371 let body = response
372 .text()
373 .await
374 .unwrap_or_else(|_| "<unreadable>".to_string());
375 return Err(anyhow!("oauth exchange failed ({status}): {body}"));
376 }
377
378 let body = response
379 .json::<TokenExchangeResponse>()
380 .await
381 .context("oauth token decode failed")?;
382
383 if body.token_handle.trim().is_empty() {
384 return Err(anyhow!("oauth token handle missing in response"));
385 }
386
387 Ok(body.token_handle)
388 }
389}
390
391#[cfg(test)]
392mod tests {
393 use super::*;
394
395 #[test]
396 fn detects_oauth_card() {
397 let activity = json!({
398 "type": "message",
399 "attachments": [
400 {"contentType": "application/vnd.microsoft.card.oauth"}
401 ]
402 });
403 assert!(contains_oauth_card(&activity));
404 }
405
406 #[test]
407 fn ignores_non_oauth_card() {
408 let activity = json!({
409 "type": "message",
410 "attachments": [
411 {"contentType": "application/vnd.microsoft.card.adaptive"}
412 ]
413 });
414 assert!(!contains_oauth_card(&activity));
415 }
416}