Skip to main content

engram_server/
rest.rs

1//! Axum REST API for Engram memory operations.
2
3use crate::config::LlmBackend;
4use axum::extract::{Path, Query, State};
5use axum::http::StatusCode;
6use axum::response::IntoResponse;
7use axum::routing::{delete, get, post};
8use axum::{Json, Router};
9use engram::context::{ContextConfig, OutputFormat};
10use engram::extract::{ExtractionConfig, Message};
11use engram::memory::{Memory, RecallQuery};
12use engram::scope::Scope;
13use serde::{Deserialize, Serialize};
14use std::sync::Arc;
15
16// ---------------------------------------------------------------------------
17// App state
18// ---------------------------------------------------------------------------
19
20#[derive(Clone)]
21pub struct AppState {
22    pub memory: Arc<Memory>,
23    pub llm_backend: LlmBackend,
24}
25
26// ---------------------------------------------------------------------------
27// Request / Response types
28// ---------------------------------------------------------------------------
29
30#[derive(Deserialize)]
31pub struct AddRequest {
32    pub messages: Vec<MessagePayload>,
33    pub user_id: Option<String>,
34    pub org_id: Option<String>,
35    pub session_id: Option<String>,
36}
37
38#[derive(Deserialize)]
39pub struct MessagePayload {
40    pub role: String,
41    pub content: String,
42}
43
44#[derive(Deserialize)]
45pub struct RecallParams {
46    pub q: String,
47    pub user_id: Option<String>,
48    pub org_id: Option<String>,
49    pub max_results: Option<usize>,
50}
51
52#[derive(Deserialize)]
53pub struct ContextRequest {
54    pub query: String,
55    pub user_id: Option<String>,
56    pub org_id: Option<String>,
57    pub token_budget: Option<usize>,
58    pub format: Option<String>,
59}
60
61#[derive(Deserialize)]
62pub struct SearchParams {
63    pub q: String,
64    pub user_id: Option<String>,
65    pub org_id: Option<String>,
66    pub top_k: Option<usize>,
67}
68
69#[derive(Deserialize)]
70pub struct ForgetRequest {
71    pub reason: Option<String>,
72}
73
74#[derive(Deserialize)]
75pub struct ConsolidateRequest {
76    pub user_id: Option<String>,
77    pub org_id: Option<String>,
78}
79
80#[derive(Serialize)]
81struct ErrorResponse {
82    error: String,
83}
84
85// ---------------------------------------------------------------------------
86// Helpers
87// ---------------------------------------------------------------------------
88
89fn parse_scope(org_id: Option<&str>, user_id: Option<&str>, session_id: Option<&str>) -> Scope {
90    let org = org_id.unwrap_or("default");
91    match user_id {
92        Some(uid) => match session_id {
93            Some(sid) => Scope::session(org, uid, sid),
94            None => Scope::user(org, uid),
95        },
96        None => Scope::org(org),
97    }
98}
99
100fn err(status: StatusCode, msg: impl Into<String>) -> (StatusCode, Json<ErrorResponse>) {
101    (status, Json(ErrorResponse { error: msg.into() }))
102}
103
104// ---------------------------------------------------------------------------
105// Handlers
106// ---------------------------------------------------------------------------
107
108/// POST /v1/memory
109async fn add_handler(
110    State(state): State<AppState>,
111    Json(body): Json<AddRequest>,
112) -> impl IntoResponse {
113    let messages: Vec<Message> = body
114        .messages
115        .iter()
116        .map(|m| Message {
117            role: m.role.clone(),
118            content: m.content.clone(),
119        })
120        .collect();
121
122    if messages.is_empty() {
123        return err(StatusCode::BAD_REQUEST, "messages must not be empty").into_response();
124    }
125
126    let scope = parse_scope(
127        body.org_id.as_deref(),
128        body.user_id.as_deref(),
129        body.session_id.as_deref(),
130    );
131
132    match state
133        .memory
134        .add_messages(
135            &messages,
136            scope,
137            state.llm_backend.build(),
138            ExtractionConfig::default(),
139        )
140        .await
141    {
142        Ok(ids) => {
143            let fact_ids: Vec<String> = ids.iter().map(|id| id.to_string()).collect();
144            (
145                StatusCode::CREATED,
146                Json(serde_json::json!({
147                    "success": true,
148                    "fact_count": ids.len(),
149                    "fact_ids": fact_ids,
150                })),
151            )
152                .into_response()
153        }
154        Err(e) => err(StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(),
155    }
156}
157
158/// GET /v1/memory/recall?q=...
159async fn recall_handler(
160    State(state): State<AppState>,
161    Query(params): Query<RecallParams>,
162) -> impl IntoResponse {
163    let scope = parse_scope(params.org_id.as_deref(), params.user_id.as_deref(), None);
164
165    let query = RecallQuery {
166        query: params.q,
167        scope: Some(scope),
168        max_results: params.max_results.unwrap_or(10),
169        as_of: None,
170        min_score: None,
171    };
172
173    match state.memory.recall(&query).await {
174        Ok(facts) => {
175            let results: Vec<serde_json::Value> = facts
176                .iter()
177                .map(|f| {
178                    serde_json::json!({
179                        "fact_id": f.id.to_string(),
180                        "text": f.text,
181                        "tier": f.tier,
182                        "category": f.category,
183                        "confidence": f.confidence,
184                    })
185                })
186                .collect();
187            Json(serde_json::json!({ "results": results, "total": results.len() })).into_response()
188        }
189        Err(e) => err(StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(),
190    }
191}
192
193/// POST /v1/memory/context
194async fn context_handler(
195    State(state): State<AppState>,
196    Json(body): Json<ContextRequest>,
197) -> impl IntoResponse {
198    let scope = parse_scope(body.org_id.as_deref(), body.user_id.as_deref(), None);
199
200    let format = match body.format.as_deref() {
201        Some("markdown") => OutputFormat::Markdown,
202        Some("raw") => OutputFormat::Raw,
203        _ => OutputFormat::SystemPrompt,
204    };
205
206    let config = ContextConfig {
207        token_budget: body.token_budget.unwrap_or(2000),
208        format,
209        ..Default::default()
210    };
211
212    match state.memory.context(&body.query, &scope, config).await {
213        Ok(block) => Json(serde_json::json!({
214            "text": block.text,
215            "token_count": block.token_count,
216            "facts_included": block.facts_included,
217            "facts_omitted": block.facts_omitted,
218            "tier_breakdown": block.tier_breakdown,
219        }))
220        .into_response(),
221        Err(e) => err(StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(),
222    }
223}
224
225/// DELETE /v1/memory/facts/:id
226async fn forget_handler(
227    State(state): State<AppState>,
228    Path(fact_id): Path<String>,
229    body: Option<Json<ForgetRequest>>,
230) -> impl IntoResponse {
231    let id = match uuid::Uuid::parse_str(&fact_id) {
232        Ok(id) => id,
233        Err(e) => {
234            return err(StatusCode::BAD_REQUEST, format!("invalid fact_id: {e}")).into_response()
235        }
236    };
237
238    let reason = body.and_then(|b| b.reason.clone());
239
240    match state.memory.forget(id, reason.as_deref()).await {
241        Ok(()) => Json(serde_json::json!({
242            "success": true,
243            "deleted_fact_id": fact_id,
244        }))
245        .into_response(),
246        Err(e) => err(StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(),
247    }
248}
249
250/// GET /v1/memory/search?q=...
251async fn search_handler(
252    State(state): State<AppState>,
253    Query(params): Query<SearchParams>,
254) -> impl IntoResponse {
255    let scope = parse_scope(params.org_id.as_deref(), params.user_id.as_deref(), None);
256    let top_k = params.top_k.unwrap_or(10);
257
258    match state
259        .memory
260        .fact_store()
261        .keyword_search(&params.q, &scope, top_k)
262        .await
263    {
264        Ok(facts) => {
265            let results: Vec<serde_json::Value> = facts
266                .iter()
267                .map(|f| {
268                    serde_json::json!({
269                        "fact_id": f.id.to_string(),
270                        "text": f.text,
271                        "tier": f.tier,
272                        "category": f.category,
273                    })
274                })
275                .collect();
276            Json(serde_json::json!({ "results": results, "total": results.len() })).into_response()
277        }
278        Err(e) => err(StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(),
279    }
280}
281
282/// GET /v1/memory/stats
283async fn stats_handler(State(state): State<AppState>) -> impl IntoResponse {
284    match state.memory.stats(None).await {
285        Ok(stats) => Json(serde_json::json!({
286            "total_facts": stats.total_facts,
287            "valid_facts": stats.valid_facts,
288            "invalidated_facts": stats.invalidated_facts,
289            "total_entities": stats.total_entities,
290            "total_relationships": stats.total_relationships,
291        }))
292        .into_response(),
293        Err(e) => err(StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(),
294    }
295}
296
297/// POST /v1/memory/consolidate
298async fn consolidate_handler(
299    State(state): State<AppState>,
300    Json(body): Json<ConsolidateRequest>,
301) -> impl IntoResponse {
302    let scope = parse_scope(body.org_id.as_deref(), body.user_id.as_deref(), None);
303    let config = engram::consolidation::ConsolidationConfig::default();
304
305    match state.memory.consolidate(&scope, None, config).await {
306        Ok(result) => Json(serde_json::json!(result)).into_response(),
307        Err(e) => err(StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(),
308    }
309}
310
311/// DELETE /v1/memory/users/:id
312async fn delete_user_handler(
313    State(state): State<AppState>,
314    Path(user_id): Path<String>,
315) -> impl IntoResponse {
316    let scope = Scope::user("default", &user_id);
317
318    match state.memory.delete_user_data(scope).await {
319        Ok(count) => Json(serde_json::json!({
320            "success": true,
321            "deleted_facts": count,
322        }))
323        .into_response(),
324        Err(e) => err(StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(),
325    }
326}
327
328/// GET /health
329async fn health_handler() -> impl IntoResponse {
330    Json(serde_json::json!({ "status": "ok", "service": "engram" }))
331}
332
333// ---------------------------------------------------------------------------
334// Router
335// ---------------------------------------------------------------------------
336
337/// Build the Axum router with all REST endpoints.
338pub fn build_router(state: AppState) -> Router {
339    Router::new()
340        .route("/health", get(health_handler))
341        .route("/v1/memory", post(add_handler))
342        .route("/v1/memory/recall", get(recall_handler))
343        .route("/v1/memory/context", post(context_handler))
344        .route("/v1/memory/facts/:id", delete(forget_handler))
345        .route("/v1/memory/search", get(search_handler))
346        .route("/v1/memory/stats", get(stats_handler))
347        .route("/v1/memory/consolidate", post(consolidate_handler))
348        .route("/v1/memory/users/:id", delete(delete_user_handler))
349        .with_state(state)
350}