mod cli;
mod config;
mod db;
mod handlers;
mod highlight;
mod mcp;
mod mcp_http;
mod oauth;
mod parser;
mod rate_limit;
mod webhook;
use std::sync::Arc;
use axum::{
extract::DefaultBodyLimit,
routing::{get, post},
Router,
};
use clap::Parser;
use axum::http::HeaderValue;
use tower::Layer;
use tower_http::{normalize_path::NormalizePathLayer, set_header::SetResponseHeaderLayer, trace::TraceLayer};
use tracing_subscriber::EnvFilter;
use cli::{Cli, Commands, TokenAction};
use config::ServeConfig;
use db::Db;
use handlers::AppState;
use rate_limit::RateLimitStore;
fn main() {
let cli = Cli::parse();
match cli.command {
Commands::Publish(args) => run_publish(args),
Commands::List(args) => run_list(args),
Commands::Delete(args) => run_delete(args),
Commands::Serve => {
let rt = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.expect("Failed to build Tokio runtime");
rt.block_on(run_server());
}
Commands::Mcp => mcp::run_mcp_server(),
Commands::Token(args) => run_token(args),
Commands::Audit(args) => run_audit(args),
}
}
async fn run_server() {
tracing_subscriber::fmt()
.with_env_filter(
EnvFilter::from_default_env()
.add_directive("twofold=info".parse().unwrap()),
)
.init();
let config = match ServeConfig::from_env() {
Ok(c) => c,
Err(e) => {
eprintln!("Configuration error: {e}");
std::process::exit(1);
}
};
let db = match Db::open(&config.db_path) {
Ok(d) => d,
Err(e) => {
eprintln!("Failed to open database '{}': {e}", config.db_path);
std::process::exit(1);
}
};
let max_size = config.max_size;
let bind_addr = config.bind.clone();
let reaper_interval = config.reaper_interval;
let rate_limit = RateLimitStore::new(&config);
let state = AppState {
db: db.clone(),
config: Arc::new(config),
auth_codes: std::sync::Arc::new(std::sync::Mutex::new(std::collections::HashMap::new())),
oauth_clients: std::sync::Arc::new(std::sync::Mutex::new(std::collections::HashMap::new())),
refresh_tokens: std::sync::Arc::new(std::sync::Mutex::new(std::collections::HashMap::new())),
access_tokens: std::sync::Arc::new(std::sync::Mutex::new(std::collections::HashMap::new())),
rate_limit: rate_limit.clone(),
};
let reaper_db = db.clone();
tokio::spawn(async move {
let mut interval = tokio::time::interval(
std::time::Duration::from_secs(reaper_interval),
);
loop {
interval.tick().await;
let now = handlers::chrono_now();
match reaper_db.delete_expired_older_than(&now, 30) {
Ok(count) if count > 0 => {
tracing::info!(count, "Reaper garbage-collected tombstones older than 30 days");
}
Ok(_) => {} Err(e) => {
tracing::error!(error = %e, "Reaper failed to garbage-collect expired documents");
}
}
}
});
let csp = HeaderValue::from_static(
"default-src 'self'; script-src 'unsafe-inline'; style-src 'unsafe-inline'",
);
let app = Router::new()
.route("/health", get(handlers::health_check))
.route("/.well-known/oauth-protected-resource", get(oauth::handle_protected_resource_metadata))
.route("/.well-known/oauth-authorization-server", get(oauth::handle_authorization_server_metadata))
.route("/oauth/register", post(oauth::handle_register))
.route("/authorize", get(oauth::handle_authorize))
.route("/oauth/token", post(oauth::handle_oauth_token))
.route("/api/v1/documents", post(handlers::post_document).get(handlers::list_documents))
.route("/api/v1/audit", get(handlers::list_audit))
.route("/api/v1/documents/:slug", get(handlers::get_agent)
.put(handlers::put_document)
.delete(handlers::delete_document))
.route("/api/v1/openapi.yaml", get(handlers::serve_openapi_yaml))
.route("/api/v1/openapi.json", get(handlers::serve_openapi_json))
.route("/icon.png", get(handlers::serve_icon))
.route("/favicon.ico", get(handlers::serve_favicon))
.route("/:slug/unlock", post(handlers::post_unlock))
.route("/:slug/full", get(handlers::get_full))
.route("/:slug", get(handlers::get_human))
.layer(SetResponseHeaderLayer::overriding(
axum::http::header::CONTENT_SECURITY_POLICY,
csp,
))
.layer(DefaultBodyLimit::max(max_size));
let mcp_router = Router::new()
.route("/mcp", post(mcp_http::handle_mcp_post))
.layer(DefaultBodyLimit::disable());
let app = app
.merge(mcp_router)
.layer(TraceLayer::new_for_http())
.layer(axum::Extension(rate_limit))
.with_state(state);
let app = NormalizePathLayer::trim_trailing_slash().layer(app);
let app = axum::ServiceExt::<axum::http::Request<axum::body::Body>>::into_make_service_with_connect_info::<std::net::SocketAddr>(app);
let listener = match tokio::net::TcpListener::bind(&bind_addr).await {
Ok(l) => l,
Err(e) => {
eprintln!("Failed to bind to '{bind_addr}': {e}");
std::process::exit(1);
}
};
println!("twofold listening on http://{bind_addr}");
if let Err(e) = axum::serve(listener, app).await {
eprintln!("Server error: {e}");
std::process::exit(1);
}
}
fn run_publish(args: cli::PublishArgs) {
let token = resolve_token(args.token);
let content = read_publish_source(&args.path);
let body = apply_publish_flags(content, args.title, args.slug, args.theme, args.expiry, args.password);
let url = format!("{}/api/v1/documents", args.server.trim_end_matches('/'));
let client = match reqwest::blocking::Client::builder()
.timeout(std::time::Duration::from_secs(30))
.build()
{
Ok(c) => c,
Err(e) => {
eprintln!("Failed to create HTTP client: {e}");
std::process::exit(1);
}
};
let response = match client
.post(&url)
.header("Authorization", format!("Bearer {token}"))
.header("Content-Type", "text/markdown")
.body(body)
.send()
{
Ok(r) => r,
Err(e) => {
eprintln!("Request failed: {e}");
std::process::exit(1);
}
};
let status = response.status();
if status == reqwest::StatusCode::CREATED {
let body: serde_json::Value = match response.json() {
Ok(v) => v,
Err(e) => {
eprintln!("Failed to parse server response: {e}");
std::process::exit(1);
}
};
if let Some(doc_url) = body.get("url").and_then(|v| v.as_str()) {
println!("{doc_url}");
} else {
eprintln!("Server returned 201 but no `url` field in response.");
std::process::exit(1);
}
} else {
let body_text = response.text().unwrap_or_default();
eprintln!("Publish failed: HTTP {status}\n{body_text}");
std::process::exit(1);
}
}
fn apply_publish_flags(
content: String,
title: Option<String>,
slug: Option<String>,
theme: Option<String>,
expiry: Option<String>,
password: Option<String>,
) -> String {
let has_flags = title.is_some() || slug.is_some() || theme.is_some()
|| expiry.is_some() || password.is_some();
if !has_flags {
return content;
}
let trimmed = content.trim_start();
if trimmed.starts_with("---") {
merge_frontmatter_flags(content, title, slug, theme, expiry, password)
} else {
let mut fm = String::from("---\n");
if let Some(t) = title {
fm.push_str(&format!("title: {}\n", crate::mcp::yaml_escape_value_pub(&t)));
}
if let Some(s) = slug {
fm.push_str(&format!("slug: {}\n", crate::mcp::yaml_escape_value_pub(&s)));
}
if let Some(th) = theme {
fm.push_str(&format!("theme: {}\n", crate::mcp::yaml_escape_value_pub(&th)));
}
if let Some(ex) = expiry {
fm.push_str(&format!("expiry: {}\n", crate::mcp::yaml_escape_value_pub(&ex)));
}
if let Some(pw) = password {
fm.push_str(&format!("password: {}\n", crate::mcp::yaml_escape_value_pub(&pw)));
}
fm.push_str("---\n");
fm.push_str(&content);
fm
}
}
fn merge_frontmatter_flags(
content: String,
title: Option<String>,
slug: Option<String>,
theme: Option<String>,
expiry: Option<String>,
password: Option<String>,
) -> String {
let lines: Vec<&str> = content.lines().collect();
let mut close_idx = None;
for (i, line) in lines.iter().enumerate().skip(1) {
if line.trim() == "---" {
close_idx = Some(i);
break;
}
}
let close_idx = match close_idx {
Some(i) => i,
None => {
let mut fm = String::from("---\n");
if let Some(t) = title {
fm.push_str(&format!("title: {}\n", crate::mcp::yaml_escape_value_pub(&t)));
}
if let Some(s) = slug {
fm.push_str(&format!("slug: {}\n", crate::mcp::yaml_escape_value_pub(&s)));
}
if let Some(th) = theme {
fm.push_str(&format!("theme: {}\n", crate::mcp::yaml_escape_value_pub(&th)));
}
if let Some(ex) = expiry {
fm.push_str(&format!("expiry: {}\n", crate::mcp::yaml_escape_value_pub(&ex)));
}
if let Some(pw) = password {
fm.push_str(&format!("password: {}\n", crate::mcp::yaml_escape_value_pub(&pw)));
}
fm.push_str("---\n");
fm.push_str(&content);
return fm;
}
};
let overrides: Vec<(&str, &str)> = [
title.as_deref().map(|v| ("title", v)),
slug.as_deref().map(|v| ("slug", v)),
theme.as_deref().map(|v| ("theme", v)),
expiry.as_deref().map(|v| ("expiry", v)),
password.as_deref().map(|v| ("password", v)),
]
.into_iter()
.flatten()
.collect();
let mut fm_lines: Vec<String> = lines[1..close_idx].iter().map(|s| s.to_string()).collect();
for (key, val) in &overrides {
let new_line = format!("{key}: {}", crate::mcp::yaml_escape_value_pub(val));
let prefix = format!("{key}:");
let found = fm_lines.iter_mut().any(|line| {
if line.trim_start().starts_with(&prefix) {
*line = new_line.clone();
true
} else {
false
}
});
if !found {
fm_lines.push(new_line);
}
}
let mut result = String::from("---\n");
for line in &fm_lines {
result.push_str(line);
result.push('\n');
}
result.push_str("---\n");
if close_idx + 1 < lines.len() {
result.push_str(&lines[close_idx + 1..].join("\n"));
}
result
}
fn read_publish_source(path: &str) -> String {
if path == "-" {
use std::io::Read;
let mut buf = String::new();
if let Err(e) = std::io::stdin().read_to_string(&mut buf) {
eprintln!("Failed to read from stdin: {e}");
std::process::exit(1);
}
buf
} else {
match std::fs::read_to_string(path) {
Ok(s) => s,
Err(e) => {
eprintln!("Failed to read file '{path}': {e}");
std::process::exit(1);
}
}
}
}
fn run_list(args: cli::ListArgs) {
let token = resolve_token(args.token);
let url = format!(
"{}/api/v1/documents?limit={}",
args.server.trim_end_matches('/'),
args.limit
);
let client = make_blocking_client();
let response = match client
.get(&url)
.header("Authorization", format!("Bearer {token}"))
.send()
{
Ok(r) => r,
Err(e) => {
eprintln!("Request failed: {e}");
std::process::exit(1);
}
};
let status = response.status();
if !status.is_success() {
let body = response.text().unwrap_or_default();
eprintln!("List failed: HTTP {status}\n{body}");
std::process::exit(1);
}
let body: serde_json::Value = match response.json() {
Ok(v) => v,
Err(e) => {
eprintln!("Failed to parse server response: {e}");
std::process::exit(1);
}
};
let docs = body.get("documents").and_then(|v| v.as_array());
let docs = match docs {
Some(d) => d,
None => {
eprintln!("Unexpected response format");
std::process::exit(1);
}
};
println!("{:<24} {:<32} {:<21} {}",
"SLUG", "TITLE", "CREATED", "EXPIRES");
println!("{}", "-".repeat(90));
for doc in docs {
let slug = doc.get("slug").and_then(|v| v.as_str()).unwrap_or("-");
let title = doc.get("title").and_then(|v| v.as_str()).unwrap_or("-");
let created = doc.get("created_at").and_then(|v| v.as_str()).unwrap_or("-");
let expires = doc.get("expires_at").and_then(|v| v.as_str()).unwrap_or("never");
let slug_d = truncate(slug, 23);
let title_d = truncate(title, 31);
let created_d = &created[..std::cmp::min(16, created.len())];
let expires_d = if expires == "never" {
"never".to_string()
} else {
expires[..std::cmp::min(16, expires.len())].to_string()
};
println!("{:<24} {:<32} {:<21} {}", slug_d, title_d, created_d, expires_d);
}
}
fn truncate(s: &str, max: usize) -> String {
if s.len() <= max {
s.to_string()
} else {
format!("{}…", &s[..max - 1])
}
}
fn run_delete(args: cli::DeleteArgs) {
let token = resolve_token(args.token);
let url = format!(
"{}/api/v1/documents/{}",
args.server.trim_end_matches('/'),
args.slug
);
let client = make_blocking_client();
let response = match client
.delete(&url)
.header("Authorization", format!("Bearer {token}"))
.send()
{
Ok(r) => r,
Err(e) => {
eprintln!("Request failed: {e}");
std::process::exit(1);
}
};
let status = response.status();
match status.as_u16() {
204 => println!("Deleted: {}", args.slug),
401 => {
eprintln!("Auth error: check your token");
std::process::exit(1);
}
404 => {
eprintln!("Error: document '{}' not found", args.slug);
std::process::exit(1);
}
_ => {
let body = response.text().unwrap_or_default();
eprintln!("Delete failed: HTTP {status}\n{body}");
std::process::exit(1);
}
}
}
fn run_audit(args: cli::AuditArgs) {
let token = resolve_token(args.token);
let url = format!(
"{}/api/v1/audit?limit={}",
args.server.trim_end_matches('/'),
args.limit
);
let client = make_blocking_client();
let response = match client
.get(&url)
.header("Authorization", format!("Bearer {token}"))
.send()
{
Ok(r) => r,
Err(e) => {
eprintln!("Request failed: {e}");
std::process::exit(1);
}
};
let status = response.status();
if !status.is_success() {
let body = response.text().unwrap_or_default();
eprintln!("Audit failed: HTTP {status}\n{body}");
std::process::exit(1);
}
let body: serde_json::Value = match response.json() {
Ok(v) => v,
Err(e) => {
eprintln!("Failed to parse server response: {e}");
std::process::exit(1);
}
};
let entries = body.get("entries").and_then(|v| v.as_array());
let entries = match entries {
Some(e) => e,
None => {
eprintln!("Unexpected response format");
std::process::exit(1);
}
};
println!("{:<21} {:<9} {:<25} {}",
"TIMESTAMP", "ACTION", "SLUG", "TOKEN");
println!("{}", "-".repeat(75));
for entry in entries {
let timestamp = entry.get("timestamp").and_then(|v| v.as_str()).unwrap_or("-");
let action = entry.get("action").and_then(|v| v.as_str()).unwrap_or("-");
let slug = entry.get("slug").and_then(|v| v.as_str()).unwrap_or("-");
let token_name = entry.get("token_name").and_then(|v| v.as_str()).unwrap_or("-");
let ts_d = ×tamp[..std::cmp::min(20, timestamp.len())];
let slug_d = truncate(slug, 24);
println!("{:<21} {:<9} {:<25} {}", ts_d, action, slug_d, token_name);
}
}
fn run_token(args: cli::TokenArgs) {
match args.action {
TokenAction::Create { name, db } => token_create(&name, &resolve_db_path(db)),
TokenAction::List { db } => token_list(&resolve_db_path(db)),
TokenAction::Revoke { name, db } => token_revoke(&name, &resolve_db_path(db)),
}
}
fn resolve_db_path(explicit: Option<String>) -> String {
explicit
.or_else(|| std::env::var("TWOFOLD_DB_PATH").ok())
.unwrap_or_else(|| "./twofold.db".to_string())
}
fn resolve_token(explicit: Option<String>) -> String {
match explicit {
Some(t) => t,
None => match std::env::var("TWOFOLD_TOKEN") {
Ok(t) => t,
Err(_) => {
eprintln!(
"Error: --token not provided and TWOFOLD_TOKEN is not set.\n\
Provide a token via --token <TOKEN> or set TWOFOLD_TOKEN."
);
std::process::exit(1);
}
},
}
}
fn make_blocking_client() -> reqwest::blocking::Client {
match reqwest::blocking::Client::builder()
.timeout(std::time::Duration::from_secs(30))
.build()
{
Ok(c) => c,
Err(e) => {
eprintln!("Failed to create HTTP client: {e}");
std::process::exit(1);
}
}
}
fn token_create(name: &str, db_path: &str) {
let db = match Db::open(db_path) {
Ok(d) => d,
Err(e) => {
eprintln!("Failed to open database '{db_path}': {e}");
std::process::exit(1);
}
};
match db.token_name_exists(name) {
Ok(true) => {
eprintln!("Error: Token name '{name}' already exists.");
std::process::exit(1);
}
Err(e) => {
eprintln!("Database error: {e}");
std::process::exit(1);
}
_ => {}
}
use rand::RngCore;
use base64::Engine;
let now = handlers::chrono_now();
let token_plain = 'generate: {
for attempt in 0..3u8 {
let mut token_bytes = [0u8; 32];
rand::thread_rng().fill_bytes(&mut token_bytes);
let plain = format!(
"tf_{}",
base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(token_bytes)
);
let hash = match handlers::hash_password(&plain) {
Ok(h) => h,
Err(_) => {
eprintln!("Failed to hash token");
std::process::exit(1);
}
};
let id = nanoid::nanoid!(10);
let prefix = plain.chars().take(8).collect::<String>();
let record = db::TokenRecord {
id,
name: name.to_string(),
hash,
created_at: now.clone(),
last_used: None,
revoked: false,
prefix: Some(prefix),
};
match db.insert_token(&record) {
Ok(()) => break 'generate plain,
Err(e) if e.to_string().contains("UNIQUE constraint failed: tokens.prefix") => {
if attempt < 2 {
eprintln!("Warning: prefix collision on attempt {}; regenerating.", attempt + 1);
continue;
}
eprintln!("Failed to store token after 3 attempts (prefix collision): {e}");
std::process::exit(1);
}
Err(e) => {
eprintln!("Failed to store token: {e}");
std::process::exit(1);
}
}
}
eprintln!("Failed to generate a unique token prefix.");
std::process::exit(1);
};
println!("{token_plain}");
}
fn token_list(db_path: &str) {
let db = match Db::open(db_path) {
Ok(d) => d,
Err(e) => {
eprintln!("Failed to open database '{db_path}': {e}");
std::process::exit(1);
}
};
let tokens = match db.list_tokens() {
Ok(t) => t,
Err(e) => {
eprintln!("Failed to list tokens: {e}");
std::process::exit(1);
}
};
println!("{:<20} {:<22} {:<22} {}",
"NAME", "CREATED", "LAST USED", "STATUS");
for token in tokens {
let status = if token.revoked { "revoked" } else { "active" };
let last_used = token.last_used.as_deref().unwrap_or("never");
let created = &token.created_at[..std::cmp::min(16, token.created_at.len())];
let used = if last_used == "never" {
"never".to_string()
} else {
last_used[..std::cmp::min(16, last_used.len())].to_string()
};
println!("{:<20} {:<22} {:<22} {}", token.name, created, used, status);
}
}
fn token_revoke(name: &str, db_path: &str) {
let db = match Db::open(db_path) {
Ok(d) => d,
Err(e) => {
eprintln!("Failed to open database '{db_path}': {e}");
std::process::exit(1);
}
};
match db.revoke_token(name) {
Ok(true) => println!("Token '{name}' revoked."),
Ok(false) => {
eprintln!("Error: Token '{name}' not found or already revoked.");
std::process::exit(1);
}
Err(e) => {
eprintln!("Database error: {e}");
std::process::exit(1);
}
}
}