1use std::sync::Arc;
2
3use axum::{
4 extract::{Path, Query, State},
5 http::{HeaderValue, StatusCode},
6 middleware,
7 response::{IntoResponse, Response},
8 routing::{delete, get, post},
9 Extension, Json, Router,
10};
11use clawdb::{prelude::MergeStrategy, ClawDBError};
12use serde::Deserialize;
13use tower_http::{
14 limit::RequestBodyLimitLayer, normalize_path::NormalizePathLayer,
15 set_header::SetResponseHeaderLayer,
16};
17use uuid::Uuid;
18
19use crate::{
20 http::auth::{self, AuthContext},
21 state::{AppState, RequestId},
22};
23
24#[derive(Deserialize)]
25struct CreateSessionBody {
26 agent_id: Uuid,
27 role: String,
28 scopes: Vec<String>,
29 #[serde(default)]
30 ttl_secs: Option<u64>,
31}
32
33#[derive(Deserialize)]
34struct MemoryBody {
35 content: String,
36 #[serde(default)]
37 r#type: Option<String>,
38 #[serde(default)]
39 tags: Vec<String>,
40 #[serde(default)]
41 metadata: serde_json::Value,
42}
43
44#[derive(Deserialize)]
45struct SearchQuery {
46 q: String,
47 #[serde(default = "default_top_k")]
48 top_k: usize,
49 #[serde(default)]
50 semantic: bool,
51}
52
53#[derive(Deserialize)]
54struct ListMemoriesQuery {
55 #[serde(default)]
56 r#type: Option<String>,
57 #[serde(default)]
58 limit: Option<usize>,
59}
60
61fn default_top_k() -> usize {
62 10
63}
64
65#[derive(Deserialize)]
66struct BranchBody {
67 name: String,
68 #[serde(default)]
69 from: Option<Uuid>,
70}
71
72#[derive(Deserialize)]
73struct MergeBody {
74 #[serde(alias = "target_id")]
75 target: Uuid,
76 #[serde(default)]
77 strategy: Option<String>,
78}
79
80#[derive(Deserialize)]
81struct DiffQuery {
82 target: Uuid,
83}
84
85pub fn router(state: Arc<AppState>) -> Router {
86 let public = Router::new()
87 .route("/v1/health", get(health))
88 .route("/v1/ready", get(ready))
89 .route("/v1/sessions", post(create_session))
90 .route("/v1/metrics", get(metrics));
91
92 let protected = Router::new()
93 .route("/v1/sessions/me", get(whoami))
94 .route("/v1/sessions/:id", delete(revoke_session))
95 .route("/v1/memories", post(remember).get(list_memories))
96 .route("/v1/memories/search", get(search))
97 .route("/v1/memories/:id", get(recall_one).delete(delete_memory))
98 .route("/v1/branches", post(create_branch).get(list_branches))
99 .route("/v1/branches/:id/merge", post(merge_branch))
100 .route("/v1/branches/:id/diff", get(diff_branch))
101 .route("/v1/branches/:id", delete(discard_branch))
102 .route("/v1/sync", post(sync))
103 .route("/v1/reflect", post(reflect))
104 .layer(middleware::from_fn_with_state(
105 state.clone(),
106 auth::rate_limit_middleware,
107 ))
108 .layer(middleware::from_fn_with_state(
109 state.clone(),
110 auth::auth_middleware,
111 ));
112
113 public
114 .merge(protected)
115 .layer(middleware::from_fn_with_state(
116 state.clone(),
117 auth::metrics_middleware,
118 ))
119 .layer(middleware::from_fn(auth::request_id_middleware))
120 .layer(SetResponseHeaderLayer::if_not_present(
121 axum::http::header::HeaderName::from_static("content-security-policy"),
122 HeaderValue::from_static("default-src 'none'; frame-ancestors 'none'; base-uri 'none'"),
123 ))
124 .layer(SetResponseHeaderLayer::if_not_present(
125 axum::http::header::HeaderName::from_static("x-content-type-options"),
126 HeaderValue::from_static("nosniff"),
127 ))
128 .layer(SetResponseHeaderLayer::if_not_present(
129 axum::http::header::HeaderName::from_static("x-frame-options"),
130 HeaderValue::from_static("DENY"),
131 ))
132 .layer(RequestBodyLimitLayer::new(10 * 1024 * 1024))
133 .layer(NormalizePathLayer::trim_trailing_slash())
134 .with_state(state)
135}
136
137pub fn metrics_router(state: Arc<AppState>) -> Router {
138 Router::new()
139 .route("/", get(metrics))
140 .route("/metrics", get(metrics))
141 .route("/v1/metrics", get(metrics))
142 .with_state(state)
143}
144
145async fn health(State(state): State<Arc<AppState>>) -> Response {
146 match state.db.health().await {
147 Ok(report) => Json(report).into_response(),
148 Err(error) => map_error(error, None),
149 }
150}
151
152async fn ready(State(state): State<Arc<AppState>>) -> Response {
153 match state.db.health().await {
154 Ok(report) if report.ok => StatusCode::OK.into_response(),
155 Ok(_) => StatusCode::SERVICE_UNAVAILABLE.into_response(),
156 Err(error) => map_error(error, None),
157 }
158}
159
160async fn metrics(State(state): State<Arc<AppState>>) -> Response {
161 if let Ok(count) = state.db.active_session_count().await {
162 state.metrics.set_active_sessions(count);
163 }
164 let rendered = state.metrics.render(state.db.metrics_handle().render());
165 (
166 StatusCode::OK,
167 [(
168 axum::http::header::CONTENT_TYPE,
169 HeaderValue::from_static("text/plain; version=0.0.4"),
170 )],
171 rendered,
172 )
173 .into_response()
174}
175
176async fn create_session(
177 State(state): State<Arc<AppState>>,
178 Extension(request_id): Extension<RequestId>,
179 Json(body): Json<CreateSessionBody>,
180) -> Response {
181 match state
182 .db
183 .session_with_ttl(
184 body.agent_id,
185 &body.role,
186 body.scopes,
187 body.ttl_secs.unwrap_or(3600) as i64,
188 )
189 .await
190 {
191 Ok(session) => Json(serde_json::json!({
192 "id": session.id,
193 "session_id": session.id,
194 "agent_id": session.agent_id,
195 "role": session.role,
196 "token": session.token,
197 "expires_at": session.expires_at.to_rfc3339(),
198 "scopes": session.scopes,
199 }))
200 .into_response(),
201 Err(error) => map_error(error, Some(request_id.0)),
202 }
203}
204
205async fn whoami(Extension(auth): Extension<AuthContext>) -> Response {
206 Json(serde_json::json!({
207 "id": auth.session.id,
208 "session_id": auth.session.id,
209 "agent_id": auth.session.agent_id,
210 "role": auth.session.role,
211 "token": auth.session.token,
212 "expires_at": auth.session.expires_at.to_rfc3339(),
213 "scopes": auth.session.scopes,
214 }))
215 .into_response()
216}
217
218async fn revoke_session(
219 State(state): State<Arc<AppState>>,
220 Extension(request_id): Extension<RequestId>,
221 Path(id): Path<Uuid>,
222) -> Response {
223 match state.db.revoke_session(id).await {
224 Ok(()) => StatusCode::NO_CONTENT.into_response(),
225 Err(error) => map_error(error, Some(request_id.0)),
226 }
227}
228
229async fn remember(
230 State(state): State<Arc<AppState>>,
231 Extension(auth): Extension<AuthContext>,
232 Extension(request_id): Extension<RequestId>,
233 Json(body): Json<MemoryBody>,
234) -> Response {
235 let result = if let Some(memory_type) = body.r#type.as_deref() {
236 state
237 .db
238 .remember_typed(
239 &auth.session,
240 &body.content,
241 memory_type,
242 &body.tags,
243 body.metadata,
244 )
245 .await
246 } else {
247 state.db.remember(&auth.session, &body.content).await
248 };
249
250 match result {
251 Ok(remembered) => Json(remembered).into_response(),
252 Err(error) => map_error(error, Some(request_id.0)),
253 }
254}
255
256async fn search(
257 State(state): State<Arc<AppState>>,
258 Extension(auth): Extension<AuthContext>,
259 Extension(request_id): Extension<RequestId>,
260 Query(query): Query<SearchQuery>,
261) -> Response {
262 match state
263 .db
264 .search_with_options(&auth.session, &query.q, query.top_k, query.semantic, None)
265 .await
266 {
267 Ok(hits) => Json(hits).into_response(),
268 Err(error) => map_error(error, Some(request_id.0)),
269 }
270}
271
272async fn recall_one(
273 State(state): State<Arc<AppState>>,
274 Extension(auth): Extension<AuthContext>,
275 Extension(request_id): Extension<RequestId>,
276 Path(id): Path<Uuid>,
277) -> Response {
278 match state.db.recall(&auth.session, &[id]).await {
279 Ok(mut memories) => match memories.pop() {
280 Some(memory) => Json(memory).into_response(),
281 None => auth::error_response(
282 StatusCode::NOT_FOUND,
283 "not_found",
284 None,
285 Some(request_id.0),
286 None,
287 ),
288 },
289 Err(error) => map_error(error, Some(request_id.0)),
290 }
291}
292
293async fn list_memories(
294 State(state): State<Arc<AppState>>,
295 Extension(auth): Extension<AuthContext>,
296 Extension(request_id): Extension<RequestId>,
297 Query(query): Query<ListMemoriesQuery>,
298) -> Response {
299 match state
300 .db
301 .list_memories(&auth.session, query.r#type.as_deref())
302 .await
303 {
304 Ok(mut memories) => {
305 if let Some(limit) = query.limit {
306 memories.truncate(limit);
307 }
308 Json(memories).into_response()
309 }
310 Err(error) => map_error(error, Some(request_id.0)),
311 }
312}
313
314async fn delete_memory(
315 State(state): State<Arc<AppState>>,
316 Extension(auth): Extension<AuthContext>,
317 Extension(request_id): Extension<RequestId>,
318 Path(id): Path<Uuid>,
319) -> Response {
320 match state.db.delete_memory(&auth.session, id).await {
321 Ok(()) => StatusCode::NO_CONTENT.into_response(),
322 Err(error) => map_error(error, Some(request_id.0)),
323 }
324}
325
326async fn create_branch(
327 State(state): State<Arc<AppState>>,
328 Extension(auth): Extension<AuthContext>,
329 Extension(request_id): Extension<RequestId>,
330 Json(body): Json<BranchBody>,
331) -> Response {
332 let branch = if let Some(from) = body.from {
333 state.db.fork_branch(&auth.session, from, &body.name).await
334 } else {
335 state.db.branch(&auth.session, &body.name).await
336 };
337 match branch {
338 Ok(id) => Json(serde_json::json!({"id": id, "branch_id": id, "name": body.name})).into_response(),
339 Err(error) => map_error(error, Some(request_id.0)),
340 }
341}
342
343async fn list_branches(
344 State(state): State<Arc<AppState>>,
345 Extension(auth): Extension<AuthContext>,
346 Extension(request_id): Extension<RequestId>,
347) -> Response {
348 match state.db.list_branches(&auth.session).await {
349 Ok(branches) => Json(branches).into_response(),
350 Err(error) => map_error(error, Some(request_id.0)),
351 }
352}
353
354async fn merge_branch(
355 State(state): State<Arc<AppState>>,
356 Extension(auth): Extension<AuthContext>,
357 Extension(request_id): Extension<RequestId>,
358 Path(id): Path<Uuid>,
359 Json(body): Json<MergeBody>,
360) -> Response {
361 match state
362 .db
363 .merge_with_strategy(
364 &auth.session,
365 id,
366 body.target,
367 parse_strategy(body.strategy.as_deref()),
368 )
369 .await
370 {
371 Ok(result) => Json(result).into_response(),
372 Err(error) => map_error(error, Some(request_id.0)),
373 }
374}
375
376async fn diff_branch(
377 State(state): State<Arc<AppState>>,
378 Extension(auth): Extension<AuthContext>,
379 Extension(request_id): Extension<RequestId>,
380 Path(id): Path<Uuid>,
381 Query(query): Query<DiffQuery>,
382) -> Response {
383 match state.db.diff(&auth.session, id, query.target).await {
384 Ok(result) => Json(result).into_response(),
385 Err(error) => map_error(error, Some(request_id.0)),
386 }
387}
388
389async fn discard_branch(
390 State(state): State<Arc<AppState>>,
391 Extension(auth): Extension<AuthContext>,
392 Extension(request_id): Extension<RequestId>,
393 Path(id): Path<Uuid>,
394) -> Response {
395 match state.db.discard_branch(&auth.session, id).await {
396 Ok(()) => StatusCode::NO_CONTENT.into_response(),
397 Err(error) => map_error(error, Some(request_id.0)),
398 }
399}
400
401async fn sync(
402 State(state): State<Arc<AppState>>,
403 Extension(auth): Extension<AuthContext>,
404 Extension(request_id): Extension<RequestId>,
405) -> Response {
406 match state.db.sync(&auth.session).await {
407 Ok(result) => Json(result).into_response(),
408 Err(error) => map_error(error, Some(request_id.0)),
409 }
410}
411
412async fn reflect(
413 State(state): State<Arc<AppState>>,
414 Extension(auth): Extension<AuthContext>,
415 Extension(request_id): Extension<RequestId>,
416) -> Response {
417 match state.db.reflect(&auth.session).await {
418 Ok(result) => Json(result).into_response(),
419 Err(error) => map_error(error, Some(request_id.0)),
420 }
421}
422
423fn map_error(error: ClawDBError, request_id: Option<String>) -> Response {
424 match error {
425 ClawDBError::PermissionDenied(reason) => auth::error_response(
426 StatusCode::FORBIDDEN,
427 "permission_denied",
428 Some(reason),
429 request_id,
430 None,
431 ),
432 ClawDBError::SessionInvalid => auth::error_response(
433 StatusCode::UNAUTHORIZED,
434 "session_invalid",
435 None,
436 request_id,
437 None,
438 ),
439 ClawDBError::ComponentDisabled(component) => auth::error_response(
440 StatusCode::SERVICE_UNAVAILABLE,
441 "component_disabled",
442 None,
443 request_id,
444 Some(component.to_string()),
445 ),
446 other => {
447 tracing::error!(request_id = ?request_id, error = %other, "HTTP handler failed");
448 auth::error_response(
449 StatusCode::INTERNAL_SERVER_ERROR,
450 "internal",
451 None,
452 request_id,
453 None,
454 )
455 }
456 }
457}
458
459fn parse_strategy(value: Option<&str>) -> MergeStrategy {
460 match value.unwrap_or("theirs").to_ascii_lowercase().as_str() {
461 "ours" => MergeStrategy::Ours,
462 "union" => MergeStrategy::Union,
463 "manual" => MergeStrategy::Manual,
464 _ => MergeStrategy::Theirs,
465 }
466}