use std::env;
use std::net::SocketAddr;
use std::path::PathBuf;
use crate::core::{Event, EventStore, EventType, Result, ShuttleError};
use crate::oauth::{self, OAuthConfig, OAuthStore};
use crate::store::SqliteEventStore;
use axum::extract::{Form, Path as AxumPath, Query, State};
use axum::http::{header, HeaderMap, HeaderValue, StatusCode};
use axum::response::{Html, IntoResponse, Redirect, Response};
use axum::routing::{get, patch, post};
use axum::{Json, Router};
use serde::{Deserialize, Serialize};
use serde_json::json;
use uuid::Uuid;
#[derive(Clone)]
pub struct AppRuntime {
pub store: SqliteEventStore,
pub cwd: PathBuf,
pub workspace_id: String,
pub agent: String,
pub session_id: String,
pub oauth: Option<OAuthRuntime>,
}
#[derive(Clone)]
pub struct OAuthRuntime {
pub config: OAuthConfig,
pub store: OAuthStore,
}
#[derive(Debug, Serialize)]
struct Dashboard {
inbox: Vec<Event>,
tasks: Vec<crate::task::TaskSummary>,
memories: Vec<Event>,
context: crate::context::Context,
}
pub async fn serve(runtime: AppRuntime, addr: SocketAddr) -> Result<()> {
let app = router(runtime);
let listener = tokio::net::TcpListener::bind(addr)
.await
.map_err(|err| ShuttleError::Store(err.to_string()))?;
axum::serve(listener, app)
.await
.map_err(|err| ShuttleError::Store(err.to_string()))
}
pub fn router(runtime: AppRuntime) -> Router {
Router::new()
.route("/", get(index))
.route("/api/dashboard", get(dashboard))
.route("/api/inbox", get(inbox))
.route("/api/tasks", get(tasks))
.route("/api/tasks", post(create_task))
.route("/api/tasks/{id}", patch(update_task))
.route("/api/tasks/{id}/done", post(done_task))
.route("/api/memories", get(memories))
.route("/api/context", get(context))
.route("/api/recall", post(recall))
.route("/api/remember", post(remember))
.route(
"/mcp",
get(mcp_health)
.post(mcp_post)
.delete(mcp_delete)
.options(mcp_options),
)
.route(
"/.well-known/oauth-protected-resource",
get(oauth_protected_resource),
)
.route(
"/.well-known/oauth-protected-resource/mcp",
get(oauth_protected_resource),
)
.route(
"/.well-known/oauth-authorization-server",
get(oauth_authorization_server),
)
.route("/oauth/register", post(oauth_register))
.route(
"/oauth/authorize",
get(oauth_authorize_page).post(oauth_authorize_submit),
)
.route("/oauth/token", post(oauth_token))
.with_state(runtime)
}
async fn index(headers: HeaderMap, State(runtime): State<AppRuntime>) -> Response {
if let Some(response) = mcp_unauthorized_response(runtime.oauth.as_ref(), &headers) {
return response;
}
Html(
r#"<!doctype html>
<html>
<head>
<meta charset="utf-8">
<meta name="viewport" content="width=device-width, initial-scale=1">
<title>Shuttle</title>
<style>
body { font-family: system-ui, sans-serif; margin: 2rem; color: #1f2937; }
main { display: grid; gap: 1rem; grid-template-columns: repeat(auto-fit, minmax(260px, 1fr)); }
section { border: 1px solid #d1d5db; border-radius: 8px; padding: 1rem; }
h1 { margin-top: 0; }
pre { white-space: pre-wrap; overflow-wrap: anywhere; }
</style>
</head>
<body>
<h1>Shuttle</h1>
<main id="dashboard"></main>
<script>
fetch('/api/dashboard').then(r => r.json()).then(data => {
const root = document.getElementById('dashboard');
for (const [name, value] of Object.entries(data)) {
const section = document.createElement('section');
const heading = document.createElement('h2');
heading.textContent = name;
const pre = document.createElement('pre');
pre.textContent = JSON.stringify(value, null, 2);
section.append(heading, pre);
root.append(section);
}
});
</script>
</body>
</html>"#,
)
.into_response()
}
async fn dashboard(headers: HeaderMap, State(runtime): State<AppRuntime>) -> Response {
if let Some(response) = mcp_unauthorized_response(runtime.oauth.as_ref(), &headers) {
return response;
}
Json(Dashboard {
inbox: crate::message::inbox(&runtime.store, &runtime.agent)
.await
.unwrap_or_default(),
tasks: crate::task::open_tasks(&runtime.store, &runtime.workspace_id, Some(20))
.await
.unwrap_or_default(),
memories: crate::memory::memories(&runtime.store)
.await
.unwrap_or_default(),
context: crate::context::assemble_context(
&runtime.store,
&runtime.cwd,
&runtime.workspace_id,
&runtime.agent,
)
.await
.unwrap_or_else(|_| crate::context::Context {
repo: runtime.cwd.display().to_string(),
branch: "unknown".to_owned(),
commit: "unknown".to_owned(),
git_remote: None,
dirty: false,
dirty_files: Vec::new(),
open_tasks: Vec::new(),
claimed_tasks: Vec::new(),
recent_decisions: Vec::new(),
related_memories: Vec::new(),
recent_messages: Vec::new(),
pending_handoffs: Vec::new(),
recent_completed_handoffs: Vec::new(),
inbox: Vec::new(),
}),
})
.into_response()
}
async fn inbox(headers: HeaderMap, State(runtime): State<AppRuntime>) -> Response {
if let Some(response) = mcp_unauthorized_response(runtime.oauth.as_ref(), &headers) {
return response;
}
Json(
crate::message::inbox(&runtime.store, &runtime.agent)
.await
.unwrap_or_default(),
)
.into_response()
}
async fn tasks(headers: HeaderMap, State(runtime): State<AppRuntime>) -> Response {
if let Some(response) = mcp_unauthorized_response(runtime.oauth.as_ref(), &headers) {
return response;
}
Json(
crate::task::open_tasks(&runtime.store, &runtime.workspace_id, Some(20))
.await
.unwrap_or_default(),
)
.into_response()
}
async fn memories(headers: HeaderMap, State(runtime): State<AppRuntime>) -> Response {
if let Some(response) = mcp_unauthorized_response(runtime.oauth.as_ref(), &headers) {
return response;
}
Json(
crate::memory::memories(&runtime.store)
.await
.unwrap_or_default(),
)
.into_response()
}
async fn context(headers: HeaderMap, State(runtime): State<AppRuntime>) -> Response {
if let Some(response) = mcp_unauthorized_response(runtime.oauth.as_ref(), &headers) {
return response;
}
Json(
crate::context::assemble_context(
&runtime.store,
&runtime.cwd,
&runtime.workspace_id,
&runtime.agent,
)
.await
.ok(),
)
.into_response()
}
#[derive(Deserialize)]
struct RecallRequest {
#[serde(default)]
query: String,
}
async fn recall(
headers: HeaderMap,
State(runtime): State<AppRuntime>,
Json(request): Json<RecallRequest>,
) -> Response {
if let Some(response) = mcp_unauthorized_response(runtime.oauth.as_ref(), &headers) {
return response;
}
if request.query.trim().is_empty() {
return api_error("query is required");
}
let status = crate::context::repo_status(&runtime.cwd).ok();
let repo_id = status.as_ref().map(crate::context::repo_id);
let branch = status.as_ref().map(|status| status.branch.as_str());
Json(
crate::memory::ranked_recall(
&runtime.store,
&request.query,
None,
Some(&runtime.workspace_id),
repo_id.as_deref(),
branch,
)
.await
.unwrap_or_default(),
)
.into_response()
}
#[derive(Deserialize)]
struct RememberRequest {
#[serde(default)]
kind: String,
#[serde(default)]
text: String,
}
async fn remember(
headers: HeaderMap,
State(runtime): State<AppRuntime>,
Json(request): Json<RememberRequest>,
) -> Response {
if let Some(response) = mcp_unauthorized_response(runtime.oauth.as_ref(), &headers) {
return response;
}
if request.text.trim().is_empty() {
return api_error("text is required");
}
let event_type = match request.kind.as_str() {
"" | "memory" => EventType::Memory,
"decision" => EventType::Decision,
"observation" => EventType::Observation,
"pattern" => EventType::Pattern,
"fact" => EventType::Fact,
"bug" => EventType::Bug,
other => return api_error(&format!("unknown memory kind {other:?}")),
};
let event = crate::memory::new_typed_memory(
event_type,
runtime.workspace_id.clone(),
runtime.agent.clone(),
runtime.session_id.clone(),
request.text,
);
match runtime
.store
.append(with_repo_metadata(event, &runtime))
.await
{
Ok(event) => Json(event).into_response(),
Err(err) => api_error(&err.to_string()),
}
}
#[derive(Deserialize)]
struct CreateTaskRequest {
#[serde(default)]
title: String,
#[serde(default)]
body: String,
}
async fn create_task(
headers: HeaderMap,
State(runtime): State<AppRuntime>,
Json(request): Json<CreateTaskRequest>,
) -> Response {
if let Some(response) = mcp_unauthorized_response(runtime.oauth.as_ref(), &headers) {
return response;
}
if request.title.trim().is_empty() {
return api_error("title is required");
}
let content = if request.body.is_empty() {
request.title
} else {
format!("{}\n\n{}", request.title, request.body)
};
let event = crate::task::new_task(
runtime.workspace_id.clone(),
runtime.agent.clone(),
runtime.session_id.clone(),
content,
);
match runtime
.store
.append(with_repo_metadata(event, &runtime))
.await
{
Ok(event) => Json(event).into_response(),
Err(err) => api_error(&err.to_string()),
}
}
#[derive(Deserialize)]
struct UpdateTaskRequest {
#[serde(default)]
text: String,
}
async fn update_task(
headers: HeaderMap,
State(runtime): State<AppRuntime>,
AxumPath(id): AxumPath<Uuid>,
Json(request): Json<UpdateTaskRequest>,
) -> Response {
if let Some(response) = mcp_unauthorized_response(runtime.oauth.as_ref(), &headers) {
return response;
}
if request.text.trim().is_empty() {
return api_error("text is required");
}
if let Err(err) =
crate::task::ensure_task_exists(&runtime.store, &runtime.workspace_id, id).await
{
return api_error(&err.to_string());
}
let event = crate::task::new_task_update(
runtime.workspace_id.clone(),
runtime.agent.clone(),
runtime.session_id.clone(),
id,
request.text,
);
match runtime
.store
.append(with_repo_metadata(event, &runtime))
.await
{
Ok(event) => Json(event).into_response(),
Err(err) => api_error(&err.to_string()),
}
}
async fn done_task(
headers: HeaderMap,
State(runtime): State<AppRuntime>,
AxumPath(id): AxumPath<Uuid>,
) -> Response {
if let Some(response) = mcp_unauthorized_response(runtime.oauth.as_ref(), &headers) {
return response;
}
if let Err(err) =
crate::task::ensure_task_exists(&runtime.store, &runtime.workspace_id, id).await
{
return api_error(&err.to_string());
}
let event = crate::task::new_task_done(
runtime.workspace_id.clone(),
runtime.agent.clone(),
runtime.session_id.clone(),
id,
);
match runtime
.store
.append(with_repo_metadata(event, &runtime))
.await
{
Ok(event) => Json(event).into_response(),
Err(err) => api_error(&err.to_string()),
}
}
fn api_error(message: &str) -> Response {
(StatusCode::BAD_REQUEST, Json(json!({ "error": message }))).into_response()
}
fn with_repo_metadata(mut event: Event, runtime: &AppRuntime) -> Event {
if let Ok(status) = crate::context::repo_status(&runtime.cwd) {
let repo_id = crate::context::repo_id(&status);
event.repo_id = Some(repo_id.clone());
event.repo_path = Some(status.repo_path.clone());
event.git_remote = status.git_remote.clone();
event.branch = Some(status.branch.clone());
event.commit = Some(status.commit.clone());
event.repo_dirty = Some(status.dirty);
if let Some(metadata) = event.metadata_json.as_object_mut() {
metadata.insert("repo_id".to_owned(), json!(repo_id));
metadata.insert("repo_path".to_owned(), json!(status.repo_path));
metadata.insert("git_remote".to_owned(), json!(status.git_remote));
metadata.insert("branch".to_owned(), json!(status.branch));
metadata.insert("commit".to_owned(), json!(status.commit));
metadata.insert("repo_dirty".to_owned(), json!(status.dirty));
metadata.insert("dirty_files".to_owned(), json!(status.dirty_files));
}
}
event
}
async fn mcp_health(headers: HeaderMap, State(runtime): State<AppRuntime>) -> Response {
if let Some(response) = mcp_unauthorized_response(runtime.oauth.as_ref(), &headers) {
return response;
}
with_cors((StatusCode::OK, "Shuttle MCP server"))
}
async fn mcp_delete(headers: HeaderMap, State(runtime): State<AppRuntime>) -> Response {
if let Some(response) = mcp_unauthorized_response(runtime.oauth.as_ref(), &headers) {
return response;
}
with_cors((StatusCode::OK, "OK"))
}
async fn mcp_options() -> Response {
with_cors(StatusCode::NO_CONTENT)
}
async fn mcp_post(
headers: HeaderMap,
State(runtime): State<AppRuntime>,
Json(request): Json<crate::mcp::Request>,
) -> Response {
if let Some(response) = mcp_unauthorized_response(runtime.oauth.as_ref(), &headers) {
return response;
}
let response = crate::mcp::handle_request(
&crate::mcp::McpRuntime {
store: runtime.store,
cwd: runtime.cwd,
workspace_id: runtime.workspace_id,
agent: runtime.agent,
session_id: runtime.session_id,
},
request,
)
.await;
with_cors(Json(response))
}
async fn oauth_protected_resource(State(runtime): State<AppRuntime>) -> Response {
let Some(oauth) = runtime.oauth else {
return (StatusCode::NOT_FOUND, "OAuth is not configured").into_response();
};
Json(oauth::protected_resource_metadata(&oauth.config)).into_response()
}
async fn oauth_authorization_server(State(runtime): State<AppRuntime>) -> Response {
let Some(oauth) = runtime.oauth else {
return (StatusCode::NOT_FOUND, "OAuth is not configured").into_response();
};
Json(oauth::authorization_server_metadata(&oauth.config)).into_response()
}
async fn oauth_register(
State(runtime): State<AppRuntime>,
Json(request): Json<oauth::RegisterRequest>,
) -> Response {
let Some(oauth) = runtime.oauth else {
return (StatusCode::NOT_FOUND, "OAuth is not configured").into_response();
};
match oauth.store.register_client(request) {
Ok(client) => Json(json!({
"client_id": client.client_id,
"client_id_issued_at": chrono::Utc::now().timestamp(),
"redirect_uris": client.redirect_uris,
"client_name": client.client_name,
"token_endpoint_auth_method": "none",
}))
.into_response(),
Err(err) => oauth_error(StatusCode::BAD_REQUEST, "invalid_request", &err.to_string()),
}
}
async fn oauth_authorize_page(
State(runtime): State<AppRuntime>,
Query(request): Query<oauth::AuthorizeRequest>,
) -> Response {
let Some(oauth) = runtime.oauth else {
return (StatusCode::NOT_FOUND, "OAuth is not configured").into_response();
};
if request.response_type != "code" {
return oauth_error(
StatusCode::BAD_REQUEST,
"unsupported_response_type",
"response_type must be code",
);
}
match oauth
.store
.client_allows_redirect(&request.client_id, &request.redirect_uri)
{
Ok(true) => {
Html(authorize_html(&request, oauth.config.admin_token.is_some())).into_response()
}
Ok(false) => oauth_error(
StatusCode::BAD_REQUEST,
"invalid_request",
"unknown client_id or redirect_uri",
),
Err(_) => oauth_error(
StatusCode::INTERNAL_SERVER_ERROR,
"server_error",
"failed to validate OAuth client",
),
}
}
async fn oauth_authorize_submit(
State(runtime): State<AppRuntime>,
Form(form): Form<oauth::AuthorizeForm>,
) -> Response {
let Some(oauth) = runtime.oauth else {
return (StatusCode::NOT_FOUND, "OAuth is not configured").into_response();
};
if let Some(expected) = oauth.config.admin_token.as_deref() {
if !constant_time_eq(form.admin_token.as_bytes(), expected.as_bytes()) {
return oauth_error(
StatusCode::UNAUTHORIZED,
"access_denied",
"invalid admin token",
);
}
}
let request = oauth::AuthorizeRequest::from(form);
if request.response_type != "code" {
return oauth_error(
StatusCode::BAD_REQUEST,
"unsupported_response_type",
"response_type must be code",
);
}
match oauth.store.create_code(request.clone()) {
Ok(code) => Redirect::to(&oauth::authorize_redirect(
&request.redirect_uri,
&code,
request.state.as_deref(),
))
.into_response(),
Err(err) => oauth_error(StatusCode::BAD_REQUEST, "invalid_request", &err.to_string()),
}
}
async fn oauth_token(
State(runtime): State<AppRuntime>,
Form(request): Form<oauth::TokenRequest>,
) -> Response {
let Some(oauth) = runtime.oauth else {
return (StatusCode::NOT_FOUND, "OAuth is not configured").into_response();
};
if request.grant_type != "authorization_code" {
return oauth_error(
StatusCode::BAD_REQUEST,
"unsupported_grant_type",
"grant_type must be authorization_code",
);
}
match oauth.store.exchange_code(request) {
Ok(token) => Json(token).into_response(),
Err(err) => oauth_error(StatusCode::BAD_REQUEST, "invalid_grant", &err.to_string()),
}
}
fn mcp_unauthorized_response(
oauth: Option<&OAuthRuntime>,
headers: &HeaderMap,
) -> Option<Response> {
if let Some(oauth) = oauth {
let Some(token) = bearer_token(headers) else {
return Some(unauthorized_oauth(&oauth.config));
};
return match oauth.store.validate_access_token(token) {
Ok(true) => None,
Ok(false) => Some(unauthorized_oauth(&oauth.config)),
Err(_) => Some(oauth_error(
StatusCode::UNAUTHORIZED,
"invalid_token",
"failed to validate access token",
)),
};
}
let token = env::var("SHUTTLE_MCP_BEARER_TOKEN")
.ok()
.filter(|token| !token.is_empty())?;
let expected = format!("Bearer {token}");
if headers
.get("authorization")
.and_then(|header| header.to_str().ok())
.is_some_and(|actual| constant_time_eq(actual.as_bytes(), expected.as_bytes()))
{
None
} else {
Some(with_cors(StatusCode::UNAUTHORIZED))
}
}
fn bearer_token(headers: &HeaderMap) -> Option<&str> {
headers
.get(header::AUTHORIZATION)
.and_then(|header| header.to_str().ok())
.and_then(|value| {
let (scheme, token) = value.split_once(' ')?;
scheme.eq_ignore_ascii_case("Bearer").then_some(token)
})
}
fn constant_time_eq(left: &[u8], right: &[u8]) -> bool {
let mut diff = left.len() ^ right.len();
for index in 0..left.len().max(right.len()) {
let left = *left.get(index).unwrap_or(&0);
let right = *right.get(index).unwrap_or(&0);
diff |= (left ^ right) as usize;
}
diff == 0
}
fn with_cors(response: impl IntoResponse) -> Response {
let (mut parts, body) = response.into_response().into_parts();
parts
.headers
.insert("access-control-allow-origin", HeaderValue::from_static("*"));
parts.headers.insert(
"access-control-allow-methods",
HeaderValue::from_static("GET,POST,DELETE,OPTIONS"),
);
parts.headers.insert(
"access-control-allow-headers",
HeaderValue::from_static(
"accept,authorization,content-type,mcp-protocol-version,mcp-session-id",
),
);
parts.headers.insert(
"access-control-expose-headers",
HeaderValue::from_static("mcp-session-id"),
);
Response::from_parts(parts, body)
}
fn unauthorized_oauth(config: &OAuthConfig) -> Response {
let mut response = with_cors(StatusCode::UNAUTHORIZED);
let header_value = format!(
r#"Bearer resource_metadata="{}/.well-known/oauth-protected-resource/mcp", scope="mcp""#,
quoted_header_value(&config.public_url)
);
if let Ok(value) = HeaderValue::from_str(&header_value) {
response
.headers_mut()
.insert(header::WWW_AUTHENTICATE, value);
}
response
}
fn oauth_error(status: StatusCode, code: &str, description: &str) -> Response {
(
status,
Json(json!({ "error": code, "error_description": description })),
)
.into_response()
}
fn authorize_html(request: &oauth::AuthorizeRequest, requires_admin_token: bool) -> String {
let admin = if requires_admin_token {
r#"<label>Admin token <input name="admin_token" type="password" autocomplete="current-password" required></label>"#
} else {
r#"<input name="admin_token" type="hidden" value="">"#
};
format!(
r#"<!doctype html>
<html>
<head>
<meta charset="utf-8">
<meta name="viewport" content="width=device-width, initial-scale=1">
<title>Authorize Shuttle</title>
<style>
body {{ font-family: system-ui, sans-serif; margin: 2rem; color: #1f2937; }}
form {{ display: grid; gap: 1rem; max-width: 32rem; }}
input, button {{ font: inherit; padding: .6rem; }}
label {{ display: grid; gap: .35rem; }}
</style>
</head>
<body>
<h1>Authorize Shuttle MCP</h1>
<p>{client_id} is requesting access to Shuttle MCP.</p>
<form method="post" action="/oauth/authorize">
{admin}
<input type="hidden" name="response_type" value="{response_type}">
<input type="hidden" name="client_id" value="{client_id}">
<input type="hidden" name="redirect_uri" value="{redirect_uri}">
<input type="hidden" name="state" value="{state}">
<input type="hidden" name="scope" value="{scope}">
<input type="hidden" name="code_challenge" value="{code_challenge}">
<input type="hidden" name="code_challenge_method" value="{code_challenge_method}">
<button type="submit">Authorize</button>
</form>
</body>
</html>"#,
admin = admin,
response_type = html_escape(&request.response_type),
client_id = html_escape(&request.client_id),
redirect_uri = html_escape(&request.redirect_uri),
state = html_escape(request.state.as_deref().unwrap_or("")),
scope = html_escape(request.scope.as_deref().unwrap_or("mcp")),
code_challenge = html_escape(request.code_challenge.as_deref().unwrap_or("")),
code_challenge_method =
html_escape(request.code_challenge_method.as_deref().unwrap_or("S256")),
)
}
fn html_escape(value: &str) -> String {
value
.replace('&', "&")
.replace('<', "<")
.replace('>', ">")
.replace('"', """)
}
fn quoted_header_value(value: &str) -> String {
value.replace('\\', "\\\\").replace('"', "\\\"")
}
#[cfg(test)]
mod tests {
use super::*;
use axum::body::Body;
use axum::http::{Method, Request};
use base64::engine::general_purpose::URL_SAFE_NO_PAD;
use base64::Engine;
use http_body_util::BodyExt;
use sha2::{Digest, Sha256};
use tower::ServiceExt;
fn runtime(oauth: Option<OAuthRuntime>) -> AppRuntime {
let dir = tempfile::tempdir().unwrap().keep();
let db = dir.join("shuttle.db");
AppRuntime {
store: SqliteEventStore::open(&db).unwrap(),
cwd: dir,
workspace_id: "workspace".to_owned(),
agent: "codex".to_owned(),
session_id: "session".to_owned(),
oauth,
}
}
fn oauth_runtime() -> OAuthRuntime {
let dir = tempfile::tempdir().unwrap().keep();
OAuthRuntime {
config: OAuthConfig {
public_url: "https://shuttle.example.test".to_owned(),
admin_token: Some("admin-token".to_owned()),
},
store: OAuthStore::open(dir.join("oauth.db")).unwrap(),
}
}
fn issue_access_token(oauth: &OAuthRuntime) -> String {
let verifier = "abc123abc123abc123abc123abc123abc123abc123abc123";
let challenge = URL_SAFE_NO_PAD.encode(Sha256::digest(verifier.as_bytes()));
let client = oauth
.store
.register_client(oauth::RegisterRequest {
redirect_uris: vec!["https://client.example.test/callback".to_owned()],
client_name: Some("client".to_owned()),
})
.unwrap();
let code = oauth
.store
.create_code(oauth::AuthorizeRequest {
response_type: "code".to_owned(),
client_id: client.client_id.clone(),
redirect_uri: "https://client.example.test/callback".to_owned(),
state: None,
scope: Some("mcp".to_owned()),
code_challenge: Some(challenge),
code_challenge_method: Some("S256".to_owned()),
})
.unwrap();
oauth
.store
.exchange_code(oauth::TokenRequest {
grant_type: "authorization_code".to_owned(),
client_id: client.client_id,
redirect_uri: "https://client.example.test/callback".to_owned(),
code: Some(code),
code_verifier: Some(verifier.to_owned()),
})
.unwrap()
.access_token
}
async fn request(
runtime: AppRuntime,
path: &str,
authorization: Option<&str>,
) -> axum::response::Response {
let mut builder = Request::builder().method(Method::GET).uri(path);
if let Some(authorization) = authorization {
builder = builder.header(header::AUTHORIZATION, authorization);
}
router(runtime)
.oneshot(builder.body(Body::empty()).unwrap())
.await
.unwrap()
}
#[tokio::test]
async fn dashboard_routes_require_bearer_when_oauth_is_configured() {
let oauth = oauth_runtime();
let token = issue_access_token(&oauth);
let index = request(runtime(Some(oauth.clone())), "/", None).await;
let dashboard = request(runtime(Some(oauth.clone())), "/api/dashboard", None).await;
let authorized_index = request(
runtime(Some(oauth.clone())),
"/",
Some(&format!("Bearer {token}")),
)
.await;
let authorized_dashboard = request(
runtime(Some(oauth)),
"/api/dashboard",
Some(&format!("Bearer {token}")),
)
.await;
assert_eq!(index.status(), StatusCode::UNAUTHORIZED);
assert_eq!(dashboard.status(), StatusCode::UNAUTHORIZED);
assert_eq!(authorized_index.status(), StatusCode::OK);
assert_eq!(authorized_dashboard.status(), StatusCode::OK);
}
#[tokio::test]
async fn dashboard_routes_remain_local_open_without_auth_configuration() {
let index = request(runtime(None), "/", None).await;
let dashboard = request(runtime(None), "/api/dashboard", None).await;
assert_eq!(index.status(), StatusCode::OK);
assert_eq!(dashboard.status(), StatusCode::OK);
}
#[tokio::test]
async fn oauth_metadata_is_not_blocked_by_dashboard_auth() {
let response = request(
runtime(Some(oauth_runtime())),
"/.well-known/oauth-authorization-server",
None,
)
.await;
assert_eq!(response.status(), StatusCode::OK);
}
#[tokio::test]
async fn oauth_authorize_submit_redirects_callback_with_get() {
let oauth = oauth_runtime();
let client = oauth
.store
.register_client(oauth::RegisterRequest {
redirect_uris: vec!["https://claude.ai/api/mcp/auth_callback".to_owned()],
client_name: Some("Claude".to_owned()),
})
.unwrap();
let body = format!(
"admin_token=admin-token&response_type=code&client_id={}&redirect_uri=https%3A%2F%2Fclaude.ai%2Fapi%2Fmcp%2Fauth_callback&state=state-123&scope=mcp&code_challenge=challenge&code_challenge_method=S256",
client.client_id
);
let response = router(runtime(Some(oauth)))
.oneshot(
Request::builder()
.method(Method::POST)
.uri("/oauth/authorize")
.header(header::CONTENT_TYPE, "application/x-www-form-urlencoded")
.body(Body::from(body))
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::SEE_OTHER);
let location = response
.headers()
.get(header::LOCATION)
.and_then(|value| value.to_str().ok())
.unwrap();
assert!(location.starts_with("https://claude.ai/api/mcp/auth_callback?code=stl_"));
assert!(location.contains("&state=state-123"));
assert!(!location.contains("&iss="));
assert!(!location.contains("?iss="));
}
#[tokio::test]
async fn oauth_authorize_rejects_redirect_uri_with_trailing_slash() {
let oauth = oauth_runtime();
let client = oauth
.store
.register_client(oauth::RegisterRequest {
redirect_uris: vec!["https://claude.ai/api/mcp/auth_callback".to_owned()],
client_name: Some("Claude".to_owned()),
})
.unwrap();
let body = format!(
"admin_token=admin-token&response_type=code&client_id={}&redirect_uri=https%3A%2F%2Fclaude.ai%2Fapi%2Fmcp%2Fauth_callback%2F&state=state-123&scope=mcp&code_challenge=challenge&code_challenge_method=S256",
client.client_id
);
let response = router(runtime(Some(oauth)))
.oneshot(
Request::builder()
.method(Method::POST)
.uri("/oauth/authorize")
.header(header::CONTENT_TYPE, "application/x-www-form-urlencoded")
.body(Body::from(body))
.unwrap(),
)
.await
.unwrap();
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
}
#[tokio::test]
async fn dashboard_html_renders_json_as_text_not_inner_html() {
let response = request(runtime(None), "/", None).await;
let body = response.into_body().collect().await.unwrap().to_bytes();
let html = String::from_utf8(body.to_vec()).unwrap();
assert!(html.contains("heading.textContent = name"));
assert!(html.contains("pre.textContent = JSON.stringify(value, null, 2)"));
assert!(!html.contains("section.innerHTML"));
}
#[tokio::test]
async fn backend_api_can_remember_and_recall() {
let runtime = runtime(None);
let app = router(runtime);
let remember = app
.clone()
.oneshot(
Request::builder()
.method(Method::POST)
.uri("/api/remember")
.header(header::CONTENT_TYPE, "application/json")
.body(Body::from(
r#"{"kind":"fact","text":"SQLite backs Shuttle"}"#,
))
.unwrap(),
)
.await
.unwrap();
assert_eq!(remember.status(), StatusCode::OK);
let recall = app
.oneshot(
Request::builder()
.method(Method::POST)
.uri("/api/recall")
.header(header::CONTENT_TYPE, "application/json")
.body(Body::from(r#"{"query":"SQLite"}"#))
.unwrap(),
)
.await
.unwrap();
assert_eq!(recall.status(), StatusCode::OK);
let body = recall.into_body().collect().await.unwrap().to_bytes();
let value: serde_json::Value = serde_json::from_slice(&body).unwrap();
assert!(value
.as_array()
.unwrap()
.iter()
.any(|item| item["event"]["content"] == "SQLite backs Shuttle"));
}
}