use std::sync::{Arc, Mutex};
use std::collections::HashMap;
use askama::Template;
use axum::{
body::Bytes,
extract::{ConnectInfo, Path, Query, State},
http::{HeaderMap, StatusCode},
response::{Html, IntoResponse, Redirect, Response},
Form, Json,
};
use std::net::SocketAddr;
use comrak::{markdown_to_html, Options};
use serde::{Deserialize, Serialize};
use crate::{
config::ServeConfig,
db::{AuditEntry, Db, DocumentRecord},
highlight,
parser::{extract_frontmatter, extract_title, parse_document, parse_expiry, validate_slug},
rate_limit::{RateLimitError, RateLimitStore, ReadRateLimit, WriteRateLimit},
webhook,
};
const SLUG_ALPHABET: [char; 63] = [
'0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm',
'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M',
'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z',
'-',
];
#[derive(Debug)]
pub enum AppError {
Unauthorized,
Forbidden,
BadRequest(String),
NotFound,
Conflict(String),
Gone,
Internal(String),
DocumentPasswordRequired,
DocumentPasswordInvalid,
RateLimited(RateLimitError),
}
impl IntoResponse for AppError {
fn into_response(self) -> Response {
match self {
AppError::RateLimited(ref err) => {
crate::rate_limit::too_many_requests_response(err)
}
_ => {
let (status, msg) = match self {
AppError::Unauthorized => (StatusCode::UNAUTHORIZED, "Unauthorized".to_string()),
AppError::Forbidden => (StatusCode::FORBIDDEN, "Forbidden".to_string()),
AppError::BadRequest(m) => (StatusCode::BAD_REQUEST, m),
AppError::NotFound => (StatusCode::NOT_FOUND, "Not found".to_string()),
AppError::Conflict(m) => (StatusCode::CONFLICT, m),
AppError::Gone => (StatusCode::GONE, "Document has expired".to_string()),
AppError::Internal(m) => (StatusCode::INTERNAL_SERVER_ERROR, m),
AppError::DocumentPasswordRequired => {
(StatusCode::UNAUTHORIZED, "Password required".to_string())
}
AppError::DocumentPasswordInvalid => {
(StatusCode::UNAUTHORIZED, "Invalid password".to_string())
}
AppError::RateLimited(_) => unreachable!(),
};
(status, Json(serde_json::json!({ "error": msg }))).into_response()
}
}
}
}
impl From<rusqlite::Error> for AppError {
fn from(e: rusqlite::Error) -> Self {
tracing::error!(error = %e, "Database error");
AppError::Internal("Database error".to_string())
}
}
#[derive(Clone)]
pub struct AuthCodeRecord {
pub client_id: String,
pub redirect_uri: String,
pub expires_at: String, pub code_challenge: String,
pub resource: Option<String>,
pub scope: Option<String>,
}
#[derive(Clone)]
pub struct AccessTokenRecord {
pub client_id: String,
pub scope: Option<String>,
pub expires_at: String, }
#[derive(Clone)]
pub struct OAuthClientRecord {
pub client_id: String,
pub client_name: String,
pub redirect_uris: Vec<String>,
pub grant_types: Vec<String>,
pub response_types: Vec<String>,
pub token_endpoint_auth_method: String,
pub created_at: String,
}
#[derive(Clone)]
pub struct RefreshTokenRecord {
pub client_id: String,
pub access_token: String,
pub scope: Option<String>,
pub expires_at: String,
}
#[derive(Clone)]
pub struct AppState {
pub db: Db,
pub config: Arc<ServeConfig>,
pub auth_codes: Arc<Mutex<HashMap<String, AuthCodeRecord>>>,
pub oauth_clients: Arc<Mutex<HashMap<String, OAuthClientRecord>>>,
pub refresh_tokens: Arc<Mutex<HashMap<String, RefreshTokenRecord>>>,
pub access_tokens: Arc<Mutex<HashMap<String, AccessTokenRecord>>>,
pub rate_limit: Arc<RateLimitStore>,
}
#[derive(Template)]
#[template(path = "document.html")]
struct CleanTemplate<'a> {
title: &'a str,
content: &'a str,
slug: &'a str,
base_url: &'a str,
full_view: bool,
body_empty: bool,
expires_at: Option<String>,
description: String,
}
#[derive(Template)]
#[template(path = "dark.html")]
struct DarkTemplate<'a> {
title: &'a str,
content: &'a str,
slug: &'a str,
base_url: &'a str,
body_empty: bool,
expires_at: Option<String>,
description: String,
}
#[derive(Template)]
#[template(path = "paper.html")]
struct PaperTemplate<'a> {
title: &'a str,
content: &'a str,
slug: &'a str,
base_url: &'a str,
body_empty: bool,
expires_at: Option<String>,
description: String,
}
#[derive(Template)]
#[template(path = "minimal.html")]
struct MinimalTemplate<'a> {
title: &'a str,
content: &'a str,
slug: &'a str,
base_url: &'a str,
body_empty: bool,
expires_at: Option<String>,
description: String,
}
#[derive(Template)]
#[template(path = "hearth.html")]
struct HearthTemplate<'a> {
title: &'a str,
content: &'a str,
slug: &'a str,
base_url: &'a str,
full_view: bool,
body_empty: bool,
expires_at: Option<String>,
description: String,
}
#[derive(Template)]
#[template(path = "password.html")]
struct PasswordTemplate<'a> {
slug: &'a str,
base_url: &'a str,
error: Option<&'a str>,
}
#[derive(Serialize)]
pub struct CreateResponse {
pub url: String,
pub slug: String,
pub api_url: String,
pub title: String,
pub description: Option<String>,
pub created_at: String,
pub expires_at: Option<String>,
}
#[derive(Deserialize)]
pub struct SlugQuery {
pub raw: Option<String>,
pub access_token: Option<String>,
pub password: Option<String>,
}
#[derive(Serialize)]
pub struct DocumentResponse {
pub slug: String,
pub title: String,
pub content: String, pub human_content: String, #[serde(skip_serializing_if = "Option::is_none")]
pub agent_content: Option<String>, pub theme: String,
pub description: Option<String>,
pub created_at: String,
pub expires_at: Option<String>,
}
#[derive(Deserialize)]
pub struct UnlockForm {
pub password: String,
}
#[derive(Deserialize)]
pub struct ListQuery {
pub limit: Option<u32>,
pub offset: Option<u32>,
}
#[derive(Serialize)]
pub struct ListResponse {
pub documents: Vec<crate::db::DocumentSummary>,
pub total: u64,
pub limit: u32,
pub offset: u32,
}
#[derive(Deserialize)]
pub struct AuditQuery {
pub limit: Option<u32>,
pub offset: Option<u32>,
}
#[derive(Serialize)]
pub struct AuditResponse {
pub entries: Vec<crate::db::AuditEntry>,
pub total: u64,
pub limit: u32,
pub offset: u32,
}
pub async fn post_document(
State(state): State<AppState>,
_rl: WriteRateLimit,
headers: HeaderMap,
connect_info: Option<ConnectInfo<SocketAddr>>,
body: Bytes,
) -> Result<Response, AppError> {
let token_name = check_auth(&state, &headers).await?;
let peer_addr = connect_info.map(|c| c.0.ip().to_string());
if body.is_empty() {
return Err(AppError::BadRequest("Request body must not be empty".to_string()));
}
let raw_content = std::str::from_utf8(&body)
.map_err(|_| AppError::BadRequest("Request body must be valid UTF-8".to_string()))?
.to_string();
let fm_result = extract_frontmatter(&raw_content)
.map_err(|e| AppError::BadRequest(e))?;
let meta = fm_result.meta.unwrap_or_default();
let body_text = &fm_result.body;
let slug = if let Some(ref custom_slug) = meta.slug {
validate_slug(custom_slug)
.map_err(|e| AppError::BadRequest(e))?;
custom_slug.clone()
} else {
nanoid::nanoid!(10, &SLUG_ALPHABET)
};
let title = meta.title.unwrap_or_else(|| extract_title(body_text, &slug));
let theme = meta.theme.unwrap_or_else(|| state.config.default_theme.clone());
let now = chrono_now();
let expires_at = match meta.expiry.as_deref() {
Some(exp) => {
let seconds = parse_expiry(exp)
.map_err(|e| AppError::BadRequest(e))?;
Some(add_seconds_to_now(&now, seconds))
}
None => None,
};
let password_hash = match meta.password.as_deref() {
Some(pw) if !pw.is_empty() => Some(hash_password(pw)?),
_ => None,
};
let doc = DocumentRecord {
id: slug.clone(),
slug: slug.clone(),
title: title.clone(),
raw_content,
theme,
password: password_hash,
description: meta.description.clone(),
created_at: now.clone(),
expires_at: expires_at.clone(),
updated_at: now.clone(),
};
let final_doc = match state.db.insert_document(&doc) {
Ok(()) => doc,
Err(e) if is_unique_violation(&e) => {
if meta.slug.is_some() {
return Err(AppError::Conflict(
format!("Slug '{}' is already in use", slug),
));
}
let new_slug = nanoid::nanoid!(10, &SLUG_ALPHABET);
let retry_doc = DocumentRecord {
id: new_slug.clone(),
slug: new_slug.clone(),
..doc
};
state.db.insert_document(&retry_doc)
.map_err(|e2| {
tracing::error!(error = %e2, "Slug collision retry failed");
AppError::Internal("Failed to allocate unique slug".to_string())
})?;
retry_doc
}
Err(e) => return Err(AppError::from(e)),
};
let base = state.config.base_url.trim_end_matches('/');
let response = CreateResponse {
url: format!("{base}/{}", final_doc.slug),
slug: final_doc.slug.clone(),
api_url: format!("{base}/api/v1/documents/{}", final_doc.slug),
title: final_doc.title.clone(),
description: final_doc.description.clone(),
created_at: final_doc.created_at.clone(),
expires_at: final_doc.expires_at.clone(),
};
if let Some(ref wh_url) = state.config.webhook_url {
webhook::dispatch_webhook(
wh_url.clone(),
state.config.webhook_secret.clone(),
"document.created",
now.clone(),
webhook::WebhookDocument {
slug: final_doc.slug.clone(),
title: final_doc.title.clone(),
url: response.url.clone(),
api_url: response.api_url.clone(),
},
);
}
let ip_address = extract_client_ip(&headers, peer_addr.as_deref());
let audit_entry = AuditEntry {
id: nanoid::nanoid!(10),
timestamp: chrono_now(),
action: "create".to_string(),
slug: final_doc.slug.clone(),
token_name,
ip_address,
};
if let Err(e) = state.db.insert_audit_entry(&audit_entry) {
tracing::error!(error = %e, "Failed to write audit entry");
}
Ok((StatusCode::CREATED, Json(response)).into_response())
}
pub async fn put_document(
State(state): State<AppState>,
_rl: WriteRateLimit,
Path(slug): Path<String>,
headers: HeaderMap,
connect_info: Option<ConnectInfo<SocketAddr>>,
body: Bytes,
) -> Result<Response, AppError> {
let token_name = check_auth(&state, &headers).await?;
let peer_addr = connect_info.map(|c| c.0.ip().to_string());
if body.is_empty() {
return Err(AppError::BadRequest("Request body must not be empty".to_string()));
}
let raw_content = std::str::from_utf8(&body)
.map_err(|_| AppError::BadRequest("Request body must be valid UTF-8".to_string()))?
.to_string();
let existing = state.db.get_by_slug(&slug)?
.ok_or(AppError::NotFound)?;
if is_expired(&existing) {
return Err(AppError::Gone);
}
let fm_result = extract_frontmatter(&raw_content)
.map_err(|e| AppError::BadRequest(e))?;
let meta = fm_result.meta.unwrap_or_default();
let body_text = &fm_result.body;
let title = meta.title.unwrap_or_else(|| extract_title(body_text, &slug));
let theme = meta.theme.unwrap_or_else(|| state.config.default_theme.clone());
let now = chrono_now();
let expires_at = match meta.expiry.as_deref() {
Some(exp) if !exp.is_empty() => {
let seconds = parse_expiry(exp)
.map_err(|e| AppError::BadRequest(e))?;
Some(add_seconds_to_now(&now, seconds))
}
Some(_) => None, None => existing.expires_at.clone(), };
let password_hash = match meta.password.as_deref() {
Some(pw) if !pw.is_empty() => Some(hash_password(pw)?),
Some(_) => None, None => existing.password.clone(), };
let updated_doc = DocumentRecord {
id: existing.id,
slug: slug.clone(),
title: title.clone(),
raw_content,
theme,
password: password_hash,
description: meta.description.clone(),
created_at: existing.created_at.clone(),
expires_at: expires_at.clone(),
updated_at: now.clone(),
};
state.db.update_document(&slug, &updated_doc)?;
let base = state.config.base_url.trim_end_matches('/');
let response = CreateResponse {
url: format!("{base}/{slug}"),
slug: slug.clone(),
api_url: format!("{base}/api/v1/documents/{slug}"),
title: updated_doc.title.clone(),
description: updated_doc.description.clone(),
created_at: existing.created_at,
expires_at: updated_doc.expires_at.clone(),
};
if let Some(ref wh_url) = state.config.webhook_url {
webhook::dispatch_webhook(
wh_url.clone(),
state.config.webhook_secret.clone(),
"document.updated",
now.clone(),
webhook::WebhookDocument {
slug: slug.clone(),
title: updated_doc.title.clone(),
url: response.url.clone(),
api_url: response.api_url.clone(),
},
);
}
let ip_address = extract_client_ip(&headers, peer_addr.as_deref());
let audit_entry = AuditEntry {
id: nanoid::nanoid!(10),
timestamp: now,
action: "update".to_string(),
slug: slug.clone(),
token_name,
ip_address,
};
if let Err(e) = state.db.insert_audit_entry(&audit_entry) {
tracing::error!(error = %e, "Failed to write audit entry");
}
Ok((StatusCode::OK, Json(response)).into_response())
}
pub async fn delete_document(
State(state): State<AppState>,
_rl: WriteRateLimit,
Path(slug): Path<String>,
headers: HeaderMap,
connect_info: Option<ConnectInfo<SocketAddr>>,
) -> Result<Response, AppError> {
let token_name = check_auth(&state, &headers).await?;
let peer_addr = connect_info.map(|c| c.0.ip().to_string());
let existing = state.db.get_by_slug(&slug)?
.ok_or(AppError::NotFound)?;
state.db.delete_by_slug(&slug)?;
let now = chrono_now();
if let Some(ref wh_url) = state.config.webhook_url {
let base = state.config.base_url.trim_end_matches('/');
webhook::dispatch_webhook(
wh_url.clone(),
state.config.webhook_secret.clone(),
"document.deleted",
now.clone(),
webhook::WebhookDocument {
slug: existing.slug.clone(),
title: existing.title.clone(),
url: format!("{base}/{}", existing.slug),
api_url: format!("{base}/api/v1/documents/{}", existing.slug),
},
);
}
let ip_address = extract_client_ip(&headers, peer_addr.as_deref());
let audit_entry = AuditEntry {
id: nanoid::nanoid!(10),
timestamp: now,
action: "delete".to_string(),
slug: slug.clone(),
token_name,
ip_address,
};
if let Err(e) = state.db.insert_audit_entry(&audit_entry) {
tracing::error!(error = %e, "Failed to write audit entry");
}
Ok(StatusCode::NO_CONTENT.into_response())
}
pub async fn list_documents(
State(state): State<AppState>,
_rl: WriteRateLimit,
headers: HeaderMap,
Query(params): Query<ListQuery>,
) -> Result<Response, AppError> {
check_auth(&state, &headers).await?;
let limit = params.limit.unwrap_or(20);
let offset = params.offset.unwrap_or(0);
let (documents, total) = state.db.list_documents(limit, offset)
.map_err(AppError::from)?;
let effective_limit = limit.min(100);
Ok(Json(ListResponse {
documents,
total,
limit: effective_limit,
offset,
}).into_response())
}
pub async fn list_audit(
State(state): State<AppState>,
headers: HeaderMap,
Query(params): Query<AuditQuery>,
) -> Result<Response, AppError> {
let identity = check_auth(&state, &headers).await?;
if identity != "admin" {
return Err(AppError::Forbidden);
}
let limit = params.limit.unwrap_or(20);
let offset = params.offset.unwrap_or(0);
let (entries, total) = state.db.list_audit_entries(limit, offset)
.map_err(AppError::from)?;
let effective_limit = limit.min(100);
Ok(Json(AuditResponse {
entries,
total,
limit: effective_limit,
offset,
}).into_response())
}
pub async fn health_check(State(state): State<AppState>) -> Response {
let db_ok = state.db.ping().is_ok();
if db_ok {
(
StatusCode::OK,
Json(serde_json::json!({"status": "ok", "db": "ok"})),
)
.into_response()
} else {
(
StatusCode::SERVICE_UNAVAILABLE,
Json(serde_json::json!({"status": "degraded", "db": "error"})),
)
.into_response()
}
}
pub async fn serve_openapi_yaml(_rl: ReadRateLimit) -> impl IntoResponse {
let yaml = include_str!("../docs/openapi.yaml");
(
StatusCode::OK,
[(axum::http::header::CONTENT_TYPE, "application/yaml; charset=utf-8")],
yaml,
)
}
pub async fn serve_openapi_json(_rl: ReadRateLimit) -> impl IntoResponse {
use std::sync::OnceLock;
static OPENAPI_JSON: OnceLock<String> = OnceLock::new();
let json = OPENAPI_JSON.get_or_init(|| {
let yaml = include_str!("../docs/openapi.yaml");
match serde_yaml::from_str::<serde_json::Value>(yaml) {
Ok(val) => serde_json::to_string(&val).unwrap_or_else(|e| {
format!("{{\"error\":\"JSON serialization failed: {e}\"}}")
}),
Err(e) => format!("{{\"error\":\"YAML parse failed: {e}\"}}"),
}
});
(
StatusCode::OK,
[(axum::http::header::CONTENT_TYPE, "application/json; charset=utf-8")],
json.as_str(),
)
}
pub async fn serve_icon() -> impl IntoResponse {
let bytes = include_bytes!("../assets/icon.jpg");
(
StatusCode::OK,
[(axum::http::header::CONTENT_TYPE, "image/jpeg")],
bytes.as_ref(),
)
}
pub async fn serve_favicon() -> impl IntoResponse {
Redirect::permanent("/icon.png")
}
fn accept_prefers_json(headers: &HeaderMap) -> bool {
headers
.get(axum::http::header::ACCEPT)
.and_then(|v| v.to_str().ok())
.map(|s| s.contains("application/json"))
.unwrap_or(false)
}
fn accept_prefers_markdown(headers: &HeaderMap) -> bool {
headers
.get(axum::http::header::ACCEPT)
.and_then(|v| v.to_str().ok())
.map(|s| s.contains("text/markdown"))
.unwrap_or(false)
}
const KNOWN_BOT_AGENTS: &[&str] = &[
"gptbot",
"chatgpt-user",
"claudebot",
"claude-user",
"google-extended",
"googlebot",
"bingbot",
"perplexitybot",
"anthropic",
"google-agent",
];
fn is_known_bot(headers: &HeaderMap) -> bool {
let ua = match headers
.get(axum::http::header::USER_AGENT)
.and_then(|v| v.to_str().ok())
{
Some(s) => s.to_lowercase(),
None => return false,
};
KNOWN_BOT_AGENTS.iter().any(|bot| ua.contains(bot))
}
fn strip_password_from_content(raw: &str) -> String {
let lines: Vec<&str> = raw.lines().collect();
if lines.is_empty() || lines[0].trim() != "---" {
return raw.to_string();
}
let close_idx = match lines.iter().enumerate().skip(1).find(|(_, l)| l.trim() == "---") {
Some((i, _)) => i,
None => return raw.to_string(),
};
let filtered: Vec<&str> = lines
.iter()
.enumerate()
.filter(|(i, line)| {
if *i >= 1 && *i < close_idx {
let trimmed = line.trim_start();
!trimmed.starts_with("password:")
} else {
true
}
})
.map(|(_, line)| *line)
.collect();
filtered.join("\n")
}
fn build_json_agent_response(doc: &crate::db::DocumentRecord) -> Response {
let safe_content = strip_password_from_content(&doc.raw_content);
let parsed = parse_document(&safe_content, &doc.slug);
let body = DocumentResponse {
slug: doc.slug.clone(),
title: doc.title.clone(),
content: safe_content,
human_content: parsed.human,
agent_content: parsed.agent,
theme: doc.theme.clone(),
description: doc.description.clone(),
created_at: doc.created_at.clone(),
expires_at: doc.expires_at.clone(),
};
(StatusCode::OK, Json(body)).into_response()
}
pub async fn get_slug_md(
State(state): State<AppState>,
_rl: ReadRateLimit,
Path(slug): Path<String>,
headers: HeaderMap,
) -> Result<Response, AppError> {
let bare_slug = slug.strip_suffix(".md").unwrap_or(&slug);
let doc = match state.db.get_by_slug(bare_slug)? {
Some(d) => d,
None => return Ok(not_found_response()),
};
if is_expired(&doc) {
return Ok(gone_response());
}
if doc.password.is_some() {
if !is_password_authed(&headers, bare_slug, &state.config.token) {
let template = PasswordTemplate { slug: bare_slug, base_url: state.config.base_url.trim_end_matches('/'), error: None };
return Ok(Html(template.render().map_err(|e| {
AppError::Internal(format!("Template error: {e}"))
})?).into_response());
}
}
Ok(markdown_response(&doc.raw_content))
}
pub async fn get_human(
State(state): State<AppState>,
_rl: ReadRateLimit,
Path(slug): Path<String>,
Query(params): Query<SlugQuery>,
headers: HeaderMap,
) -> Result<Response, AppError> {
let (slug, force_markdown) = if let Some(bare) = slug.strip_suffix(".md") {
(bare.to_string(), true)
} else {
(slug, false)
};
let doc = match state.db.get_by_slug(&slug)? {
Some(d) => d,
None => return Ok(not_found_response()),
};
if is_expired(&doc) {
return Ok(gone_response());
}
if let Some(stored_hash) = &doc.password {
let query_provided = params.access_token.as_deref().or(params.password.as_deref());
let query_pw_valid = if let Some(provided) = query_provided {
let provided_owned = provided.to_string();
let hash_owned = stored_hash.clone();
tokio::task::spawn_blocking(move || verify_password(&provided_owned, &hash_owned))
.await
.map_err(|e| AppError::Internal(format!("Auth task failed: {e}")))?
} else {
false
};
if !query_pw_valid && !is_password_authed(&headers, &slug, &state.config.token) {
let template = PasswordTemplate { slug: &slug, base_url: state.config.base_url.trim_end_matches('/'), error: None };
return Ok(Html(template.render().map_err(|e| {
AppError::Internal(format!("Template error: {e}"))
})?).into_response());
}
}
if force_markdown {
return Ok(markdown_response(&doc.raw_content));
}
if params.raw.as_deref() == Some("1") {
return Ok(markdown_response(&doc.raw_content));
}
if accept_prefers_json(&headers) {
return Ok(build_json_agent_response(&doc));
}
if accept_prefers_markdown(&headers) {
return Ok(markdown_response(&doc.raw_content));
}
let accept_explicitly_html = headers
.get(axum::http::header::ACCEPT)
.and_then(|v| v.to_str().ok())
.map(|s| s.contains("text/html"))
.unwrap_or(false);
if !accept_explicitly_html && is_known_bot(&headers) {
return Ok(build_json_agent_response(&doc));
}
let raw_content = doc.raw_content.clone();
let title = doc.title.clone();
let theme = doc.theme.clone();
let slug_owned = slug.clone();
let expires_at = doc.expires_at.clone();
let base_url = state.config.base_url.trim_end_matches('/').to_string();
let base_url_clone = base_url.clone();
let html_result = tokio::task::spawn_blocking(move || {
let fm_result = extract_frontmatter(&raw_content)
.unwrap_or_else(|_| crate::parser::FrontmatterResult {
meta: None,
body: raw_content.clone(),
});
let parse_result = parse_document(&fm_result.body, &slug_owned);
let rendered_html = render_markdown(&parse_result.human);
render_themed_sync(&title, &rendered_html, &slug_owned, &theme, &base_url_clone, false, expires_at)
})
.await
.map_err(|e| AppError::Internal(format!("Render task failed: {e}")))?;
let link_header = format!(
"<{base_url}/api/v1/documents/{slug}>; rel=\"alternate\"; type=\"application/json\"",
);
let html_response = html_result?;
let mut response = html_response.into_response();
response.headers_mut().insert(
axum::http::header::LINK,
axum::http::HeaderValue::from_str(&link_header)
.unwrap_or_else(|_| axum::http::HeaderValue::from_static("")),
);
Ok(response)
}
pub async fn post_unlock(
State(state): State<AppState>,
_rl: ReadRateLimit,
Path(slug): Path<String>,
Form(form): Form<UnlockForm>,
) -> Result<Response, AppError> {
let doc = match state.db.get_by_slug(&slug)? {
Some(d) => d,
None => return Ok(not_found_response()),
};
if is_expired(&doc) {
return Ok(gone_response());
}
let stored_hash = match &doc.password {
Some(h) => h,
None => {
return Ok(Redirect::to(&format!("/{slug}")).into_response());
}
};
if verify_password(&form.password, stored_hash) {
let cookie_value = make_auth_cookie(&slug, &state.config.token);
let cookie_header = format!(
"twofold_auth_{}={}; Path=/{}; HttpOnly; SameSite=Strict; Max-Age=3600",
slug, cookie_value, slug
);
Ok((
StatusCode::SEE_OTHER,
[
(axum::http::header::LOCATION, format!("/{slug}")),
(axum::http::header::SET_COOKIE, cookie_header),
],
"",
).into_response())
} else {
let template = PasswordTemplate {
slug: &slug,
base_url: state.config.base_url.trim_end_matches('/'),
error: Some("Incorrect password"),
};
Ok(Html(template.render().map_err(|e| {
AppError::Internal(format!("Template error: {e}"))
})?).into_response())
}
}
pub async fn get_full(
State(state): State<AppState>,
_rl: ReadRateLimit,
Path(slug): Path<String>,
headers: HeaderMap,
) -> Result<Response, AppError> {
let doc = match state.db.get_by_slug(&slug)? {
Some(d) => d,
None => return Ok(not_found_response()),
};
if is_expired(&doc) {
return Ok(gone_response());
}
if doc.password.is_some() {
if !is_password_authed(&headers, &slug, &state.config.token) {
let template = PasswordTemplate { slug: &slug, base_url: state.config.base_url.trim_end_matches('/'), error: None };
return Ok(Html(template.render().map_err(|e| {
AppError::Internal(format!("Template error: {e}"))
})?).into_response());
}
}
let raw_content = doc.raw_content.clone();
let title = doc.title.clone();
let theme = doc.theme.clone();
let slug_owned = slug.clone();
let expires_at = doc.expires_at.clone();
let base_url = state.config.base_url.trim_end_matches('/').to_string();
let base_url_clone = base_url.clone();
let html_result = tokio::task::spawn_blocking(move || {
let fm_result = extract_frontmatter(&raw_content)
.unwrap_or_else(|_| crate::parser::FrontmatterResult {
meta: None,
body: raw_content.clone(),
});
let stripped = strip_marker_comments(&fm_result.body);
let rendered_html = render_markdown(&stripped);
render_themed_sync(&title, &rendered_html, &slug_owned, &theme, &base_url_clone, true, expires_at)
})
.await
.map_err(|e| AppError::Internal(format!("Render task failed: {e}")))?;
let link_header = format!(
"<{base_url}/api/v1/documents/{slug}>; rel=\"alternate\"; type=\"application/json\"",
);
let html_response = html_result?;
let mut response = html_response.into_response();
response.headers_mut().insert(
axum::http::header::LINK,
axum::http::HeaderValue::from_str(&link_header)
.unwrap_or_else(|_| axum::http::HeaderValue::from_static("")),
);
Ok(response)
}
#[derive(Deserialize)]
pub struct AgentQuery {
pub access_token: Option<String>,
pub password: Option<String>,
}
pub async fn get_agent(
State(state): State<AppState>,
_rl: ReadRateLimit,
Path(slug): Path<String>,
Query(params): Query<AgentQuery>,
) -> Result<Response, AppError> {
let doc = state.db.get_by_slug(&slug)?
.ok_or(AppError::NotFound)?;
if is_expired(&doc) {
return Err(AppError::Gone);
}
if let Some(stored_hash) = &doc.password {
let provided = params.access_token.as_deref().or(params.password.as_deref());
match provided {
Some(provided) if verify_password(provided, stored_hash) => {
}
Some(_) => {
return Err(AppError::DocumentPasswordInvalid);
}
None => {
return Err(AppError::DocumentPasswordRequired);
}
}
}
Ok(markdown_response(&doc.raw_content))
}
async fn check_auth(state: &AppState, headers: &HeaderMap) -> Result<String, AppError> {
let provided = extract_bearer(headers)
.ok_or(AppError::Unauthorized)?;
check_auth_token(state, provided).await
}
pub async fn check_auth_token(state: &AppState, provided: &str) -> Result<String, AppError> {
if constant_time_eq(provided.as_bytes(), state.config.token.as_bytes()) {
return Ok("admin".to_string());
}
{
let now = chrono_now();
let mut tokens = state.access_tokens.lock().unwrap_or_else(|e| e.into_inner());
tokens.retain(|_, v| v.expires_at.as_str() >= now.as_str());
if tokens.contains_key(provided) {
return Ok("oauth".to_string());
}
}
let prefix: String = provided.chars().take(8).collect();
let candidate = state.db.get_token_by_prefix(&prefix)
.map_err(|_| AppError::Internal("Failed to check tokens".to_string()))?;
if let Some(token_record) = candidate {
let provided_owned = provided.to_string();
let hash_owned = token_record.hash.clone();
let verified = tokio::task::spawn_blocking(move || {
verify_password(&provided_owned, &hash_owned)
})
.await
.map_err(|e| AppError::Internal(format!("Auth task failed: {e}")))?;
if verified {
let now = chrono_now();
let _ = state.db.touch_token(&token_record.id, &now);
return Ok(token_record.name.clone());
}
}
let legacy_tokens = state.db.get_legacy_active_tokens()
.map_err(|_| AppError::Internal("Failed to check tokens".to_string()))?;
if !legacy_tokens.is_empty() {
let provided_owned = provided.to_string();
let result = tokio::task::spawn_blocking(move || {
for token_record in &legacy_tokens {
if verify_password(&provided_owned, &token_record.hash) {
return Some((token_record.id.clone(), token_record.name.clone()));
}
}
None
})
.await
.map_err(|e| AppError::Internal(format!("Auth task failed: {e}")))?;
if let Some((id, name)) = result {
let now = chrono_now();
let _ = state.db.touch_token(&id, &now);
return Ok(name);
}
}
Err(AppError::Unauthorized)
}
pub fn extract_client_ip(headers: &HeaderMap, fallback: Option<&str>) -> String {
if let Some(xff) = headers.get("x-forwarded-for").and_then(|v| v.to_str().ok()) {
if let Some(first) = xff.split(',').next() {
let candidate = first.trim();
if !candidate.is_empty() {
if candidate.parse::<std::net::IpAddr>().is_ok() {
return candidate.to_string();
}
}
}
}
if let Some(addr_str) = fallback {
if let Ok(socket_addr) = addr_str.parse::<std::net::SocketAddr>() {
return socket_addr.ip().to_string();
}
if !addr_str.is_empty() {
return addr_str.to_string();
}
}
"unknown".to_string()
}
fn extract_bearer(headers: &HeaderMap) -> Option<&str> {
let auth = headers.get("authorization")?.to_str().ok()?;
auth.strip_prefix("Bearer ")
}
fn markdown_response(content: &str) -> Response {
(
StatusCode::OK,
[(
axum::http::header::CONTENT_TYPE,
"text/markdown; charset=utf-8",
)],
content.to_string(),
)
.into_response()
}
fn strip_marker_comments(source: &str) -> String {
let mut result: Vec<&str> = Vec::new();
let mut in_instructions = false;
for line in source.lines() {
let t = line.trim();
let tag = if t.starts_with("<!--") && t.ends_with("-->") {
let inner = &t["<!--".len()..t.len() - "-->".len()];
Some(inner.trim())
} else {
None
};
match tag {
Some("@instructions") => {
in_instructions = true;
}
Some("@end-instructions") if in_instructions => {
in_instructions = false;
}
Some("@agent") | Some("@end") if !in_instructions => {
}
_ if in_instructions => {}
_ => {
result.push(line);
}
}
}
result.join("\n")
}
fn render_markdown(source: &str) -> String {
let mut options = Options::default();
options.extension.table = true;
options.extension.strikethrough = true;
options.extension.autolink = true;
options.extension.tasklist = true;
options.render.unsafe_ = true;
markdown_to_html(source, &options)
}
fn render_themed_sync(title: &str, content: &str, slug: &str, theme: &str, base_url: &str, full_view: bool, expires_at: Option<String>) -> Result<Response, AppError> {
let is_dark = theme == "dark";
let highlighted = highlight::apply_syntax_highlighting(content, is_dark);
let body_empty = highlighted.trim().is_empty();
let description = plain_text_excerpt(&highlighted, 150);
let html = match theme {
"dark" => {
let t = DarkTemplate { title, content: &highlighted, slug, base_url, body_empty, expires_at, description };
t.render()
}
"paper" => {
let t = PaperTemplate { title, content: &highlighted, slug, base_url, body_empty, expires_at, description };
t.render()
}
"minimal" => {
let t = MinimalTemplate { title, content: &highlighted, slug, base_url, body_empty, expires_at, description };
t.render()
}
"hearth" => {
let t = HearthTemplate { title, content: &highlighted, slug, base_url, full_view, body_empty, expires_at, description };
t.render()
}
_ => {
let t = CleanTemplate { title, content: &highlighted, slug, base_url, full_view, body_empty, expires_at, description };
t.render()
}
};
html.map(|h| Html(h).into_response())
.map_err(|e| AppError::Internal(format!("Template render error: {e}")))
}
fn plain_text_excerpt(html: &str, max_chars: usize) -> String {
let mut result = String::with_capacity(html.len().min(512));
let mut in_tag = false;
let mut last_was_space = true;
for ch in html.chars() {
match ch {
'<' => { in_tag = true; }
'>' => {
in_tag = false;
if !last_was_space {
result.push(' ');
last_was_space = true;
}
}
_ if in_tag => {}
'\n' | '\r' | '\t' | ' ' => {
if !last_was_space {
result.push(' ');
last_was_space = true;
}
}
_ => {
result.push(ch);
last_was_space = false;
}
}
}
let trimmed = result.trim().to_string();
if trimmed.chars().count() <= max_chars {
trimmed
} else {
let cut: String = trimmed.chars().take(max_chars).collect();
format!("{cut}...")
}
}
fn not_found_response() -> Response {
let html = r#"<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="utf-8">
<meta name="viewport" content="width=device-width, initial-scale=1">
<title>Document not found</title>
<style>
/* twofold — error 404 page (hearth palette) */
:root {
--bg: #F5F0EB;
--fg: #2C2420;
--fg-secondary: #6B5D52;
--fg-muted: #A89888;
--border: #E8E0D8;
--border-strong: #D4C8B8;
--accent: #C4762B;
--accent-hover: #A86220;
--font-body: Charter, 'Bitstream Charter', 'Sitka Text', Cambria, serif;
--font-heading: system-ui, -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif;
--max-width: 850px;
}
*, *::before, *::after { box-sizing: border-box; }
html {
font-size: 16px;
-webkit-text-size-adjust: 100%;
text-size-adjust: 100%;
-webkit-font-smoothing: antialiased;
-moz-osx-font-smoothing: grayscale;
}
body {
margin: 0;
padding: 0;
background: var(--bg);
color: var(--fg);
font-family: var(--font-body);
font-size: 1.0625rem;
line-height: 1.75;
border-top: 4px solid var(--accent);
min-height: 100vh;
display: flex;
flex-direction: column;
}
main {
flex: 1;
max-width: var(--max-width);
margin: 0 auto;
padding: 4rem 1.75rem 2.5rem;
width: 100%;
}
.error-code {
font-family: var(--font-heading);
font-size: 0.75rem;
font-weight: 600;
letter-spacing: 0.1em;
text-transform: uppercase;
color: var(--accent);
margin: 0 0 1rem;
}
h1 {
font-family: var(--font-heading);
font-size: 2rem;
font-weight: 800;
line-height: 1.15;
color: var(--accent);
margin: 0 0 1rem;
letter-spacing: -0.02em;
padding-bottom: 0.5rem;
border-bottom: 3px solid var(--accent);
}
p {
color: var(--fg-secondary);
margin: 0;
max-width: 36rem;
}
footer {
max-width: var(--max-width);
margin: 0 auto;
padding: 1.75rem 1.75rem 2.5rem;
text-align: center;
border-top: 1px solid var(--border-strong);
width: 100%;
}
footer::before {
content: "";
display: block;
width: 2.5rem;
height: 3px;
background: var(--accent);
margin: 0 auto 1rem;
border-radius: 2px;
}
footer small {
color: var(--fg-muted);
font-size: 0.7rem;
font-family: var(--font-heading);
letter-spacing: 0.08em;
text-transform: uppercase;
}
footer small a {
color: var(--accent);
text-decoration: underline;
text-decoration-thickness: 1px;
text-underline-offset: 2px;
transition: color 0.15s ease;
}
footer small a:hover {
color: var(--accent-hover);
}
@media (max-width: 600px) {
main { padding: 2.5rem 1rem 1.75rem; }
h1 { font-size: 1.625rem; }
}
</style>
</head>
<body>
<main>
<p class="error-code">404</p>
<h1>Document not found</h1>
<p>This document doesn't exist, or the link may be incorrect.</p>
</main>
<footer>
<small>SHARED VIA FLINT · TWOFOLD</small>
</footer>
</body>
</html>"#;
(
StatusCode::NOT_FOUND,
[(axum::http::header::CONTENT_TYPE, "text/html; charset=utf-8")],
html,
)
.into_response()
}
fn gone_response() -> Response {
let html = r#"<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="utf-8">
<meta name="viewport" content="width=device-width, initial-scale=1">
<title>Document expired</title>
<style>
/* twofold — error 410 page (hearth palette, muted for impermanence) */
:root {
--bg: #F5F0EB;
--fg: #2C2420;
--fg-secondary: #6B5D52;
--fg-muted: #A89888;
--border: #E8E0D8;
--border-strong: #D4C8B8;
--accent: #C4762B;
--accent-hover: #A86220;
--font-body: Charter, 'Bitstream Charter', 'Sitka Text', Cambria, serif;
--font-heading: system-ui, -apple-system, BlinkMacSystemFont, 'Segoe UI', sans-serif;
--max-width: 850px;
}
*, *::before, *::after { box-sizing: border-box; }
html {
font-size: 16px;
-webkit-text-size-adjust: 100%;
text-size-adjust: 100%;
-webkit-font-smoothing: antialiased;
-moz-osx-font-smoothing: grayscale;
}
body {
margin: 0;
padding: 0;
background: var(--bg);
color: var(--fg);
font-family: var(--font-body);
font-size: 1.0625rem;
line-height: 1.75;
/* Muted top bar — not gone, just quieter. Ember fading to ash. */
border-top: 4px solid var(--fg-muted);
min-height: 100vh;
display: flex;
flex-direction: column;
}
main {
flex: 1;
max-width: var(--max-width);
margin: 0 auto;
padding: 4rem 1.75rem 2.5rem;
width: 100%;
/* Slightly washed out — this was here but isn't anymore */
opacity: 0.85;
}
.error-code {
font-family: var(--font-heading);
font-size: 0.75rem;
font-weight: 600;
letter-spacing: 0.1em;
text-transform: uppercase;
color: var(--fg-muted);
margin: 0 0 1rem;
}
h1 {
font-family: var(--font-heading);
font-size: 2rem;
font-weight: 800;
line-height: 1.15;
/* Secondary color instead of accent — the fire has gone out */
color: var(--fg-secondary);
margin: 0 0 1rem;
letter-spacing: -0.02em;
padding-bottom: 0.5rem;
/* Muted border — a trace of what was */
border-bottom: 3px solid var(--border-strong);
}
p {
color: var(--fg-muted);
margin: 0 0 1.5rem;
max-width: 36rem;
}
.expiry-mark {
display: inline-block;
width: 2rem;
height: 2px;
background: var(--border-strong);
border-radius: 2px;
vertical-align: middle;
margin-right: 0.5rem;
opacity: 0.6;
}
footer {
max-width: var(--max-width);
margin: 0 auto;
padding: 1.75rem 1.75rem 2.5rem;
text-align: center;
border-top: 1px solid var(--border-strong);
width: 100%;
}
footer::before {
content: "";
display: block;
width: 2.5rem;
height: 3px;
/* Footer ember stays warm even when document is gone */
background: var(--accent);
margin: 0 auto 1rem;
border-radius: 2px;
opacity: 0.5;
}
footer small {
color: var(--fg-muted);
font-size: 0.7rem;
font-family: var(--font-heading);
letter-spacing: 0.08em;
text-transform: uppercase;
}
footer small a {
color: var(--accent);
text-decoration: underline;
text-decoration-thickness: 1px;
text-underline-offset: 2px;
transition: color 0.15s ease;
}
footer small a:hover {
color: var(--accent-hover);
}
@media (max-width: 600px) {
main { padding: 2.5rem 1rem 1.75rem; }
h1 { font-size: 1.625rem; }
}
</style>
</head>
<body>
<main>
<p class="error-code">410</p>
<h1>This document has expired</h1>
<p>This document was set to expire and has been removed.</p>
<p><span class="expiry-mark" aria-hidden="true"></span>The link is no longer valid.</p>
</main>
<footer>
<small>SHARED VIA FLINT · TWOFOLD</small>
</footer>
</body>
</html>"#;
(
StatusCode::GONE,
[(axum::http::header::CONTENT_TYPE, "text/html; charset=utf-8")],
html,
)
.into_response()
}
fn is_expired(doc: &DocumentRecord) -> bool {
match &doc.expires_at {
Some(exp) => {
let now = chrono_now();
exp.as_str() < now.as_str()
}
None => false,
}
}
pub fn chrono_now() -> String {
chrono::Utc::now().format("%Y-%m-%dT%H:%M:%SZ").to_string()
}
fn add_seconds_to_now(_now: &str, seconds: u64) -> String {
let future = chrono::Utc::now() + chrono::Duration::seconds(seconds as i64);
future.format("%Y-%m-%dT%H:%M:%SZ").to_string()
}
fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
use subtle::ConstantTimeEq;
if a.len() != b.len() {
return false;
}
a.ct_eq(b).into()
}
fn is_unique_violation(e: &rusqlite::Error) -> bool {
matches!(
e,
rusqlite::Error::SqliteFailure(err, _) if err.code == rusqlite::ErrorCode::ConstraintViolation
)
}
pub fn hash_password(password: &str) -> Result<String, AppError> {
use argon2::{Argon2, password_hash::{SaltString, PasswordHasher, rand_core::OsRng}};
let salt = SaltString::generate(&mut OsRng);
let argon2 = Argon2::default();
let hash = argon2
.hash_password(password.as_bytes(), &salt)
.map_err(|e| AppError::Internal(format!("Password hashing failed: {e}")))?;
Ok(hash.to_string())
}
pub fn verify_password(password: &str, hash: &str) -> bool {
use argon2::{Argon2, PasswordHash, PasswordVerifier};
let parsed = match PasswordHash::new(hash) {
Ok(h) => h,
Err(_) => return false,
};
Argon2::default()
.verify_password(password.as_bytes(), &parsed)
.is_ok()
}
fn make_auth_cookie(slug: &str, server_secret: &str) -> String {
use hmac::{Hmac, Mac};
use sha2::Sha256;
use base64::Engine;
let expiry = chrono::Utc::now() + chrono::Duration::hours(1);
let expiry_str = expiry.format("%Y-%m-%dT%H:%M:%SZ").to_string();
let mut mac = Hmac::<Sha256>::new_from_slice(server_secret.as_bytes())
.expect("HMAC can take key of any size");
mac.update(slug.as_bytes());
mac.update(expiry_str.as_bytes());
let signature = mac.finalize().into_bytes();
let sig_b64 = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(signature);
format!("{}:{}", sig_b64, expiry_str)
}
fn is_password_authed(headers: &HeaderMap, slug: &str, server_secret: &str) -> bool {
use hmac::{Hmac, Mac};
use sha2::Sha256;
use base64::Engine;
let cookie_name = format!("twofold_auth_{}", slug);
let cookies = match headers.get("cookie").and_then(|v| v.to_str().ok()) {
Some(c) => c,
None => return false,
};
let cookie_value = cookies
.split(';')
.map(|s| s.trim())
.find_map(|pair| {
let mut parts = pair.splitn(2, '=');
let name = parts.next()?;
let value = parts.next()?;
if name == cookie_name { Some(value) } else { None }
});
let cookie_value = match cookie_value {
Some(v) => v,
None => return false,
};
let mut parts = cookie_value.splitn(2, ':');
let sig_b64 = match parts.next() {
Some(s) => s,
None => return false,
};
let expiry_str = match parts.next() {
Some(s) => s,
None => return false,
};
let now = chrono_now();
if expiry_str < now.as_str() {
return false; }
let mut mac = match Hmac::<Sha256>::new_from_slice(server_secret.as_bytes()) {
Ok(m) => m,
Err(_) => return false,
};
mac.update(slug.as_bytes());
mac.update(expiry_str.as_bytes());
let expected_sig = mac.finalize().into_bytes();
let provided_sig = match base64::engine::general_purpose::URL_SAFE_NO_PAD.decode(sig_b64) {
Ok(s) => s,
Err(_) => return false,
};
constant_time_eq(&provided_sig, &expected_sig)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_is_expired_none() {
let doc = DocumentRecord {
id: "test".to_string(),
slug: "test".to_string(),
title: "Test".to_string(),
raw_content: "content".to_string(),
theme: "clean".to_string(),
password: None,
description: None,
created_at: "2024-01-01T00:00:00Z".to_string(),
expires_at: None,
updated_at: "2024-01-01T00:00:00Z".to_string(),
};
assert!(!is_expired(&doc));
}
#[test]
fn test_is_expired_past() {
let doc = DocumentRecord {
id: "test".to_string(),
slug: "test".to_string(),
title: "Test".to_string(),
raw_content: "content".to_string(),
theme: "clean".to_string(),
password: None,
description: None,
created_at: "2024-01-01T00:00:00Z".to_string(),
expires_at: Some("2020-01-01T00:00:00Z".to_string()),
updated_at: "2024-01-01T00:00:00Z".to_string(),
};
assert!(is_expired(&doc));
}
#[test]
fn test_is_expired_future() {
let doc = DocumentRecord {
id: "test".to_string(),
slug: "test".to_string(),
title: "Test".to_string(),
raw_content: "content".to_string(),
theme: "clean".to_string(),
password: None,
description: None,
created_at: "2024-01-01T00:00:00Z".to_string(),
expires_at: Some("2099-01-01T00:00:00Z".to_string()),
updated_at: "2024-01-01T00:00:00Z".to_string(),
};
assert!(!is_expired(&doc));
}
#[test]
fn test_hash_and_verify_password() {
let hash = hash_password("hunter2").unwrap();
assert!(verify_password("hunter2", &hash));
assert!(!verify_password("wrong", &hash));
}
use std::sync::Arc;
use axum::{
body::Body,
http::{Request, StatusCode},
Router,
routing::{get, post},
};
use tower::ServiceExt;
fn test_app(token: &str) -> Router {
let db = crate::db::Db::open(":memory:").expect("in-memory db");
let config = crate::config::ServeConfig {
token: token.to_string(),
db_path: ":memory:".to_string(),
bind: "127.0.0.1:0".to_string(),
base_url: "http://localhost".to_string(),
default_theme: "clean".to_string(),
max_size: 1_048_576,
webhook_url: None,
webhook_secret: None,
reaper_interval: 3600,
rate_limit_read: 1000,
rate_limit_write: 1000,
rate_limit_window: 60,
};
let rate_limit = crate::rate_limit::RateLimitStore::new(&config);
let state = AppState {
db, config: Arc::new(config),
auth_codes: Arc::new(Mutex::new(HashMap::new())),
oauth_clients: Arc::new(Mutex::new(HashMap::new())),
refresh_tokens: Arc::new(Mutex::new(HashMap::new())),
access_tokens: Arc::new(Mutex::new(HashMap::new())),
rate_limit: rate_limit.clone(),
};
Router::new()
.route(
"/api/v1/documents",
post(crate::handlers::post_document).get(crate::handlers::list_documents),
)
.route(
"/api/v1/documents/:slug",
get(crate::handlers::get_agent)
.put(crate::handlers::put_document)
.delete(crate::handlers::delete_document),
)
.layer(axum::Extension(rate_limit))
.with_state(state)
}
fn test_app_full(token: &str) -> Router {
let db = crate::db::Db::open(":memory:").expect("in-memory db");
let config = crate::config::ServeConfig {
token: token.to_string(),
db_path: ":memory:".to_string(),
bind: "127.0.0.1:0".to_string(),
base_url: "http://localhost".to_string(),
default_theme: "clean".to_string(),
max_size: 1_048_576,
webhook_url: None,
webhook_secret: None,
reaper_interval: 3600,
rate_limit_read: 1000,
rate_limit_write: 1000,
rate_limit_window: 60,
};
let rate_limit = crate::rate_limit::RateLimitStore::new(&config);
let state = AppState {
db, config: Arc::new(config),
auth_codes: Arc::new(Mutex::new(HashMap::new())),
oauth_clients: Arc::new(Mutex::new(HashMap::new())),
refresh_tokens: Arc::new(Mutex::new(HashMap::new())),
access_tokens: Arc::new(Mutex::new(HashMap::new())),
rate_limit: rate_limit.clone(),
};
Router::new()
.route(
"/api/v1/documents",
post(crate::handlers::post_document).get(crate::handlers::list_documents),
)
.route(
"/api/v1/documents/:slug",
get(crate::handlers::get_agent)
.put(crate::handlers::put_document)
.delete(crate::handlers::delete_document),
)
.route("/:slug/unlock", post(crate::handlers::post_unlock))
.route("/:slug/full", get(crate::handlers::get_full))
.route("/:slug", get(crate::handlers::get_human))
.layer(axum::Extension(rate_limit))
.with_state(state)
}
#[tokio::test]
async fn test_human_get_nonexistent_returns_404_with_html() {
let token = "test-token";
let app = test_app_full(token);
let req = Request::builder()
.method("GET")
.uri("/this-slug-does-not-exist")
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::NOT_FOUND);
let content_type = resp
.headers()
.get("content-type")
.and_then(|v| v.to_str().ok())
.unwrap_or("");
assert!(
content_type.contains("text/html"),
"404 response should be HTML, got: {content_type}"
);
let body = axum::body::to_bytes(resp.into_body(), usize::MAX).await.unwrap();
let text = std::str::from_utf8(&body).unwrap();
assert!(
text.contains("Document not found") || text.contains("not found"),
"404 body should contain 'not found' text"
);
assert!(
text.contains("FLINT") || text.contains("flint") || text.contains("twofold"),
"404 body should contain footer branding"
);
assert!(
text.contains("<!DOCTYPE html>"),
"404 body should be valid HTML"
);
}
#[tokio::test]
async fn test_human_get_expired_returns_410_with_html() {
let token = "test-token";
let db = crate::db::Db::open(":memory:").expect("in-memory db");
let config = crate::config::ServeConfig {
token: token.to_string(),
db_path: ":memory:".to_string(),
bind: "127.0.0.1:0".to_string(),
base_url: "http://localhost".to_string(),
default_theme: "clean".to_string(),
max_size: 1_048_576,
webhook_url: None,
webhook_secret: None,
reaper_interval: 3600,
rate_limit_read: 1000,
rate_limit_write: 1000,
rate_limit_window: 60,
};
let expired_doc = crate::db::DocumentRecord {
id: "expired-slug".to_string(),
slug: "expired-slug".to_string(),
title: "Expired Doc".to_string(),
raw_content: "# Expired\nThis document has expired.".to_string(),
theme: "clean".to_string(),
password: None,
description: None,
created_at: "2020-01-01T00:00:00Z".to_string(),
expires_at: Some("2020-06-01T00:00:00Z".to_string()), updated_at: "2020-01-01T00:00:00Z".to_string(),
};
db.insert_document(&expired_doc).expect("insert expired doc");
let rate_limit = crate::rate_limit::RateLimitStore::new(&config);
let state = AppState {
db, config: Arc::new(config),
auth_codes: Arc::new(Mutex::new(HashMap::new())),
oauth_clients: Arc::new(Mutex::new(HashMap::new())),
refresh_tokens: Arc::new(Mutex::new(HashMap::new())),
access_tokens: Arc::new(Mutex::new(HashMap::new())),
rate_limit: rate_limit.clone(),
};
let app = Router::new()
.route(
"/api/v1/documents",
post(crate::handlers::post_document).get(crate::handlers::list_documents),
)
.route(
"/api/v1/documents/:slug",
get(crate::handlers::get_agent)
.put(crate::handlers::put_document)
.delete(crate::handlers::delete_document),
)
.route("/:slug/unlock", post(crate::handlers::post_unlock))
.route("/:slug/full", get(crate::handlers::get_full))
.route("/:slug", get(crate::handlers::get_human))
.layer(axum::Extension(rate_limit))
.with_state(state);
let req = Request::builder()
.method("GET")
.uri("/expired-slug")
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::GONE);
let content_type = resp
.headers()
.get("content-type")
.and_then(|v| v.to_str().ok())
.unwrap_or("");
assert!(
content_type.contains("text/html"),
"410 response should be HTML, got: {content_type}"
);
let body = axum::body::to_bytes(resp.into_body(), usize::MAX).await.unwrap();
let text = std::str::from_utf8(&body).unwrap();
assert!(
text.contains("expired") || text.contains("Expired"),
"410 body should contain 'expired' text"
);
assert!(
text.contains("FLINT") || text.contains("flint") || text.contains("twofold"),
"410 body should contain footer branding"
);
assert!(
text.contains("<!DOCTYPE html>"),
"410 body should be valid HTML"
);
}
#[tokio::test]
async fn test_api_get_nonexistent_still_returns_json_404() {
let token = "test-token";
let app = test_app(token);
let req = Request::builder()
.method("GET")
.uri("/api/v1/documents/does-not-exist")
.header("Authorization", format!("Bearer {token}"))
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::NOT_FOUND);
let bytes = axum::body::to_bytes(resp.into_body(), usize::MAX).await.unwrap();
let json: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
assert_eq!(json["error"].as_str().unwrap(), "Not found");
}
#[tokio::test]
async fn test_nonexistent_protected_slug_returns_404_not_password_prompt() {
let token = "test-token";
let app = test_app_full(token);
let req = Request::builder()
.method("GET")
.uri("/nonexistent-protected-slug")
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::NOT_FOUND);
let body = axum::body::to_bytes(resp.into_body(), usize::MAX).await.unwrap();
let text = std::str::from_utf8(&body).unwrap();
assert!(
!text.contains(r#"type="password""#),
"nonexistent slug should not show password prompt"
);
assert!(
text.contains("not found") || text.contains("Not found"),
"should contain not found message"
);
}
async fn publish_doc(app: Router, token: &str, slug: &str, content: &str) -> String {
let body = format!("---\nslug: {slug}\n---\n{content}");
let req = Request::builder()
.method("POST")
.uri("/api/v1/documents")
.header("Authorization", format!("Bearer {token}"))
.header("Content-Type", "text/markdown")
.body(Body::from(body))
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::CREATED, "publish failed");
let bytes = axum::body::to_bytes(resp.into_body(), usize::MAX).await.unwrap();
let json: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
json["slug"].as_str().unwrap().to_string()
}
#[tokio::test]
async fn test_put_updates_existing_document() {
let token = "test-token";
let app = test_app(token);
let slug = publish_doc(app.clone(), token, "my-slug", "# Original\nOriginal content.").await;
assert_eq!(slug, "my-slug");
let req = Request::builder()
.method("PUT")
.uri(format!("/api/v1/documents/{slug}"))
.header("Authorization", format!("Bearer {token}"))
.header("Content-Type", "text/markdown")
.body(Body::from("# Updated\nUpdated content."))
.unwrap();
let resp = app.clone().oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let bytes = axum::body::to_bytes(resp.into_body(), usize::MAX).await.unwrap();
let json: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
assert_eq!(json["slug"].as_str().unwrap(), "my-slug");
let req = Request::builder()
.method("GET")
.uri(format!("/api/v1/documents/{slug}"))
.header("Authorization", format!("Bearer {token}"))
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let body_bytes = axum::body::to_bytes(resp.into_body(), usize::MAX).await.unwrap();
let raw = std::str::from_utf8(&body_bytes).unwrap();
assert!(raw.contains("Updated content."), "content should reflect PUT body");
assert!(!raw.contains("Original content."), "old content should be gone");
}
#[tokio::test]
async fn test_put_returns_404_for_nonexistent_slug() {
let token = "test-token";
let app = test_app(token);
let req = Request::builder()
.method("PUT")
.uri("/api/v1/documents/does-not-exist")
.header("Authorization", format!("Bearer {token}"))
.header("Content-Type", "text/markdown")
.body(Body::from("# Content"))
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::NOT_FOUND);
}
#[tokio::test]
async fn test_put_requires_auth() {
let token = "test-token";
let app = test_app(token);
let req = Request::builder()
.method("PUT")
.uri("/api/v1/documents/anything")
.header("Content-Type", "text/markdown")
.body(Body::from("# Content"))
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
}
#[tokio::test]
async fn test_put_updates_title_from_frontmatter() {
let token = "test-token";
let app = test_app(token);
publish_doc(app.clone(), token, "title-test", "# Old Title\nBody.").await;
let new_content = "---\ntitle: New Title\n---\n# New Title\nBody.";
let req = Request::builder()
.method("PUT")
.uri("/api/v1/documents/title-test")
.header("Authorization", format!("Bearer {token}"))
.header("Content-Type", "text/markdown")
.body(Body::from(new_content))
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let bytes = axum::body::to_bytes(resp.into_body(), usize::MAX).await.unwrap();
let json: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
assert_eq!(json["title"].as_str().unwrap(), "New Title");
}
#[tokio::test]
async fn test_put_response_is_well_formed() {
let token = "test-token";
let app = test_app(token);
let slug = publish_doc(app.clone(), token, "ts-test", "# V1").await;
let req = Request::builder()
.method("PUT")
.uri(format!("/api/v1/documents/{slug}"))
.header("Authorization", format!("Bearer {token}"))
.header("Content-Type", "text/markdown")
.body(Body::from("# V2"))
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let bytes = axum::body::to_bytes(resp.into_body(), usize::MAX).await.unwrap();
let json: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
assert_eq!(json["slug"].as_str().unwrap(), slug);
assert!(json.get("created_at").is_some(), "response should include created_at");
}
#[tokio::test]
async fn test_put_does_not_change_slug() {
let token = "test-token";
let app = test_app(token);
publish_doc(app.clone(), token, "original-slug", "# Doc").await;
let content_with_different_slug = "---\nslug: different-slug\n---\n# Doc";
let req = Request::builder()
.method("PUT")
.uri("/api/v1/documents/original-slug")
.header("Authorization", format!("Bearer {token}"))
.header("Content-Type", "text/markdown")
.body(Body::from(content_with_different_slug))
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let bytes = axum::body::to_bytes(resp.into_body(), usize::MAX).await.unwrap();
let json: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
assert_eq!(json["slug"].as_str().unwrap(), "original-slug");
}
async fn publish_protected_doc(app: Router, token: &str, slug: &str, password: &str) -> String {
let body = format!("---\nslug: {slug}\npassword: {password}\n---\nSecret content.");
let req = Request::builder()
.method("POST")
.uri("/api/v1/documents")
.header("Authorization", format!("Bearer {token}"))
.header("Content-Type", "text/markdown")
.body(Body::from(body))
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::CREATED, "publish protected doc failed");
let bytes = axum::body::to_bytes(resp.into_body(), usize::MAX).await.unwrap();
let json: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
json["slug"].as_str().unwrap().to_string()
}
#[tokio::test]
async fn test_agent_get_protected_doc_correct_password_returns_content() {
let token = "test-token";
let app = test_app(token);
let slug = publish_protected_doc(app.clone(), token, "pw-correct", "hunter2").await;
let req = Request::builder()
.method("GET")
.uri(format!("/api/v1/documents/{slug}?password=hunter2"))
.header("Authorization", format!("Bearer {token}"))
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let body = axum::body::to_bytes(resp.into_body(), usize::MAX).await.unwrap();
let text = std::str::from_utf8(&body).unwrap();
assert!(text.contains("Secret content."), "body should contain document content");
}
#[tokio::test]
async fn test_agent_get_protected_doc_wrong_password_returns_401() {
let token = "test-token";
let app = test_app(token);
let slug = publish_protected_doc(app.clone(), token, "pw-wrong", "hunter2").await;
let req = Request::builder()
.method("GET")
.uri(format!("/api/v1/documents/{slug}?password=wrongpass"))
.header("Authorization", format!("Bearer {token}"))
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
let bytes = axum::body::to_bytes(resp.into_body(), usize::MAX).await.unwrap();
let json: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
assert_eq!(json["error"].as_str().unwrap(), "Invalid password");
}
#[tokio::test]
async fn test_agent_get_protected_doc_no_password_returns_401() {
let token = "test-token";
let app = test_app(token);
let slug = publish_protected_doc(app.clone(), token, "pw-none", "hunter2").await;
let req = Request::builder()
.method("GET")
.uri(format!("/api/v1/documents/{slug}"))
.header("Authorization", format!("Bearer {token}"))
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
let bytes = axum::body::to_bytes(resp.into_body(), usize::MAX).await.unwrap();
let json: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
assert_eq!(json["error"].as_str().unwrap(), "Password required");
}
#[tokio::test]
async fn test_agent_get_unprotected_doc_works_without_password() {
let token = "test-token";
let app = test_app(token);
let slug = publish_doc(app.clone(), token, "no-pw-doc", "# Public\nOpen content.").await;
let req = Request::builder()
.method("GET")
.uri(format!("/api/v1/documents/{slug}"))
.header("Authorization", format!("Bearer {token}"))
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let body = axum::body::to_bytes(resp.into_body(), usize::MAX).await.unwrap();
let text = std::str::from_utf8(&body).unwrap();
assert!(text.contains("Open content."), "unprotected doc should be served without password");
}
fn test_app_with_managed_token() -> (Router, String) {
use crate::db::{Db, TokenRecord};
use crate::config::ServeConfig;
let db = Db::open(":memory:").expect("in-memory db");
let managed_plain = "tf_AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA";
let prefix: String = managed_plain.chars().take(8).collect();
let hash = hash_password(managed_plain).expect("hash");
let record = TokenRecord {
id: "test-managed-id".to_string(),
name: "test-managed".to_string(),
hash,
created_at: "2024-01-01T00:00:00Z".to_string(),
last_used: None,
revoked: false,
prefix: Some(prefix),
};
db.insert_token(&record).expect("insert managed token");
let config = crate::config::ServeConfig {
token: "admin-token".to_string(),
db_path: ":memory:".to_string(),
bind: "127.0.0.1:0".to_string(),
base_url: "http://localhost".to_string(),
default_theme: "clean".to_string(),
max_size: 1_048_576,
webhook_url: None,
webhook_secret: None,
reaper_interval: 3600,
rate_limit_read: 10000,
rate_limit_write: 10000,
rate_limit_window: 60,
};
let rate_limit = crate::rate_limit::RateLimitStore::new(&config);
let state = AppState {
db, config: Arc::new(config),
auth_codes: Arc::new(Mutex::new(HashMap::new())),
oauth_clients: Arc::new(Mutex::new(HashMap::new())),
refresh_tokens: Arc::new(Mutex::new(HashMap::new())),
access_tokens: Arc::new(Mutex::new(HashMap::new())),
rate_limit: rate_limit.clone(),
};
let router = Router::new()
.route(
"/api/v1/documents",
post(crate::handlers::post_document).get(crate::handlers::list_documents),
)
.route(
"/api/v1/documents/:slug",
get(crate::handlers::get_agent)
.put(crate::handlers::put_document)
.delete(crate::handlers::delete_document),
)
.layer(axum::Extension(rate_limit))
.with_state(state);
(router, managed_plain.to_string())
}
#[tokio::test]
async fn test_managed_token_auth_accepted() {
let (app, managed_token) = test_app_with_managed_token();
let slug = publish_doc(app.clone(), "admin-token", "managed-test", "# Hello").await;
let req = Request::builder()
.method("GET")
.uri(format!("/api/v1/documents/{slug}"))
.header("Authorization", format!("Bearer {managed_token}"))
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK,
"managed token should be accepted by prefix lookup + argon2 verify");
}
#[tokio::test]
async fn test_managed_token_wrong_value_rejected() {
let (app, managed_token) = test_app_with_managed_token();
let wrong_token = format!("{}X_WRONG", &managed_token[..8]);
let req = Request::builder()
.method("POST")
.uri("/api/v1/documents")
.header("Authorization", format!("Bearer {wrong_token}"))
.header("content-type", "text/markdown")
.body(Body::from("# Test"))
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::UNAUTHORIZED,
"token with matching prefix but wrong value must be rejected");
}
#[tokio::test]
async fn test_admin_token_still_works_with_managed_tokens_present() {
let (app, _) = test_app_with_managed_token();
let req = Request::builder()
.method("POST")
.uri("/api/v1/documents")
.header("Authorization", "Bearer admin-token")
.header("content-type", "text/markdown")
.body(Body::from("# Admin Test"))
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::CREATED,
"admin TWOFOLD_TOKEN must still work when managed tokens exist");
}
#[tokio::test]
async fn test_no_token_returns_401() {
let (app, _) = test_app_with_managed_token();
let req = Request::builder()
.method("POST")
.uri("/api/v1/documents")
.header("content-type", "text/markdown")
.body(Body::from("# No Auth"))
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::UNAUTHORIZED,
"missing token must return 401");
}
#[tokio::test]
async fn test_revoked_managed_token_rejected() {
use crate::db::{Db, TokenRecord};
use crate::config::ServeConfig;
let db = Db::open(":memory:").expect("in-memory db");
let managed_plain = "tf_BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB";
let prefix: String = managed_plain.chars().take(8).collect();
let hash = hash_password(managed_plain).expect("hash");
let record = TokenRecord {
id: "revoked-id".to_string(),
name: "revoked-token".to_string(),
hash,
created_at: "2024-01-01T00:00:00Z".to_string(),
last_used: None,
revoked: true, prefix: Some(prefix),
};
db.insert_token(&record).expect("insert revoked token");
let config = crate::config::ServeConfig {
token: "admin-token".to_string(),
db_path: ":memory:".to_string(),
bind: "127.0.0.1:0".to_string(),
base_url: "http://localhost".to_string(),
default_theme: "clean".to_string(),
max_size: 1_048_576,
webhook_url: None,
webhook_secret: None,
reaper_interval: 3600,
rate_limit_read: 10000,
rate_limit_write: 10000,
rate_limit_window: 60,
};
let rate_limit = crate::rate_limit::RateLimitStore::new(&config);
let state = AppState {
db, config: Arc::new(config),
auth_codes: Arc::new(Mutex::new(HashMap::new())),
oauth_clients: Arc::new(Mutex::new(HashMap::new())),
refresh_tokens: Arc::new(Mutex::new(HashMap::new())),
access_tokens: Arc::new(Mutex::new(HashMap::new())),
rate_limit: rate_limit.clone(),
};
let router = Router::new()
.route(
"/api/v1/documents",
post(crate::handlers::post_document).get(crate::handlers::list_documents),
)
.layer(axum::Extension(rate_limit))
.with_state(state);
let req = Request::builder()
.method("POST")
.uri("/api/v1/documents")
.header("Authorization", format!("Bearer {managed_plain}"))
.header("content-type", "text/markdown")
.body(Body::from("# Revoked Test"))
.unwrap();
let resp = router.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::UNAUTHORIZED,
"revoked token must not authenticate");
}
async fn publish_doc_full(app: Router, token: &str, slug: &str, content: &str) -> String {
let body = format!("---\nslug: {slug}\n---\n{content}");
let req = Request::builder()
.method("POST")
.uri("/api/v1/documents")
.header("Authorization", format!("Bearer {token}"))
.header("Content-Type", "text/markdown")
.body(Body::from(body))
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::CREATED, "publish failed");
let bytes = axum::body::to_bytes(resp.into_body(), usize::MAX).await.unwrap();
let json: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
json["slug"].as_str().unwrap().to_string()
}
#[tokio::test]
async fn test_content_neg_html_accept_returns_html() {
let token = "test-token";
let app = test_app_full(token);
let slug = publish_doc_full(app.clone(), token, "cn-html", "# Hello").await;
let req = Request::builder()
.method("GET")
.uri(format!("/{slug}"))
.header("Accept", "text/html,application/xhtml+xml,*/*;q=0.9")
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let ct = resp.headers().get("content-type").unwrap().to_str().unwrap();
assert!(ct.contains("text/html"), "expected HTML, got {ct}");
}
#[tokio::test]
async fn test_content_neg_json_accept_returns_json() {
let token = "test-token";
let app = test_app_full(token);
let slug = publish_doc_full(app.clone(), token, "cn-json", "# Hello\n\nAgent content.").await;
let req = Request::builder()
.method("GET")
.uri(format!("/{slug}"))
.header("Accept", "application/json")
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let ct = resp.headers().get("content-type").unwrap().to_str().unwrap();
assert!(ct.contains("application/json"), "expected JSON, got {ct}");
let bytes = axum::body::to_bytes(resp.into_body(), usize::MAX).await.unwrap();
let json: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
assert_eq!(json["slug"].as_str().unwrap(), "cn-json");
assert!(json["content"].as_str().unwrap().contains("Hello"));
}
#[tokio::test]
async fn test_content_neg_markdown_accept_returns_markdown() {
let token = "test-token";
let app = test_app_full(token);
let slug = publish_doc_full(app.clone(), token, "cn-md-accept", "# Markdown test").await;
let req = Request::builder()
.method("GET")
.uri(format!("/{slug}"))
.header("Accept", "text/markdown")
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let ct = resp.headers().get("content-type").unwrap().to_str().unwrap();
assert!(ct.contains("text/markdown"), "expected markdown, got {ct}");
let bytes = axum::body::to_bytes(resp.into_body(), usize::MAX).await.unwrap();
let body = std::str::from_utf8(&bytes).unwrap();
assert!(body.contains("Markdown test"), "expected raw markdown in body");
}
#[tokio::test]
async fn test_content_neg_bot_ua_returns_json() {
let token = "test-token";
let app = test_app_full(token);
let slug = publish_doc_full(app.clone(), token, "cn-bot", "# Bot content").await;
let req = Request::builder()
.method("GET")
.uri(format!("/{slug}"))
.header("User-Agent", "GPTBot/1.0")
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let ct = resp.headers().get("content-type").unwrap().to_str().unwrap();
assert!(ct.contains("application/json"), "expected JSON for bot UA, got {ct}");
}
#[tokio::test]
async fn test_content_neg_html_accept_beats_bot_ua() {
let token = "test-token";
let app = test_app_full(token);
let slug = publish_doc_full(app.clone(), token, "cn-ua-html", "# Dev inspect").await;
let req = Request::builder()
.method("GET")
.uri(format!("/{slug}"))
.header("Accept", "text/html")
.header("User-Agent", "ClaudeBot/1.0")
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let ct = resp.headers().get("content-type").unwrap().to_str().unwrap();
assert!(ct.contains("text/html"), "Accept: text/html should beat bot UA, got {ct}");
}
#[tokio::test]
async fn test_slug_md_route_returns_markdown() {
let token = "test-token";
let app = test_app_full(token);
let slug = publish_doc_full(app.clone(), token, "cn-dotmd", "# Dotmd test").await;
let req = Request::builder()
.method("GET")
.uri(format!("/{slug}.md"))
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let ct = resp.headers().get("content-type").unwrap().to_str().unwrap();
assert!(ct.contains("text/markdown"), "expected markdown content-type, got {ct}");
let bytes = axum::body::to_bytes(resp.into_body(), usize::MAX).await.unwrap();
let body = std::str::from_utf8(&bytes).unwrap();
assert!(body.contains("Dotmd test"), "expected raw markdown in body");
}
fn make_test_config(token: &str) -> crate::config::ServeConfig {
crate::config::ServeConfig {
token: token.to_string(),
db_path: ":memory:".to_string(),
bind: "127.0.0.1:0".to_string(),
base_url: "http://localhost".to_string(),
default_theme: "clean".to_string(),
max_size: 1_048_576,
webhook_url: None,
webhook_secret: None,
reaper_interval: 3600,
rate_limit_read: 1000,
rate_limit_write: 1000,
rate_limit_window: 60,
}
}
fn make_test_state(token: &str) -> (AppState, crate::db::Db) {
let db = crate::db::Db::open(":memory:").expect("in-memory db");
let config = make_test_config(token);
let rate_limit = crate::rate_limit::RateLimitStore::new(&config);
let state = AppState {
db: db.clone(),
config: Arc::new(config),
auth_codes: Arc::new(Mutex::new(HashMap::new())),
oauth_clients: Arc::new(Mutex::new(HashMap::new())),
refresh_tokens: Arc::new(Mutex::new(HashMap::new())),
access_tokens: Arc::new(Mutex::new(HashMap::new())),
rate_limit: rate_limit.clone(),
};
(state, db)
}
fn test_app_with_audit(token: &str) -> Router {
let (state, _db) = make_test_state(token);
let rate_limit = state.rate_limit.clone();
Router::new()
.route("/api/v1/documents", post(crate::handlers::post_document).get(crate::handlers::list_documents))
.route("/api/v1/documents/:slug", get(crate::handlers::get_agent)
.put(crate::handlers::put_document)
.delete(crate::handlers::delete_document))
.route("/api/v1/audit", get(crate::handlers::list_audit))
.layer(axum::Extension(rate_limit))
.with_state(state)
}
fn test_app_with_db(token: &str) -> (Router, crate::db::Db) {
let (state, db) = make_test_state(token);
let rate_limit = state.rate_limit.clone();
let router = Router::new()
.route("/api/v1/documents", post(crate::handlers::post_document).get(crate::handlers::list_documents))
.route("/api/v1/documents/:slug", get(crate::handlers::get_agent)
.put(crate::handlers::put_document)
.delete(crate::handlers::delete_document))
.route("/api/v1/audit", get(crate::handlers::list_audit))
.layer(axum::Extension(rate_limit))
.with_state(state);
(router, db)
}
#[test]
fn test_db_insert_and_list_audit_entries() {
let db = crate::db::Db::open(":memory:").expect("in-memory db");
let (entries, total) = db.list_audit_entries(20, 0).expect("list ok");
assert_eq!(total, 0);
assert!(entries.is_empty());
let entry = crate::db::AuditEntry {
id: "test001".to_string(),
timestamp: "2026-05-12T14:00:00Z".to_string(),
action: "create".to_string(),
slug: "my-doc".to_string(),
token_name: "admin".to_string(),
ip_address: "127.0.0.1".to_string(),
};
db.insert_audit_entry(&entry).expect("insert ok");
let (entries, total) = db.list_audit_entries(20, 0).expect("list ok");
assert_eq!(total, 1);
assert_eq!(entries.len(), 1);
assert_eq!(entries[0].action, "create");
assert_eq!(entries[0].slug, "my-doc");
assert_eq!(entries[0].token_name, "admin");
assert_eq!(entries[0].ip_address, "127.0.0.1");
}
#[tokio::test]
async fn test_check_auth_returns_admin_for_master_token() {
let token = "master-secret-token";
let app = test_app(token);
let req = Request::builder()
.method("POST")
.uri("/api/v1/documents")
.header("Authorization", format!("Bearer {token}"))
.header("Content-Type", "text/markdown")
.body(Body::from("# Test Doc\nContent."))
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::CREATED, "master token should authenticate");
}
#[tokio::test]
async fn test_check_auth_returns_token_name_for_managed_token() {
let (app, managed_plain) = test_app_with_managed_token();
let req = Request::builder()
.method("POST")
.uri("/api/v1/documents")
.header("Authorization", format!("Bearer {managed_plain}"))
.header("Content-Type", "text/markdown")
.body(Body::from("# Managed Token Test\nContent."))
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::CREATED, "managed token should authenticate");
}
#[tokio::test]
async fn test_list_audit_returns_200_with_correct_shape() {
let token = "test-token";
let app = test_app_with_audit(token);
let req = Request::builder()
.method("GET")
.uri("/api/v1/audit")
.header("Authorization", format!("Bearer {token}"))
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let bytes = axum::body::to_bytes(resp.into_body(), usize::MAX).await.unwrap();
let json: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
assert!(json.get("entries").is_some(), "response must have 'entries' field");
assert!(json.get("total").is_some(), "response must have 'total' field");
assert!(json.get("limit").is_some(), "response must have 'limit' field");
assert!(json.get("offset").is_some(), "response must have 'offset' field");
assert_eq!(json["total"].as_u64().unwrap(), 0);
assert!(json["entries"].as_array().unwrap().is_empty());
}
#[tokio::test]
async fn test_list_audit_requires_auth() {
let token = "test-token";
let app = test_app_with_audit(token);
let req = Request::builder()
.method("GET")
.uri("/api/v1/audit")
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::UNAUTHORIZED, "audit endpoint must require auth");
}
#[tokio::test]
async fn test_post_document_writes_audit_entry() {
let token = "test-token";
let (app, db) = test_app_with_db(token);
let req = Request::builder()
.method("POST")
.uri("/api/v1/documents")
.header("Authorization", format!("Bearer {token}"))
.header("Content-Type", "text/markdown")
.body(Body::from("---\nslug: audit-test-create\n---\n# Audit Test\nContent."))
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::CREATED, "publish failed");
let (entries, total) = db.list_audit_entries(20, 0).expect("list ok");
assert_eq!(total, 1, "should have 1 audit entry after create");
assert_eq!(entries[0].action, "create");
assert_eq!(entries[0].slug, "audit-test-create");
assert_eq!(entries[0].token_name, "admin");
}
#[tokio::test]
async fn test_delete_document_writes_audit_entry() {
let token = "test-token";
let (app, db) = test_app_with_db(token);
let req = Request::builder()
.method("POST")
.uri("/api/v1/documents")
.header("Authorization", format!("Bearer {token}"))
.header("Content-Type", "text/markdown")
.body(Body::from("---\nslug: to-delete\n---\n# Delete Me"))
.unwrap();
let resp = app.clone().oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::CREATED);
let req = Request::builder()
.method("DELETE")
.uri("/api/v1/documents/to-delete")
.header("Authorization", format!("Bearer {token}"))
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::NO_CONTENT);
let (entries, total) = db.list_audit_entries(20, 0).expect("list ok");
assert_eq!(total, 2, "should have 2 audit entries (create + delete)");
let delete_entry = entries.iter().find(|e| e.action == "delete").expect("delete entry");
assert_eq!(delete_entry.slug, "to-delete");
assert_eq!(delete_entry.token_name, "admin");
}
#[test]
fn test_extract_client_ip_xff_priority() {
let mut headers = HeaderMap::new();
headers.insert("x-forwarded-for", "10.0.0.1, 192.168.1.1".parse().unwrap());
let ip = extract_client_ip(&headers, Some("127.0.0.1:12345"));
assert_eq!(ip, "10.0.0.1", "XFF first value should take priority");
}
#[test]
fn test_extract_client_ip_fallback_to_socket() {
let headers = HeaderMap::new();
let ip = extract_client_ip(&headers, Some("1.2.3.4:5678"));
assert_eq!(ip, "1.2.3.4", "should strip port and return bare IP");
}
#[test]
fn test_extract_client_ip_unknown() {
let headers = HeaderMap::new();
let ip = extract_client_ip(&headers, None);
assert_eq!(ip, "unknown");
}
}