use anyhow::{Context, Result, bail};
use async_trait::async_trait;
use reqwest::{Client, Url};
use serde::{Deserialize, Serialize};
use serde_json::Value;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OauthStartRequest {
pub provider: String,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub scopes: Vec<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub resource: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub prompt: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tenant: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub team: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub user: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub relay: Option<OauthRelayContext>,
#[serde(skip_serializing_if = "Option::is_none")]
pub metadata: Option<Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct OauthRelayContext {
#[serde(skip_serializing_if = "Option::is_none")]
pub provider_message_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub platform: Option<String>,
}
#[derive(Debug, Clone)]
pub struct OauthClient<T: StartTransport = ReqwestTransport> {
transport: T,
base_url: Url,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct StartLink {
pub url: Url,
pub connection_name: Option<String>,
}
impl<T: StartTransport> OauthClient<T> {
pub fn with_transport(transport: T, base_url: Url) -> Self {
Self {
transport,
base_url,
}
}
pub async fn build_start_url(&self, request: &OauthStartRequest) -> Result<StartLink> {
let endpoint = self
.base_url
.join("oauth/start")
.context("failed to resolve /oauth/start")?;
let response = self.transport.post_start(endpoint, request).await?;
let url = Url::parse(&response.url).context("oauth/start returned invalid URL")?;
let connection_name = response
.connection_name
.as_deref()
.map(str::trim)
.filter(|value| !value.is_empty())
.map(|value| value.to_string());
Ok(StartLink {
url,
connection_name,
})
}
}
impl OauthClient<ReqwestTransport> {
pub fn new(http: Client, base_url: Url) -> Self {
Self::with_transport(ReqwestTransport::new(http), base_url)
}
}
#[async_trait]
pub trait StartTransport: Send + Sync {
async fn post_start(&self, url: Url, payload: &OauthStartRequest) -> Result<StartResponse>;
}
#[derive(Clone)]
pub struct ReqwestTransport {
http: Client,
}
impl ReqwestTransport {
pub fn new(http: Client) -> Self {
Self { http }
}
}
#[async_trait]
impl StartTransport for ReqwestTransport {
async fn post_start(&self, url: Url, payload: &OauthStartRequest) -> Result<StartResponse> {
let response = self
.http
.post(url)
.json(payload)
.send()
.await
.context("failed to call oauth/start")?;
if !response.status().is_success() {
let status = response.status();
let body = response
.text()
.await
.unwrap_or_else(|_| "<unavailable>".into());
bail!("oauth/start returned {status}: {body}");
}
let payload = response
.json::<StartResponse>()
.await
.context("invalid oauth/start response body")?;
Ok(payload)
}
}
#[derive(Debug, Deserialize, Clone)]
pub struct StartResponse {
url: String,
#[serde(default)]
connection_name: Option<String>,
}
#[cfg(test)]
mod tests {
use super::*;
use anyhow::anyhow;
use serde_json::json;
use std::sync::{Arc, Mutex};
#[tokio::test]
async fn client_posts_payload_and_returns_url() {
let transport = MockTransport::new(
"https://oauth.example/oauth/start",
Ok(StartResponse {
url: "https://oauth.example/start/abc123".into(),
connection_name: Some("m365".into()),
}),
);
let client = OauthClient::with_transport(
transport.clone(),
Url::parse("https://oauth.example/").unwrap(),
);
let request = OauthStartRequest {
provider: "microsoft".into(),
scopes: vec!["User.Read".into()],
resource: Some("https://graph.microsoft.com".into()),
prompt: Some("consent".into()),
tenant: Some("acme".into()),
team: Some("support".into()),
user: Some("user-1".into()),
relay: Some(OauthRelayContext {
provider_message_id: Some("abc123".into()),
platform: Some("teams".into()),
}),
metadata: Some(json!({"variant":"beta"})),
};
let link = client.build_start_url(&request).await.expect("start url");
assert_eq!(link.url.as_str(), "https://oauth.example/start/abc123");
assert_eq!(link.connection_name.as_deref(), Some("m365"));
assert_eq!(transport.captured_count(), 1);
}
#[tokio::test]
async fn client_surface_errors_are_returned() {
let transport = MockTransport::new(
"https://oauth.example/oauth/start",
Err("missing provider".into()),
);
let client = OauthClient::with_transport(
transport.clone(),
Url::parse("https://oauth.example/").unwrap(),
);
let request = OauthStartRequest {
provider: "microsoft".into(),
scopes: Vec::new(),
resource: None,
prompt: None,
tenant: None,
team: None,
user: None,
relay: None,
metadata: None,
};
let err = client.build_start_url(&request).await.unwrap_err();
assert!(
err.to_string().contains("missing provider"),
"unexpected error: {err}"
);
assert_eq!(transport.captured_count(), 1);
}
#[derive(Clone)]
struct MockTransport {
expected_url: String,
response: Arc<Mutex<Result<StartResponse, String>>>,
captured: Arc<Mutex<Vec<OauthStartRequest>>>,
}
impl MockTransport {
fn new(expected_url: &str, response: Result<StartResponse, String>) -> Self {
Self {
expected_url: expected_url.into(),
response: Arc::new(Mutex::new(response)),
captured: Arc::new(Mutex::new(Vec::new())),
}
}
fn captured_count(&self) -> usize {
self.captured.lock().unwrap().len()
}
}
#[async_trait]
impl StartTransport for MockTransport {
async fn post_start(&self, url: Url, payload: &OauthStartRequest) -> Result<StartResponse> {
assert_eq!(url.as_str(), self.expected_url);
self.captured.lock().unwrap().push(payload.clone());
let outcome = self.response.lock().unwrap().clone();
outcome.map_err(|err| anyhow!(err))
}
}
}