#![allow(unused_variables)]
use axum::{
extract::{Path, Query, State},
http::StatusCode,
Json,
};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use crate::api::models::{ApiError, ApiResponse};
use crate::api::server::AppState;
#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct MemoryMessage {
pub role: String,
pub content: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub function_call: Option<FunctionCall>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<ToolCall>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub metadata: Option<HashMap<String, serde_json::Value>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub timestamp: Option<String>,
}
#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct FunctionCall {
pub name: String,
pub arguments: String,
}
#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct ToolCall {
pub id: String,
#[serde(rename = "type")]
pub call_type: String,
pub function: FunctionCall,
}
#[derive(Debug, Deserialize)]
pub struct AddMessageRequest {
pub role: String,
pub content: String,
pub name: Option<String>,
pub function_call: Option<FunctionCall>,
pub tool_calls: Option<Vec<ToolCall>>,
pub metadata: Option<HashMap<String, serde_json::Value>>,
}
#[derive(Debug, Deserialize)]
pub struct AddMessagesRequest {
pub messages: Vec<AddMessageRequest>,
}
#[derive(Debug, Serialize)]
pub struct SessionInfo {
pub session_id: String,
pub message_count: usize,
pub token_count: usize,
pub metadata: HashMap<String, serde_json::Value>,
pub created_at: String,
pub updated_at: String,
}
#[derive(Debug, Deserialize)]
pub struct SearchMemoryRequest {
pub query: String,
#[serde(default = "default_limit")]
pub limit: usize,
pub role_filter: Option<Vec<String>>,
pub since: Option<String>,
pub until: Option<String>,
pub min_score: Option<f32>,
}
fn default_limit() -> usize {
10
}
#[derive(Debug, Serialize)]
pub struct MemorySearchResult {
pub message: MemoryMessage,
pub score: f32,
pub index: usize,
}
#[derive(Debug, Deserialize)]
pub struct GetMessagesQuery {
pub limit: Option<usize>,
pub offset: Option<usize>,
pub role: Option<String>,
pub order: Option<String>,
pub since: Option<String>,
}
#[derive(Debug, Deserialize)]
pub struct CreateSessionRequest {
pub session_id: Option<String>,
pub metadata: Option<HashMap<String, serde_json::Value>>,
pub token_limit: Option<usize>,
pub summarization: Option<String>,
}
#[derive(Debug, Deserialize)]
pub struct UpdateSessionRequest {
pub metadata: Option<HashMap<String, serde_json::Value>>,
pub token_limit: Option<usize>,
pub summarization: Option<String>,
}
#[derive(Debug, Deserialize)]
pub struct SummarizeRequest {
#[serde(default = "default_strategy")]
pub strategy: String,
pub keep_last: Option<usize>,
pub target_tokens: Option<usize>,
pub prompt: Option<String>,
}
fn default_strategy() -> String {
"rolling".to_string()
}
#[derive(Debug, Serialize)]
pub struct SummaryResponse {
pub summary: String,
pub original_count: usize,
pub new_count: usize,
pub tokens_saved: usize,
}
#[derive(Debug, Deserialize)]
pub struct ForkSessionRequest {
pub new_session_id: Option<String>,
pub from_index: Option<usize>,
#[serde(default = "default_true")]
pub include_summary: bool,
}
fn default_true() -> bool {
true
}
#[derive(Debug, Deserialize)]
pub struct ClearMessagesRequest {
pub keep_last: Option<usize>,
#[serde(default = "default_true")]
pub keep_system: bool,
}
#[derive(Debug, Deserialize)]
pub struct GetContextRequest {
pub max_tokens: usize,
#[serde(default = "default_true")]
pub include_system: bool,
#[serde(default = "default_true")]
pub include_summary: bool,
pub recency_bias: Option<f32>,
}
#[derive(Debug, Serialize)]
pub struct ContextResponse {
pub messages: Vec<MemoryMessage>,
pub token_count: usize,
pub truncated: bool,
pub summary_included: bool,
}
pub async fn list_sessions(
State(state): State<AppState>,
Query(params): Query<HashMap<String, String>>,
) -> Result<Json<ApiResponse<Vec<SessionInfo>>>, ApiError> {
let limit = params.get("limit")
.and_then(|s| s.parse().ok())
.unwrap_or(100);
let sessions = state.db.list_agent_sessions()
.map_err(|e| ApiError::internal(format!("Failed to list sessions: {}", e)))?;
let session_infos: Vec<SessionInfo> = sessions
.into_iter()
.map(|s| {
let metadata = match s.metadata {
serde_json::Value::Object(map) => map.into_iter().collect(),
_ => HashMap::new(),
};
SessionInfo {
session_id: s.session_id,
message_count: s.message_count as usize,
token_count: s.token_count as usize,
metadata,
created_at: s.created_at,
updated_at: s.updated_at,
}
})
.collect();
Ok(Json(ApiResponse::success(session_infos)))
}
pub async fn create_session(
State(state): State<AppState>,
Json(req): Json<CreateSessionRequest>,
) -> Result<(StatusCode, Json<ApiResponse<SessionInfo>>), ApiError> {
let session_id = req.session_id
.unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
let session = state.db.create_agent_session(
&session_id,
).map_err(|e| ApiError::internal(format!("Failed to create session: {}", e)))?;
let metadata = match session.metadata {
serde_json::Value::Object(map) => map.into_iter().collect(),
_ => HashMap::new(),
};
let info = SessionInfo {
session_id: session.session_id,
message_count: 0,
token_count: 0,
metadata,
created_at: session.created_at,
updated_at: session.updated_at,
};
Ok((StatusCode::CREATED, Json(ApiResponse::success(info))))
}
pub async fn get_session(
State(state): State<AppState>,
Path(session_id): Path<String>,
) -> Result<Json<ApiResponse<SessionInfo>>, ApiError> {
let session = state.db.get_agent_session(&session_id)
.map_err(|e| ApiError::not_found(format!("Session not found: {}", e)))?;
let metadata = match session.metadata {
serde_json::Value::Object(map) => map.into_iter().collect(),
_ => HashMap::new(),
};
let info = SessionInfo {
session_id: session.session_id,
message_count: session.message_count as usize,
token_count: session.token_count as usize,
metadata,
created_at: session.created_at,
updated_at: session.updated_at,
};
Ok(Json(ApiResponse::success(info)))
}
pub async fn update_session(
State(state): State<AppState>,
Path(session_id): Path<String>,
Json(_req): Json<UpdateSessionRequest>,
) -> Result<Json<ApiResponse<SessionInfo>>, ApiError> {
let session = state.db.get_agent_session(&session_id)
.map_err(|e| ApiError::not_found(format!("Session not found: {}", e)))?;
let metadata = match session.metadata {
serde_json::Value::Object(map) => map.into_iter().collect(),
_ => HashMap::new(),
};
let info = SessionInfo {
session_id: session.session_id,
message_count: session.message_count as usize,
token_count: session.token_count as usize,
metadata,
created_at: session.created_at,
updated_at: session.updated_at,
};
Ok(Json(ApiResponse::success(info)))
}
pub async fn delete_session(
State(state): State<AppState>,
Path(session_id): Path<String>,
) -> Result<StatusCode, ApiError> {
state.db.delete_agent_session(&session_id)
.map_err(|e| ApiError::internal(format!("Failed to delete session: {}", e)))?;
Ok(StatusCode::NO_CONTENT)
}
pub async fn add_message(
State(state): State<AppState>,
Path(session_id): Path<String>,
Json(req): Json<AddMessageRequest>,
) -> Result<(StatusCode, Json<ApiResponse<MemoryMessage>>), ApiError> {
let message = state.db.add_agent_message(
&session_id,
&req.role,
&req.content,
).map_err(|e| ApiError::internal(format!("Failed to add message: {}", e)))?;
let function_call = message.function_call.as_ref().and_then(|fc_str| {
serde_json::from_str::<FunctionCall>(fc_str).ok()
});
let metadata = if let serde_json::Value::Object(map) = message.metadata {
Some(map.into_iter().collect())
} else {
Some(HashMap::new())
};
let tool_calls = message.tool_calls.as_ref().and_then(|tc_val| {
serde_json::from_value::<Vec<ToolCall>>(tc_val.clone()).ok()
});
let response = MemoryMessage {
role: message.role,
content: message.content,
name: if message.name.is_empty() { None } else { Some(message.name) },
function_call,
tool_calls,
metadata,
timestamp: Some(message.timestamp),
};
Ok((StatusCode::CREATED, Json(ApiResponse::success(response))))
}
pub async fn add_messages(
State(state): State<AppState>,
Path(session_id): Path<String>,
Json(req): Json<AddMessagesRequest>,
) -> Result<(StatusCode, Json<ApiResponse<serde_json::Value>>), ApiError> {
let mut count = 0;
for msg in &req.messages {
let _ = state.db.add_agent_message(
&session_id,
&msg.role,
&msg.content,
).map_err(|e| ApiError::internal(format!("Failed to add message: {}", e)))?;
count += 1;
}
Ok((StatusCode::CREATED, Json(ApiResponse::success(serde_json::json!({
"added_count": count,
})))))
}
pub async fn get_messages(
State(state): State<AppState>,
Path(session_id): Path<String>,
Query(query): Query<GetMessagesQuery>,
) -> Result<Json<ApiResponse<Vec<MemoryMessage>>>, ApiError> {
let messages = state.db.get_agent_messages(
&session_id,
).map_err(|e| ApiError::internal(format!("Failed to get messages: {}", e)))?;
let response: Vec<MemoryMessage> = messages
.into_iter()
.map(|m| {
let function_call = m.function_call.as_ref().and_then(|fc_str| {
serde_json::from_str::<FunctionCall>(fc_str).ok()
});
let metadata = if let serde_json::Value::Object(map) = m.metadata {
Some(map.into_iter().collect())
} else {
Some(HashMap::new())
};
let tool_calls = m.tool_calls.as_ref().and_then(|tc_val| {
serde_json::from_value::<Vec<ToolCall>>(tc_val.clone()).ok()
});
MemoryMessage {
role: m.role,
content: m.content,
name: if m.name.is_empty() { None } else { Some(m.name) },
function_call,
tool_calls,
metadata,
timestamp: Some(m.timestamp),
}
})
.collect();
Ok(Json(ApiResponse::success(response)))
}
pub async fn search_memory(
State(state): State<AppState>,
Path(session_id): Path<String>,
Json(req): Json<SearchMemoryRequest>,
) -> Result<Json<ApiResponse<Vec<MemorySearchResult>>>, ApiError> {
let raw_results = state.db.search_agent_memory(
&session_id,
&req.query,
).map_err(|e| ApiError::internal(format!("Memory search failed: {}", e)))?;
let response: Vec<MemorySearchResult> = raw_results
.into_iter()
.enumerate()
.map(|(index, (msg, score))| {
let function_call = msg.function_call.as_ref().and_then(|fc_str| {
serde_json::from_str::<FunctionCall>(fc_str).ok()
});
let metadata = if let serde_json::Value::Object(map) = msg.metadata {
map.into_iter().collect()
} else {
HashMap::new()
};
let tool_calls = msg.tool_calls.as_ref().and_then(|tc_val| {
serde_json::from_value::<Vec<ToolCall>>(tc_val.clone()).ok()
});
MemorySearchResult {
message: MemoryMessage {
role: msg.role,
content: msg.content,
name: if msg.name.is_empty() { None } else { Some(msg.name) },
function_call,
tool_calls,
metadata: Some(metadata),
timestamp: Some(msg.timestamp),
},
score,
index,
}
})
.collect();
Ok(Json(ApiResponse::success(response)))
}
pub async fn summarize_memory(
State(state): State<AppState>,
Path(session_id): Path<String>,
Json(_req): Json<SummarizeRequest>,
) -> Result<Json<ApiResponse<SummaryResponse>>, ApiError> {
let summary = state.db.summarize_agent_memory(
&session_id,
).map_err(|e| ApiError::internal(format!("Summarization failed: {}", e)))?;
Ok(Json(ApiResponse::success(SummaryResponse {
summary,
original_count: 0, new_count: 0,
tokens_saved: 0,
})))
}
pub async fn get_context(
State(state): State<AppState>,
Path(session_id): Path<String>,
Json(_req): Json<GetContextRequest>,
) -> Result<Json<ApiResponse<ContextResponse>>, ApiError> {
let _result = state.db.get_agent_context(
&session_id,
).map_err(|e| ApiError::internal(format!("Failed to get context: {}", e)))?;
Ok(Json(ApiResponse::success(ContextResponse {
messages: vec![],
token_count: 0,
truncated: false,
summary_included: false,
})))
}
pub async fn fork_session(
State(state): State<AppState>,
Path(session_id): Path<String>,
Json(req): Json<ForkSessionRequest>,
) -> Result<(StatusCode, Json<ApiResponse<SessionInfo>>), ApiError> {
let new_session_id = req.new_session_id
.unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
let session = state.db.fork_agent_session(
&session_id,
&new_session_id,
).map_err(|e| ApiError::internal(format!("Failed to fork session: {}", e)))?;
let metadata = match session.metadata {
serde_json::Value::Object(map) => map.into_iter().collect(),
_ => HashMap::new(),
};
let info = SessionInfo {
session_id: session.session_id,
message_count: session.message_count as usize,
token_count: session.token_count as usize,
metadata,
created_at: session.created_at,
updated_at: session.updated_at,
};
Ok((StatusCode::CREATED, Json(ApiResponse::success(info))))
}
pub async fn clear_messages(
State(state): State<AppState>,
Path(session_id): Path<String>,
Json(req): Json<ClearMessagesRequest>,
) -> Result<Json<ApiResponse<serde_json::Value>>, ApiError> {
state.db.clear_agent_messages(
&session_id,
).map_err(|e| ApiError::internal(format!("Failed to clear messages: {}", e)))?;
let deleted = 0;
Ok(Json(ApiResponse::success(serde_json::json!({
"deleted_count": deleted,
}))))
}