Skip to main content

engram_server/
rest.rs

1//! Axum REST API for Engram memory operations.
2
3use axum::extract::{Path, Query, State};
4use axum::http::StatusCode;
5use axum::response::IntoResponse;
6use axum::routing::{delete, get, post};
7use axum::{Json, Router};
8use engram::context::{ContextConfig, OutputFormat};
9use engram::extract::{ExtractionConfig, Message};
10use engram::llm::MockLlmClient;
11use engram::memory::{Memory, RecallQuery};
12use engram::scope::Scope;
13use serde::{Deserialize, Serialize};
14use std::sync::Arc;
15
16// ---------------------------------------------------------------------------
17// App state
18// ---------------------------------------------------------------------------
19
20#[derive(Clone)]
21pub struct AppState {
22    pub memory: Arc<Memory>,
23}
24
25// ---------------------------------------------------------------------------
26// Request / Response types
27// ---------------------------------------------------------------------------
28
29#[derive(Deserialize)]
30pub struct AddRequest {
31    pub messages: Vec<MessagePayload>,
32    pub user_id: Option<String>,
33    pub org_id: Option<String>,
34    pub session_id: Option<String>,
35}
36
37#[derive(Deserialize)]
38pub struct MessagePayload {
39    pub role: String,
40    pub content: String,
41}
42
43#[derive(Deserialize)]
44pub struct RecallParams {
45    pub q: String,
46    pub user_id: Option<String>,
47    pub org_id: Option<String>,
48    pub max_results: Option<usize>,
49}
50
51#[derive(Deserialize)]
52pub struct ContextRequest {
53    pub query: String,
54    pub user_id: Option<String>,
55    pub org_id: Option<String>,
56    pub token_budget: Option<usize>,
57    pub format: Option<String>,
58}
59
60#[derive(Deserialize)]
61pub struct SearchParams {
62    pub q: String,
63    pub user_id: Option<String>,
64    pub org_id: Option<String>,
65    pub top_k: Option<usize>,
66}
67
68#[derive(Deserialize)]
69pub struct ForgetRequest {
70    pub reason: Option<String>,
71}
72
73#[derive(Deserialize)]
74pub struct ConsolidateRequest {
75    pub user_id: Option<String>,
76    pub org_id: Option<String>,
77}
78
79#[derive(Serialize)]
80struct ErrorResponse {
81    error: String,
82}
83
84// ---------------------------------------------------------------------------
85// Helpers
86// ---------------------------------------------------------------------------
87
88fn parse_scope(org_id: Option<&str>, user_id: Option<&str>, session_id: Option<&str>) -> Scope {
89    let org = org_id.unwrap_or("default");
90    match user_id {
91        Some(uid) => match session_id {
92            Some(sid) => Scope::session(org, uid, sid),
93            None => Scope::user(org, uid),
94        },
95        None => Scope::org(org),
96    }
97}
98
99fn err(status: StatusCode, msg: impl Into<String>) -> (StatusCode, Json<ErrorResponse>) {
100    (status, Json(ErrorResponse { error: msg.into() }))
101}
102
103// ---------------------------------------------------------------------------
104// Handlers
105// ---------------------------------------------------------------------------
106
107/// POST /v1/memory
108async fn add_handler(
109    State(state): State<AppState>,
110    Json(body): Json<AddRequest>,
111) -> impl IntoResponse {
112    let messages: Vec<Message> = body
113        .messages
114        .iter()
115        .map(|m| Message {
116            role: m.role.clone(),
117            content: m.content.clone(),
118        })
119        .collect();
120
121    if messages.is_empty() {
122        return err(StatusCode::BAD_REQUEST, "messages must not be empty").into_response();
123    }
124
125    let scope = parse_scope(
126        body.org_id.as_deref(),
127        body.user_id.as_deref(),
128        body.session_id.as_deref(),
129    );
130
131    let llm = MockLlmClient::new(vec![serde_json::json!({"facts": []})]);
132
133    match state
134        .memory
135        .add_messages(&messages, scope, Box::new(llm), ExtractionConfig::default())
136        .await
137    {
138        Ok(ids) => {
139            let fact_ids: Vec<String> = ids.iter().map(|id| id.to_string()).collect();
140            (
141                StatusCode::CREATED,
142                Json(serde_json::json!({
143                    "success": true,
144                    "fact_count": ids.len(),
145                    "fact_ids": fact_ids,
146                })),
147            )
148                .into_response()
149        }
150        Err(e) => err(StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(),
151    }
152}
153
154/// GET /v1/memory/recall?q=...
155async fn recall_handler(
156    State(state): State<AppState>,
157    Query(params): Query<RecallParams>,
158) -> impl IntoResponse {
159    let scope = parse_scope(params.org_id.as_deref(), params.user_id.as_deref(), None);
160
161    let query = RecallQuery {
162        query: params.q,
163        scope: Some(scope),
164        max_results: params.max_results.unwrap_or(10),
165        as_of: None,
166        min_score: None,
167    };
168
169    match state.memory.recall(&query).await {
170        Ok(facts) => {
171            let results: Vec<serde_json::Value> = facts
172                .iter()
173                .map(|f| {
174                    serde_json::json!({
175                        "fact_id": f.id.to_string(),
176                        "text": f.text,
177                        "tier": f.tier,
178                        "category": f.category,
179                        "confidence": f.confidence,
180                    })
181                })
182                .collect();
183            Json(serde_json::json!({ "results": results, "total": results.len() })).into_response()
184        }
185        Err(e) => err(StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(),
186    }
187}
188
189/// POST /v1/memory/context
190async fn context_handler(
191    State(state): State<AppState>,
192    Json(body): Json<ContextRequest>,
193) -> impl IntoResponse {
194    let scope = parse_scope(body.org_id.as_deref(), body.user_id.as_deref(), None);
195
196    let format = match body.format.as_deref() {
197        Some("markdown") => OutputFormat::Markdown,
198        Some("raw") => OutputFormat::Raw,
199        _ => OutputFormat::SystemPrompt,
200    };
201
202    let config = ContextConfig {
203        token_budget: body.token_budget.unwrap_or(2000),
204        format,
205        ..Default::default()
206    };
207
208    match state.memory.context(&body.query, &scope, config).await {
209        Ok(block) => Json(serde_json::json!({
210            "text": block.text,
211            "token_count": block.token_count,
212            "facts_included": block.facts_included,
213            "facts_omitted": block.facts_omitted,
214            "tier_breakdown": block.tier_breakdown,
215        }))
216        .into_response(),
217        Err(e) => err(StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(),
218    }
219}
220
221/// DELETE /v1/memory/facts/:id
222async fn forget_handler(
223    State(state): State<AppState>,
224    Path(fact_id): Path<String>,
225    body: Option<Json<ForgetRequest>>,
226) -> impl IntoResponse {
227    let id = match uuid::Uuid::parse_str(&fact_id) {
228        Ok(id) => id,
229        Err(e) => {
230            return err(StatusCode::BAD_REQUEST, format!("invalid fact_id: {e}")).into_response()
231        }
232    };
233
234    let reason = body.and_then(|b| b.reason.clone());
235
236    match state.memory.forget(id, reason.as_deref()).await {
237        Ok(()) => Json(serde_json::json!({
238            "success": true,
239            "deleted_fact_id": fact_id,
240        }))
241        .into_response(),
242        Err(e) => err(StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(),
243    }
244}
245
246/// GET /v1/memory/search?q=...
247async fn search_handler(
248    State(state): State<AppState>,
249    Query(params): Query<SearchParams>,
250) -> impl IntoResponse {
251    let scope = parse_scope(params.org_id.as_deref(), params.user_id.as_deref(), None);
252    let top_k = params.top_k.unwrap_or(10);
253
254    match state
255        .memory
256        .fact_store()
257        .keyword_search(&params.q, &scope, top_k)
258        .await
259    {
260        Ok(facts) => {
261            let results: Vec<serde_json::Value> = facts
262                .iter()
263                .map(|f| {
264                    serde_json::json!({
265                        "fact_id": f.id.to_string(),
266                        "text": f.text,
267                        "tier": f.tier,
268                        "category": f.category,
269                    })
270                })
271                .collect();
272            Json(serde_json::json!({ "results": results, "total": results.len() })).into_response()
273        }
274        Err(e) => err(StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(),
275    }
276}
277
278/// GET /v1/memory/stats
279async fn stats_handler(State(state): State<AppState>) -> impl IntoResponse {
280    match state.memory.stats(None).await {
281        Ok(stats) => Json(serde_json::json!({
282            "total_facts": stats.total_facts,
283            "valid_facts": stats.valid_facts,
284            "invalidated_facts": stats.invalidated_facts,
285            "total_entities": stats.total_entities,
286            "total_relationships": stats.total_relationships,
287        }))
288        .into_response(),
289        Err(e) => err(StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(),
290    }
291}
292
293/// POST /v1/memory/consolidate
294async fn consolidate_handler(
295    State(state): State<AppState>,
296    Json(body): Json<ConsolidateRequest>,
297) -> impl IntoResponse {
298    let scope = parse_scope(body.org_id.as_deref(), body.user_id.as_deref(), None);
299    let config = engram::consolidation::ConsolidationConfig::default();
300
301    match state.memory.consolidate(&scope, None, config).await {
302        Ok(result) => Json(serde_json::json!(result)).into_response(),
303        Err(e) => err(StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(),
304    }
305}
306
307/// DELETE /v1/memory/users/:id
308async fn delete_user_handler(
309    State(state): State<AppState>,
310    Path(user_id): Path<String>,
311) -> impl IntoResponse {
312    let scope = Scope::user("default", &user_id);
313
314    match state.memory.delete_user_data(scope).await {
315        Ok(count) => Json(serde_json::json!({
316            "success": true,
317            "deleted_facts": count,
318        }))
319        .into_response(),
320        Err(e) => err(StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(),
321    }
322}
323
324/// GET /health
325async fn health_handler() -> impl IntoResponse {
326    Json(serde_json::json!({ "status": "ok", "service": "engram" }))
327}
328
329// ---------------------------------------------------------------------------
330// Router
331// ---------------------------------------------------------------------------
332
333/// Build the Axum router with all REST endpoints.
334pub fn build_router(state: AppState) -> Router {
335    Router::new()
336        .route("/health", get(health_handler))
337        .route("/v1/memory", post(add_handler))
338        .route("/v1/memory/recall", get(recall_handler))
339        .route("/v1/memory/context", post(context_handler))
340        .route("/v1/memory/facts/:id", delete(forget_handler))
341        .route("/v1/memory/search", get(search_handler))
342        .route("/v1/memory/stats", get(stats_handler))
343        .route("/v1/memory/consolidate", post(consolidate_handler))
344        .route("/v1/memory/users/:id", delete(delete_user_handler))
345        .with_state(state)
346}