1use crate::config::LlmBackend;
4use axum::extract::{Path, Query, State};
5use axum::http::StatusCode;
6use axum::response::IntoResponse;
7use axum::routing::{delete, get, post};
8use axum::{Json, Router};
9use engram::context::{ContextConfig, OutputFormat};
10use engram::extract::{ExtractionConfig, Message};
11use engram::memory::{Memory, RecallQuery};
12use engram::message::ChatMessage;
13use engram::scope::Scope;
14use serde::{Deserialize, Serialize};
15use std::sync::Arc;
16
17#[derive(Clone)]
22pub struct AppState {
23 pub memory: Arc<Memory>,
24 pub llm_backend: LlmBackend,
25 pub extract_on_save: bool,
26}
27
28#[derive(Deserialize)]
33pub struct AddRequest {
34 pub messages: Vec<MessagePayload>,
35 pub user_id: Option<String>,
36 pub org_id: Option<String>,
37 pub session_id: Option<String>,
38}
39
40#[derive(Deserialize)]
41pub struct MessagePayload {
42 pub role: String,
43 pub content: String,
44}
45
46#[derive(Deserialize)]
47pub struct RecallParams {
48 pub q: String,
49 pub user_id: Option<String>,
50 pub org_id: Option<String>,
51 pub max_results: Option<usize>,
52}
53
54#[derive(Deserialize)]
55pub struct ContextRequest {
56 pub query: String,
57 pub user_id: Option<String>,
58 pub org_id: Option<String>,
59 pub token_budget: Option<usize>,
60 pub format: Option<String>,
61}
62
63#[derive(Deserialize)]
64pub struct SearchParams {
65 pub q: String,
66 pub user_id: Option<String>,
67 pub org_id: Option<String>,
68 pub top_k: Option<usize>,
69}
70
71#[derive(Deserialize)]
72pub struct ForgetRequest {
73 pub reason: Option<String>,
74}
75
76#[derive(Deserialize)]
77pub struct ConsolidateRequest {
78 pub user_id: Option<String>,
79 pub org_id: Option<String>,
80}
81
82#[derive(Deserialize)]
83pub struct SaveMessagesRequest {
84 pub conversation_id: String,
85 pub messages: Vec<MessageInput>,
86 pub user_id: Option<String>,
87 pub org_id: Option<String>,
88}
89
90#[derive(Deserialize)]
91pub struct MessageInput {
92 pub role: String,
93 pub content: String,
94 #[serde(default)]
95 pub metadata: Option<serde_json::Map<String, serde_json::Value>>,
96}
97
98#[derive(Deserialize)]
99pub struct GetMessagesParams {
100 pub last_n: Option<usize>,
101 pub user_id: Option<String>,
102 pub org_id: Option<String>,
103}
104
105#[derive(Deserialize)]
106pub struct ListConversationsParams {
107 pub user_id: Option<String>,
108 pub org_id: Option<String>,
109}
110
111#[derive(Deserialize)]
112pub struct DeleteMessagesParams {
113 pub user_id: Option<String>,
114 pub org_id: Option<String>,
115}
116
117#[derive(Serialize)]
118struct ErrorResponse {
119 error: String,
120}
121
122fn parse_scope(org_id: Option<&str>, user_id: Option<&str>, session_id: Option<&str>) -> Scope {
127 let org = org_id.unwrap_or("default");
128 match user_id {
129 Some(uid) => match session_id {
130 Some(sid) => Scope::session(org, uid, sid),
131 None => Scope::user(org, uid),
132 },
133 None => Scope::org(org),
134 }
135}
136
137fn err(status: StatusCode, msg: impl Into<String>) -> (StatusCode, Json<ErrorResponse>) {
138 (status, Json(ErrorResponse { error: msg.into() }))
139}
140
141async fn add_handler(
147 State(state): State<AppState>,
148 Json(body): Json<AddRequest>,
149) -> impl IntoResponse {
150 let messages: Vec<Message> = body
151 .messages
152 .iter()
153 .map(|m| Message {
154 role: m.role.clone(),
155 content: m.content.clone(),
156 })
157 .collect();
158
159 if messages.is_empty() {
160 return err(StatusCode::BAD_REQUEST, "messages must not be empty").into_response();
161 }
162
163 let scope = parse_scope(
164 body.org_id.as_deref(),
165 body.user_id.as_deref(),
166 body.session_id.as_deref(),
167 );
168
169 match state
170 .memory
171 .add_messages(
172 &messages,
173 scope,
174 state.llm_backend.build(),
175 ExtractionConfig::default(),
176 )
177 .await
178 {
179 Ok(ids) => {
180 let fact_ids: Vec<String> = ids.iter().map(|id| id.to_string()).collect();
181 (
182 StatusCode::CREATED,
183 Json(serde_json::json!({
184 "success": true,
185 "fact_count": ids.len(),
186 "fact_ids": fact_ids,
187 })),
188 )
189 .into_response()
190 }
191 Err(e) => err(StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(),
192 }
193}
194
195async fn recall_handler(
197 State(state): State<AppState>,
198 Query(params): Query<RecallParams>,
199) -> impl IntoResponse {
200 let scope = parse_scope(params.org_id.as_deref(), params.user_id.as_deref(), None);
201
202 let query = RecallQuery {
203 query: params.q,
204 scope: Some(scope),
205 max_results: params.max_results.unwrap_or(10),
206 as_of: None,
207 min_score: None,
208 };
209
210 match state.memory.recall(&query).await {
211 Ok(facts) => {
212 let results: Vec<serde_json::Value> = facts
213 .iter()
214 .map(|f| {
215 serde_json::json!({
216 "fact_id": f.id.to_string(),
217 "text": f.text,
218 "tier": f.tier,
219 "category": f.category,
220 "confidence": f.confidence,
221 })
222 })
223 .collect();
224 Json(serde_json::json!({ "results": results, "total": results.len() })).into_response()
225 }
226 Err(e) => err(StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(),
227 }
228}
229
230async fn context_handler(
232 State(state): State<AppState>,
233 Json(body): Json<ContextRequest>,
234) -> impl IntoResponse {
235 let scope = parse_scope(body.org_id.as_deref(), body.user_id.as_deref(), None);
236
237 let format = match body.format.as_deref() {
238 Some("markdown") => OutputFormat::Markdown,
239 Some("raw") => OutputFormat::Raw,
240 _ => OutputFormat::SystemPrompt,
241 };
242
243 let config = ContextConfig {
244 token_budget: body.token_budget.unwrap_or(2000),
245 format,
246 ..Default::default()
247 };
248
249 match state.memory.context(&body.query, &scope, config).await {
250 Ok(block) => Json(serde_json::json!({
251 "text": block.text,
252 "token_count": block.token_count,
253 "facts_included": block.facts_included,
254 "facts_omitted": block.facts_omitted,
255 "tier_breakdown": block.tier_breakdown,
256 }))
257 .into_response(),
258 Err(e) => err(StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(),
259 }
260}
261
262async fn forget_handler(
264 State(state): State<AppState>,
265 Path(fact_id): Path<String>,
266 body: Option<Json<ForgetRequest>>,
267) -> impl IntoResponse {
268 let id = match uuid::Uuid::parse_str(&fact_id) {
269 Ok(id) => id,
270 Err(e) => {
271 return err(StatusCode::BAD_REQUEST, format!("invalid fact_id: {e}")).into_response()
272 }
273 };
274
275 let reason = body.and_then(|b| b.reason.clone());
276
277 match state.memory.forget(id, reason.as_deref()).await {
278 Ok(()) => Json(serde_json::json!({
279 "success": true,
280 "deleted_fact_id": fact_id,
281 }))
282 .into_response(),
283 Err(e) => err(StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(),
284 }
285}
286
287async fn search_handler(
289 State(state): State<AppState>,
290 Query(params): Query<SearchParams>,
291) -> impl IntoResponse {
292 let scope = parse_scope(params.org_id.as_deref(), params.user_id.as_deref(), None);
293 let top_k = params.top_k.unwrap_or(10);
294
295 match state
296 .memory
297 .fact_store()
298 .keyword_search(¶ms.q, &scope, top_k)
299 .await
300 {
301 Ok(facts) => {
302 let results: Vec<serde_json::Value> = facts
303 .iter()
304 .map(|f| {
305 serde_json::json!({
306 "fact_id": f.id.to_string(),
307 "text": f.text,
308 "tier": f.tier,
309 "category": f.category,
310 })
311 })
312 .collect();
313 Json(serde_json::json!({ "results": results, "total": results.len() })).into_response()
314 }
315 Err(e) => err(StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(),
316 }
317}
318
319async fn stats_handler(State(state): State<AppState>) -> impl IntoResponse {
321 match state.memory.stats(None).await {
322 Ok(stats) => Json(serde_json::json!({
323 "total_facts": stats.total_facts,
324 "valid_facts": stats.valid_facts,
325 "invalidated_facts": stats.invalidated_facts,
326 "total_entities": stats.total_entities,
327 "total_relationships": stats.total_relationships,
328 }))
329 .into_response(),
330 Err(e) => err(StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(),
331 }
332}
333
334async fn consolidate_handler(
336 State(state): State<AppState>,
337 Json(body): Json<ConsolidateRequest>,
338) -> impl IntoResponse {
339 let scope = parse_scope(body.org_id.as_deref(), body.user_id.as_deref(), None);
340 let config = engram::consolidation::ConsolidationConfig::default();
341
342 match state.memory.consolidate(&scope, None, config).await {
343 Ok(result) => Json(serde_json::json!(result)).into_response(),
344 Err(e) => err(StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(),
345 }
346}
347
348async fn delete_user_handler(
350 State(state): State<AppState>,
351 Path(user_id): Path<String>,
352) -> impl IntoResponse {
353 let scope = Scope::user("default", &user_id);
354
355 match state.memory.delete_user_data(scope).await {
356 Ok(count) => Json(serde_json::json!({
357 "success": true,
358 "deleted_facts": count,
359 }))
360 .into_response(),
361 Err(e) => err(StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(),
362 }
363}
364
365async fn save_messages_handler(
367 State(state): State<AppState>,
368 Json(body): Json<SaveMessagesRequest>,
369) -> impl IntoResponse {
370 if body.messages.is_empty() {
371 return err(StatusCode::BAD_REQUEST, "messages must not be empty").into_response();
372 }
373
374 let scope = parse_scope(body.org_id.as_deref(), body.user_id.as_deref(), None);
375
376 let chat_messages: Vec<ChatMessage> = body
377 .messages
378 .iter()
379 .enumerate()
380 .map(|(i, m)| {
381 let mut msg = ChatMessage::new(
382 &body.conversation_id,
383 &m.role,
384 &m.content,
385 scope.clone(),
386 i as i32,
387 );
388 if let Some(ref meta) = m.metadata {
389 msg.metadata = meta.clone();
390 }
391 msg
392 })
393 .collect();
394
395 let message_ids = match state
396 .memory
397 .save_chat_messages(&body.conversation_id, &chat_messages, &scope)
398 .await
399 {
400 Ok(ids) => ids,
401 Err(e) => return err(StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(),
402 };
403
404 let message_id_strs: Vec<String> = message_ids.iter().map(|id| id.to_string()).collect();
405
406 let fact_ids = if state.extract_on_save {
408 let extract_messages: Vec<Message> = body
409 .messages
410 .iter()
411 .map(|m| Message {
412 role: m.role.clone(),
413 content: m.content.clone(),
414 })
415 .collect();
416
417 match state
418 .memory
419 .add_messages(
420 &extract_messages,
421 scope,
422 state.llm_backend.build(),
423 ExtractionConfig::default(),
424 )
425 .await
426 {
427 Ok(ids) => Some(ids.iter().map(|id| id.to_string()).collect::<Vec<_>>()),
428 Err(e) => {
429 tracing::warn!("fact extraction failed (messages saved): {e}");
430 None
431 }
432 }
433 } else {
434 None
435 };
436
437 (
438 StatusCode::CREATED,
439 Json(serde_json::json!({
440 "success": true,
441 "message_ids": message_id_strs,
442 "fact_ids": fact_ids,
443 })),
444 )
445 .into_response()
446}
447
448async fn get_messages_handler(
450 State(state): State<AppState>,
451 Path(conversation_id): Path<String>,
452 Query(params): Query<GetMessagesParams>,
453) -> impl IntoResponse {
454 let scope = parse_scope(params.org_id.as_deref(), params.user_id.as_deref(), None);
455
456 match state
457 .memory
458 .get_chat_messages(&conversation_id, params.last_n, &scope)
459 .await
460 {
461 Ok(messages) => {
462 let results: Vec<serde_json::Value> = messages
463 .iter()
464 .map(|m| {
465 serde_json::json!({
466 "id": m.id.to_string(),
467 "conversation_id": m.conversation_id,
468 "role": m.role,
469 "content": m.content,
470 "seq": m.seq,
471 "created_at": m.created_at.to_rfc3339(),
472 "metadata": m.metadata,
473 })
474 })
475 .collect();
476 Json(serde_json::json!({ "messages": results, "total": results.len() })).into_response()
477 }
478 Err(e) => err(StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(),
479 }
480}
481
482async fn list_conversations_handler(
484 State(state): State<AppState>,
485 Query(params): Query<ListConversationsParams>,
486) -> impl IntoResponse {
487 let scope = parse_scope(params.org_id.as_deref(), params.user_id.as_deref(), None);
488
489 match state.memory.list_conversations(&scope).await {
490 Ok(ids) => {
491 Json(serde_json::json!({ "conversation_ids": ids, "total": ids.len() })).into_response()
492 }
493 Err(e) => err(StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(),
494 }
495}
496
497async fn delete_messages_handler(
499 State(state): State<AppState>,
500 Path(conversation_id): Path<String>,
501 Query(params): Query<DeleteMessagesParams>,
502) -> impl IntoResponse {
503 let scope = parse_scope(params.org_id.as_deref(), params.user_id.as_deref(), None);
504
505 match state
506 .memory
507 .delete_chat_messages(&conversation_id, &scope)
508 .await
509 {
510 Ok(count) => Json(serde_json::json!({
511 "success": true,
512 "deleted_count": count,
513 }))
514 .into_response(),
515 Err(e) => err(StatusCode::INTERNAL_SERVER_ERROR, e.to_string()).into_response(),
516 }
517}
518
519async fn health_handler() -> impl IntoResponse {
521 Json(serde_json::json!({ "status": "ok", "service": "engram" }))
522}
523
524pub fn build_router(state: AppState) -> Router {
530 Router::new()
531 .route("/health", get(health_handler))
532 .route("/v1/memory", post(add_handler))
533 .route("/v1/memory/recall", get(recall_handler))
534 .route("/v1/memory/context", post(context_handler))
535 .route("/v1/memory/facts/:id", delete(forget_handler))
536 .route("/v1/memory/search", get(search_handler))
537 .route("/v1/memory/stats", get(stats_handler))
538 .route("/v1/memory/consolidate", post(consolidate_handler))
539 .route("/v1/memory/users/:id", delete(delete_user_handler))
540 .route(
541 "/v1/memory/messages",
542 post(save_messages_handler).get(list_conversations_handler),
543 )
544 .route(
545 "/v1/memory/messages/{conversation_id}",
546 get(get_messages_handler).delete(delete_messages_handler),
547 )
548 .with_state(state)
549}