use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use turbomcp_protocol::{Error as McpError, Result as McpResult};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RegistrationRequest {
#[serde(skip_serializing_if = "Option::is_none")]
pub redirect_uris: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub token_endpoint_auth_method: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub grant_types: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub response_types: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub client_name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub client_uri: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub logo_uri: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub scope: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub contacts: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tos_uri: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub policy_uri: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub software_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub software_version: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub jwks_uri: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub application_type: Option<String>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct RegistrationResponse {
pub client_id: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub client_secret: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub client_secret_expires_at: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub registration_access_token: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub registration_client_uri: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub client_id_issued_at: Option<u64>,
#[serde(flatten)]
pub metadata: HashMap<String, serde_json::Value>,
}
#[derive(Debug, Clone)]
pub struct DcrClient {
endpoint: String,
initial_access_token: Option<String>,
http_client: reqwest::Client,
}
impl DcrClient {
pub fn new(endpoint: String, initial_access_token: Option<String>) -> Self {
Self {
endpoint,
initial_access_token,
http_client: reqwest::Client::new(),
}
}
pub async fn register(&self, request: RegistrationRequest) -> McpResult<RegistrationResponse> {
let mut req = self.http_client.post(&self.endpoint).json(&request);
if let Some(ref token) = self.initial_access_token {
req = req.bearer_auth(token);
}
let response = req
.send()
.await
.map_err(|e| McpError::internal(format!("Registration request failed: {}", e)))?;
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
return Err(McpError::internal(format!(
"Registration failed with {}: {}",
status, body
)));
}
let registration_response = response.json::<RegistrationResponse>().await.map_err(|e| {
McpError::internal(format!("Failed to parse registration response: {}", e))
})?;
Ok(registration_response)
}
pub async fn update(
&self,
registration_uri: &str,
access_token: &str,
request: RegistrationRequest,
) -> McpResult<RegistrationResponse> {
let response = self
.http_client
.put(registration_uri)
.bearer_auth(access_token)
.json(&request)
.send()
.await
.map_err(|e| McpError::internal(format!("Update request failed: {}", e)))?;
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
return Err(McpError::internal(format!(
"Update failed with {}: {}",
status, body
)));
}
let registration_response = response
.json::<RegistrationResponse>()
.await
.map_err(|e| McpError::internal(format!("Failed to parse update response: {}", e)))?;
Ok(registration_response)
}
pub async fn delete(&self, registration_uri: &str, access_token: &str) -> McpResult<()> {
let response = self
.http_client
.delete(registration_uri)
.bearer_auth(access_token)
.send()
.await
.map_err(|e| McpError::internal(format!("Delete request failed: {}", e)))?;
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
return Err(McpError::internal(format!(
"Delete failed with {}: {}",
status, body
)));
}
Ok(())
}
}
pub struct DcrBuilder {
request: RegistrationRequest,
}
impl DcrBuilder {
pub fn mcp_client(client_name: &str, redirect_uri: &str) -> Self {
Self {
request: RegistrationRequest {
client_name: Some(client_name.to_string()),
redirect_uris: Some(vec![redirect_uri.to_string()]),
grant_types: Some(vec![
"authorization_code".to_string(),
"refresh_token".to_string(),
]),
response_types: Some(vec!["code".to_string()]),
token_endpoint_auth_method: Some("client_secret_basic".to_string()),
application_type: Some("web".to_string()),
software_id: Some("turbomcp".to_string()),
software_version: Some(env!("CARGO_PKG_VERSION").to_string()),
scope: None,
client_uri: None,
logo_uri: None,
contacts: None,
tos_uri: None,
policy_uri: None,
jwks_uri: None,
},
}
}
pub fn native_client(client_name: &str, redirect_uri: &str) -> Self {
let mut builder = Self::mcp_client(client_name, redirect_uri);
builder.request.application_type = Some("native".to_string());
builder.request.token_endpoint_auth_method = Some("none".to_string()); builder
}
pub fn with_scopes(mut self, scopes: Vec<String>) -> Self {
self.request.scope = Some(scopes.join(" "));
self
}
pub fn with_client_uri(mut self, uri: String) -> Self {
self.request.client_uri = Some(uri);
self
}
pub fn with_logo_uri(mut self, uri: String) -> Self {
self.request.logo_uri = Some(uri);
self
}
pub fn with_contacts(mut self, contacts: Vec<String>) -> Self {
self.request.contacts = Some(contacts);
self
}
pub fn with_tos_uri(mut self, uri: String) -> Self {
self.request.tos_uri = Some(uri);
self
}
pub fn with_policy_uri(mut self, uri: String) -> Self {
self.request.policy_uri = Some(uri);
self
}
pub fn with_jwks_uri(mut self, uri: String) -> Self {
self.request.jwks_uri = Some(uri);
self
}
pub fn with_redirect_uris(mut self, uris: Vec<String>) -> Self {
self.request.redirect_uris = Some(uris);
self
}
pub fn build(self) -> RegistrationRequest {
self.request
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_dcr_builder_mcp_client() {
let request = DcrBuilder::mcp_client("My MCP Client", "http://localhost:3000/callback")
.with_scopes(vec!["mcp:tools".to_string()])
.build();
assert_eq!(request.client_name, Some("My MCP Client".to_string()));
assert_eq!(
request.redirect_uris,
Some(vec!["http://localhost:3000/callback".to_string()])
);
assert_eq!(request.scope, Some("mcp:tools".to_string()));
assert!(request.software_id.is_some());
assert_eq!(request.application_type, Some("web".to_string()));
}
#[test]
fn test_dcr_builder_native_client() {
let request = DcrBuilder::native_client("My App", "myapp://callback").build();
assert_eq!(request.application_type, Some("native".to_string()));
assert_eq!(request.token_endpoint_auth_method, Some("none".to_string()));
}
#[test]
fn test_registration_response_deserialization() {
let json = r#"{
"client_id": "s6BhdRkqt3",
"client_secret": "cf136dc3c1fc93f31185e5885805d",
"client_secret_expires_at": 1577858400,
"registration_access_token": "this.is.an.access.token.value.ffx83",
"registration_client_uri": "https://server.example.com/register/s6BhdRkqt3",
"client_id_issued_at": 1571158400
}"#;
let response: RegistrationResponse = serde_json::from_str(json).unwrap();
assert_eq!(response.client_id, "s6BhdRkqt3");
assert_eq!(
response.client_secret,
Some("cf136dc3c1fc93f31185e5885805d".to_string())
);
assert_eq!(response.client_secret_expires_at, Some(1577858400));
assert!(response.registration_access_token.is_some());
assert!(response.registration_client_uri.is_some());
}
#[test]
fn test_dcr_client_creation() {
let client = DcrClient::new(
"https://auth.example.com/register".to_string(),
Some("initial_token".to_string()),
);
assert_eq!(client.endpoint, "https://auth.example.com/register");
assert!(client.initial_access_token.is_some());
}
}