use axum::{
extract::{Path, State},
http::StatusCode,
Json,
};
use serde::{Deserialize, Serialize};
use ineru::{MemoryEntry, MemoryId, MemoryQuery};
use crate::error::{Error, Result};
use crate::state::AppState;
#[derive(Debug, Deserialize)]
pub struct RememberRequest {
pub entry_type: String,
pub data: serde_json::Value,
#[serde(default)]
pub tags: Vec<String>,
#[serde(default = "default_importance")]
pub importance: f32,
pub embedding: Option<Vec<f32>>,
}
fn default_importance() -> f32 {
0.7
}
#[derive(Debug, Serialize)]
pub struct RememberResponse {
pub id: String,
}
#[derive(Debug, Deserialize)]
pub struct RecallRequest {
pub text: Option<String>,
#[serde(default)]
pub tags: Vec<String>,
pub entry_type: Option<String>,
pub min_importance: Option<f32>,
pub limit: Option<usize>,
}
#[derive(Debug, Serialize)]
pub struct MemoryResultDto {
pub id: String,
pub entry_type: String,
pub data: serde_json::Value,
pub tags: Vec<String>,
pub importance: f32,
pub relevance: f32,
pub source: String,
pub created_at: String,
pub last_accessed: String,
pub access_count: u32,
}
#[derive(Debug, Serialize)]
pub struct ConsolidateResponse {
pub consolidated: usize,
}
#[derive(Debug, Serialize)]
pub struct MemoryStatsDto {
pub stm_count: usize,
pub stm_capacity: usize,
pub ltm_entity_count: usize,
pub ltm_link_count: usize,
pub total_memory_bytes: usize,
}
#[derive(Debug, Deserialize)]
pub struct CheckpointRequest {
pub label: Option<String>,
}
#[derive(Debug, Serialize)]
pub struct CheckpointResponse {
#[serde(rename = "checkpointId")]
pub checkpoint_id: String,
}
#[derive(Debug, Serialize)]
pub struct CheckpointListDto {
pub id: String,
pub label: Option<String>,
pub created_at: String,
pub stm_count: usize,
pub ltm_entity_count: usize,
}
pub async fn remember(
State(state): State<AppState>,
Json(req): Json<RememberRequest>,
) -> Result<(StatusCode, Json<RememberResponse>)> {
#[cfg(feature = "cluster")]
if let Some(ref raft) = state.raft {
let raft_req = aingle_raft::CortexRequest {
kind: aingle_wal::WalEntryKind::MemoryStore {
memory_id: String::new(), entry_type: req.entry_type.clone(),
data: req.data.clone(),
importance: req.importance,
},
};
let resp = raft
.client_write(raft_req)
.await
.map_err(|e| handle_raft_write_error(e, &state))?;
if !resp.response().success {
return Err(Error::Internal(
resp.response()
.detail
.clone()
.unwrap_or_else(|| "Raft memory store failed".to_string()),
));
}
let id = resp
.response()
.id
.clone()
.unwrap_or_else(|| "raft".to_string());
return Ok((
StatusCode::CREATED,
Json(RememberResponse { id }),
));
}
#[cfg(feature = "cluster")]
if state.raft.is_some() {
return Err(Error::Internal("Raft initialized but write not routed through Raft".into()));
}
#[cfg(feature = "cluster")]
let wal_data = req.data.clone();
let mut entry = MemoryEntry::new(&req.entry_type, req.data);
if !req.tags.is_empty() {
let tag_refs: Vec<&str> = req.tags.iter().map(|s| s.as_str()).collect();
entry = entry.with_tags(&tag_refs);
}
entry = entry.with_importance(req.importance);
if let Some(emb) = req.embedding {
entry = entry.with_embedding(ineru::Embedding::new(emb));
}
let mut memory = state.memory.write().await;
let id = memory
.remember(entry)
.map_err(|e| Error::Internal(format!("Memory store failed: {e}")))?;
#[cfg(feature = "cluster")]
if let Some(ref wal) = state.wal {
wal.append(aingle_wal::WalEntryKind::MemoryStore {
memory_id: id.to_hex(),
entry_type: req.entry_type.clone(),
data: wal_data.clone(),
importance: req.importance,
}).map_err(|e| Error::Internal(format!("WAL append failed: {e}")))?;
}
Ok((
StatusCode::CREATED,
Json(RememberResponse {
id: id.to_hex(),
}),
))
}
pub async fn recall(
State(state): State<AppState>,
Json(req): Json<RecallRequest>,
) -> Result<Json<Vec<MemoryResultDto>>> {
let query = build_query(&req);
let memory = state.memory.read().await;
let results = memory
.recall(&query)
.map_err(|e| Error::Internal(format!("Memory recall failed: {e}")))?;
let dtos: Vec<MemoryResultDto> = results
.into_iter()
.map(|r| MemoryResultDto {
id: r.entry.id.to_hex(),
entry_type: r.entry.entry_type.clone(),
data: r.entry.data.clone(),
tags: r.entry.tags.iter().map(|t| t.0.clone()).collect(),
importance: r.entry.metadata.importance,
relevance: r.relevance,
source: match r.source {
ineru::types::MemorySource::ShortTerm => "ShortTerm".to_string(),
ineru::types::MemorySource::LongTerm => "LongTerm".to_string(),
},
created_at: r.entry.metadata.created_at.0.to_string(),
last_accessed: r.entry.metadata.last_accessed.0.to_string(),
access_count: r.entry.metadata.access_count,
})
.collect();
Ok(Json(dtos))
}
pub async fn consolidate(
State(state): State<AppState>,
) -> Result<Json<ConsolidateResponse>> {
#[cfg(feature = "cluster")]
if let Some(ref raft) = state.raft {
let raft_req = aingle_raft::CortexRequest {
kind: aingle_wal::WalEntryKind::MemoryConsolidate {
consolidated_count: 0, },
};
let resp = raft
.client_write(raft_req)
.await
.map_err(|e| handle_raft_write_error(e, &state))?;
if !resp.response().success {
return Err(Error::Internal(
resp.response()
.detail
.clone()
.unwrap_or_else(|| "Raft consolidate failed".to_string()),
));
}
let count: usize = resp
.response()
.detail
.as_ref()
.and_then(|d| d.parse().ok())
.unwrap_or(0);
return Ok(Json(ConsolidateResponse {
consolidated: count,
}));
}
#[cfg(feature = "cluster")]
if state.raft.is_some() {
return Err(Error::Internal("Raft initialized but write not routed through Raft".into()));
}
let mut memory = state.memory.write().await;
let count = memory
.consolidate()
.map_err(|e| Error::Internal(format!("Consolidation failed: {e}")))?;
#[cfg(feature = "cluster")]
if let Some(ref wal) = state.wal {
wal.append(aingle_wal::WalEntryKind::MemoryConsolidate {
consolidated_count: count,
}).map_err(|e| Error::Internal(format!("WAL append failed: {e}")))?;
}
Ok(Json(ConsolidateResponse {
consolidated: count,
}))
}
pub async fn stats(State(state): State<AppState>) -> Result<Json<MemoryStatsDto>> {
let memory = state.memory.read().await;
let s = memory.stats();
Ok(Json(MemoryStatsDto {
stm_count: s.stm_count,
stm_capacity: s.stm_capacity,
ltm_entity_count: s.ltm_entity_count,
ltm_link_count: s.ltm_link_count,
total_memory_bytes: s.total_memory_bytes,
}))
}
pub async fn forget(
State(state): State<AppState>,
Path(id): Path<String>,
) -> Result<StatusCode> {
#[cfg(feature = "cluster")]
if let Some(ref raft) = state.raft {
let raft_req = aingle_raft::CortexRequest {
kind: aingle_wal::WalEntryKind::MemoryForget {
memory_id: id.clone(),
},
};
let resp = raft
.client_write(raft_req)
.await
.map_err(|e| handle_raft_write_error(e, &state))?;
if !resp.response().success {
return Err(Error::Internal(
resp.response()
.detail
.clone()
.unwrap_or_else(|| "Raft forget failed".to_string()),
));
}
return Ok(StatusCode::NO_CONTENT);
}
#[cfg(feature = "cluster")]
if state.raft.is_some() {
return Err(Error::Internal("Raft initialized but write not routed through Raft".into()));
}
let memory_id = MemoryId::from_hex(&id)
.ok_or_else(|| Error::InvalidInput(format!("Invalid memory ID: {id}")))?;
let mut memory = state.memory.write().await;
memory
.forget(&memory_id)
.map_err(|e| Error::NotFound(format!("Memory not found: {e}")))?;
#[cfg(feature = "cluster")]
if let Some(ref wal) = state.wal {
wal.append(aingle_wal::WalEntryKind::MemoryForget {
memory_id: id.clone(),
}).map_err(|e| Error::Internal(format!("WAL append failed: {e}")))?;
}
Ok(StatusCode::NO_CONTENT)
}
pub async fn checkpoint(
State(state): State<AppState>,
Json(req): Json<CheckpointRequest>,
) -> Result<(StatusCode, Json<CheckpointResponse>)> {
let memory = state.memory.read().await;
let s = memory.stats();
let label = req.label.unwrap_or_else(|| {
format!("checkpoint-{}", chrono::Utc::now().timestamp())
});
let checkpoint_data = serde_json::json!({
"label": label,
"stm_count": s.stm_count,
"ltm_entity_count": s.ltm_entity_count,
"ltm_link_count": s.ltm_link_count,
"total_memory_bytes": s.total_memory_bytes,
"created_at": chrono::Utc::now().to_rfc3339(),
});
let proof_req = crate::proofs::SubmitProofRequest {
proof_type: crate::proofs::ProofType::Knowledge,
proof_data: checkpoint_data,
metadata: Some(crate::proofs::ProofMetadata {
submitter: Some("memory-system".to_string()),
tags: vec!["checkpoint".to_string(), "memory".to_string()],
extra: Default::default(),
}),
};
let proof_id = state
.proof_store
.submit(proof_req)
.await
.map_err(|e| Error::Internal(format!("Checkpoint creation failed: {e}")))?;
Ok((
StatusCode::CREATED,
Json(CheckpointResponse {
checkpoint_id: proof_id,
}),
))
}
pub async fn list_checkpoints(
State(state): State<AppState>,
) -> Result<Json<Vec<CheckpointListDto>>> {
let proofs = state
.proof_store
.list(Some(crate::proofs::ProofType::Knowledge))
.await;
let checkpoints: Vec<CheckpointListDto> = proofs
.into_iter()
.filter(|p| p.metadata.tags.contains(&"checkpoint".to_string()))
.map(|p| {
let data: serde_json::Value =
serde_json::from_slice(&p.data).unwrap_or_default();
CheckpointListDto {
id: p.id.clone(),
label: data.get("label").and_then(|v| v.as_str()).map(String::from),
created_at: p.created_at.to_rfc3339(),
stm_count: data
.get("stm_count")
.and_then(|v| v.as_u64())
.unwrap_or(0) as usize,
ltm_entity_count: data
.get("ltm_entity_count")
.and_then(|v| v.as_u64())
.unwrap_or(0) as usize,
}
})
.collect();
Ok(Json(checkpoints))
}
pub async fn restore_checkpoint(
State(state): State<AppState>,
Path(id): Path<String>,
) -> Result<StatusCode> {
let proof = state
.proof_store
.get(&id)
.await
.ok_or_else(|| Error::NotFound(format!("Checkpoint not found: {id}")))?;
if !proof.metadata.tags.contains(&"checkpoint".to_string()) {
return Err(Error::InvalidInput("Not a memory checkpoint".to_string()));
}
tracing::info!(checkpoint_id = %id, "Memory checkpoint acknowledged for restoration");
Ok(StatusCode::OK)
}
#[derive(Debug, Deserialize)]
pub struct VectorSearchRequest {
pub embedding: Vec<f32>,
pub k: usize,
#[serde(default = "default_min_similarity")]
pub min_similarity: f32,
pub entry_type: Option<String>,
pub tags: Option<Vec<String>>,
}
fn default_min_similarity() -> f32 {
0.0
}
#[derive(Debug, Serialize)]
pub struct VectorIndexStatsDto {
pub point_count: usize,
pub deleted_count: usize,
pub dimensions: usize,
pub memory_bytes: usize,
}
pub async fn vector_search(
State(state): State<AppState>,
Json(req): Json<VectorSearchRequest>,
) -> Result<Json<Vec<MemoryResultDto>>> {
let memory = state.memory.read().await;
let results = memory.ltm.vector_search_memories(&req.embedding, req.k, req.min_similarity);
let mut dtos: Vec<MemoryResultDto> = results
.into_iter()
.map(|(entry, similarity)| MemoryResultDto {
id: entry.id.to_hex(),
entry_type: entry.entry_type.clone(),
data: entry.data.clone(),
tags: entry.tags.iter().map(|t| t.0.clone()).collect(),
importance: entry.metadata.importance,
relevance: similarity,
source: "LongTerm".to_string(),
created_at: entry.metadata.created_at.0.to_string(),
last_accessed: entry.metadata.last_accessed.0.to_string(),
access_count: entry.metadata.access_count,
})
.collect();
if let Some(ref entry_type) = req.entry_type {
dtos.retain(|d| &d.entry_type == entry_type);
}
if let Some(ref tags) = req.tags {
if !tags.is_empty() {
dtos.retain(|d| tags.iter().any(|t| d.tags.contains(t)));
}
}
Ok(Json(dtos))
}
pub async fn vector_index_stats(
State(state): State<AppState>,
) -> Result<Json<VectorIndexStatsDto>> {
let memory = state.memory.read().await;
let stats = memory.ltm.hnsw_index()
.map(|idx| idx.stats())
.unwrap_or(ineru::hnsw::HnswStats {
point_count: 0,
deleted_count: 0,
dimensions: 0,
max_layer: 0,
memory_bytes: 0,
});
Ok(Json(VectorIndexStatsDto {
point_count: stats.point_count,
deleted_count: stats.deleted_count,
dimensions: stats.dimensions,
memory_bytes: stats.memory_bytes,
}))
}
pub async fn rebuild_vector_index(
State(state): State<AppState>,
) -> Result<StatusCode> {
let mut memory = state.memory.write().await;
if let Some(hnsw) = memory.ltm.hnsw_index_mut() {
hnsw.rebuild();
tracing::info!("HNSW index rebuilt, {} active points", hnsw.len());
}
Ok(StatusCode::OK)
}
#[cfg(feature = "cluster")]
use crate::rest::cluster_utils::handle_raft_write_error;
fn build_query(req: &RecallRequest) -> MemoryQuery {
let mut query = if let Some(text) = &req.text {
MemoryQuery::text(text)
} else if !req.tags.is_empty() {
let tag_refs: Vec<&str> = req.tags.iter().map(|s| s.as_str()).collect();
MemoryQuery::tags(&tag_refs)
} else {
MemoryQuery::text("")
};
if let Some(limit) = req.limit {
query = query.with_limit(limit);
}
if let Some(min_imp) = req.min_importance {
query = query.with_min_importance(min_imp);
}
if let Some(entry_type) = &req.entry_type {
query = MemoryQuery::entry_type(entry_type);
if let Some(limit) = req.limit {
query = query.with_limit(limit);
}
if let Some(min_imp) = req.min_importance {
query = query.with_min_importance(min_imp);
}
}
query
}
pub fn memory_router() -> axum::Router<AppState> {
use axum::routing::{delete, get, post};
axum::Router::new()
.route("/api/v1/memory/remember", post(remember))
.route("/api/v1/memory/recall", post(recall))
.route("/api/v1/memory/consolidate", post(consolidate))
.route("/api/v1/memory/stats", get(stats))
.route("/api/v1/memory/{id}", delete(forget))
.route("/api/v1/memory/checkpoint", post(checkpoint))
.route("/api/v1/memory/checkpoints", get(list_checkpoints))
.route("/api/v1/memory/restore/{id}", post(restore_checkpoint))
.route("/api/v1/memory/search", post(vector_search))
.route("/api/v1/memory/index/stats", get(vector_index_stats))
.route("/api/v1/memory/index/rebuild", post(rebuild_vector_index))
}