Skip to main content

engram/intelligence/
session_context.rs

1//! Session Context Tracking (Phase 8 - ENG-70, ENG-71)
2//!
3//! Extends session indexing with:
4//! - Named session creation and management
5//! - Memory-to-session linking
6//! - Session-scoped memory search
7//! - Session summarization
8//! - Context role tracking (referenced, created, updated)
9//!
10//! This enables agents to:
11//! - Track which memories were used in a session
12//! - Search within session context
13//! - Generate session summaries
14//! - Export session data for analysis
15
16use std::collections::HashMap;
17
18use chrono::{DateTime, Utc};
19use rusqlite::{params, Connection};
20use serde::{Deserialize, Serialize};
21
22use crate::error::{EngramError, Result};
23use crate::types::{Memory, MemoryId};
24
25/// Role of a memory in a session context
26#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
27#[serde(rename_all = "lowercase")]
28pub enum ContextRole {
29    /// Memory was referenced/read during session
30    #[default]
31    Referenced,
32    /// Memory was created during session
33    Created,
34    /// Memory was updated during session
35    Updated,
36    /// Memory was explicitly added to context
37    Pinned,
38}
39
40impl std::fmt::Display for ContextRole {
41    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
42        match self {
43            ContextRole::Referenced => write!(f, "referenced"),
44            ContextRole::Created => write!(f, "created"),
45            ContextRole::Updated => write!(f, "updated"),
46            ContextRole::Pinned => write!(f, "pinned"),
47        }
48    }
49}
50
51impl std::str::FromStr for ContextRole {
52    type Err = String;
53
54    fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
55        match s.to_lowercase().as_str() {
56            "referenced" => Ok(ContextRole::Referenced),
57            "created" => Ok(ContextRole::Created),
58            "updated" => Ok(ContextRole::Updated),
59            "pinned" => Ok(ContextRole::Pinned),
60            _ => Err(format!("Unknown context role: {}", s)),
61        }
62    }
63}
64
65/// A memory linked to a session with context information
66#[derive(Debug, Clone, Serialize, Deserialize)]
67pub struct SessionMemoryLink {
68    /// Session ID
69    pub session_id: String,
70    /// Memory ID
71    pub memory_id: MemoryId,
72    /// When the memory was added to session
73    pub added_at: DateTime<Utc>,
74    /// Relevance score (0.0 - 1.0)
75    pub relevance_score: f32,
76    /// Role of the memory in the session
77    pub context_role: ContextRole,
78}
79
80/// Extended session information with context
81#[derive(Debug, Clone, Serialize, Deserialize)]
82pub struct SessionContext {
83    /// Session ID
84    pub session_id: String,
85    /// Session title (optional)
86    pub title: Option<String>,
87    /// When the session started
88    pub created_at: DateTime<Utc>,
89    /// When the session ended (None if active)
90    pub ended_at: Option<DateTime<Utc>>,
91    /// Number of messages in the session
92    pub message_count: i32,
93    /// Workspace for the session
94    pub workspace: String,
95    /// Summary of the session (auto-generated or manual)
96    pub summary: Option<String>,
97    /// Active context (JSON-encoded working memory)
98    pub context: Option<String>,
99    /// Session metadata
100    pub metadata: HashMap<String, serde_json::Value>,
101    /// Linked memories with context info
102    #[serde(default)]
103    pub memories: Vec<SessionMemoryLink>,
104}
105
106/// Input for creating a new session
107#[derive(Debug, Clone, Serialize, Deserialize)]
108pub struct CreateSessionInput {
109    /// Optional custom session ID (auto-generated if not provided)
110    pub session_id: Option<String>,
111    /// Optional session title
112    pub title: Option<String>,
113    /// Initial context (JSON string)
114    pub initial_context: Option<String>,
115    /// Optional workspace (defaults to "default")
116    pub workspace: Option<String>,
117    /// Session metadata
118    #[serde(default)]
119    pub metadata: HashMap<String, serde_json::Value>,
120}
121
122/// Result of session search
123#[derive(Debug, Clone, Serialize, Deserialize)]
124pub struct SessionSearchResult {
125    /// The memory
126    pub memory: Memory,
127    /// Search relevance score
128    pub relevance_score: f32,
129    /// Context role in the session
130    pub context_role: ContextRole,
131    /// When added to session
132    pub added_at: DateTime<Utc>,
133}
134
135/// Session export format
136#[derive(Debug, Clone, Serialize, Deserialize)]
137pub struct SessionExport {
138    /// Session information
139    pub session: SessionContext,
140    /// All linked memories
141    pub memories: Vec<Memory>,
142    /// Export timestamp
143    pub exported_at: DateTime<Utc>,
144    /// Export format version
145    pub format_version: String,
146}
147
148/// Create a new named session
149pub fn create_session(conn: &Connection, input: CreateSessionInput) -> Result<SessionContext> {
150    let session_id = input
151        .session_id
152        .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
153    let now = Utc::now();
154    let now_str = now.to_rfc3339();
155    let metadata_json = serde_json::to_string(&input.metadata).unwrap_or_else(|_| "{}".to_string());
156    let workspace = input.workspace.unwrap_or_else(|| "default".to_string());
157
158    conn.execute(
159        "INSERT INTO sessions (session_id, title, started_at, message_count, workspace, metadata, summary, context)
160         VALUES (?, ?, ?, 0, ?, ?, NULL, ?)",
161        params![
162            session_id,
163            input.title,
164            now_str,
165            workspace,
166            metadata_json,
167            input.initial_context
168        ],
169    )?;
170
171    Ok(SessionContext {
172        session_id,
173        title: input.title,
174        created_at: now,
175        ended_at: None,
176        message_count: 0,
177        workspace,
178        summary: None,
179        context: input.initial_context,
180        metadata: input.metadata,
181        memories: vec![],
182    })
183}
184
185/// Add a memory to a session's context
186pub fn add_memory_to_session(
187    conn: &Connection,
188    session_id: &str,
189    memory_id: MemoryId,
190    relevance_score: f32,
191    context_role: ContextRole,
192) -> Result<SessionMemoryLink> {
193    let now = Utc::now();
194    let now_str = now.to_rfc3339();
195    let role_str = context_role.to_string();
196
197    // Check if session exists
198    let exists: bool = conn.query_row(
199        "SELECT EXISTS(SELECT 1 FROM sessions WHERE session_id = ?)",
200        params![session_id],
201        |row| row.get(0),
202    )?;
203
204    if !exists {
205        return Err(EngramError::InvalidInput(format!(
206            "Session not found: {}",
207            session_id
208        )));
209    }
210
211    // Insert or update the link
212    conn.execute(
213        "INSERT INTO session_memories (session_id, memory_id, added_at, relevance_score, context_role)
214         VALUES (?, ?, ?, ?, ?)
215         ON CONFLICT(session_id, memory_id) DO UPDATE SET
216             relevance_score = MAX(relevance_score, excluded.relevance_score),
217             context_role = excluded.context_role",
218        params![session_id, memory_id, now_str, relevance_score, role_str],
219    )?;
220
221    Ok(SessionMemoryLink {
222        session_id: session_id.to_string(),
223        memory_id,
224        added_at: now,
225        relevance_score,
226        context_role,
227    })
228}
229
230/// Remove a memory from a session's context
231pub fn remove_memory_from_session(
232    conn: &Connection,
233    session_id: &str,
234    memory_id: MemoryId,
235) -> Result<bool> {
236    let rows = conn.execute(
237        "DELETE FROM session_memories WHERE session_id = ? AND memory_id = ?",
238        params![session_id, memory_id],
239    )?;
240
241    Ok(rows > 0)
242}
243
244/// Get all memories linked to a session
245pub fn get_session_memories(
246    conn: &Connection,
247    session_id: &str,
248    role_filter: Option<ContextRole>,
249) -> Result<Vec<SessionMemoryLink>> {
250    let base_query = "SELECT session_id, memory_id, added_at, relevance_score, context_role
251                      FROM session_memories WHERE session_id = ?";
252
253    let query = if role_filter.is_some() {
254        format!("{} AND context_role = ?", base_query)
255    } else {
256        format!("{} ORDER BY relevance_score DESC", base_query)
257    };
258
259    let mut stmt = conn.prepare(&query)?;
260
261    let links = if let Some(role) = role_filter {
262        stmt.query_map(params![session_id, role.to_string()], parse_link)?
263    } else {
264        stmt.query_map(params![session_id], parse_link)?
265    };
266
267    Ok(links.filter_map(|r| r.ok()).collect::<Vec<_>>())
268}
269
270fn parse_link(row: &rusqlite::Row) -> rusqlite::Result<SessionMemoryLink> {
271    let session_id: String = row.get(0)?;
272    let memory_id: MemoryId = row.get(1)?;
273    let added_at_str: String = row.get(2)?;
274    let relevance_score: f32 = row.get(3)?;
275    let role_str: String = row.get(4)?;
276
277    let added_at = DateTime::parse_from_rfc3339(&added_at_str)
278        .map(|dt| dt.with_timezone(&Utc))
279        .unwrap_or_else(|_| Utc::now());
280
281    let context_role = role_str.parse().unwrap_or(ContextRole::Referenced);
282
283    Ok(SessionMemoryLink {
284        session_id,
285        memory_id,
286        added_at,
287        relevance_score,
288        context_role,
289    })
290}
291
292/// Get a session with all its linked memories
293pub fn get_session_context(conn: &Connection, session_id: &str) -> Result<Option<SessionContext>> {
294    let row = conn.query_row(
295        "SELECT session_id, title, started_at, ended_at, message_count, workspace, metadata, summary, context
296         FROM sessions WHERE session_id = ?",
297        params![session_id],
298        |row| {
299            Ok((
300                row.get::<_, String>(0)?,
301                row.get::<_, Option<String>>(1)?,
302                row.get::<_, String>(2)?,
303                row.get::<_, Option<String>>(3)?,
304                row.get::<_, i32>(4)?,
305                row.get::<_, String>(5)?,
306                row.get::<_, Option<String>>(6)?,
307                row.get::<_, Option<String>>(7)?,
308                row.get::<_, Option<String>>(8)?,
309            ))
310        },
311    );
312
313    match row {
314        Ok((
315            id,
316            title,
317            started_at_str,
318            ended_at_str,
319            message_count,
320            workspace,
321            metadata_str,
322            summary,
323            context,
324        )) => {
325            let created_at = DateTime::parse_from_rfc3339(&started_at_str)
326                .map(|dt| dt.with_timezone(&Utc))
327                .unwrap_or_else(|_| Utc::now());
328
329            let ended_at = ended_at_str.and_then(|s| {
330                DateTime::parse_from_rfc3339(&s)
331                    .map(|dt| dt.with_timezone(&Utc))
332                    .ok()
333            });
334
335            let metadata: HashMap<String, serde_json::Value> = metadata_str
336                .and_then(|s| serde_json::from_str(&s).ok())
337                .unwrap_or_default();
338            let title = title.or_else(|| {
339                metadata
340                    .get("title")
341                    .and_then(|v| v.as_str())
342                    .map(String::from)
343            });
344
345            let memories = get_session_memories(conn, session_id, None)?;
346
347            Ok(Some(SessionContext {
348                session_id: id,
349                title,
350                created_at,
351                ended_at,
352                message_count,
353                workspace,
354                summary,
355                context,
356                metadata,
357                memories,
358            }))
359        }
360        Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
361        Err(e) => Err(e.into()),
362    }
363}
364
365/// Update session summary
366pub fn update_session_summary(conn: &Connection, session_id: &str, summary: &str) -> Result<()> {
367    let now = Utc::now().to_rfc3339();
368
369    let rows = conn.execute(
370        "UPDATE sessions SET summary = ?, ended_at = COALESCE(ended_at, ?) WHERE session_id = ?",
371        params![summary, now, session_id],
372    )?;
373
374    if rows == 0 {
375        return Err(EngramError::InvalidInput(format!(
376            "Session not found: {}",
377            session_id
378        )));
379    }
380
381    Ok(())
382}
383
384/// Update session context (working memory)
385pub fn update_session_context(conn: &Connection, session_id: &str, context: &str) -> Result<()> {
386    let rows = conn.execute(
387        "UPDATE sessions SET context = ? WHERE session_id = ?",
388        params![context, session_id],
389    )?;
390
391    if rows == 0 {
392        return Err(EngramError::InvalidInput(format!(
393            "Session not found: {}",
394            session_id
395        )));
396    }
397
398    Ok(())
399}
400
401/// End a session
402pub fn end_session(conn: &Connection, session_id: &str) -> Result<()> {
403    let now = Utc::now().to_rfc3339();
404
405    let rows = conn.execute(
406        "UPDATE sessions SET ended_at = ? WHERE session_id = ? AND ended_at IS NULL",
407        params![now, session_id],
408    )?;
409
410    if rows == 0 {
411        // Check if session exists but was already ended
412        let exists: bool = conn.query_row(
413            "SELECT EXISTS(SELECT 1 FROM sessions WHERE session_id = ?)",
414            params![session_id],
415            |row| row.get(0),
416        )?;
417
418        if !exists {
419            return Err(EngramError::InvalidInput(format!(
420                "Session not found: {}",
421                session_id
422            )));
423        }
424        // Session exists but already ended - that's OK
425    }
426
427    Ok(())
428}
429
430/// Search memories within a session's context
431pub fn search_session_memories(
432    conn: &Connection,
433    session_id: &str,
434    query: &str,
435    limit: i64,
436) -> Result<Vec<SessionSearchResult>> {
437    // First get memory IDs linked to this session
438    let memory_ids: Vec<MemoryId> = conn
439        .prepare(
440            "SELECT memory_id FROM session_memories WHERE session_id = ? ORDER BY relevance_score DESC",
441        )?
442        .query_map(params![session_id], |row| row.get(0))?
443        .filter_map(|r| r.ok())
444        .collect();
445
446    if memory_ids.is_empty() {
447        return Ok(vec![]);
448    }
449
450    // Build IN clause
451    let placeholders: Vec<String> = memory_ids.iter().map(|_| "?".to_string()).collect();
452    let in_clause = placeholders.join(", ");
453
454    // Search within those memories using FTS
455    let sql = format!(
456        "SELECT m.id, m.content, m.memory_type, m.importance, m.access_count,
457                m.created_at, m.updated_at, m.last_accessed_at, m.tags,
458                m.workspace, m.tier, m.lifecycle_state,
459                sm.relevance_score, sm.context_role, sm.added_at,
460                bm25(memories_fts) as search_score
461         FROM memories m
462         JOIN session_memories sm ON m.id = sm.memory_id
463         JOIN memories_fts ON memories_fts.rowid = m.id
464         WHERE sm.session_id = ?
465           AND m.id IN ({})
466           AND memories_fts MATCH ?
467         ORDER BY search_score * sm.relevance_score DESC
468         LIMIT ?",
469        in_clause
470    );
471
472    let mut params_vec: Vec<Box<dyn rusqlite::ToSql>> = vec![Box::new(session_id.to_string())];
473    for id in &memory_ids {
474        params_vec.push(Box::new(*id));
475    }
476    params_vec.push(Box::new(query.to_string()));
477    params_vec.push(Box::new(limit));
478
479    let params_refs: Vec<&dyn rusqlite::ToSql> = params_vec.iter().map(|p| p.as_ref()).collect();
480
481    let mut stmt = conn.prepare(&sql)?;
482    let results = stmt
483        .query_map(params_refs.as_slice(), |row| {
484            // Parse memory fields
485            let id: MemoryId = row.get(0)?;
486            let content: String = row.get(1)?;
487            let memory_type_str: String = row.get(2)?;
488            let importance: f32 = row.get(3)?;
489            let access_count: i32 = row.get(4)?;
490            let created_at_str: String = row.get(5)?;
491            let updated_at_str: String = row.get(6)?;
492            let last_accessed_str: Option<String> = row.get(7)?;
493            let tags_str: Option<String> = row.get(8)?;
494            let workspace: String = row.get(9)?;
495            let tier_str: String = row.get(10)?;
496            let lifecycle_str: String = row.get(11)?;
497            let relevance_score: f32 = row.get(12)?;
498            let context_role_str: String = row.get(13)?;
499            let added_at_str: String = row.get(14)?;
500
501            Ok((
502                id,
503                content,
504                memory_type_str,
505                importance,
506                access_count,
507                created_at_str,
508                updated_at_str,
509                last_accessed_str,
510                tags_str,
511                workspace,
512                tier_str,
513                lifecycle_str,
514                relevance_score,
515                context_role_str,
516                added_at_str,
517            ))
518        })?
519        .filter_map(|r| r.ok())
520        .map(
521            |(
522                id,
523                content,
524                memory_type_str,
525                importance,
526                access_count,
527                created_at_str,
528                updated_at_str,
529                last_accessed_str,
530                tags_str,
531                workspace,
532                tier_str,
533                lifecycle_str,
534                relevance_score,
535                context_role_str,
536                added_at_str,
537            )| {
538                let now = Utc::now();
539
540                let memory = Memory {
541                    id,
542                    content,
543                    memory_type: memory_type_str
544                        .parse()
545                        .unwrap_or(crate::types::MemoryType::Note),
546                    tags: tags_str
547                        .map(|s| serde_json::from_str(&s).unwrap_or_default())
548                        .unwrap_or_default(),
549                    metadata: HashMap::new(),
550                    importance,
551                    access_count,
552                    created_at: DateTime::parse_from_rfc3339(&created_at_str)
553                        .map(|dt| dt.with_timezone(&Utc))
554                        .unwrap_or(now),
555                    updated_at: DateTime::parse_from_rfc3339(&updated_at_str)
556                        .map(|dt| dt.with_timezone(&Utc))
557                        .unwrap_or(now),
558                    last_accessed_at: last_accessed_str.and_then(|s| {
559                        DateTime::parse_from_rfc3339(&s)
560                            .map(|dt| dt.with_timezone(&Utc))
561                            .ok()
562                    }),
563                    owner_id: None,
564                    visibility: crate::types::Visibility::Private,
565                    scope: crate::types::MemoryScope::Global,
566                    workspace,
567                    tier: tier_str
568                        .parse()
569                        .unwrap_or(crate::types::MemoryTier::Permanent),
570                    version: 1,
571                    has_embedding: false,
572                    expires_at: None,
573                    content_hash: None,
574                    event_time: None,
575                    event_duration_seconds: None,
576                    trigger_pattern: None,
577                    procedure_success_count: 0,
578                    procedure_failure_count: 0,
579                    summary_of_id: None,
580                    lifecycle_state: lifecycle_str
581                        .parse()
582                        .unwrap_or(crate::types::LifecycleState::Active),
583                    media_url: None,
584                };
585
586                SessionSearchResult {
587                    memory,
588                    relevance_score,
589                    context_role: context_role_str.parse().unwrap_or(ContextRole::Referenced),
590                    added_at: DateTime::parse_from_rfc3339(&added_at_str)
591                        .map(|dt| dt.with_timezone(&Utc))
592                        .unwrap_or(now),
593                }
594            },
595        )
596        .collect();
597
598    Ok(results)
599}
600
601/// Export a session with all its data
602pub fn export_session(
603    conn: &Connection,
604    session_id: &str,
605    include_content: bool,
606) -> Result<SessionExport> {
607    let session = get_session_context(conn, session_id)?
608        .ok_or_else(|| EngramError::InvalidInput(format!("Session not found: {}", session_id)))?;
609
610    // Get all linked memories
611    let memory_ids: Vec<MemoryId> = session.memories.iter().map(|m| m.memory_id).collect();
612
613    let mut memories = Vec::new();
614    if !memory_ids.is_empty() {
615        for id in memory_ids {
616            match crate::storage::queries::get_memory(conn, id) {
617                Ok(mut memory) => {
618                    if !include_content {
619                        memory.content.clear();
620                    }
621                    memories.push(memory);
622                }
623                Err(EngramError::NotFound(_)) => continue,
624                Err(e) => return Err(e),
625            }
626        }
627    }
628
629    Ok(SessionExport {
630        session,
631        memories,
632        exported_at: Utc::now(),
633        format_version: "1.0".to_string(),
634    })
635}
636
637/// List sessions with optional filters
638pub fn list_sessions_extended(
639    conn: &Connection,
640    workspace: Option<&str>,
641    active_only: bool,
642    limit: i64,
643    offset: i64,
644) -> Result<Vec<SessionContext>> {
645    let mut query = String::from(
646        "SELECT session_id, title, started_at, ended_at, message_count, workspace, metadata, summary, context
647         FROM sessions",
648    );
649
650    let mut filters = Vec::new();
651    if active_only {
652        filters.push("ended_at IS NULL");
653    }
654    if workspace.is_some() {
655        filters.push("workspace = ?");
656    }
657    if !filters.is_empty() {
658        query.push_str(" WHERE ");
659        query.push_str(&filters.join(" AND "));
660    }
661
662    query.push_str(" ORDER BY started_at DESC LIMIT ? OFFSET ?");
663
664    let mut stmt = conn.prepare(&query)?;
665    let rows: Vec<(
666        String,
667        Option<String>,
668        String,
669        Option<String>,
670        i32,
671        String,
672        Option<String>,
673        Option<String>,
674        Option<String>,
675    )> = if let Some(workspace) = workspace {
676        let rows = stmt.query_map(params![workspace, limit, offset], |row| {
677            Ok((
678                row.get::<_, String>(0)?,
679                row.get::<_, Option<String>>(1)?,
680                row.get::<_, String>(2)?,
681                row.get::<_, Option<String>>(3)?,
682                row.get::<_, i32>(4)?,
683                row.get::<_, String>(5)?,
684                row.get::<_, Option<String>>(6)?,
685                row.get::<_, Option<String>>(7)?,
686                row.get::<_, Option<String>>(8)?,
687            ))
688        })?;
689        rows.collect::<std::result::Result<Vec<_>, _>>()?
690    } else {
691        let rows = stmt.query_map(params![limit, offset], |row| {
692            Ok((
693                row.get::<_, String>(0)?,
694                row.get::<_, Option<String>>(1)?,
695                row.get::<_, String>(2)?,
696                row.get::<_, Option<String>>(3)?,
697                row.get::<_, i32>(4)?,
698                row.get::<_, String>(5)?,
699                row.get::<_, Option<String>>(6)?,
700                row.get::<_, Option<String>>(7)?,
701                row.get::<_, Option<String>>(8)?,
702            ))
703        })?;
704        rows.collect::<std::result::Result<Vec<_>, _>>()?
705    };
706
707    let sessions = rows
708        .into_iter()
709        .map(
710            |(
711                id,
712                title,
713                started_at_str,
714                ended_at_str,
715                message_count,
716                workspace,
717                metadata_str,
718                summary,
719                context,
720            )| {
721                let now = Utc::now();
722                let created_at = DateTime::parse_from_rfc3339(&started_at_str)
723                    .map(|dt| dt.with_timezone(&Utc))
724                    .unwrap_or(now);
725
726                let ended_at = ended_at_str.and_then(|s| {
727                    DateTime::parse_from_rfc3339(&s)
728                        .map(|dt| dt.with_timezone(&Utc))
729                        .ok()
730                });
731
732                let metadata: HashMap<String, serde_json::Value> = metadata_str
733                    .and_then(|s| serde_json::from_str(&s).ok())
734                    .unwrap_or_default();
735                let title = title.or_else(|| {
736                    metadata
737                        .get("title")
738                        .and_then(|v| v.as_str())
739                        .map(String::from)
740                });
741
742                SessionContext {
743                    session_id: id,
744                    title,
745                    created_at,
746                    ended_at,
747                    message_count,
748                    workspace,
749                    summary,
750                    context,
751                    metadata,
752                    memories: vec![], // Don't load memories for list view
753                }
754            },
755        )
756        .collect();
757
758    Ok(sessions)
759}
760
761/// Get sessions that reference a specific memory
762pub fn get_sessions_for_memory(
763    conn: &Connection,
764    memory_id: MemoryId,
765) -> Result<Vec<SessionMemoryLink>> {
766    let mut stmt = conn.prepare(
767        "SELECT session_id, memory_id, added_at, relevance_score, context_role
768         FROM session_memories
769         WHERE memory_id = ?
770         ORDER BY added_at DESC",
771    )?;
772
773    let links: Vec<SessionMemoryLink> = stmt
774        .query_map(params![memory_id], parse_link)?
775        .filter_map(|r| r.ok())
776        .collect();
777
778    Ok(links)
779}
780
781#[cfg(test)]
782mod tests {
783    use super::*;
784    use rusqlite::Connection;
785
786    fn setup_test_db() -> Connection {
787        let conn = Connection::open_in_memory().unwrap();
788
789        // Create minimal schema for testing
790        conn.execute_batch(
791            r#"
792            CREATE TABLE sessions (
793                id INTEGER PRIMARY KEY AUTOINCREMENT,
794                session_id TEXT NOT NULL UNIQUE,
795                title TEXT,
796                started_at TEXT NOT NULL,
797                last_indexed_at TEXT,
798                message_count INTEGER NOT NULL DEFAULT 0,
799                chunk_count INTEGER NOT NULL DEFAULT 0,
800                workspace TEXT NOT NULL DEFAULT 'default',
801                metadata TEXT NOT NULL DEFAULT '{}',
802                created_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP,
803                summary TEXT,
804                context TEXT,
805                ended_at TEXT
806            );
807
808            CREATE TABLE memories (
809                id INTEGER PRIMARY KEY AUTOINCREMENT,
810                content TEXT NOT NULL,
811                memory_type TEXT DEFAULT 'note',
812                importance REAL DEFAULT 0.5,
813                access_count INTEGER DEFAULT 0,
814                created_at TEXT NOT NULL,
815                updated_at TEXT NOT NULL,
816                last_accessed_at TEXT,
817                workspace TEXT DEFAULT 'default',
818                tier TEXT DEFAULT 'permanent',
819                lifecycle_state TEXT DEFAULT 'active',
820                tags TEXT
821            );
822
823            CREATE TABLE session_memories (
824                session_id TEXT NOT NULL REFERENCES sessions(session_id) ON DELETE CASCADE,
825                memory_id INTEGER NOT NULL,
826                added_at TEXT NOT NULL,
827                relevance_score REAL DEFAULT 1.0,
828                context_role TEXT DEFAULT 'referenced',
829                PRIMARY KEY (session_id, memory_id)
830            );
831
832            CREATE VIRTUAL TABLE memories_fts USING fts5(content);
833            "#,
834        )
835        .unwrap();
836
837        conn
838    }
839
840    #[test]
841    fn test_create_session() {
842        let conn = setup_test_db();
843
844        let input = CreateSessionInput {
845            session_id: Some("test-session-1".to_string()),
846            title: Some("Test Session".to_string()),
847            initial_context: Some(r#"{"topic": "testing"}"#.to_string()),
848            workspace: None,
849            metadata: HashMap::new(),
850        };
851
852        let session = create_session(&conn, input).unwrap();
853        assert_eq!(session.session_id, "test-session-1");
854        assert!(session.context.is_some());
855    }
856
857    #[test]
858    fn test_add_memory_to_session() {
859        let conn = setup_test_db();
860
861        // Create session
862        let input = CreateSessionInput {
863            session_id: Some("test-session".to_string()),
864            title: None,
865            initial_context: None,
866            workspace: None,
867            metadata: HashMap::new(),
868        };
869        create_session(&conn, input).unwrap();
870
871        // Create a memory
872        let now = Utc::now().to_rfc3339();
873        conn.execute(
874            "INSERT INTO memories (content, created_at, updated_at) VALUES (?, ?, ?)",
875            params!["Test memory", now, now],
876        )
877        .unwrap();
878
879        // Add memory to session
880        let link =
881            add_memory_to_session(&conn, "test-session", 1, 0.9, ContextRole::Created).unwrap();
882
883        assert_eq!(link.session_id, "test-session");
884        assert_eq!(link.memory_id, 1);
885        assert_eq!(link.context_role, ContextRole::Created);
886    }
887
888    #[test]
889    fn test_get_session_context() {
890        let conn = setup_test_db();
891
892        // Create session
893        let input = CreateSessionInput {
894            session_id: Some("context-test".to_string()),
895            title: None,
896            initial_context: None,
897            workspace: None,
898            metadata: HashMap::new(),
899        };
900        create_session(&conn, input).unwrap();
901
902        // Get context
903        let context = get_session_context(&conn, "context-test").unwrap();
904        assert!(context.is_some());
905        assert_eq!(context.unwrap().session_id, "context-test");
906    }
907
908    #[test]
909    fn test_context_role_parsing() {
910        assert_eq!(
911            "referenced".parse::<ContextRole>().unwrap(),
912            ContextRole::Referenced
913        );
914        assert_eq!(
915            "created".parse::<ContextRole>().unwrap(),
916            ContextRole::Created
917        );
918        assert_eq!(
919            "updated".parse::<ContextRole>().unwrap(),
920            ContextRole::Updated
921        );
922        assert_eq!(
923            "pinned".parse::<ContextRole>().unwrap(),
924            ContextRole::Pinned
925        );
926    }
927
928    #[test]
929    fn test_end_session() {
930        let conn = setup_test_db();
931
932        let input = CreateSessionInput {
933            session_id: Some("end-test".to_string()),
934            title: None,
935            initial_context: None,
936            workspace: None,
937            metadata: HashMap::new(),
938        };
939        create_session(&conn, input).unwrap();
940
941        end_session(&conn, "end-test").unwrap();
942
943        let session = get_session_context(&conn, "end-test").unwrap().unwrap();
944        assert!(session.ended_at.is_some());
945    }
946}