Skip to main content

engram/intelligence/
session_context.rs

1//! Session Context Tracking (Phase 8 - ENG-70, ENG-71)
2//!
3//! Extends session indexing with:
4//! - Named session creation and management
5//! - Memory-to-session linking
6//! - Session-scoped memory search
7//! - Session summarization
8//! - Context role tracking (referenced, created, updated)
9//!
10//! This enables agents to:
11//! - Track which memories were used in a session
12//! - Search within session context
13//! - Generate session summaries
14//! - Export session data for analysis
15
16use std::collections::HashMap;
17
18use chrono::{DateTime, Utc};
19use rusqlite::{params, Connection};
20use serde::{Deserialize, Serialize};
21
22use crate::error::{EngramError, Result};
23use crate::types::{Memory, MemoryId};
24
25/// Role of a memory in a session context
26#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
27#[serde(rename_all = "lowercase")]
28pub enum ContextRole {
29    /// Memory was referenced/read during session
30    #[default]
31    Referenced,
32    /// Memory was created during session
33    Created,
34    /// Memory was updated during session
35    Updated,
36    /// Memory was explicitly added to context
37    Pinned,
38}
39
40impl std::fmt::Display for ContextRole {
41    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
42        match self {
43            ContextRole::Referenced => write!(f, "referenced"),
44            ContextRole::Created => write!(f, "created"),
45            ContextRole::Updated => write!(f, "updated"),
46            ContextRole::Pinned => write!(f, "pinned"),
47        }
48    }
49}
50
51impl std::str::FromStr for ContextRole {
52    type Err = String;
53
54    fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
55        match s.to_lowercase().as_str() {
56            "referenced" => Ok(ContextRole::Referenced),
57            "created" => Ok(ContextRole::Created),
58            "updated" => Ok(ContextRole::Updated),
59            "pinned" => Ok(ContextRole::Pinned),
60            _ => Err(format!("Unknown context role: {}", s)),
61        }
62    }
63}
64
65/// A memory linked to a session with context information
66#[derive(Debug, Clone, Serialize, Deserialize)]
67pub struct SessionMemoryLink {
68    /// Session ID
69    pub session_id: String,
70    /// Memory ID
71    pub memory_id: MemoryId,
72    /// When the memory was added to session
73    pub added_at: DateTime<Utc>,
74    /// Relevance score (0.0 - 1.0)
75    pub relevance_score: f32,
76    /// Role of the memory in the session
77    pub context_role: ContextRole,
78}
79
80/// Extended session information with context
81#[derive(Debug, Clone, Serialize, Deserialize)]
82pub struct SessionContext {
83    /// Session ID
84    pub session_id: String,
85    /// Session title (optional)
86    pub title: Option<String>,
87    /// When the session started
88    pub created_at: DateTime<Utc>,
89    /// When the session ended (None if active)
90    pub ended_at: Option<DateTime<Utc>>,
91    /// Number of messages in the session
92    pub message_count: i32,
93    /// Workspace for the session
94    pub workspace: String,
95    /// Summary of the session (auto-generated or manual)
96    pub summary: Option<String>,
97    /// Active context (JSON-encoded working memory)
98    pub context: Option<String>,
99    /// Session metadata
100    pub metadata: HashMap<String, serde_json::Value>,
101    /// Linked memories with context info
102    #[serde(default)]
103    pub memories: Vec<SessionMemoryLink>,
104}
105
106/// Input for creating a new session
107#[derive(Debug, Clone, Serialize, Deserialize)]
108pub struct CreateSessionInput {
109    /// Optional custom session ID (auto-generated if not provided)
110    pub session_id: Option<String>,
111    /// Optional session title
112    pub title: Option<String>,
113    /// Initial context (JSON string)
114    pub initial_context: Option<String>,
115    /// Optional workspace (defaults to "default")
116    pub workspace: Option<String>,
117    /// Session metadata
118    #[serde(default)]
119    pub metadata: HashMap<String, serde_json::Value>,
120}
121
122/// Result of session search
123#[derive(Debug, Clone, Serialize, Deserialize)]
124pub struct SessionSearchResult {
125    /// The memory
126    pub memory: Memory,
127    /// Search relevance score
128    pub relevance_score: f32,
129    /// Context role in the session
130    pub context_role: ContextRole,
131    /// When added to session
132    pub added_at: DateTime<Utc>,
133}
134
135/// Session export format
136#[derive(Debug, Clone, Serialize, Deserialize)]
137pub struct SessionExport {
138    /// Session information
139    pub session: SessionContext,
140    /// All linked memories
141    pub memories: Vec<Memory>,
142    /// Export timestamp
143    pub exported_at: DateTime<Utc>,
144    /// Export format version
145    pub format_version: String,
146}
147
148/// Create a new named session
149pub fn create_session(conn: &Connection, input: CreateSessionInput) -> Result<SessionContext> {
150    let session_id = input
151        .session_id
152        .unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
153    let now = Utc::now();
154    let now_str = now.to_rfc3339();
155    let metadata_json = serde_json::to_string(&input.metadata).unwrap_or_else(|_| "{}".to_string());
156    let workspace = input.workspace.unwrap_or_else(|| "default".to_string());
157
158    conn.execute(
159        "INSERT INTO sessions (session_id, title, started_at, message_count, workspace, metadata, summary, context)
160         VALUES (?, ?, ?, 0, ?, ?, NULL, ?)",
161        params![
162            session_id,
163            input.title,
164            now_str,
165            workspace,
166            metadata_json,
167            input.initial_context
168        ],
169    )?;
170
171    Ok(SessionContext {
172        session_id,
173        title: input.title,
174        created_at: now,
175        ended_at: None,
176        message_count: 0,
177        workspace,
178        summary: None,
179        context: input.initial_context,
180        metadata: input.metadata,
181        memories: vec![],
182    })
183}
184
185/// Add a memory to a session's context
186pub fn add_memory_to_session(
187    conn: &Connection,
188    session_id: &str,
189    memory_id: MemoryId,
190    relevance_score: f32,
191    context_role: ContextRole,
192) -> Result<SessionMemoryLink> {
193    let now = Utc::now();
194    let now_str = now.to_rfc3339();
195    let role_str = context_role.to_string();
196
197    // Check if session exists
198    let exists: bool = conn.query_row(
199        "SELECT EXISTS(SELECT 1 FROM sessions WHERE session_id = ?)",
200        params![session_id],
201        |row| row.get(0),
202    )?;
203
204    if !exists {
205        return Err(EngramError::InvalidInput(format!(
206            "Session not found: {}",
207            session_id
208        )));
209    }
210
211    // Insert or update the link
212    conn.execute(
213        "INSERT INTO session_memories (session_id, memory_id, added_at, relevance_score, context_role)
214         VALUES (?, ?, ?, ?, ?)
215         ON CONFLICT(session_id, memory_id) DO UPDATE SET
216             relevance_score = MAX(relevance_score, excluded.relevance_score),
217             context_role = excluded.context_role",
218        params![session_id, memory_id, now_str, relevance_score, role_str],
219    )?;
220
221    Ok(SessionMemoryLink {
222        session_id: session_id.to_string(),
223        memory_id,
224        added_at: now,
225        relevance_score,
226        context_role,
227    })
228}
229
230/// Remove a memory from a session's context
231pub fn remove_memory_from_session(
232    conn: &Connection,
233    session_id: &str,
234    memory_id: MemoryId,
235) -> Result<bool> {
236    let rows = conn.execute(
237        "DELETE FROM session_memories WHERE session_id = ? AND memory_id = ?",
238        params![session_id, memory_id],
239    )?;
240
241    Ok(rows > 0)
242}
243
244/// Get all memories linked to a session
245pub fn get_session_memories(
246    conn: &Connection,
247    session_id: &str,
248    role_filter: Option<ContextRole>,
249) -> Result<Vec<SessionMemoryLink>> {
250    let base_query = "SELECT session_id, memory_id, added_at, relevance_score, context_role
251                      FROM session_memories WHERE session_id = ?";
252
253    let query = if role_filter.is_some() {
254        format!("{} AND context_role = ?", base_query)
255    } else {
256        format!("{} ORDER BY relevance_score DESC", base_query)
257    };
258
259    let mut stmt = conn.prepare(&query)?;
260
261    let links = if let Some(role) = role_filter {
262        stmt.query_map(params![session_id, role.to_string()], parse_link)?
263    } else {
264        stmt.query_map(params![session_id], parse_link)?
265    };
266
267    Ok(links.filter_map(|r| r.ok()).collect::<Vec<_>>())
268}
269
270fn parse_link(row: &rusqlite::Row) -> rusqlite::Result<SessionMemoryLink> {
271    let session_id: String = row.get(0)?;
272    let memory_id: MemoryId = row.get(1)?;
273    let added_at_str: String = row.get(2)?;
274    let relevance_score: f32 = row.get(3)?;
275    let role_str: String = row.get(4)?;
276
277    let added_at = DateTime::parse_from_rfc3339(&added_at_str)
278        .map(|dt| dt.with_timezone(&Utc))
279        .unwrap_or_else(|_| Utc::now());
280
281    let context_role = role_str.parse().unwrap_or(ContextRole::Referenced);
282
283    Ok(SessionMemoryLink {
284        session_id,
285        memory_id,
286        added_at,
287        relevance_score,
288        context_role,
289    })
290}
291
292/// Get a session with all its linked memories
293pub fn get_session_context(conn: &Connection, session_id: &str) -> Result<Option<SessionContext>> {
294    let row = conn.query_row(
295        "SELECT session_id, title, started_at, ended_at, message_count, workspace, metadata, summary, context
296         FROM sessions WHERE session_id = ?",
297        params![session_id],
298        |row| {
299            Ok((
300                row.get::<_, String>(0)?,
301                row.get::<_, Option<String>>(1)?,
302                row.get::<_, String>(2)?,
303                row.get::<_, Option<String>>(3)?,
304                row.get::<_, i32>(4)?,
305                row.get::<_, String>(5)?,
306                row.get::<_, Option<String>>(6)?,
307                row.get::<_, Option<String>>(7)?,
308                row.get::<_, Option<String>>(8)?,
309            ))
310        },
311    );
312
313    match row {
314        Ok((
315            id,
316            title,
317            started_at_str,
318            ended_at_str,
319            message_count,
320            workspace,
321            metadata_str,
322            summary,
323            context,
324        )) => {
325            let created_at = DateTime::parse_from_rfc3339(&started_at_str)
326                .map(|dt| dt.with_timezone(&Utc))
327                .unwrap_or_else(|_| Utc::now());
328
329            let ended_at = ended_at_str.and_then(|s| {
330                DateTime::parse_from_rfc3339(&s)
331                    .map(|dt| dt.with_timezone(&Utc))
332                    .ok()
333            });
334
335            let metadata: HashMap<String, serde_json::Value> = metadata_str
336                .and_then(|s| serde_json::from_str(&s).ok())
337                .unwrap_or_default();
338            let title = title.or_else(|| {
339                metadata
340                    .get("title")
341                    .and_then(|v| v.as_str())
342                    .map(String::from)
343            });
344
345            let memories = get_session_memories(conn, session_id, None)?;
346
347            Ok(Some(SessionContext {
348                session_id: id,
349                title,
350                created_at,
351                ended_at,
352                message_count,
353                workspace,
354                summary,
355                context,
356                metadata,
357                memories,
358            }))
359        }
360        Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
361        Err(e) => Err(e.into()),
362    }
363}
364
365/// Update session summary
366pub fn update_session_summary(conn: &Connection, session_id: &str, summary: &str) -> Result<()> {
367    let now = Utc::now().to_rfc3339();
368
369    let rows = conn.execute(
370        "UPDATE sessions SET summary = ?, ended_at = COALESCE(ended_at, ?) WHERE session_id = ?",
371        params![summary, now, session_id],
372    )?;
373
374    if rows == 0 {
375        return Err(EngramError::InvalidInput(format!(
376            "Session not found: {}",
377            session_id
378        )));
379    }
380
381    Ok(())
382}
383
384/// Update session context (working memory)
385pub fn update_session_context(conn: &Connection, session_id: &str, context: &str) -> Result<()> {
386    let rows = conn.execute(
387        "UPDATE sessions SET context = ? WHERE session_id = ?",
388        params![context, session_id],
389    )?;
390
391    if rows == 0 {
392        return Err(EngramError::InvalidInput(format!(
393            "Session not found: {}",
394            session_id
395        )));
396    }
397
398    Ok(())
399}
400
401/// End a session
402pub fn end_session(conn: &Connection, session_id: &str) -> Result<()> {
403    let now = Utc::now().to_rfc3339();
404
405    let rows = conn.execute(
406        "UPDATE sessions SET ended_at = ? WHERE session_id = ? AND ended_at IS NULL",
407        params![now, session_id],
408    )?;
409
410    if rows == 0 {
411        // Check if session exists but was already ended
412        let exists: bool = conn.query_row(
413            "SELECT EXISTS(SELECT 1 FROM sessions WHERE session_id = ?)",
414            params![session_id],
415            |row| row.get(0),
416        )?;
417
418        if !exists {
419            return Err(EngramError::InvalidInput(format!(
420                "Session not found: {}",
421                session_id
422            )));
423        }
424        // Session exists but already ended - that's OK
425    }
426
427    Ok(())
428}
429
430/// Search memories within a session's context
431pub fn search_session_memories(
432    conn: &Connection,
433    session_id: &str,
434    query: &str,
435    limit: i64,
436) -> Result<Vec<SessionSearchResult>> {
437    // First get memory IDs linked to this session
438    let memory_ids: Vec<MemoryId> = conn
439        .prepare(
440            "SELECT memory_id FROM session_memories WHERE session_id = ? ORDER BY relevance_score DESC",
441        )?
442        .query_map(params![session_id], |row| row.get(0))?
443        .filter_map(|r| r.ok())
444        .collect();
445
446    if memory_ids.is_empty() {
447        return Ok(vec![]);
448    }
449
450    // Build IN clause
451    let placeholders: Vec<String> = memory_ids.iter().map(|_| "?".to_string()).collect();
452    let in_clause = placeholders.join(", ");
453
454    // Search within those memories using FTS
455    let sql = format!(
456        "SELECT m.id, m.content, m.memory_type, m.importance, m.access_count,
457                m.created_at, m.updated_at, m.last_accessed_at, m.tags,
458                m.workspace, m.tier, m.lifecycle_state,
459                sm.relevance_score, sm.context_role, sm.added_at,
460                bm25(memories_fts) as search_score
461         FROM memories m
462         JOIN session_memories sm ON m.id = sm.memory_id
463         JOIN memories_fts ON memories_fts.rowid = m.id
464         WHERE sm.session_id = ?
465           AND m.id IN ({})
466           AND memories_fts MATCH ?
467         ORDER BY search_score * sm.relevance_score DESC
468         LIMIT ?",
469        in_clause
470    );
471
472    let mut params_vec: Vec<Box<dyn rusqlite::ToSql>> = vec![Box::new(session_id.to_string())];
473    for id in &memory_ids {
474        params_vec.push(Box::new(*id));
475    }
476    params_vec.push(Box::new(query.to_string()));
477    params_vec.push(Box::new(limit));
478
479    let params_refs: Vec<&dyn rusqlite::ToSql> = params_vec.iter().map(|p| p.as_ref()).collect();
480
481    let mut stmt = conn.prepare(&sql)?;
482    let results = stmt
483        .query_map(params_refs.as_slice(), |row| {
484            // Parse memory fields
485            let id: MemoryId = row.get(0)?;
486            let content: String = row.get(1)?;
487            let memory_type_str: String = row.get(2)?;
488            let importance: f32 = row.get(3)?;
489            let access_count: i32 = row.get(4)?;
490            let created_at_str: String = row.get(5)?;
491            let updated_at_str: String = row.get(6)?;
492            let last_accessed_str: Option<String> = row.get(7)?;
493            let tags_str: Option<String> = row.get(8)?;
494            let workspace: String = row.get(9)?;
495            let tier_str: String = row.get(10)?;
496            let lifecycle_str: String = row.get(11)?;
497            let relevance_score: f32 = row.get(12)?;
498            let context_role_str: String = row.get(13)?;
499            let added_at_str: String = row.get(14)?;
500
501            Ok((
502                id,
503                content,
504                memory_type_str,
505                importance,
506                access_count,
507                created_at_str,
508                updated_at_str,
509                last_accessed_str,
510                tags_str,
511                workspace,
512                tier_str,
513                lifecycle_str,
514                relevance_score,
515                context_role_str,
516                added_at_str,
517            ))
518        })?
519        .filter_map(|r| r.ok())
520        .map(
521            |(
522                id,
523                content,
524                memory_type_str,
525                importance,
526                access_count,
527                created_at_str,
528                updated_at_str,
529                last_accessed_str,
530                tags_str,
531                workspace,
532                tier_str,
533                lifecycle_str,
534                relevance_score,
535                context_role_str,
536                added_at_str,
537            )| {
538                let now = Utc::now();
539
540                let memory = Memory {
541                    id,
542                    content,
543                    memory_type: memory_type_str
544                        .parse()
545                        .unwrap_or(crate::types::MemoryType::Note),
546                    tags: tags_str
547                        .map(|s| serde_json::from_str(&s).unwrap_or_default())
548                        .unwrap_or_default(),
549                    metadata: HashMap::new(),
550                    importance,
551                    access_count,
552                    created_at: DateTime::parse_from_rfc3339(&created_at_str)
553                        .map(|dt| dt.with_timezone(&Utc))
554                        .unwrap_or(now),
555                    updated_at: DateTime::parse_from_rfc3339(&updated_at_str)
556                        .map(|dt| dt.with_timezone(&Utc))
557                        .unwrap_or(now),
558                    last_accessed_at: last_accessed_str.and_then(|s| {
559                        DateTime::parse_from_rfc3339(&s)
560                            .map(|dt| dt.with_timezone(&Utc))
561                            .ok()
562                    }),
563                    owner_id: None,
564                    visibility: crate::types::Visibility::Private,
565                    scope: crate::types::MemoryScope::Global,
566                    workspace,
567                    tier: tier_str
568                        .parse()
569                        .unwrap_or(crate::types::MemoryTier::Permanent),
570                    version: 1,
571                    has_embedding: false,
572                    expires_at: None,
573                    content_hash: None,
574                    event_time: None,
575                    event_duration_seconds: None,
576                    trigger_pattern: None,
577                    procedure_success_count: 0,
578                    procedure_failure_count: 0,
579                    summary_of_id: None,
580                    lifecycle_state: lifecycle_str
581                        .parse()
582                        .unwrap_or(crate::types::LifecycleState::Active),
583                };
584
585                SessionSearchResult {
586                    memory,
587                    relevance_score,
588                    context_role: context_role_str.parse().unwrap_or(ContextRole::Referenced),
589                    added_at: DateTime::parse_from_rfc3339(&added_at_str)
590                        .map(|dt| dt.with_timezone(&Utc))
591                        .unwrap_or(now),
592                }
593            },
594        )
595        .collect();
596
597    Ok(results)
598}
599
600/// Export a session with all its data
601pub fn export_session(
602    conn: &Connection,
603    session_id: &str,
604    include_content: bool,
605) -> Result<SessionExport> {
606    let session = get_session_context(conn, session_id)?
607        .ok_or_else(|| EngramError::InvalidInput(format!("Session not found: {}", session_id)))?;
608
609    // Get all linked memories
610    let memory_ids: Vec<MemoryId> = session.memories.iter().map(|m| m.memory_id).collect();
611
612    let mut memories = Vec::new();
613    if !memory_ids.is_empty() {
614        for id in memory_ids {
615            match crate::storage::queries::get_memory(conn, id) {
616                Ok(mut memory) => {
617                    if !include_content {
618                        memory.content.clear();
619                    }
620                    memories.push(memory);
621                }
622                Err(EngramError::NotFound(_)) => continue,
623                Err(e) => return Err(e),
624            }
625        }
626    }
627
628    Ok(SessionExport {
629        session,
630        memories,
631        exported_at: Utc::now(),
632        format_version: "1.0".to_string(),
633    })
634}
635
636/// List sessions with optional filters
637pub fn list_sessions_extended(
638    conn: &Connection,
639    workspace: Option<&str>,
640    active_only: bool,
641    limit: i64,
642    offset: i64,
643) -> Result<Vec<SessionContext>> {
644    let mut query = String::from(
645        "SELECT session_id, title, started_at, ended_at, message_count, workspace, metadata, summary, context
646         FROM sessions",
647    );
648
649    let mut filters = Vec::new();
650    if active_only {
651        filters.push("ended_at IS NULL");
652    }
653    if workspace.is_some() {
654        filters.push("workspace = ?");
655    }
656    if !filters.is_empty() {
657        query.push_str(" WHERE ");
658        query.push_str(&filters.join(" AND "));
659    }
660
661    query.push_str(" ORDER BY started_at DESC LIMIT ? OFFSET ?");
662
663    let mut stmt = conn.prepare(&query)?;
664    let rows: Vec<(
665        String,
666        Option<String>,
667        String,
668        Option<String>,
669        i32,
670        String,
671        Option<String>,
672        Option<String>,
673        Option<String>,
674    )> = if let Some(workspace) = workspace {
675        let rows = stmt.query_map(params![workspace, limit, offset], |row| {
676            Ok((
677                row.get::<_, String>(0)?,
678                row.get::<_, Option<String>>(1)?,
679                row.get::<_, String>(2)?,
680                row.get::<_, Option<String>>(3)?,
681                row.get::<_, i32>(4)?,
682                row.get::<_, String>(5)?,
683                row.get::<_, Option<String>>(6)?,
684                row.get::<_, Option<String>>(7)?,
685                row.get::<_, Option<String>>(8)?,
686            ))
687        })?;
688        rows.collect::<std::result::Result<Vec<_>, _>>()?
689    } else {
690        let rows = stmt.query_map(params![limit, offset], |row| {
691            Ok((
692                row.get::<_, String>(0)?,
693                row.get::<_, Option<String>>(1)?,
694                row.get::<_, String>(2)?,
695                row.get::<_, Option<String>>(3)?,
696                row.get::<_, i32>(4)?,
697                row.get::<_, String>(5)?,
698                row.get::<_, Option<String>>(6)?,
699                row.get::<_, Option<String>>(7)?,
700                row.get::<_, Option<String>>(8)?,
701            ))
702        })?;
703        rows.collect::<std::result::Result<Vec<_>, _>>()?
704    };
705
706    let sessions = rows
707        .into_iter()
708        .map(
709            |(
710                id,
711                title,
712                started_at_str,
713                ended_at_str,
714                message_count,
715                workspace,
716                metadata_str,
717                summary,
718                context,
719            )| {
720                let now = Utc::now();
721                let created_at = DateTime::parse_from_rfc3339(&started_at_str)
722                    .map(|dt| dt.with_timezone(&Utc))
723                    .unwrap_or(now);
724
725                let ended_at = ended_at_str.and_then(|s| {
726                    DateTime::parse_from_rfc3339(&s)
727                        .map(|dt| dt.with_timezone(&Utc))
728                        .ok()
729                });
730
731                let metadata: HashMap<String, serde_json::Value> = metadata_str
732                    .and_then(|s| serde_json::from_str(&s).ok())
733                    .unwrap_or_default();
734                let title = title.or_else(|| {
735                    metadata
736                        .get("title")
737                        .and_then(|v| v.as_str())
738                        .map(String::from)
739                });
740
741                SessionContext {
742                    session_id: id,
743                    title,
744                    created_at,
745                    ended_at,
746                    message_count,
747                    workspace,
748                    summary,
749                    context,
750                    metadata,
751                    memories: vec![], // Don't load memories for list view
752                }
753            },
754        )
755        .collect();
756
757    Ok(sessions)
758}
759
760/// Get sessions that reference a specific memory
761pub fn get_sessions_for_memory(
762    conn: &Connection,
763    memory_id: MemoryId,
764) -> Result<Vec<SessionMemoryLink>> {
765    let mut stmt = conn.prepare(
766        "SELECT session_id, memory_id, added_at, relevance_score, context_role
767         FROM session_memories
768         WHERE memory_id = ?
769         ORDER BY added_at DESC",
770    )?;
771
772    let links: Vec<SessionMemoryLink> = stmt
773        .query_map(params![memory_id], parse_link)?
774        .filter_map(|r| r.ok())
775        .collect();
776
777    Ok(links)
778}
779
780#[cfg(test)]
781mod tests {
782    use super::*;
783    use rusqlite::Connection;
784
785    fn setup_test_db() -> Connection {
786        let conn = Connection::open_in_memory().unwrap();
787
788        // Create minimal schema for testing
789        conn.execute_batch(
790            r#"
791            CREATE TABLE sessions (
792                id INTEGER PRIMARY KEY AUTOINCREMENT,
793                session_id TEXT NOT NULL UNIQUE,
794                title TEXT,
795                started_at TEXT NOT NULL,
796                last_indexed_at TEXT,
797                message_count INTEGER NOT NULL DEFAULT 0,
798                chunk_count INTEGER NOT NULL DEFAULT 0,
799                workspace TEXT NOT NULL DEFAULT 'default',
800                metadata TEXT NOT NULL DEFAULT '{}',
801                created_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP,
802                summary TEXT,
803                context TEXT,
804                ended_at TEXT
805            );
806
807            CREATE TABLE memories (
808                id INTEGER PRIMARY KEY AUTOINCREMENT,
809                content TEXT NOT NULL,
810                memory_type TEXT DEFAULT 'note',
811                importance REAL DEFAULT 0.5,
812                access_count INTEGER DEFAULT 0,
813                created_at TEXT NOT NULL,
814                updated_at TEXT NOT NULL,
815                last_accessed_at TEXT,
816                workspace TEXT DEFAULT 'default',
817                tier TEXT DEFAULT 'permanent',
818                lifecycle_state TEXT DEFAULT 'active',
819                tags TEXT
820            );
821
822            CREATE TABLE session_memories (
823                session_id TEXT NOT NULL REFERENCES sessions(session_id) ON DELETE CASCADE,
824                memory_id INTEGER NOT NULL,
825                added_at TEXT NOT NULL,
826                relevance_score REAL DEFAULT 1.0,
827                context_role TEXT DEFAULT 'referenced',
828                PRIMARY KEY (session_id, memory_id)
829            );
830
831            CREATE VIRTUAL TABLE memories_fts USING fts5(content);
832            "#,
833        )
834        .unwrap();
835
836        conn
837    }
838
839    #[test]
840    fn test_create_session() {
841        let conn = setup_test_db();
842
843        let input = CreateSessionInput {
844            session_id: Some("test-session-1".to_string()),
845            title: Some("Test Session".to_string()),
846            initial_context: Some(r#"{"topic": "testing"}"#.to_string()),
847            workspace: None,
848            metadata: HashMap::new(),
849        };
850
851        let session = create_session(&conn, input).unwrap();
852        assert_eq!(session.session_id, "test-session-1");
853        assert!(session.context.is_some());
854    }
855
856    #[test]
857    fn test_add_memory_to_session() {
858        let conn = setup_test_db();
859
860        // Create session
861        let input = CreateSessionInput {
862            session_id: Some("test-session".to_string()),
863            title: None,
864            initial_context: None,
865            workspace: None,
866            metadata: HashMap::new(),
867        };
868        create_session(&conn, input).unwrap();
869
870        // Create a memory
871        let now = Utc::now().to_rfc3339();
872        conn.execute(
873            "INSERT INTO memories (content, created_at, updated_at) VALUES (?, ?, ?)",
874            params!["Test memory", now, now],
875        )
876        .unwrap();
877
878        // Add memory to session
879        let link =
880            add_memory_to_session(&conn, "test-session", 1, 0.9, ContextRole::Created).unwrap();
881
882        assert_eq!(link.session_id, "test-session");
883        assert_eq!(link.memory_id, 1);
884        assert_eq!(link.context_role, ContextRole::Created);
885    }
886
887    #[test]
888    fn test_get_session_context() {
889        let conn = setup_test_db();
890
891        // Create session
892        let input = CreateSessionInput {
893            session_id: Some("context-test".to_string()),
894            title: None,
895            initial_context: None,
896            workspace: None,
897            metadata: HashMap::new(),
898        };
899        create_session(&conn, input).unwrap();
900
901        // Get context
902        let context = get_session_context(&conn, "context-test").unwrap();
903        assert!(context.is_some());
904        assert_eq!(context.unwrap().session_id, "context-test");
905    }
906
907    #[test]
908    fn test_context_role_parsing() {
909        assert_eq!(
910            "referenced".parse::<ContextRole>().unwrap(),
911            ContextRole::Referenced
912        );
913        assert_eq!(
914            "created".parse::<ContextRole>().unwrap(),
915            ContextRole::Created
916        );
917        assert_eq!(
918            "updated".parse::<ContextRole>().unwrap(),
919            ContextRole::Updated
920        );
921        assert_eq!(
922            "pinned".parse::<ContextRole>().unwrap(),
923            ContextRole::Pinned
924        );
925    }
926
927    #[test]
928    fn test_end_session() {
929        let conn = setup_test_db();
930
931        let input = CreateSessionInput {
932            session_id: Some("end-test".to_string()),
933            title: None,
934            initial_context: None,
935            workspace: None,
936            metadata: HashMap::new(),
937        };
938        create_session(&conn, input).unwrap();
939
940        end_session(&conn, "end-test").unwrap();
941
942        let session = get_session_context(&conn, "end-test").unwrap().unwrap();
943        assert!(session.ended_at.is_some());
944    }
945}