Skip to main content

tuitbot_server/routes/
mcp.rs

1//! MCP governance and telemetry endpoints.
2
3use std::sync::Arc;
4use std::time::{SystemTime, UNIX_EPOCH};
5
6use axum::extract::{Query, State};
7use axum::Json;
8use serde::Deserialize;
9use serde_json::{json, Value};
10use tuitbot_core::config::Config;
11use tuitbot_core::storage::{mcp_telemetry, rate_limits};
12
13use crate::error::ApiError;
14use crate::state::AppState;
15
16// ---------------------------------------------------------------------------
17// Query types
18// ---------------------------------------------------------------------------
19
20#[derive(Deserialize)]
21pub struct TimeWindowQuery {
22    /// Lookback window in hours (default: 24).
23    #[serde(default = "default_hours")]
24    pub hours: u32,
25}
26
27fn default_hours() -> u32 {
28    24
29}
30
31#[derive(Deserialize)]
32pub struct RecentQuery {
33    /// Number of recent entries to return (default: 50).
34    #[serde(default = "default_limit")]
35    pub limit: u32,
36}
37
38fn default_limit() -> u32 {
39    50
40}
41
42// ---------------------------------------------------------------------------
43// Policy endpoints
44// ---------------------------------------------------------------------------
45
46/// `GET /api/mcp/policy` — current MCP policy config + rate limit usage.
47pub async fn get_policy(State(state): State<Arc<AppState>>) -> Result<Json<Value>, ApiError> {
48    let config = read_config(&state)?;
49
50    let rate_limit_info = match rate_limits::get_all_rate_limits(&state.db).await {
51        Ok(limits) => {
52            let mcp = limits.iter().find(|l| l.action_type == "mcp_mutation");
53            match mcp {
54                Some(rl) => json!({
55                    "used": rl.request_count,
56                    "max": rl.max_requests,
57                    "period_seconds": rl.period_seconds,
58                    "period_start": rl.period_start,
59                }),
60                None => json!({ "used": 0, "max": config.mcp_policy.max_mutations_per_hour }),
61            }
62        }
63        Err(_) => json!({ "used": 0, "max": config.mcp_policy.max_mutations_per_hour }),
64    };
65
66    Ok(Json(json!({
67        "enforce_for_mutations": config.mcp_policy.enforce_for_mutations,
68        "require_approval_for": config.mcp_policy.require_approval_for,
69        "blocked_tools": config.mcp_policy.blocked_tools,
70        "dry_run_mutations": config.mcp_policy.dry_run_mutations,
71        "max_mutations_per_hour": config.mcp_policy.max_mutations_per_hour,
72        "mode": format!("{}", config.mode),
73        "rate_limit": rate_limit_info,
74    })))
75}
76
77/// `PATCH /api/mcp/policy` — update MCP policy config fields.
78///
79/// Accepts partial JSON with `mcp_policy` fields and merges into config.
80pub async fn patch_policy(
81    State(state): State<Arc<AppState>>,
82    Json(patch): Json<Value>,
83) -> Result<Json<Value>, ApiError> {
84    if !patch.is_object() {
85        return Err(ApiError::BadRequest(
86            "request body must be a JSON object".to_string(),
87        ));
88    }
89
90    // Wrap the patch under `mcp_policy` key for the settings merge.
91    let wrapped = json!({ "mcp_policy": patch });
92
93    let contents = std::fs::read_to_string(&state.config_path).map_err(|e| {
94        ApiError::BadRequest(format!(
95            "could not read config file {}: {e}",
96            state.config_path.display()
97        ))
98    })?;
99
100    let mut toml_value: toml::Value = contents.parse().map_err(|e: toml::de::Error| {
101        ApiError::BadRequest(format!("failed to parse existing config: {e}"))
102    })?;
103
104    let patch_toml = json_to_toml(&wrapped)
105        .map_err(|e| ApiError::BadRequest(format!("patch contains invalid values: {e}")))?;
106
107    merge_toml(&mut toml_value, &patch_toml);
108
109    let merged_str = toml::to_string_pretty(&toml_value)
110        .map_err(|e| ApiError::BadRequest(format!("failed to serialize merged config: {e}")))?;
111
112    let config: Config = toml::from_str(&merged_str)
113        .map_err(|e| ApiError::BadRequest(format!("merged config is invalid: {e}")))?;
114
115    std::fs::write(&state.config_path, &merged_str).map_err(|e| {
116        ApiError::BadRequest(format!(
117            "could not write config file {}: {e}",
118            state.config_path.display()
119        ))
120    })?;
121
122    Ok(Json(json!({
123        "enforce_for_mutations": config.mcp_policy.enforce_for_mutations,
124        "require_approval_for": config.mcp_policy.require_approval_for,
125        "blocked_tools": config.mcp_policy.blocked_tools,
126        "dry_run_mutations": config.mcp_policy.dry_run_mutations,
127        "max_mutations_per_hour": config.mcp_policy.max_mutations_per_hour,
128    })))
129}
130
131// ---------------------------------------------------------------------------
132// Telemetry endpoints
133// ---------------------------------------------------------------------------
134
135/// `GET /api/mcp/telemetry/summary` — aggregate stats over a time window.
136pub async fn telemetry_summary(
137    State(state): State<Arc<AppState>>,
138    Query(params): Query<TimeWindowQuery>,
139) -> Result<Json<Value>, ApiError> {
140    let since = since_timestamp(params.hours);
141    let summary = mcp_telemetry::get_summary(&state.db, &since).await?;
142    Ok(Json(serde_json::to_value(summary).unwrap()))
143}
144
145/// `GET /api/mcp/telemetry/metrics` — per-tool metrics over a time window.
146pub async fn telemetry_metrics(
147    State(state): State<Arc<AppState>>,
148    Query(params): Query<TimeWindowQuery>,
149) -> Result<Json<Value>, ApiError> {
150    let since = since_timestamp(params.hours);
151    let metrics = mcp_telemetry::get_metrics_since(&state.db, &since).await?;
152    Ok(Json(json!(metrics)))
153}
154
155/// `GET /api/mcp/telemetry/errors` — error breakdown over a time window.
156pub async fn telemetry_errors(
157    State(state): State<Arc<AppState>>,
158    Query(params): Query<TimeWindowQuery>,
159) -> Result<Json<Value>, ApiError> {
160    let since = since_timestamp(params.hours);
161    let errors = mcp_telemetry::get_error_breakdown(&state.db, &since).await?;
162    Ok(Json(json!(errors)))
163}
164
165/// `GET /api/mcp/telemetry/recent` — recent tool executions.
166pub async fn telemetry_recent(
167    State(state): State<Arc<AppState>>,
168    Query(params): Query<RecentQuery>,
169) -> Result<Json<Value>, ApiError> {
170    let entries = mcp_telemetry::get_recent_entries(&state.db, params.limit).await?;
171    Ok(Json(json!(entries)))
172}
173
174// ---------------------------------------------------------------------------
175// Helpers
176// ---------------------------------------------------------------------------
177
178fn read_config(state: &AppState) -> Result<Config, ApiError> {
179    let contents = std::fs::read_to_string(&state.config_path).map_err(|e| {
180        ApiError::BadRequest(format!(
181            "could not read config file {}: {e}",
182            state.config_path.display()
183        ))
184    })?;
185    let config: Config = toml::from_str(&contents)
186        .map_err(|e| ApiError::BadRequest(format!("failed to parse config: {e}")))?;
187    Ok(config)
188}
189
190fn since_timestamp(hours: u32) -> String {
191    let now = SystemTime::now()
192        .duration_since(UNIX_EPOCH)
193        .unwrap_or_default()
194        .as_secs();
195    let since_epoch = now.saturating_sub(u64::from(hours) * 3600);
196
197    // Convert epoch seconds to ISO-8601 UTC (YYYY-MM-DDTHH:MM:SSZ).
198    let secs = since_epoch as i64;
199    let days = secs.div_euclid(86400);
200    let day_secs = secs.rem_euclid(86400);
201    let h = day_secs / 3600;
202    let m = (day_secs % 3600) / 60;
203    let s = day_secs % 60;
204
205    // Days since epoch → date using the civil-from-days algorithm.
206    let z = days + 719468;
207    let era = z.div_euclid(146097);
208    let doe = z.rem_euclid(146097);
209    let yoe = (doe - doe / 1460 + doe / 36524 - doe / 146096) / 365;
210    let y = yoe + era * 400;
211    let doy = doe - (365 * yoe + yoe / 4 - yoe / 100);
212    let mp = (5 * doy + 2) / 153;
213    let d = doy - (153 * mp + 2) / 5 + 1;
214    let month = if mp < 10 { mp + 3 } else { mp - 9 };
215    let year = if month <= 2 { y + 1 } else { y };
216
217    format!("{year:04}-{month:02}-{d:02}T{h:02}:{m:02}:{s:02}Z")
218}
219
220/// Recursively merge `patch` into `base`.
221fn merge_toml(base: &mut toml::Value, patch: &toml::Value) {
222    match (base, patch) {
223        (toml::Value::Table(base_table), toml::Value::Table(patch_table)) => {
224            for (key, patch_val) in patch_table {
225                if let Some(base_val) = base_table.get_mut(key) {
226                    merge_toml(base_val, patch_val);
227                } else {
228                    base_table.insert(key.clone(), patch_val.clone());
229                }
230            }
231        }
232        (base, _) => {
233            *base = patch.clone();
234        }
235    }
236}
237
238/// Convert JSON to TOML, skipping nulls in objects.
239fn json_to_toml(json: &serde_json::Value) -> Result<toml::Value, String> {
240    match json {
241        serde_json::Value::Object(map) => {
242            let mut table = toml::map::Map::new();
243            for (key, val) in map {
244                if val.is_null() {
245                    continue;
246                }
247                table.insert(key.clone(), json_to_toml(val)?);
248            }
249            Ok(toml::Value::Table(table))
250        }
251        serde_json::Value::Array(arr) => {
252            let values: Result<Vec<_>, _> = arr.iter().map(json_to_toml).collect();
253            Ok(toml::Value::Array(values?))
254        }
255        serde_json::Value::String(s) => Ok(toml::Value::String(s.clone())),
256        serde_json::Value::Number(n) => {
257            if let Some(i) = n.as_i64() {
258                Ok(toml::Value::Integer(i))
259            } else if let Some(f) = n.as_f64() {
260                Ok(toml::Value::Float(f))
261            } else {
262                Err(format!("unsupported number: {n}"))
263            }
264        }
265        serde_json::Value::Bool(b) => Ok(toml::Value::Boolean(*b)),
266        serde_json::Value::Null => Err("null values are not supported in TOML arrays".to_string()),
267    }
268}
269
270#[cfg(test)]
271mod tests {
272    use super::*;
273
274    #[test]
275    fn since_timestamp_is_valid_utc() {
276        let ts = since_timestamp(24);
277        assert!(ts.ends_with('Z'));
278        assert!(ts.contains('T'));
279    }
280}