Skip to main content

clawdb_server/http/
router.rs

1use std::sync::Arc;
2
3use axum::{
4    extract::{Path, Query, State},
5    http::{HeaderValue, StatusCode},
6    middleware,
7    response::{IntoResponse, Response},
8    routing::{delete, get, post},
9    Extension, Json, Router,
10};
11use clawdb::{prelude::MergeStrategy, ClawDBError};
12use serde::Deserialize;
13use tower_http::{
14    limit::RequestBodyLimitLayer, normalize_path::NormalizePathLayer,
15    set_header::SetResponseHeaderLayer,
16};
17use uuid::Uuid;
18
19use crate::{
20    http::auth::{self, AuthContext},
21    state::{AppState, RequestId},
22};
23
24#[derive(Deserialize)]
25struct CreateSessionBody {
26    agent_id: Uuid,
27    role: String,
28    scopes: Vec<String>,
29    #[serde(default)]
30    ttl_secs: Option<u64>,
31}
32
33#[derive(Deserialize)]
34struct MemoryBody {
35    content: String,
36    #[serde(default)]
37    r#type: Option<String>,
38    #[serde(default)]
39    tags: Vec<String>,
40    #[serde(default)]
41    metadata: serde_json::Value,
42}
43
44#[derive(Deserialize)]
45struct SearchQuery {
46    q: String,
47    #[serde(default = "default_top_k")]
48    top_k: usize,
49    #[serde(default)]
50    semantic: bool,
51}
52
53#[derive(Deserialize)]
54struct ListMemoriesQuery {
55    #[serde(default)]
56    r#type: Option<String>,
57    #[serde(default)]
58    limit: Option<usize>,
59}
60
61fn default_top_k() -> usize {
62    10
63}
64
65#[derive(Deserialize)]
66struct BranchBody {
67    name: String,
68    #[serde(default)]
69    from: Option<Uuid>,
70}
71
72#[derive(Deserialize)]
73struct MergeBody {
74    #[serde(alias = "target_id")]
75    target: Uuid,
76    #[serde(default)]
77    strategy: Option<String>,
78}
79
80#[derive(Deserialize)]
81struct DiffQuery {
82    target: Uuid,
83}
84
85pub fn router(state: Arc<AppState>) -> Router {
86    let public = Router::new()
87        .route("/v1/health", get(health))
88        .route("/v1/ready", get(ready))
89        .route("/v1/sessions", post(create_session))
90        .route("/v1/metrics", get(metrics));
91
92    let protected = Router::new()
93        .route("/v1/sessions/me", get(whoami))
94        .route("/v1/sessions/:id", delete(revoke_session))
95        .route("/v1/memories", post(remember).get(list_memories))
96        .route("/v1/memories/search", get(search))
97        .route("/v1/memories/:id", get(recall_one).delete(delete_memory))
98        .route("/v1/branches", post(create_branch).get(list_branches))
99        .route("/v1/branches/:id/merge", post(merge_branch))
100        .route("/v1/branches/:id/diff", get(diff_branch))
101        .route("/v1/branches/:id", delete(discard_branch))
102        .route("/v1/sync", post(sync))
103        .route("/v1/reflect", post(reflect))
104        .layer(middleware::from_fn_with_state(
105            state.clone(),
106            auth::rate_limit_middleware,
107        ))
108        .layer(middleware::from_fn_with_state(
109            state.clone(),
110            auth::auth_middleware,
111        ));
112
113    public
114        .merge(protected)
115        .layer(middleware::from_fn_with_state(
116            state.clone(),
117            auth::metrics_middleware,
118        ))
119        .layer(middleware::from_fn(auth::request_id_middleware))
120        .layer(SetResponseHeaderLayer::if_not_present(
121            axum::http::header::HeaderName::from_static("content-security-policy"),
122            HeaderValue::from_static("default-src 'none'; frame-ancestors 'none'; base-uri 'none'"),
123        ))
124        .layer(SetResponseHeaderLayer::if_not_present(
125            axum::http::header::HeaderName::from_static("x-content-type-options"),
126            HeaderValue::from_static("nosniff"),
127        ))
128        .layer(SetResponseHeaderLayer::if_not_present(
129            axum::http::header::HeaderName::from_static("x-frame-options"),
130            HeaderValue::from_static("DENY"),
131        ))
132        .layer(RequestBodyLimitLayer::new(10 * 1024 * 1024))
133        .layer(NormalizePathLayer::trim_trailing_slash())
134        .with_state(state)
135}
136
137pub fn metrics_router(state: Arc<AppState>) -> Router {
138    Router::new()
139        .route("/", get(metrics))
140        .route("/metrics", get(metrics))
141        .route("/v1/metrics", get(metrics))
142        .with_state(state)
143}
144
145async fn health(State(state): State<Arc<AppState>>) -> Response {
146    match state.db.health().await {
147        Ok(report) => Json(report).into_response(),
148        Err(error) => map_error(error, None),
149    }
150}
151
152async fn ready(State(state): State<Arc<AppState>>) -> Response {
153    match state.db.health().await {
154        Ok(report) if report.ok => StatusCode::OK.into_response(),
155        Ok(_) => StatusCode::SERVICE_UNAVAILABLE.into_response(),
156        Err(error) => map_error(error, None),
157    }
158}
159
160async fn metrics(State(state): State<Arc<AppState>>) -> Response {
161    if let Ok(count) = state.db.active_session_count().await {
162        state.metrics.set_active_sessions(count);
163    }
164    let rendered = state.metrics.render(state.db.metrics_handle().render());
165    (
166        StatusCode::OK,
167        [(
168            axum::http::header::CONTENT_TYPE,
169            HeaderValue::from_static("text/plain; version=0.0.4"),
170        )],
171        rendered,
172    )
173        .into_response()
174}
175
176async fn create_session(
177    State(state): State<Arc<AppState>>,
178    Extension(request_id): Extension<RequestId>,
179    Json(body): Json<CreateSessionBody>,
180) -> Response {
181    match state
182        .db
183        .session_with_ttl(
184            body.agent_id,
185            &body.role,
186            body.scopes,
187            body.ttl_secs.unwrap_or(3600) as i64,
188        )
189        .await
190    {
191        Ok(session) => Json(serde_json::json!({
192            "id": session.id,
193            "session_id": session.id,
194            "agent_id": session.agent_id,
195            "role": session.role,
196            "token": session.token,
197            "expires_at": session.expires_at.to_rfc3339(),
198            "scopes": session.scopes,
199        }))
200        .into_response(),
201        Err(error) => map_error(error, Some(request_id.0)),
202    }
203}
204
205async fn whoami(Extension(auth): Extension<AuthContext>) -> Response {
206    Json(serde_json::json!({
207        "id": auth.session.id,
208        "session_id": auth.session.id,
209        "agent_id": auth.session.agent_id,
210        "role": auth.session.role,
211        "token": auth.session.token,
212        "expires_at": auth.session.expires_at.to_rfc3339(),
213        "scopes": auth.session.scopes,
214    }))
215    .into_response()
216}
217
218async fn revoke_session(
219    State(state): State<Arc<AppState>>,
220    Extension(request_id): Extension<RequestId>,
221    Path(id): Path<Uuid>,
222) -> Response {
223    match state.db.revoke_session(id).await {
224        Ok(()) => StatusCode::NO_CONTENT.into_response(),
225        Err(error) => map_error(error, Some(request_id.0)),
226    }
227}
228
229async fn remember(
230    State(state): State<Arc<AppState>>,
231    Extension(auth): Extension<AuthContext>,
232    Extension(request_id): Extension<RequestId>,
233    Json(body): Json<MemoryBody>,
234) -> Response {
235    let result = if let Some(memory_type) = body.r#type.as_deref() {
236        state
237            .db
238            .remember_typed(
239                &auth.session,
240                &body.content,
241                memory_type,
242                &body.tags,
243                body.metadata,
244            )
245            .await
246    } else {
247        state.db.remember(&auth.session, &body.content).await
248    };
249
250    match result {
251        Ok(remembered) => Json(remembered).into_response(),
252        Err(error) => map_error(error, Some(request_id.0)),
253    }
254}
255
256async fn search(
257    State(state): State<Arc<AppState>>,
258    Extension(auth): Extension<AuthContext>,
259    Extension(request_id): Extension<RequestId>,
260    Query(query): Query<SearchQuery>,
261) -> Response {
262    match state
263        .db
264        .search_with_options(&auth.session, &query.q, query.top_k, query.semantic, None)
265        .await
266    {
267        Ok(hits) => Json(hits).into_response(),
268        Err(error) => map_error(error, Some(request_id.0)),
269    }
270}
271
272async fn recall_one(
273    State(state): State<Arc<AppState>>,
274    Extension(auth): Extension<AuthContext>,
275    Extension(request_id): Extension<RequestId>,
276    Path(id): Path<Uuid>,
277) -> Response {
278    match state.db.recall(&auth.session, &[id]).await {
279        Ok(mut memories) => match memories.pop() {
280            Some(memory) => Json(memory).into_response(),
281            None => auth::error_response(
282                StatusCode::NOT_FOUND,
283                "not_found",
284                None,
285                Some(request_id.0),
286                None,
287            ),
288        },
289        Err(error) => map_error(error, Some(request_id.0)),
290    }
291}
292
293async fn list_memories(
294    State(state): State<Arc<AppState>>,
295    Extension(auth): Extension<AuthContext>,
296    Extension(request_id): Extension<RequestId>,
297    Query(query): Query<ListMemoriesQuery>,
298) -> Response {
299    match state
300        .db
301        .list_memories(&auth.session, query.r#type.as_deref())
302        .await
303    {
304        Ok(mut memories) => {
305            if let Some(limit) = query.limit {
306                memories.truncate(limit);
307            }
308            Json(memories).into_response()
309        }
310        Err(error) => map_error(error, Some(request_id.0)),
311    }
312}
313
314async fn delete_memory(
315    State(state): State<Arc<AppState>>,
316    Extension(auth): Extension<AuthContext>,
317    Extension(request_id): Extension<RequestId>,
318    Path(id): Path<Uuid>,
319) -> Response {
320    match state.db.delete_memory(&auth.session, id).await {
321        Ok(()) => StatusCode::NO_CONTENT.into_response(),
322        Err(error) => map_error(error, Some(request_id.0)),
323    }
324}
325
326async fn create_branch(
327    State(state): State<Arc<AppState>>,
328    Extension(auth): Extension<AuthContext>,
329    Extension(request_id): Extension<RequestId>,
330    Json(body): Json<BranchBody>,
331) -> Response {
332    let branch = if let Some(from) = body.from {
333        state.db.fork_branch(&auth.session, from, &body.name).await
334    } else {
335        state.db.branch(&auth.session, &body.name).await
336    };
337    match branch {
338        Ok(id) => Json(serde_json::json!({"id": id, "branch_id": id, "name": body.name})).into_response(),
339        Err(error) => map_error(error, Some(request_id.0)),
340    }
341}
342
343async fn list_branches(
344    State(state): State<Arc<AppState>>,
345    Extension(auth): Extension<AuthContext>,
346    Extension(request_id): Extension<RequestId>,
347) -> Response {
348    match state.db.list_branches(&auth.session).await {
349        Ok(branches) => Json(branches).into_response(),
350        Err(error) => map_error(error, Some(request_id.0)),
351    }
352}
353
354async fn merge_branch(
355    State(state): State<Arc<AppState>>,
356    Extension(auth): Extension<AuthContext>,
357    Extension(request_id): Extension<RequestId>,
358    Path(id): Path<Uuid>,
359    Json(body): Json<MergeBody>,
360) -> Response {
361    match state
362        .db
363        .merge_with_strategy(
364            &auth.session,
365            id,
366            body.target,
367            parse_strategy(body.strategy.as_deref()),
368        )
369        .await
370    {
371        Ok(result) => Json(result).into_response(),
372        Err(error) => map_error(error, Some(request_id.0)),
373    }
374}
375
376async fn diff_branch(
377    State(state): State<Arc<AppState>>,
378    Extension(auth): Extension<AuthContext>,
379    Extension(request_id): Extension<RequestId>,
380    Path(id): Path<Uuid>,
381    Query(query): Query<DiffQuery>,
382) -> Response {
383    match state.db.diff(&auth.session, id, query.target).await {
384        Ok(result) => Json(result).into_response(),
385        Err(error) => map_error(error, Some(request_id.0)),
386    }
387}
388
389async fn discard_branch(
390    State(state): State<Arc<AppState>>,
391    Extension(auth): Extension<AuthContext>,
392    Extension(request_id): Extension<RequestId>,
393    Path(id): Path<Uuid>,
394) -> Response {
395    match state.db.discard_branch(&auth.session, id).await {
396        Ok(()) => StatusCode::NO_CONTENT.into_response(),
397        Err(error) => map_error(error, Some(request_id.0)),
398    }
399}
400
401async fn sync(
402    State(state): State<Arc<AppState>>,
403    Extension(auth): Extension<AuthContext>,
404    Extension(request_id): Extension<RequestId>,
405) -> Response {
406    match state.db.sync(&auth.session).await {
407        Ok(result) => Json(result).into_response(),
408        Err(error) => map_error(error, Some(request_id.0)),
409    }
410}
411
412async fn reflect(
413    State(state): State<Arc<AppState>>,
414    Extension(auth): Extension<AuthContext>,
415    Extension(request_id): Extension<RequestId>,
416) -> Response {
417    match state.db.reflect(&auth.session).await {
418        Ok(result) => Json(result).into_response(),
419        Err(error) => map_error(error, Some(request_id.0)),
420    }
421}
422
423fn map_error(error: ClawDBError, request_id: Option<String>) -> Response {
424    match error {
425        ClawDBError::PermissionDenied(reason) => auth::error_response(
426            StatusCode::FORBIDDEN,
427            "permission_denied",
428            Some(reason),
429            request_id,
430            None,
431        ),
432        ClawDBError::SessionInvalid => auth::error_response(
433            StatusCode::UNAUTHORIZED,
434            "session_invalid",
435            None,
436            request_id,
437            None,
438        ),
439        ClawDBError::ComponentDisabled(component) => auth::error_response(
440            StatusCode::SERVICE_UNAVAILABLE,
441            "component_disabled",
442            None,
443            request_id,
444            Some(component.to_string()),
445        ),
446        other => {
447            tracing::error!(request_id = ?request_id, error = %other, "HTTP handler failed");
448            auth::error_response(
449                StatusCode::INTERNAL_SERVER_ERROR,
450                "internal",
451                None,
452                request_id,
453                None,
454            )
455        }
456    }
457}
458
459fn parse_strategy(value: Option<&str>) -> MergeStrategy {
460    match value.unwrap_or("theirs").to_ascii_lowercase().as_str() {
461        "ours" => MergeStrategy::Ours,
462        "union" => MergeStrategy::Union,
463        "manual" => MergeStrategy::Manual,
464        _ => MergeStrategy::Theirs,
465    }
466}