Skip to main content

engram_server/
rest.rs

1//! Axum REST API for Engram memory operations.
2
3use crate::config::LlmBackend;
4use axum::extract::{Path, Query, State};
5use axum::http::StatusCode;
6use axum::response::IntoResponse;
7use axum::routing::{delete, get, post};
8use axum::{Json, Router};
9use engram::context::{ContextConfig, OutputFormat};
10use engram::extract::{ExtractionConfig, Message};
11use engram::memory::{Memory, RecallQuery};
12use engram::message::ChatMessage;
13use engram::scope::Scope;
14use serde::{Deserialize, Serialize};
15use std::sync::Arc;
16
17// ---------------------------------------------------------------------------
18// App state
19// ---------------------------------------------------------------------------
20
21#[derive(Clone)]
22pub struct AppState {
23    pub memory: Arc<Memory>,
24    pub llm_backend: LlmBackend,
25    pub extract_on_save: bool,
26}
27
28// ---------------------------------------------------------------------------
29// Request / Response types
30// ---------------------------------------------------------------------------
31
32#[derive(Deserialize)]
33pub struct AddRequest {
34    pub messages: Vec<MessagePayload>,
35    pub user_id: Option<String>,
36    pub org_id: Option<String>,
37    pub session_id: Option<String>,
38}
39
40#[derive(Deserialize)]
41pub struct MessagePayload {
42    pub role: String,
43    pub content: String,
44}
45
46#[derive(Deserialize)]
47pub struct RecallParams {
48    pub q: String,
49    pub user_id: Option<String>,
50    pub org_id: Option<String>,
51    pub max_results: Option<usize>,
52}
53
54#[derive(Deserialize)]
55pub struct ContextRequest {
56    pub query: String,
57    pub user_id: Option<String>,
58    pub org_id: Option<String>,
59    pub token_budget: Option<usize>,
60    pub format: Option<String>,
61}
62
63#[derive(Deserialize)]
64pub struct SearchParams {
65    pub q: String,
66    pub user_id: Option<String>,
67    pub org_id: Option<String>,
68    pub top_k: Option<usize>,
69}
70
71#[derive(Deserialize)]
72pub struct ForgetRequest {
73    pub reason: Option<String>,
74}
75
76#[derive(Deserialize)]
77pub struct ConsolidateRequest {
78    pub user_id: Option<String>,
79    pub org_id: Option<String>,
80}
81
82#[derive(Deserialize)]
83pub struct SaveMessagesRequest {
84    pub conversation_id: String,
85    pub messages: Vec<MessageInput>,
86    pub user_id: Option<String>,
87    pub org_id: Option<String>,
88}
89
90#[derive(Deserialize)]
91pub struct MessageInput {
92    pub role: String,
93    pub content: String,
94    #[serde(default)]
95    pub metadata: Option<serde_json::Map<String, serde_json::Value>>,
96}
97
98#[derive(Deserialize)]
99pub struct GetMessagesParams {
100    pub last_n: Option<usize>,
101    pub user_id: Option<String>,
102    pub org_id: Option<String>,
103}
104
105#[derive(Deserialize)]
106pub struct ListConversationsParams {
107    pub user_id: Option<String>,
108    pub org_id: Option<String>,
109}
110
111#[derive(Deserialize)]
112pub struct DeleteMessagesParams {
113    pub user_id: Option<String>,
114    pub org_id: Option<String>,
115}
116
117#[derive(Serialize)]
118struct ErrorResponse {
119    error: String,
120}
121
122// ---------------------------------------------------------------------------
123// Helpers
124// ---------------------------------------------------------------------------
125
126fn parse_scope(org_id: Option<&str>, user_id: Option<&str>, session_id: Option<&str>) -> Scope {
127    let org = org_id.unwrap_or("default");
128    match user_id {
129        Some(uid) => match session_id {
130            Some(sid) => Scope::session(org, uid, sid),
131            None => Scope::user(org, uid),
132        },
133        None => Scope::org(org),
134    }
135}
136
137fn err(status: StatusCode, msg: impl Into<String>) -> (StatusCode, Json<ErrorResponse>) {
138    (status, Json(ErrorResponse { error: msg.into() }))
139}
140
141// ---------------------------------------------------------------------------
142// Handlers
143// ---------------------------------------------------------------------------
144
145/// POST /v1/memory
146async fn add_handler(
147    State(state): State<AppState>,
148    Json(body): Json<AddRequest>,
149) -> impl IntoResponse {
150    let messages: Vec<Message> = body
151        .messages
152        .iter()
153        .map(|m| Message {
154            role: m.role.clone(),
155            content: m.content.clone(),
156        })
157        .collect();
158
159    if messages.is_empty() {
160        return err(StatusCode::BAD_REQUEST, "messages must not be empty").into_response();
161    }
162
163    let scope = parse_scope(
164        body.org_id.as_deref(),
165        body.user_id.as_deref(),
166        body.session_id.as_deref(),
167    );
168
169    match state
170        .memory
171        .add_messages(
172            &messages,
173            scope,
174            state.llm_backend.build(),
175            ExtractionConfig::default(),
176        )
177        .await
178    {
179        Ok(ids) => {
180            let fact_ids: Vec<String> = ids.iter().map(|id| id.to_string()).collect();
181            (
182                StatusCode::CREATED,
183                Json(serde_json::json!({
184                    "success": true,
185                    "fact_count": ids.len(),
186                    "fact_ids": fact_ids,
187                })),
188            )
189                .into_response()
190        }
191        Err(e) => err(StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(),
192    }
193}
194
195/// GET /v1/memory/recall?q=...
196async fn recall_handler(
197    State(state): State<AppState>,
198    Query(params): Query<RecallParams>,
199) -> impl IntoResponse {
200    let scope = parse_scope(params.org_id.as_deref(), params.user_id.as_deref(), None);
201
202    let query = RecallQuery {
203        query: params.q,
204        scope: Some(scope),
205        max_results: params.max_results.unwrap_or(10),
206        as_of: None,
207        min_score: None,
208    };
209
210    match state.memory.recall(&query).await {
211        Ok(facts) => {
212            let results: Vec<serde_json::Value> = facts
213                .iter()
214                .map(|f| {
215                    serde_json::json!({
216                        "fact_id": f.id.to_string(),
217                        "text": f.text,
218                        "tier": f.tier,
219                        "category": f.category,
220                        "confidence": f.confidence,
221                    })
222                })
223                .collect();
224            Json(serde_json::json!({ "results": results, "total": results.len() })).into_response()
225        }
226        Err(e) => err(StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(),
227    }
228}
229
230/// POST /v1/memory/context
231async fn context_handler(
232    State(state): State<AppState>,
233    Json(body): Json<ContextRequest>,
234) -> impl IntoResponse {
235    let scope = parse_scope(body.org_id.as_deref(), body.user_id.as_deref(), None);
236
237    let format = match body.format.as_deref() {
238        Some("markdown") => OutputFormat::Markdown,
239        Some("raw") => OutputFormat::Raw,
240        _ => OutputFormat::SystemPrompt,
241    };
242
243    let config = ContextConfig {
244        token_budget: body.token_budget.unwrap_or(2000),
245        format,
246        ..Default::default()
247    };
248
249    match state.memory.context(&body.query, &scope, config).await {
250        Ok(block) => Json(serde_json::json!({
251            "text": block.text,
252            "token_count": block.token_count,
253            "facts_included": block.facts_included,
254            "facts_omitted": block.facts_omitted,
255            "tier_breakdown": block.tier_breakdown,
256        }))
257        .into_response(),
258        Err(e) => err(StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(),
259    }
260}
261
262/// DELETE /v1/memory/facts/:id
263async fn forget_handler(
264    State(state): State<AppState>,
265    Path(fact_id): Path<String>,
266    body: Option<Json<ForgetRequest>>,
267) -> impl IntoResponse {
268    let id = match uuid::Uuid::parse_str(&fact_id) {
269        Ok(id) => id,
270        Err(e) => {
271            return err(StatusCode::BAD_REQUEST, format!("invalid fact_id: {e}")).into_response()
272        }
273    };
274
275    let reason = body.and_then(|b| b.reason.clone());
276
277    match state.memory.forget(id, reason.as_deref()).await {
278        Ok(()) => Json(serde_json::json!({
279            "success": true,
280            "deleted_fact_id": fact_id,
281        }))
282        .into_response(),
283        Err(e) => err(StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(),
284    }
285}
286
287/// GET /v1/memory/search?q=...
288async fn search_handler(
289    State(state): State<AppState>,
290    Query(params): Query<SearchParams>,
291) -> impl IntoResponse {
292    let scope = parse_scope(params.org_id.as_deref(), params.user_id.as_deref(), None);
293    let top_k = params.top_k.unwrap_or(10);
294
295    match state
296        .memory
297        .fact_store()
298        .keyword_search(&params.q, &scope, top_k)
299        .await
300    {
301        Ok(facts) => {
302            let results: Vec<serde_json::Value> = facts
303                .iter()
304                .map(|f| {
305                    serde_json::json!({
306                        "fact_id": f.id.to_string(),
307                        "text": f.text,
308                        "tier": f.tier,
309                        "category": f.category,
310                    })
311                })
312                .collect();
313            Json(serde_json::json!({ "results": results, "total": results.len() })).into_response()
314        }
315        Err(e) => err(StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(),
316    }
317}
318
319/// GET /v1/memory/stats
320async fn stats_handler(State(state): State<AppState>) -> impl IntoResponse {
321    match state.memory.stats(None).await {
322        Ok(stats) => Json(serde_json::json!({
323            "total_facts": stats.total_facts,
324            "valid_facts": stats.valid_facts,
325            "invalidated_facts": stats.invalidated_facts,
326            "total_entities": stats.total_entities,
327            "total_relationships": stats.total_relationships,
328        }))
329        .into_response(),
330        Err(e) => err(StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(),
331    }
332}
333
334/// POST /v1/memory/consolidate
335async fn consolidate_handler(
336    State(state): State<AppState>,
337    Json(body): Json<ConsolidateRequest>,
338) -> impl IntoResponse {
339    let scope = parse_scope(body.org_id.as_deref(), body.user_id.as_deref(), None);
340    let config = engram::consolidation::ConsolidationConfig::default();
341
342    match state.memory.consolidate(&scope, None, config).await {
343        Ok(result) => Json(serde_json::json!(result)).into_response(),
344        Err(e) => err(StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(),
345    }
346}
347
348/// DELETE /v1/memory/users/:id
349async fn delete_user_handler(
350    State(state): State<AppState>,
351    Path(user_id): Path<String>,
352) -> impl IntoResponse {
353    let scope = Scope::user("default", &user_id);
354
355    match state.memory.delete_user_data(scope).await {
356        Ok(count) => Json(serde_json::json!({
357            "success": true,
358            "deleted_facts": count,
359        }))
360        .into_response(),
361        Err(e) => err(StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(),
362    }
363}
364
365/// POST /v1/memory/messages
366async fn save_messages_handler(
367    State(state): State<AppState>,
368    Json(body): Json<SaveMessagesRequest>,
369) -> impl IntoResponse {
370    if body.messages.is_empty() {
371        return err(StatusCode::BAD_REQUEST, "messages must not be empty").into_response();
372    }
373
374    let scope = parse_scope(body.org_id.as_deref(), body.user_id.as_deref(), None);
375
376    let chat_messages: Vec<ChatMessage> = body
377        .messages
378        .iter()
379        .enumerate()
380        .map(|(i, m)| {
381            let mut msg = ChatMessage::new(
382                &body.conversation_id,
383                &m.role,
384                &m.content,
385                scope.clone(),
386                i as i32,
387            );
388            if let Some(ref meta) = m.metadata {
389                msg.metadata = meta.clone();
390            }
391            msg
392        })
393        .collect();
394
395    let message_ids = match state
396        .memory
397        .save_chat_messages(&body.conversation_id, &chat_messages, &scope)
398        .await
399    {
400        Ok(ids) => ids,
401        Err(e) => return err(StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(),
402    };
403
404    let message_id_strs: Vec<String> = message_ids.iter().map(|id| id.to_string()).collect();
405
406    // Optionally extract facts from the saved messages.
407    let fact_ids = if state.extract_on_save {
408        let extract_messages: Vec<Message> = body
409            .messages
410            .iter()
411            .map(|m| Message {
412                role: m.role.clone(),
413                content: m.content.clone(),
414            })
415            .collect();
416
417        match state
418            .memory
419            .add_messages(
420                &extract_messages,
421                scope,
422                state.llm_backend.build(),
423                ExtractionConfig::default(),
424            )
425            .await
426        {
427            Ok(ids) => Some(ids.iter().map(|id| id.to_string()).collect::<Vec<_>>()),
428            Err(e) => {
429                tracing::warn!("fact extraction failed (messages saved): {e}");
430                None
431            }
432        }
433    } else {
434        None
435    };
436
437    (
438        StatusCode::CREATED,
439        Json(serde_json::json!({
440            "success": true,
441            "message_ids": message_id_strs,
442            "fact_ids": fact_ids,
443        })),
444    )
445        .into_response()
446}
447
448/// GET /v1/memory/messages/{conversation_id}
449async fn get_messages_handler(
450    State(state): State<AppState>,
451    Path(conversation_id): Path<String>,
452    Query(params): Query<GetMessagesParams>,
453) -> impl IntoResponse {
454    let scope = parse_scope(params.org_id.as_deref(), params.user_id.as_deref(), None);
455
456    match state
457        .memory
458        .get_chat_messages(&conversation_id, params.last_n, &scope)
459        .await
460    {
461        Ok(messages) => {
462            let results: Vec<serde_json::Value> = messages
463                .iter()
464                .map(|m| {
465                    serde_json::json!({
466                        "id": m.id.to_string(),
467                        "conversation_id": m.conversation_id,
468                        "role": m.role,
469                        "content": m.content,
470                        "seq": m.seq,
471                        "created_at": m.created_at.to_rfc3339(),
472                        "metadata": m.metadata,
473                    })
474                })
475                .collect();
476            Json(serde_json::json!({ "messages": results, "total": results.len() })).into_response()
477        }
478        Err(e) => err(StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(),
479    }
480}
481
482/// GET /v1/memory/messages
483async fn list_conversations_handler(
484    State(state): State<AppState>,
485    Query(params): Query<ListConversationsParams>,
486) -> impl IntoResponse {
487    let scope = parse_scope(params.org_id.as_deref(), params.user_id.as_deref(), None);
488
489    match state.memory.list_conversations(&scope).await {
490        Ok(ids) => {
491            Json(serde_json::json!({ "conversation_ids": ids, "total": ids.len() })).into_response()
492        }
493        Err(e) => err(StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(),
494    }
495}
496
497/// DELETE /v1/memory/messages/{conversation_id}
498async fn delete_messages_handler(
499    State(state): State<AppState>,
500    Path(conversation_id): Path<String>,
501    Query(params): Query<DeleteMessagesParams>,
502) -> impl IntoResponse {
503    let scope = parse_scope(params.org_id.as_deref(), params.user_id.as_deref(), None);
504
505    match state
506        .memory
507        .delete_chat_messages(&conversation_id, &scope)
508        .await
509    {
510        Ok(count) => Json(serde_json::json!({
511            "success": true,
512            "deleted_count": count,
513        }))
514        .into_response(),
515        Err(e) => err(StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(),
516    }
517}
518
519/// GET /health
520async fn health_handler() -> impl IntoResponse {
521    Json(serde_json::json!({ "status": "ok", "service": "engram" }))
522}
523
524// ---------------------------------------------------------------------------
525// Router
526// ---------------------------------------------------------------------------
527
528/// Build the Axum router with all REST endpoints.
529pub fn build_router(state: AppState) -> Router {
530    Router::new()
531        .route("/health", get(health_handler))
532        .route("/v1/memory", post(add_handler))
533        .route("/v1/memory/recall", get(recall_handler))
534        .route("/v1/memory/context", post(context_handler))
535        .route("/v1/memory/facts/:id", delete(forget_handler))
536        .route("/v1/memory/search", get(search_handler))
537        .route("/v1/memory/stats", get(stats_handler))
538        .route("/v1/memory/consolidate", post(consolidate_handler))
539        .route("/v1/memory/users/:id", delete(delete_user_handler))
540        .route(
541            "/v1/memory/messages",
542            post(save_messages_handler).get(list_conversations_handler),
543        )
544        .route(
545            "/v1/memory/messages/{conversation_id}",
546            get(get_messages_handler).delete(delete_messages_handler),
547        )
548        .with_state(state)
549}