use anyhow::{anyhow, Context, Result};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use tokio::sync::Mutex;
use super::discovery::AuthServerMetadata;
use super::OAuthConfig;
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct ClientMetadataDocument {
pub client_name: String,
pub redirect_uris: Vec<String>,
pub grant_types: Vec<String>,
pub token_endpoint_auth_method: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub scope: Option<String>,
}
#[derive(Debug, Deserialize)]
pub struct DcrRegistrationResponse {
pub client_id: String,
#[serde(default)]
pub client_secret: Option<String>,
#[serde(default)]
pub client_id_issued_at: Option<u64>,
#[serde(default)]
pub client_secret_expires_at: Option<u64>,
}
lazy_static::lazy_static! {
static ref CIMD_SERVER: Arc<Mutex<Option<CimdServerState>>> = Arc::new(Mutex::new(None));
}
#[allow(dead_code)]
struct CimdServerState {
shutdown: Arc<tokio::sync::Notify>,
}
fn build_client_metadata(callback_url: &str, scopes: &[String]) -> ClientMetadataDocument {
ClientMetadataDocument {
client_name: "Octomind".to_string(),
redirect_uris: vec![callback_url.to_string()],
grant_types: vec!["authorization_code".to_string()],
token_endpoint_auth_method: "none".to_string(),
scope: if scopes.is_empty() {
None
} else {
Some(scopes.join(" "))
},
}
}
async fn start_cimd_server(callback_url: &str, scopes: &[String]) -> Result<String> {
let metadata = build_client_metadata(callback_url, scopes);
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await?;
let cimd_port = listener
.local_addr()
.context("Failed to get CIMD server port")?
.port();
let client_id_url = format!(
"http://localhost:{}/.well-known/oauth-client.json",
cimd_port
);
crate::log_debug!("Starting CIMD server at {} (client_id URL)", client_id_url);
let shutdown = Arc::new(tokio::sync::Notify::new());
let shutdown_clone = shutdown.clone();
let metadata_clone = metadata.clone();
tokio::spawn(async move {
run_cimd_server(&listener, &metadata_clone, shutdown_clone).await;
});
let mut state = CIMD_SERVER.lock().await;
*state = Some(CimdServerState { shutdown });
Ok(client_id_url)
}
pub async fn stop_cimd_server() {
let mut state = CIMD_SERVER.lock().await;
if let Some(s) = state.take() {
s.shutdown.notify_one();
crate::log_debug!("CIMD server stopped");
}
}
async fn run_cimd_server(
listener: &tokio::net::TcpListener,
metadata: &ClientMetadataDocument,
shutdown: Arc<tokio::sync::Notify>,
) {
let metadata_json = serde_json::to_string(metadata).unwrap_or_default();
loop {
tokio::select! {
result = listener.accept() => {
match result {
Ok((stream, _addr)) => {
let json = metadata_json.clone();
tokio::spawn(async move {
if let Err(e) = handle_cimd_request(stream, &json).await {
crate::log_debug!("CIMD request error: {}", e);
}
});
}
Err(e) => {
crate::log_debug!("CIMD accept error: {}", e);
}
}
}
_ = shutdown.notified() => {
crate::log_debug!("CIMD server shutting down");
break;
}
}
}
}
async fn handle_cimd_request(mut stream: tokio::net::TcpStream, metadata_json: &str) -> Result<()> {
use tokio::io::{AsyncReadExt, AsyncWriteExt};
let mut buf = [0u8; 4096];
let bytes_read = stream.read(&mut buf).await?;
if bytes_read == 0 {
return Ok(());
}
let request = String::from_utf8_lossy(&buf[..bytes_read]);
let request_line = match request.lines().next() {
Some(line) => line.trim(),
None => return Ok(()),
};
let cors_headers = "Access-Control-Allow-Origin: *\r\n\
Access-Control-Allow-Methods: GET, OPTIONS\r\n\
Access-Control-Allow-Headers: Content-Type\r\n";
if request_line.starts_with("OPTIONS") {
let response = format!("HTTP/1.1 204 No Content\r\n{}\r\n", cors_headers);
stream.write_all(response.as_bytes()).await?;
} else if request_line.starts_with("GET /.well-known/oauth-client.json") {
let response = format!(
"HTTP/1.1 200 OK\r\n\
Content-Type: application/json\r\n\
{}\r\n\
Content-Length: {}\r\n\r\n{}",
cors_headers,
metadata_json.len(),
metadata_json
);
stream.write_all(response.as_bytes()).await?;
} else {
let body = "404 Not Found";
let response = format!(
"HTTP/1.1 404 Not Found\r\nContent-Length: {}\r\n\r\n{}",
body.len(),
body
);
stream.write_all(response.as_bytes()).await?;
}
Ok(())
}
async fn register_via_dcr(
registration_endpoint: &str,
callback_url: &str,
scopes: &[String],
) -> Result<DcrRegistrationResponse> {
crate::log_debug!("Registering client via DCR at: {}", registration_endpoint);
let client_metadata = build_client_metadata(callback_url, scopes);
let client = reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(10))
.build()
.context("Failed to create HTTP client for DCR")?;
let response = client
.post(registration_endpoint)
.header("Content-Type", "application/json")
.json(&client_metadata)
.send()
.await
.context(format!(
"Failed to register client at {}",
registration_endpoint
))?;
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
return Err(anyhow!(
"DCR registration failed with status {}: {}",
status,
body
));
}
let dcr_response: DcrRegistrationResponse = response
.json()
.await
.context("Failed to parse DCR registration response")?;
crate::log_debug!(
"DCR registration successful: client_id={}",
dcr_response.client_id
);
Ok(dcr_response)
}
pub async fn resolve_client_id(
oauth_config: OAuthConfig,
auth_metadata: &AuthServerMetadata,
) -> Result<OAuthConfig> {
if auth_metadata
.client_id_metadata_document_supported
.unwrap_or(false)
{
crate::log_debug!("Auth server supports CIMD, starting local metadata server...");
match start_cimd_server(&oauth_config.callback_url, &oauth_config.scopes).await {
Ok(client_id_url) => {
crate::log_debug!("CIMD server started, using client_id: {}", client_id_url);
return Ok(OAuthConfig {
client_id: client_id_url,
..oauth_config
});
}
Err(e) => {
crate::log_debug!(
"CIMD server failed: {}, falling back to DCR if available",
e
);
}
}
}
if let Some(ref registration_endpoint) = auth_metadata.registration_endpoint {
crate::log_debug!(
"Auth server provides DCR endpoint: {}",
registration_endpoint
);
match register_via_dcr(
registration_endpoint,
&oauth_config.callback_url,
&oauth_config.scopes,
)
.await
{
Ok(dcr_response) => {
return Ok(OAuthConfig {
client_id: dcr_response.client_id,
client_secret: dcr_response.client_secret.unwrap_or_default(),
..oauth_config
});
}
Err(e) => {
return Err(anyhow!(
"DCR registration failed: {}. Auth server at {} provides registration_endpoint but registration failed.",
e,
auth_metadata.issuer
));
}
}
}
Err(anyhow!(
"Cannot resolve OAuth client_id: auth server '{}' does not support CIMD \
(client_id_metadata_document_supported not true) and provides no \
registration_endpoint for DCR. MCP Authorization requires one of these \
mechanisms to obtain a client_id.",
auth_metadata.issuer
))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_build_client_metadata() {
let metadata = build_client_metadata(
"http://localhost:34567/oauth/callback",
&["read".to_string(), "write".to_string()],
);
assert_eq!(metadata.client_name, "Octomind");
assert_eq!(
metadata.redirect_uris,
vec!["http://localhost:34567/oauth/callback"]
);
assert_eq!(metadata.grant_types, vec!["authorization_code"]);
assert_eq!(metadata.token_endpoint_auth_method, "none");
assert_eq!(metadata.scope, Some("read write".to_string()));
}
#[test]
fn test_build_client_metadata_empty_scopes() {
let metadata = build_client_metadata("http://localhost:34567/oauth/callback", &[]);
assert!(metadata.scope.is_none());
}
#[test]
fn test_client_metadata_serialization() {
let metadata = build_client_metadata(
"http://localhost:34567/oauth/callback",
&["openid".to_string()],
);
let json = serde_json::to_string(&metadata).unwrap();
assert!(json.contains("Octomind"));
assert!(json.contains("authorization_code"));
assert!(json.contains("none"));
}
}