agentic_memory_mcp/tools/
session_start.rs1use std::sync::Arc;
8use tokio::sync::Mutex;
9
10use serde::Deserialize;
11use serde_json::{json, Value};
12
13use agentic_memory::{EventType, PatternParams, PatternSort};
14
15use crate::session::SessionManager;
16use crate::types::{McpError, McpResult, ToolCallResult, ToolDefinition};
17
18#[derive(Debug, Deserialize)]
19#[allow(dead_code)]
20struct StartParams {
21 session_id: Option<u32>,
22 metadata: Option<Value>,
23}
24
25pub fn definition() -> ToolDefinition {
27 ToolDefinition {
28 name: "session_start".to_string(),
29 description: Some("Start a new interaction session".to_string()),
30 input_schema: json!({
31 "type": "object",
32 "properties": {
33 "session_id": { "type": "integer", "description": "Optional explicit session ID" },
34 "metadata": { "type": "object", "description": "Optional session metadata" }
35 }
36 }),
37 }
38}
39
40pub async fn execute(
42 args: Value,
43 session: &Arc<Mutex<SessionManager>>,
44) -> McpResult<ToolCallResult> {
45 let params: StartParams =
46 serde_json::from_value(args).map_err(|e| McpError::InvalidParams(e.to_string()))?;
47
48 let mut session = session.lock().await;
49 let session_id = session.start_session(params.session_id)?;
50
51 let graph = session.graph();
57 let query = session.query_engine();
58
59 let episode_pattern = PatternParams {
60 event_types: vec![EventType::Episode],
61 min_confidence: None,
62 max_confidence: None,
63 session_ids: vec![],
64 created_after: None,
65 created_before: None,
66 min_decay_score: None,
67 max_results: 1,
68 sort_by: PatternSort::MostRecent,
69 };
70
71 let last_episode = query
72 .pattern(graph, episode_pattern)
73 .ok()
74 .and_then(|eps| eps.into_iter().next())
75 .map(|ep| {
76 json!({
77 "session_id": ep.session_id,
78 "summary": ep.content,
79 "created_at": ep.created_at,
80 })
81 });
82
83 let all_sessions = graph.session_index().session_ids();
85 let prev_session = all_sessions
86 .iter()
87 .filter(|&&s| s < session_id)
88 .max()
89 .copied();
90
91 let total_sessions = all_sessions.len();
92
93 Ok(ToolCallResult::json(&json!({
94 "session_id": session_id,
95 "message": format!("Session {session_id} started"),
96 "total_sessions": total_sessions,
97 "previous_session": prev_session,
98 "last_episode": last_episode,
99 })))
100}